Skip to main content

LangChain Integration

Integrate Koreshield security directly into your LangChain pipelines for comprehensive LLM protection.

Installation

pip install Koreshield-sdk langchain langchain-openai

Basic Integration

Callback Handler

Create a Koreshield callback handler:

from langchain.callbacks.base import BaseCallbackHandler
from Koreshield_sdk import KoreshieldClient

class KoreshieldCallback(BaseCallbackHandler):
def __init__(self, api_key: str, sensitivity: str = "medium"):
self.client = KoreshieldClient(api_key=api_key)
self.sensitivity = sensitivity

def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs):
"""Scan prompts before sending to LLM"""
for prompt in prompts:
result = self.client.scan(
input=prompt,
sensitivity=self.sensitivity
)

if result.is_threat:
raise ValueError(
f"Security threat detected: {result.attack_types[0]} "
f"(confidence: {result.confidence:.2f})"
)

def on_chain_start(self, serialized: dict, inputs: dict, **kwargs):
"""Scan chain inputs"""
for key, value in inputs.items():
if isinstance(value, str):
result = self.client.scan(value)
if result.is_threat:
raise ValueError(f"Threat in {key}: {result.attack_types}")

# Usage
from langchain_openai import ChatOpenAI

Koreshield_callback = KoreshieldCallback(api_key="ks_prod_xxx")

llm = ChatOpenAI(
model="gpt-4",
callbacks=[Koreshield_callback]
)

# Prompts are automatically scanned
response = llm.invoke("What is the capital of France?")

Chain Protection

Simple Chain

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from Koreshield_sdk import KoreshieldClient

Koreshield = KoreshieldClient(api_key="ks_prod_xxx")

def secure_chain(user_input: str):
# Scan input first
scan_result = Koreshield.scan(user_input)

if scan_result.is_threat:
return {
"error": "Security violation detected",
"attack_type": scan_result.attack_types[0],
"confidence": scan_result.confidence
}

# Safe to process
prompt = PromptTemplate(
input_variables=["question"],
template="Answer this question: {question}"
)

chain = LLMChain(llm=llm, prompt=prompt)
return chain.invoke({"question": user_input})

# Use secure chain
result = secure_chain("How does photosynthesis work?")

RAG Chain with Protection

from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from Koreshield_sdk import KoreshieldClient

Koreshield = KoreshieldClient(api_key="ks_prod_xxx")

# Load vector store
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.load_local("./vectorstore", embeddings)

# Create secure RAG chain
def secure_rag_query(query: str):
# Scan query for injection attacks
scan_result = Koreshield.scan(query, context="rag_query")

if scan_result.is_threat:
return {
"error": "Query blocked",
"reason": scan_result.attack_types,
"confidence": scan_result.confidence
}

# Query is safe
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(),
return_source_documents=True
)

result = qa_chain.invoke({"query": query})
return result

# Use secure RAG
answer = secure_rag_query("What are the company's revenue figures?")

Agent Protection

Secure Agent

from langchain.agents import create_openai_functions_agent, AgentExecutor
from langchain.tools import Tool
from langchain import hub

# Create tools
def search_tool(query: str) -> str:
# Scan tool input
result = Koreshield.scan(query)
if result.is_threat:
return f"Security violation: {result.attack_types[0]}"
return f"Search results for: {query}"

tools = [
Tool(
name="Search",
func=search_tool,
description="Search for information"
)
]

# Create secure agent
prompt = hub.pull("hwchase17/openai-functions-agent")

agent = create_openai_functions_agent(
llm=llm,
tools=tools,
prompt=prompt
)

agent_executor = AgentExecutor(
agent=agent,
tools=tools,
callbacks=[Koreshield_callback],
verbose=True
)

# Execute with automatic security
result = agent_executor.invoke({
"input": "Search for Python tutorials"
})

Custom Security Layer

Wrapper Class

from typing import Any, List, Optional
from langchain.llms.base import LLM
from Koreshield_sdk import KoreshieldClient

class SecureLLM(LLM):
"""LLM wrapper with Koreshield protection"""

llm: LLM
Koreshield: KoreshieldClient
sensitivity: str = "medium"
block_on_threat: bool = True

def __init__(self, llm: LLM, api_key: str, **kwargs):
super().__init__(**kwargs)
self.llm = llm
self.Koreshield = KoreshieldClient(api_key=api_key)

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
**kwargs: Any
) -> str:
# Scan input
scan_result = self.Koreshield.scan(
input=prompt,
sensitivity=self.sensitivity
)

if scan_result.is_threat:
if self.block_on_threat:
raise ValueError(
f"Blocked: {scan_result.attack_types[0]} "
f"(confidence: {scan_result.confidence:.2%})"
)
else:
# Log but allow
print(f"Warning: Potential threat detected - {scan_result.attack_types}")

# Call underlying LLM
return self.llm._call(prompt, stop=stop, **kwargs)

@property
def _llm_type(self) -> str:
return f"secure_{self.llm._llm_type}"

# Usage
from langchain_openai import OpenAI

base_llm = OpenAI(temperature=0.7)
secure_llm = SecureLLM(
llm=base_llm,
api_key="ks_prod_xxx",
sensitivity="high"
)

response = secure_llm.invoke("Tell me about AI safety")

Multi-Tenancy Support

class TenantAwareLLM(LLM):
"""LLM with per-tenant security policies"""

