LangChain Integration
Integrate Koreshield security directly into your LangChain pipelines for comprehensive LLM protection.
Installation
pip install Koreshield-sdk langchain langchain-openai
Basic Integration
Callback Handler
Create a Koreshield callback handler:
from langchain.callbacks.base import BaseCallbackHandler
from Koreshield_sdk import KoreshieldClient
class KoreshieldCallback(BaseCallbackHandler):
def __init__(self, api_key: str, sensitivity: str = "medium"):
self.client = KoreshieldClient(api_key=api_key)
self.sensitivity = sensitivity
def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs):
"""Scan prompts before sending to LLM"""
for prompt in prompts:
result = self.client.scan(
input=prompt,
sensitivity=self.sensitivity
)
if result.is_threat:
raise ValueError(
f"Security threat detected: {result.attack_types[0]} "
f"(confidence: {result.confidence:.2f})"
)
def on_chain_start(self, serialized: dict, inputs: dict, **kwargs):
"""Scan chain inputs"""
for key, value in inputs.items():
if isinstance(value, str):
result = self.client.scan(value)
if result.is_threat:
raise ValueError(f"Threat in {key}: {result.attack_types}")
# Usage
from langchain_openai import ChatOpenAI
Koreshield_callback = KoreshieldCallback(api_key="ks_prod_xxx")
llm = ChatOpenAI(
model="gpt-4",
callbacks=[Koreshield_callback]
)
# Prompts are automatically scanned
response = llm.invoke("What is the capital of France?")
Chain Protection
Simple Chain
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from Koreshield_sdk import KoreshieldClient
Koreshield = KoreshieldClient(api_key="ks_prod_xxx")
def secure_chain(user_input: str):
# Scan input first
scan_result = Koreshield.scan(user_input)
if scan_result.is_threat:
return {
"error": "Security violation detected",
"attack_type": scan_result.attack_types[0],
"confidence": scan_result.confidence
}
# Safe to process
prompt = PromptTemplate(
input_variables=["question"],
template="Answer this question: {question}"
)
chain = LLMChain(llm=llm, prompt=prompt)
return chain.invoke({"question": user_input})
# Use secure chain
result = secure_chain("How does photosynthesis work?")
RAG Chain with Protection
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from Koreshield_sdk import KoreshieldClient
Koreshield = KoreshieldClient(api_key="ks_prod_xxx")
# Load vector store
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.load_local("./vectorstore", embeddings)
# Create secure RAG chain
def secure_rag_query(query: str):
# Scan query for injection attacks
scan_result = Koreshield.scan(query, context="rag_query")
if scan_result.is_threat:
return {
"error": "Query blocked",
"reason": scan_result.attack_types,
"confidence": scan_result.confidence
}
# Query is safe
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(),
return_source_documents=True
)
result = qa_chain.invoke({"query": query})
return result
# Use secure RAG
answer = secure_rag_query("What are the company's revenue figures?")
Agent Protection
Secure Agent
from langchain.agents import create_openai_functions_agent, AgentExecutor
from langchain.tools import Tool
from langchain import hub
# Create tools
def search_tool(query: str) -> str:
# Scan tool input
result = Koreshield.scan(query)
if result.is_threat:
return f"Security violation: {result.attack_types[0]}"
return f"Search results for: {query}"
tools = [
Tool(
name="Search",
func=search_tool,
description="Search for information"
)
]
# Create secure agent
prompt = hub.pull("hwchase17/openai-functions-agent")
agent = create_openai_functions_agent(
llm=llm,
tools=tools,
prompt=prompt
)
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
callbacks=[Koreshield_callback],
verbose=True
)
# Execute with automatic security
result = agent_executor.invoke({
"input": "Search for Python tutorials"
})
Custom Security Layer
Wrapper Class
from typing import Any, List, Optional
from langchain.llms.base import LLM
from Koreshield_sdk import KoreshieldClient
class SecureLLM(LLM):
"""LLM wrapper with Koreshield protection"""
llm: LLM
Koreshield: KoreshieldClient
sensitivity: str = "medium"
block_on_threat: bool = True
def __init__(self, llm: LLM, api_key: str, **kwargs):
super().__init__(**kwargs)
self.llm = llm
self.Koreshield = KoreshieldClient(api_key=api_key)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
**kwargs: Any
) -> str:
# Scan input
scan_result = self.Koreshield.scan(
input=prompt,
sensitivity=self.sensitivity
)
if scan_result.is_threat:
if self.block_on_threat:
raise ValueError(
f"Blocked: {scan_result.attack_types[0]} "
f"(confidence: {scan_result.confidence:.2%})"
)
else:
# Log but allow
print(f"Warning: Potential threat detected - {scan_result.attack_types}")
# Call underlying LLM
return self.llm._call(prompt, stop=stop, **kwargs)
@property
def _llm_type(self) -> str:
return f"secure_{self.llm._llm_type}"
# Usage
from langchain_openai import OpenAI
base_llm = OpenAI(temperature=0.7)
secure_llm = SecureLLM(
llm=base_llm,
api_key="ks_prod_xxx",
sensitivity="high"
)
response = secure_llm.invoke("Tell me about AI safety")
Multi-Tenancy Support
class TenantAwareLLM(LLM):
"""LLM with per-tenant security policies"""
def __init__(self, llm: LLM, api_key: str):
super().__init__()
self.llm = llm
self.Koreshield = KoreshieldClient(api_key=api_key)
self.tenant_policies = {
"free": {"sensitivity": "high", "max_requests": 100},
"pro": {"sensitivity": "medium", "max_requests": 10000},
"enterprise": {"sensitivity": "low", "max_requests": -1}
}
def _call(self, prompt: str, tenant_id: str = "free", **kwargs) -> str:
policy = self.tenant_policies.get(tenant_id, self.tenant_policies["free"])
# Apply tenant-specific security
scan_result = self.Koreshield.scan(
input=prompt,
sensitivity=policy["sensitivity"],
metadata={"tenant_id": tenant_id}
)
if scan_result.is_threat:
raise ValueError(f"Threat detected for tenant {tenant_id}")
return self.llm._call(prompt, **kwargs)
LangChain Expression Language (LCEL)
from langchain_core.runnables import RunnableLambda
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
Koreshield = KoreshieldClient(api_key="ks_prod_xxx")
# Security check runnable
def security_check(input_dict):
message = input_dict.get("question", "")
result = Koreshield.scan(message)
if result.is_threat:
raise ValueError(f"Security threat: {result.attack_types[0]}")
return input_dict
# Build secure chain with LCEL
prompt = ChatPromptTemplate.from_template("Answer: {question}")
model = ChatOpenAI(model="gpt-4")
secure_chain = (
RunnableLambda(security_check) |
prompt |
model
)
# Use chain
response = secure_chain.invoke({"question": "What is machine learning?"})
Async Support
import asyncio
from Koreshield_sdk import AsyncKoreshieldClient
class AsyncSecureCallback(BaseCallbackHandler):
def __init__(self, api_key: str):
self.client = AsyncKoreshieldClient(api_key=api_key)
async def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs):
for prompt in prompts:
result = await self.client.scan(prompt)
if result.is_threat:
raise ValueError(f"Threat: {result.attack_types}")
# Async usage
async def async_query(question: str):
callback = AsyncSecureCallback(api_key="ks_prod_xxx")
llm = ChatOpenAI(callbacks=[callback])
response = await llm.ainvoke(question)
return response
# Run
result = asyncio.run(async_query("How does encryption work?"))
Streaming with Protection
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
class SecureStreamingCallback(StreamingStdOutCallbackHandler):
def __init__(self, api_key: str):
super().__init__()
self.Koreshield = KoreshieldClient(api_key=api_key)
self.buffer = ""
def on_llm_new_token(self, token: str, **kwargs):
# Accumulate tokens
self.buffer += token
# Check for threats in output (optional)
if len(self.buffer) > 100:
result = self.Koreshield.scan(self.buffer[-100:])
if result.is_threat:
raise ValueError("Threat detected in output")
super().on_llm_new_token(token, **kwargs)
# Use streaming with security
llm = ChatOpenAI(
streaming=True,
callbacks=[SecureStreamingCallback(api_key="ks_prod_xxx")]
)
response = llm.invoke("Explain quantum computing")
Memory Protection
from langchain.memory import ConversationBufferMemory
class SecureMemory(ConversationBufferMemory):
def __init__(self, api_key: str, **kwargs):
super().__init__(**kwargs)
self.Koreshield = KoreshieldClient(api_key=api_key)
def save_context(self, inputs: dict, outputs: dict):
# Scan before saving to memory
for value in inputs.values():
if isinstance(value, str):
result = self.Koreshield.scan(value)
if result.is_threat:
# Don't save malicious content to memory
return
super().save_context(inputs, outputs)
# Use secure memory
memory = SecureMemory(api_key="ks_prod_xxx")
chain = LLMChain(
llm=llm,
prompt=prompt,
memory=memory
)
Testing
import pytest
from unittest.mock import Mock, patch
@pytest.fixture
def mock_Koreshield():
with patch('Koreshield_sdk.KoreshieldClient') as mock:
client = mock.return_value
client.scan.return_value = Mock(
is_threat=False,
confidence=0.1,
attack_types=[]
)
yield client
def test_secure_chain(mock_Koreshield):
callback = KoreshieldCallback(api_key="test_key")
llm = ChatOpenAI(callbacks=[callback])
response = llm.invoke("Safe question")
assert response is not None
def test_threat_detection(mock_Koreshield):
mock_Koreshield.scan.return_value = Mock(
is_threat=True,
confidence=0.95,
attack_types=["prompt_injection"]
)
callback = KoreshieldCallback(api_key="test_key")
llm = ChatOpenAI(callbacks=[callback])
with pytest.raises(ValueError, match="Security threat"):
llm.invoke("Malicious input")
Best Practices
Scan at Multiple Layers
# 1. Input validation
user_input_scan = Koreshield.scan(user_input)
# 2. Chain input validation
chain_input_scan = Koreshield.scan(formatted_prompt)
# 3. Tool input validation
tool_input_scan = Koreshield.scan(tool_args)
# 4. Optional: Output validation
output_scan = Koreshield.scan(llm_response)
Handle Errors Gracefully
def safe_invoke(chain, input_data):
try:
return chain.invoke(input_data)
except ValueError as e:
if "Security threat" in str(e):
return {
"error": "Your request was blocked for security reasons",
"support": "contact@company.com"
}
raise
Related Documentation
Support
- Discord: discord.gg/Koreshield
- GitHub: github.com/Koreshield/Koreshield
- Email: support@Koreshield.com