Skip to content

Instantly share code, notes, and snippets.

@wizhippo
Last active July 31, 2025 16:38
Show Gist options
  • Select an option

  • Save wizhippo/0f5e14f421e780db1e1f65648db71ff4 to your computer and use it in GitHub Desktop.

Select an option

Save wizhippo/0f5e14f421e780db1e1f65648db71ff4 to your computer and use it in GitHub Desktop.
import asyncio
import logging
import time
from functools import wraps
from typing import Any, Dict, Callable, Annotated, Literal, Union
from haystack.components.agents import Agent
from haystack.components.tools import ToolInvoker
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses import StreamingChunk
from haystack.tools import tool
from pydantic import BaseModel, ValidationError
logger = logging.getLogger(__name__)
class ConfirmationRequest(BaseModel):
tool: str
args: Dict[str, Any]
message: str
timestamp: float = None
def __init__(self, **data):
if 'timestamp' not in data:
data['timestamp'] = time.time()
super().__init__(**data)
class ConfirmationError(Exception):
"""Custom exception for confirmation-related errors."""
pass
def requires_confirmation_tool(fn: Callable = None, *, name: str = None, confirm_message: str = None, **tool_kwargs):
"""
Creates a tool that requires confirmation before execution.
Args:
fn: The function to wrap
name: Override the tool name (defaults to function name)
confirm_message: Custom confirmation message template (defaults to "Confirm running {name}?")
"""
def _decorate(f):
public_name = name or f.__name__
private_name = f"{public_name}__exec"
# Create the private executor that does the actual work
@tool(name=private_name, **tool_kwargs)
@wraps(f)
def executor(**kwargs):
try:
return f(**kwargs)
except Exception as e:
logger.error(f"Error executing {private_name}: {e}")
raise
# Create the public tool that requests confirmation
@tool(name=public_name, **tool_kwargs)
@wraps(f)
def requester(**kwargs):
# Generate confirmation message
if confirm_message:
try:
message = confirm_message.format(name=public_name, **kwargs)
except KeyError as e:
logger.warning(f"Invalid format key in confirm_message: {e}")
message = f"Confirm running **{public_name}**?"
else:
message = f"Confirm running **{public_name}**?"
return ConfirmationRequest(
tool=private_name,
args=kwargs,
message=message,
).model_dump_json()
# Attach metadata
requester.confirmable = True
requester.exec_tool = executor
return requester
return _decorate(fn) if fn else _decorate
class ConfirmationAwareAgent:
def __init__(self, *, chat_generator, tools, exit_conditions=None,
confirmation_timeout: float = 300.0, max_pending: int = 10, **agent_kwargs):
"""
Initialize the confirmation-aware agent.
Args:
confirmation_timeout: Time in seconds before confirmations expire (default: 5 minutes)
max_pending: Maximum number of pending confirmations (default: 10)
"""
exit_conditions = exit_conditions or ["text"]
self.confirmation_timeout = confirmation_timeout
self.max_pending = max_pending
# Separate confirmable and regular tools
self.confirmable_tools = [t for t in tools if getattr(t, "confirmable", False)]
self.regular_tools = [t for t in tools if not getattr(t, "confirmable", False)]
self.private_tools = [t.exec_tool for t in self.confirmable_tools if hasattr(t, "exec_tool")]
# Public tools are confirmable + regular tools
self.public_tools = self.confirmable_tools + self.regular_tools
self.all_tools = self.public_tools + self.private_tools
# Create agent with only public tools
self.agent = Agent(
chat_generator=chat_generator,
tools=self.public_tools,
exit_conditions=[t.name for t in self.confirmable_tools] + exit_conditions,
**agent_kwargs,
)
self.pending_confirmations: Dict[str, ConfirmationRequest] = {}
self._lock = asyncio.Lock()
def run(self, **kwargs) -> Dict[str, Any]:
result = self.agent.run(**kwargs)
return self._process_result(result)
async def run_async(self, **kwargs) -> Dict[str, Any]:
result = await self.agent.run_async(**kwargs)
return self._process_result(result)
def handle_confirmation(self, confirmation_id: str, confirmed: bool) -> ChatMessage:
"""Sync version of confirmation handling."""
return self._execute_confirmation_sync(confirmation_id, confirmed)
async def handle_confirmation_async(self, confirmation_id: str, confirmed: bool) -> ChatMessage:
"""Async version of confirmation handling."""
return await self._execute_confirmation_async(confirmation_id, confirmed)
def _cleanup_expired_confirmations(self):
"""Remove expired confirmations."""
current_time = time.time()
expired_ids = [
cid for cid, req in self.pending_confirmations.items()
if current_time - req.timestamp > self.confirmation_timeout
]
for cid in expired_ids:
del self.pending_confirmations[cid]
logger.info(f"Expired confirmation {cid}")
def _process_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""Check if the result contains a confirmation request."""
messages = result.get("messages", [])
if not messages:
return result
last_message = messages[-1]
# Check if this is a tool result that might be a confirmation request
if (last_message.is_from("tool") and
last_message.tool_call_result and
last_message.tool_call_result.result):
try:
payload = last_message.tool_call_result.result
# More specific validation
if isinstance(payload, str):
try:
confirmation_req = ConfirmationRequest.model_validate_json(payload)
except ValidationError as e:
logger.debug(f"Payload is not a confirmation request: {e}")
return result
else:
try:
confirmation_req = ConfirmationRequest.model_validate(payload)
except ValidationError as e:
logger.debug(f"Payload is not a confirmation request: {e}")
return result
# Clean up expired confirmations
self._cleanup_expired_confirmations()
# Check if we're at max pending
if len(self.pending_confirmations) >= self.max_pending:
logger.warning("Maximum pending confirmations reached")
result["error"] = "Too many pending confirmations. Please resolve existing ones first."
return result
# Store the confirmation request
confirmation_id = last_message.tool_call_result.origin.id
self.pending_confirmations[confirmation_id] = confirmation_req
# Add confirmation info to result
result["confirmation_request"] = {
"id": confirmation_id,
"message": confirmation_req.message
}
except Exception as e:
logger.error(f"Error processing confirmation request: {e}")
result["error"] = "Error processing confirmation request"
return result
def _create_confirmation_sync(self, confirmation_id: str, confirmed: bool) -> Union[ChatMessage, ToolCall]:
"""Sync version of confirmation creation."""
confirmation_req = self.pending_confirmations.pop(confirmation_id, None)
if not confirmation_req:
raise ConfirmationError("Unknown or expired confirmation ID")
# Check if confirmation has expired
if time.time() - confirmation_req.timestamp > self.confirmation_timeout:
raise ConfirmationError("Confirmation has expired")
if not confirmed:
return ChatMessage.from_assistant(text="❌ Operation cancelled.")
# Create tool call for the private executor
return ToolCall.from_dict({
"tool_name": confirmation_req.tool,
"arguments": confirmation_req.args,
"id": confirmation_id
})
async def _create_confirmation(self, confirmation_id: str, confirmed: bool) -> Union[ChatMessage, ToolCall]:
"""Create confirmation response with async safety."""
async with self._lock:
confirmation_req = self.pending_confirmations.pop(confirmation_id, None)
if not confirmation_req:
raise ConfirmationError("Unknown or expired confirmation ID")
# Check if confirmation has expired
if time.time() - confirmation_req.timestamp > self.confirmation_timeout:
raise ConfirmationError("Confirmation has expired")
if not confirmed:
return ChatMessage.from_assistant(text="❌ Operation cancelled.")
# Create tool call for the private executor
return ToolCall.from_dict({
"tool_name": confirmation_req.tool,
"arguments": confirmation_req.args,
"id": confirmation_id
})
def _execute_confirmation_sync(self, confirmation_id: str, confirmed: bool) -> ChatMessage:
"""Sync version of confirmation execution."""
try:
res = self._create_confirmation_sync(confirmation_id, confirmed)
if isinstance(res, ChatMessage):
return res
# Execute the tool call
call_message = ChatMessage.from_assistant(tool_calls=[res])
try:
result = ToolInvoker(tools=self.all_tools).run(messages=[call_message])
return self._extract_tool_result(result)
except Exception as e:
logger.error(f"Error executing tool: {e}")
return ChatMessage.from_assistant(text=f"❌ Error executing tool: {str(e)}")
except ConfirmationError as e:
return ChatMessage.from_assistant(text=f"❌ {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in confirmation execution: {e}")
return ChatMessage.from_assistant(text="❌ An unexpected error occurred.")
async def _execute_confirmation_async(self, confirmation_id: str, confirmed: bool) -> ChatMessage:
"""Async version of confirmation execution."""
try:
res = await self._create_confirmation(confirmation_id, confirmed)
if isinstance(res, ChatMessage):
return res
# Execute the tool call
call_message = ChatMessage.from_assistant(tool_calls=[res])
try:
result = await ToolInvoker(tools=self.all_tools).run_async(messages=[call_message])
return self._extract_tool_result(result)
except Exception as e:
logger.error(f"Error executing tool: {e}")
return ChatMessage.from_assistant(text=f"❌ Error executing tool: {str(e)}")
except ConfirmationError as e:
return ChatMessage.from_assistant(text=f"❌ {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in confirmation execution: {e}")
return ChatMessage.from_assistant(text="❌ An unexpected error occurred.")
@staticmethod
def _extract_tool_result(result: Dict[str, Any]) -> ChatMessage:
"""Extract the tool result from ToolInvoker output."""
tool_messages = result.get("tool_messages", [])
if not tool_messages:
return ChatMessage.from_assistant(text="❌ No tool response received.")
tool_result = tool_messages[0].tool_call_result
return ChatMessage.from_tool(
tool_result=tool_result.result,
origin=tool_result.origin,
error=tool_result.error
)
def get_pending_confirmations(self) -> Dict[str, str]:
"""Get all pending confirmations with their messages."""
self._cleanup_expired_confirmations()
return {cid: req.message for cid, req in self.pending_confirmations.items()}
# Example usage with cleaner tool definitions
def example_usage():
import chainlit as cl
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.utils import Secret
@requires_confirmation_tool(confirm_message="Send support request to **{recipient}** about '{query}'?")
def contact_support_tool(
query: Annotated[str, "the email subject, must not be empty"],
recipient: Annotated[Literal['fred', 'bob'], "the support team recipient, must not be empty"],
) -> str:
"""Contact the support team with a query."""
return f"✅ Support request sent to {recipient} team: {query}"
@requires_confirmation_tool(confirm_message="Send email to {recipients} with subject '{subject}'?")
def send_email_tool(
recipients: Annotated[list[str], "the email recipients, must not be empty, if missing ask for it"],
subject: Annotated[str, "the email subject, must not be empty, if missing ask for it"],
body: Annotated[str, "the email body, must not be empty, if missing ask for it"],
) -> str:
"""Email the recipients."""
return f"✅ Email sent to {recipients} with subject '{subject}' and body '{body}'"
chat_gen = OpenAIChatGenerator(
model="gpt-4o-mini",
api_key=Secret.from_env_var("OPENAI_API_KEY"),
)
@cl.set_starters
async def set_starters():
return [
cl.Starter(
label="Send email to [email protected]",
message="Send an email to [email protected] with subject 'Hello' and body 'How are you?'",
icon="✉️",
),
cl.Starter(
label="Contact support",
message="Contact billing support about my invoice question",
icon="🎧",
),
cl.Starter(
label="Contact support",
message="Contact billing support about my invoice question confirmed=true",
icon="🎧",
),
]
@cl.on_chat_start
async def on_chat_start():
agent = ConfirmationAwareAgent(
chat_generator=chat_gen,
tools=[contact_support_tool, send_email_tool]
)
cl.user_session.set("agent", agent)
cl.user_session.set("message_history", [])
@cl.on_message
async def on_message(message: cl.Message):
agent = cl.user_session.get("agent")
query = message.content
message_history = cl.user_session.get("message_history", [])
msg = await cl.Message(content="Thinking...").send()
if not agent:
msg.content = "Error: Agent not initialized!"
await msg.update()
return
async def my_streaming_callback(chunk: StreamingChunk):
if chunk.content:
if chunk.start:
msg.content = ""
msg.content = msg.content + chunk.content
await msg.update()
# Run the agent
result = await agent.run_async(
messages=message_history + [ChatMessage.from_user(query)],
streaming_callback=my_streaming_callback,
tool_invoker_kwargs={
"streaming_callback": my_streaming_callback,
"enable_streaming_callback_passthrough": True
}
)
if messages := result.get("messages"):
message_history = messages
# Handle confirmation request
if confirmation := result.get("confirmation_request"):
await msg.remove()
ask = await cl.AskActionMessage(
content=confirmation["message"],
actions=[
cl.Action(
name="confirm",
payload={"cid": confirmation["id"], "confirmed": True},
label="✅ Confirm"
),
cl.Action(
name="cancel",
payload={"cid": confirmation["id"], "confirmed": False},
label="❌ Cancel"
),
],
).send()
msg = await cl.Message(content="").send()
if ask and (payload := ask.get("payload")):
response = await agent.handle_confirmation_async(
payload["cid"],
payload["confirmed"]
)
content = (
response.tool_call_result.result
if response.tool_call_result
else response.text
)
message_history = message_history + [ChatMessage.from_assistant(text=content)]
msg.content = content
await msg.update()
# Handle regular response
elif messages := result.get("messages"):
last_message = messages[-1]
if last_message.is_from('tool') and last_message.tool_call_result:
msg.content = last_message.tool_call_result.result
await msg.update()
else:
msg.content = last_message.text or "No response"
await msg.update()
cl.user_session.set("message_history", message_history)
if __name__ == "__main__":
from chainlit.cli import run_chainlit
example_usage()
run_chainlit(__file__)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment