Agent Framework のミドルウェアは、実行のさまざまな段階でエージェントの対話をインターセプト、変更、および強化する強力な方法を提供します。 ミドルウェアを使用すると、コア エージェントや関数ロジックを変更することなく、ログ記録、セキュリティ検証、エラー処理、結果変換などの横断的な問題を実装できます。
Agent Framework は、次の 3 種類のミドルウェアを使用してカスタマイズできます。
- エージェント実行ミドルウェア: 必要に応じて入力と出力を検査したり変更したりできるように、すべてのエージェント実行のインターセプトを許可します。
- 関数呼び出しミドルウェア: エージェントによって実行されるすべての関数呼び出しのインターセプトを許可します。これにより、入力と出力を必要に応じて検査および変更できます。
-
IChatClient ミドルウェア:
IChatClient実装への呼び出しのインターセプトを許可します。この場合、エージェントは推論呼び出しにIChatClientを使用します (たとえば、ChatClientAgentを使用する場合)。
すべての種類のミドルウェアは関数コールバックを介して実装され、同じ型の複数のミドルウェア インスタンスが登録されるとチェーンを形成し、各ミドルウェア インスタンスは、提供された nextFuncを介してチェーン内の次のインスタンスを呼び出す必要があります。
エージェントの実行と関数呼び出しのミドルウェアの種類は、エージェント ビルダーと既存のエージェント オブジェクトを使用して、エージェントに登録できます。
var middlewareEnabledAgent = originalAgent
.AsBuilder()
.Use(runFunc: CustomAgentRunMiddleware, runStreamingFunc: CustomAgentRunStreamingMiddleware)
.Use(CustomFunctionCallingMiddleware)
.Build();
Important
理想的には、 runFunc と runStreamingFunc の両方を提供する必要があります。 非ストリーミング ミドルウェアのみを提供する場合、エージェントは、ストリーミングと非ストリーミングの両方の呼び出しに使用します。 ストリーミングは、ミドルウェアの期待に応えるために、非ストリーミング モードでのみ実行されます。
注
Use(sharedFunc: ...)追加のオーバーロードがあり、ストリーミングをブロックすることなく、非ストリーミングとストリーミングに同じミドルウェアを提供できます。 ただし、共有ミドルウェアは出力をインターセプトまたはオーバーライドできません。 このオーバーロードは、エージェントに到達する前に入力を検査または変更するだけで済むシナリオに使用する必要があります。
IChatClientミドルウェアは、チャット クライアント ビルダー パターンを使用して、IChatClientで使用する前に、ChatClientAgentに登録できます。
var chatClient = new AzureOpenAIClient(new Uri("https://<myresource>.openai.azure.com"), new DefaultAzureCredential())
.GetChatClient(deploymentName)
.AsIChatClient();
var middlewareEnabledChatClient = chatClient
.AsBuilder()
.Use(getResponseFunc: CustomChatClientMiddleware, getStreamingResponseFunc: null)
.Build();
var agent = new ChatClientAgent(middlewareEnabledChatClient, instructions: "You are a helpful assistant.");
Warnung
DefaultAzureCredential は開発には便利ですが、運用環境では慎重に考慮する必要があります。 運用環境では、待機時間の問題、意図しない資格情報のプローブ、フォールバック メカニズムによる潜在的なセキュリティ リスクを回避するために、特定の資格情報 ( ManagedIdentityCredential など) を使用することを検討してください。
IChatClient SDK クライアントのヘルパー メソッドの 1 つを使用してエージェントを構築するときに、ファクトリ メソッドを使用してミドルウェアを登録することもできます。
var agent = new AzureOpenAIClient(new Uri(endpoint), new DefaultAzureCredential())
.GetChatClient(deploymentName)
.AsAIAgent("You are a helpful assistant.", clientFactory: (chatClient) => chatClient
.AsBuilder()
.Use(getResponseFunc: CustomChatClientMiddleware, getStreamingResponseFunc: null)
.Build());
エージェント実行ミドルウェア
エージェント実行のミドルウェアの例を次に示します。このミドルウェアは、エージェント実行からの入力と出力を検査または変更できます。
async Task<AgentResponse> CustomAgentRunMiddleware(
IEnumerable<ChatMessage> messages,
AgentSession? session,
AgentRunOptions? options,
AIAgent innerAgent,
CancellationToken cancellationToken)
{
Console.WriteLine(messages.Count());
var response = await innerAgent.RunAsync(messages, session, options, cancellationToken).ConfigureAwait(false);
Console.WriteLine(response.Messages.Count);
return response;
}
エージェント実行ストリーミング ミドルウェア
エージェントのストリーミング実行からの入力と出力を検査または変更できるエージェント実行ストリーミング ミドルウェアの例を次に示します。
async IAsyncEnumerable<AgentResponseUpdate> CustomAgentRunStreamingMiddleware(
IEnumerable<ChatMessage> messages,
AgentSession? session,
AgentRunOptions? options,
AIAgent innerAgent,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
Console.WriteLine(messages.Count());
List<AgentResponseUpdate> updates = [];
await foreach (var update in innerAgent.RunStreamingAsync(messages, session, options, cancellationToken))
{
updates.Add(update);
yield return update;
}
Console.WriteLine(updates.ToAgentResponse().Messages.Count);
}
関数呼び出しミドルウェア
注
現在、関数呼び出しミドルウェアは、AIAgentなどのFunctionInvokingChatClientを使用するChatClientAgentでのみサポートされています。
呼び出される関数を検査したり変更したりできる関数呼び出しミドルウェアの例と、関数呼び出しの結果を次に示します。
async ValueTask<object?> CustomFunctionCallingMiddleware(
AIAgent agent,
FunctionInvocationContext context,
Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> next,
CancellationToken cancellationToken)
{
Console.WriteLine($"Function Name: {context!.Function.Name}");
var result = await next(context, cancellationToken);
Console.WriteLine($"Function Call Result: {result}");
return result;
}
指定された FunctionInvocationContext.Terminate を true に設定することで、関数呼び出しミドルウェアを使用して関数呼び出しループを終了できます。
これにより、関数呼び出しループは、関数呼び出し後に関数呼び出し結果を含む推論サービスに要求を発行できなくなります。
このイテレーション中に呼び出しに使用できる関数が複数ある場合は、残りの関数が実行されない可能性もあります。
Warnung
関数呼び出しループを終了すると、チャット履歴が一貫性のない状態のままになることがあります。たとえば、関数の結果コンテンツのない関数呼び出しコンテンツが含まれる場合があります。 これにより、チャット履歴がそれ以降の実行で使用できなくなる可能性があります。
IChatClient ミドルウェア
チャット クライアントが提供する推論サービスへの要求の入力と出力を検査または変更できるチャット クライアント ミドルウェアの例を次に示します。
async Task<ChatResponse> CustomChatClientMiddleware(
IEnumerable<ChatMessage> messages,
ChatOptions? options,
IChatClient innerChatClient,
CancellationToken cancellationToken)
{
Console.WriteLine(messages.Count());
var response = await innerChatClient.GetResponseAsync(messages, options, cancellationToken);
Console.WriteLine(response.Messages.Count);
return response;
}
ヒント
実行可能な完全な例については、 .NET サンプル を参照してください。
注
IChatClient ミドルウェアの詳細については、「Custom IChatClient ミドルウェア」を参照してください。
Agent Framework は、次の 3 種類のミドルウェアを使用してカスタマイズできます。
- エージェント ミドルウェア: エージェントの実行をインターセプトして、入力、出力、制御フローを検査および変更できるようにします。
- 関数ミドルウェア: エージェントの実行中に行われた関数 (ツール) 呼び出しをインターセプトし、入力検証、結果変換、および実行制御を有効にします。
- チャット ミドルウェア: AI モデルに送信された基になるチャット要求をインターセプトし、生のメッセージ、オプション、応答へのアクセスを提供します。
すべての型は、関数ベースの実装とクラスベースの実装の両方をサポートします。 同じ型の複数のミドルウェアが登録されると、それぞれが呼び出し可能な next を呼び出して処理を続行するチェーンを形成します。
エージェント ミドルウェア
エージェント ミドルウェアは、エージェントの実行をインターセプトして変更します。 次を含む AgentContext を使用します。
-
agent: 呼び出されるエージェント -
messages: 会話内のチャット メッセージの一覧 -
is_streaming: 応答がストリーミングされているかどうかを示すブール値 -
metadata: ミドルウェア間で追加のデータを格納するためのディクショナリ -
result: エージェントの応答 (変更可能) -
terminate: それ以降の処理を停止するフラグ -
kwargs: エージェント実行メソッドに渡される追加のキーワード引数
呼び出し可能な next は、ミドルウェア チェーンを続行するか、最後のミドルウェアである場合はエージェントを実行します。
関数ベース
async def logging_agent_middleware(
context: AgentContext,
next: Callable[[AgentContext], Awaitable[None]],
) -> None:
"""Agent middleware that logs execution timing."""
# Pre-processing: Log before agent execution
print("[Agent] Starting execution")
# Continue to next middleware or agent execution
await next(context)
# Post-processing: Log after agent execution
print("[Agent] Execution completed")
クラスベース
クラス ベースのエージェント ミドルウェアは、関数ベースのミドルウェアと同じシグネチャと動作を持つ process メソッドを使用します。
from agent_framework import AgentMiddleware, AgentContext
class LoggingAgentMiddleware(AgentMiddleware):
"""Agent middleware that logs execution."""
async def process(
self,
context: AgentContext,
next: Callable[[AgentContext], Awaitable[None]],
) -> None:
print("[Agent Class] Starting execution")
await next(context)
print("[Agent Class] Execution completed")
関数ミドルウェア
関数ミドルウェアは、エージェント内の関数呼び出しをインターセプトします。 次を含む FunctionInvocationContext を使用します。
-
function: 呼び出される関数 -
arguments: 関数の検証済み引数 -
metadata: ミドルウェア間で追加のデータを格納するためのディクショナリ -
result: 関数の戻り値 (変更可能) -
terminate: それ以降の処理を停止するフラグ -
kwargs: この関数を呼び出したチャット メソッドに渡される追加のキーワード引数
呼び出し可能な next は、次のミドルウェアに進むか、実際の関数を実行します。
関数ベース
async def logging_function_middleware(
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""Function middleware that logs function execution."""
# Pre-processing: Log before function execution
print(f"[Function] Calling {context.function.name}")
# Continue to next middleware or function execution
await next(context)
# Post-processing: Log after function execution
print(f"[Function] {context.function.name} completed")
クラスベース
from agent_framework import FunctionMiddleware, FunctionInvocationContext
class LoggingFunctionMiddleware(FunctionMiddleware):
"""Function middleware that logs function execution."""
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
print(f"[Function Class] Calling {context.function.name}")
await next(context)
print(f"[Function Class] {context.function.name} completed")
チャット ミドルウェア
チャット ミドルウェアは、AI モデルに送信されたチャット要求をインターセプトします。 次を含む ChatContext を使用します。
-
chat_client: 呼び出されるチャット クライアント -
messages: AI サービスに送信されるメッセージの一覧 -
options: チャット要求のオプション -
is_streaming: これがストリーミング呼び出しであるかどうかを示すブール値 -
metadata: ミドルウェア間で追加のデータを格納するためのディクショナリ -
result: AI からのチャット応答 (変更可能) -
terminate: それ以降の処理を停止するフラグ -
kwargs: チャット クライアントに渡される追加のキーワード引数
呼び出し可能な next は、次のミドルウェアに続くか、AI サービスに要求を送信します。
関数ベース
async def logging_chat_middleware(
context: ChatContext,
next: Callable[[ChatContext], Awaitable[None]],
) -> None:
"""Chat middleware that logs AI interactions."""
# Pre-processing: Log before AI call
print(f"[Chat] Sending {len(context.messages)} messages to AI")
# Continue to next middleware or AI service
await next(context)
# Post-processing: Log after AI response
print("[Chat] AI response received")
クラスベース
from agent_framework import ChatMiddleware, ChatContext
class LoggingChatMiddleware(ChatMiddleware):
"""Chat middleware that logs AI interactions."""
async def process(
self,
context: ChatContext,
next: Callable[[ChatContext], Awaitable[None]],
) -> None:
print(f"[Chat Class] Sending {len(context.messages)} messages to AI")
await next(context)
print("[Chat Class] AI response received")
ミドルウェア デコレーター
デコレーターは、型注釈を必要とせずに、明示的なミドルウェア型宣言を提供します。 これらは、型の注釈を使用しない場合や、型の不一致を防ぐ場合に役立ちます。
from agent_framework import agent_middleware, function_middleware, chat_middleware
@agent_middleware
async def simple_agent_middleware(context, next):
print("Before agent execution")
await next(context)
print("After agent execution")
@function_middleware
async def simple_function_middleware(context, next):
print(f"Calling function: {context.function.name}")
await next(context)
print("Function call completed")
@chat_middleware
async def simple_chat_middleware(context, next):
print(f"Processing {len(context.messages)} chat messages")
await next(context)
print("Chat processing completed")
ミドルウェアの登録
ミドルウェアは、スコープと動作が異なる 2 つのレベルで登録できます。
Agent-Level ミドルウェアと Run-Level ミドルウェア
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
# Agent-level middleware: Applied to ALL runs of the agent
async with AzureAIAgentClient(async_credential=credential).as_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[
SecurityAgentMiddleware(), # Applies to all runs
TimingFunctionMiddleware(), # Applies to all runs
],
) as agent:
# This run uses agent-level middleware only
result1 = await agent.run("What's the weather in Seattle?")
# This run uses agent-level + run-level middleware
result2 = await agent.run(
"What's the weather in Portland?",
middleware=[ # Run-level middleware (this run only)
logging_chat_middleware,
]
)
# This run uses agent-level middleware only (no run-level)
result3 = await agent.run("What's the weather in Vancouver?")
主な違い:
- エージェント レベル: エージェントの作成時に 1 回構成されたすべての実行で永続的
- 実行レベル: 特定の実行にのみ適用され、要求ごとのカスタマイズが可能
- 実行順序: エージェント ミドルウェア (最も外側) → 実行ミドルウェア (最も内側) → エージェントの実行
ミドルウェアの終了
ミドルウェアは、 context.terminateを使用して早期に実行を終了できます。 これは、セキュリティ チェック、レート制限、または検証エラーに役立ちます。
async def blocking_middleware(
context: AgentContext,
next: Callable[[AgentContext], Awaitable[None]],
) -> None:
"""Middleware that blocks execution based on conditions."""
# Check for blocked content
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
if "blocked" in last_message.text.lower():
print("Request blocked by middleware")
context.terminate = True
return
# If no issues, continue normally
await next(context)
終了とは次のことを意味します。
- 処理
context.terminate = True停止するシグナルを設定する - 終了する前にカスタム結果を提供して、ユーザーにフィードバックを提供できます
- ミドルウェアの終了時にエージェントの実行が完全にスキップされる
ミドルウェアの結果のオーバーライド
ミドルウェアは、非ストリーミング シナリオとストリーミング シナリオの両方で結果をオーバーライドできるため、エージェントの応答を変更または完全に置き換えることができます。
context.resultの結果の種類は、エージェントの呼び出しがストリーミングか非ストリーミングかによって異なります。
-
非ストリーミング:
context.resultには、完全な応答を含むAgentResponseが含まれています -
ストリーミング:
context.resultには、AgentResponseUpdateチャンクを生成する非同期ジェネレーターが含まれています
context.is_streamingを使用して、これらのシナリオを区別し、結果のオーバーライドを適切に処理できます。
async def weather_override_middleware(
context: AgentContext,
next: Callable[[AgentContext], Awaitable[None]]
) -> None:
"""Middleware that overrides weather results for both streaming and non-streaming."""
# Execute the original agent logic
await next(context)
# Override results if present
if context.result is not None:
custom_message_parts = [
"Weather Override: ",
"Perfect weather everywhere today! ",
"22°C with gentle breezes. ",
"Great day for outdoor activities!"
]
if context.is_streaming:
# Streaming override
async def override_stream() -> AsyncIterable[AgentResponseUpdate]:
for chunk in custom_message_parts:
yield AgentResponseUpdate(contents=[Content.from_text(text=chunk)])
context.result = override_stream()
else:
# Non-streaming override
custom_message = "".join(custom_message_parts)
context.result = AgentResponse(
messages=[Message(role="assistant", contents=[custom_message])]
)
このミドルウェア アプローチを使用すると、高度な応答変換、コンテンツ フィルター処理、結果の強化、ストリーミングのカスタマイズを実装しながら、エージェント ロジックをクリーンで集中させ続けます。
ミドルウェアの完全な例
クラス ベースのミドルウェア
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentContext,
AgentMiddleware,
AgentResponse,
FunctionInvocationContext,
FunctionMiddleware,
Message,
tool,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Class-based MiddlewareTypes Example
This sample demonstrates how to implement middleware using class-based approach by inheriting
from AgentMiddleware and FunctionMiddleware base classes. The example includes:
- SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests
containing sensitive information like passwords or secrets
- LoggingFunctionMiddleware: Logs function execution details including timing and parameters
This approach is useful when you need stateful middleware or complex logic that benefits
from object-oriented design patterns.
"""
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent middleware that checks for security violations."""
async def process(
self,
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
# Check for potential security violations in the query
# Look at the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text
if "password" in query.lower() or "secret" in query.lower():
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
# Override the result with warning message
context.result = AgentResponse(
messages=[Message("assistant", ["Detected sensitive information, the request is blocked."])]
)
# Simply don't call call_next() to prevent execution
return
print("[SecurityAgentMiddleware] Security check passed.")
await call_next()
class LoggingFunctionMiddleware(FunctionMiddleware):
"""Function middleware that logs function calls."""
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await call_next()
end_time = time.time()
duration = end_time - start_time
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
async def main() -> None:
"""Example demonstrating class-based middleware."""
print("=== Class-based MiddlewareTypes Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()],
) as agent,
):
# Test with normal query
print("\n--- Normal Query ---")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
# Test with security-related query
print("--- Security Test ---")
query = "What's the password for the weather service?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
if __name__ == "__main__":
asyncio.run(main())
関数ベースのミドルウェア
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentContext,
AgentMiddleware,
AgentResponse,
FunctionInvocationContext,
FunctionMiddleware,
Message,
tool,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Class-based MiddlewareTypes Example
This sample demonstrates how to implement middleware using class-based approach by inheriting
from AgentMiddleware and FunctionMiddleware base classes. The example includes:
- SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests
containing sensitive information like passwords or secrets
- LoggingFunctionMiddleware: Logs function execution details including timing and parameters
This approach is useful when you need stateful middleware or complex logic that benefits
from object-oriented design patterns.
"""
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent middleware that checks for security violations."""
async def process(
self,
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
# Check for potential security violations in the query
# Look at the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text
if "password" in query.lower() or "secret" in query.lower():
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
# Override the result with warning message
context.result = AgentResponse(
messages=[Message("assistant", ["Detected sensitive information, the request is blocked."])]
)
# Simply don't call call_next() to prevent execution
return
print("[SecurityAgentMiddleware] Security check passed.")
await call_next()
class LoggingFunctionMiddleware(FunctionMiddleware):
"""Function middleware that logs function calls."""
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await call_next()
end_time = time.time()
duration = end_time - start_time
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
async def main() -> None:
"""Example demonstrating class-based middleware."""
print("=== Class-based MiddlewareTypes Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()],
) as agent,
):
# Test with normal query
print("\n--- Normal Query ---")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
# Test with security-related query
print("--- Security Test ---")
query = "What's the password for the weather service?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
if __name__ == "__main__":
asyncio.run(main())
デコレーター ベースのミドルウェア
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentContext,
AgentMiddleware,
AgentResponse,
FunctionInvocationContext,
FunctionMiddleware,
Message,
tool,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Class-based MiddlewareTypes Example
This sample demonstrates how to implement middleware using class-based approach by inheriting
from AgentMiddleware and FunctionMiddleware base classes. The example includes:
- SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests
containing sensitive information like passwords or secrets
- LoggingFunctionMiddleware: Logs function execution details including timing and parameters
This approach is useful when you need stateful middleware or complex logic that benefits
from object-oriented design patterns.
"""
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent middleware that checks for security violations."""
async def process(
self,
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
# Check for potential security violations in the query
# Look at the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text
if "password" in query.lower() or "secret" in query.lower():
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
# Override the result with warning message
context.result = AgentResponse(
messages=[Message("assistant", ["Detected sensitive information, the request is blocked."])]
)
# Simply don't call call_next() to prevent execution
return
print("[SecurityAgentMiddleware] Security check passed.")
await call_next()
class LoggingFunctionMiddleware(FunctionMiddleware):
"""Function middleware that logs function calls."""
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await call_next()
end_time = time.time()
duration = end_time - start_time
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
async def main() -> None:
"""Example demonstrating class-based middleware."""
print("=== Class-based MiddlewareTypes Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()],
) as agent,
):
# Test with normal query
print("\n--- Normal Query ---")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
# Test with security-related query
print("--- Security Test ---")
query = "What's the password for the weather service?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
if __name__ == "__main__":
asyncio.run(main())