def __init__(self, llm: LLM, api_key: str):
super().__init__()
self.llm = llm
self.Koreshield = KoreshieldClient(api_key=api_key)
self.tenant_policies = {
"free": {"sensitivity": "high", "max_requests": 100},
"pro": {"sensitivity": "medium", "max_requests": 10000},
"enterprise": {"sensitivity": "low", "max_requests": -1}
}

def _call(self, prompt: str, tenant_id: str = "free", **kwargs) -> str:
policy = self.tenant_policies.get(tenant_id, self.tenant_policies["free"])

# Apply tenant-specific security
scan_result = self.Koreshield.scan(
input=prompt,
sensitivity=policy["sensitivity"],
metadata={"tenant_id": tenant_id}
)

if scan_result.is_threat:
raise ValueError(f"Threat detected for tenant {tenant_id}")

return self.llm._call(prompt, **kwargs)

LangChain Expression Language (LCEL)

from langchain_core.runnables import RunnableLambda
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

Koreshield = KoreshieldClient(api_key="ks_prod_xxx")

# Security check runnable
def security_check(input_dict):
message = input_dict.get("question", "")
result = Koreshield.scan(message)

if result.is_threat:
raise ValueError(f"Security threat: {result.attack_types[0]}")

return input_dict

# Build secure chain with LCEL
prompt = ChatPromptTemplate.from_template("Answer: {question}")
model = ChatOpenAI(model="gpt-4")

secure_chain = (
RunnableLambda(security_check) |
prompt |
model
)

# Use chain
response = secure_chain.invoke({"question": "What is machine learning?"})

Async Support

import asyncio
from Koreshield_sdk import AsyncKoreshieldClient

class AsyncSecureCallback(BaseCallbackHandler):
def __init__(self, api_key: str):
self.client = AsyncKoreshieldClient(api_key=api_key)

async def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs):
for prompt in prompts:
result = await self.client.scan(prompt)
if result.is_threat:
raise ValueError(f"Threat: {result.attack_types}")

# Async usage
async def async_query(question: str):
callback = AsyncSecureCallback(api_key="ks_prod_xxx")

llm = ChatOpenAI(callbacks=[callback])
response = await llm.ainvoke(question)
return response

# Run
result = asyncio.run(async_query("How does encryption work?"))

Streaming with Protection

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

class SecureStreamingCallback(StreamingStdOutCallbackHandler):
def __init__(self, api_key: str):
super().__init__()
self.Koreshield = KoreshieldClient(api_key=api_key)
self.buffer = ""

def on_llm_new_token(self, token: str, **kwargs):
# Accumulate tokens
self.buffer += token

# Check for threats in output (optional)
if len(self.buffer) > 100:
result = self.Koreshield.scan(self.buffer[-100:])
if result.is_threat:
raise ValueError("Threat detected in output")

super().on_llm_new_token(token, **kwargs)

# Use streaming with security
llm = ChatOpenAI(
streaming=True,
callbacks=[SecureStreamingCallback(api_key="ks_prod_xxx")]
)

response = llm.invoke("Explain quantum computing")

Memory Protection

from langchain.memory import ConversationBufferMemory

class SecureMemory(ConversationBufferMemory):
def __init__(self, api_key: str, **kwargs):
super().__init__(**kwargs)
self.Koreshield = KoreshieldClient(api_key=api_key)

def save_context(self, inputs: dict, outputs: dict):
# Scan before saving to memory
for value in inputs.values():
if isinstance(value, str):
result = self.Koreshield.scan(value)
if result.is_threat:
# Don't save malicious content to memory
return

super().save_context(inputs, outputs)

# Use secure memory
memory = SecureMemory(api_key="ks_prod_xxx")

chain = LLMChain(
llm=llm,
prompt=prompt,
memory=memory
)

Testing

import pytest
from unittest.mock import Mock, patch

@pytest.fixture
def mock_Koreshield():
with patch('Koreshield_sdk.KoreshieldClient') as mock:
client = mock.return_value
client.scan.return_value = Mock(
is_threat=False,
confidence=0.1,
attack_types=[]
)
yield client

def test_secure_chain(mock_Koreshield):
callback = KoreshieldCallback(api_key="test_key")
llm = ChatOpenAI(callbacks=[callback])

response = llm.invoke("Safe question")
assert response is not None

def test_threat_detection(mock_Koreshield):
mock_Koreshield.scan.return_value = Mock(
is_threat=True,
confidence=0.95,
attack_types=["prompt_injection"]
)

callback = KoreshieldCallback(api_key="test_key")
llm = ChatOpenAI(callbacks=[callback])

with pytest.raises(ValueError, match="Security threat"):
llm.invoke("Malicious input")

Best Practices

Scan at Multiple Layers

# 1. Input validation
user_input_scan = Koreshield.scan(user_input)

# 2. Chain input validation
chain_input_scan = Koreshield.scan(formatted_prompt)

# 3. Tool input validation
tool_input_scan = Koreshield.scan(tool_args)

# 4. Optional: Output validation
output_scan = Koreshield.scan(llm_response)

Handle Errors Gracefully

def safe_invoke(chain, input_data):
try:
return chain.invoke(input_data)
except ValueError as e:
if "Security threat" in str(e):
return {
"error": "Your request was blocked for security reasons",
"support": "contact@company.com"
}
raise

Support