Last active
July 31, 2025 16:38
-
-
Save wizhippo/0f5e14f421e780db1e1f65648db71ff4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| import logging | |
| import time | |
| from functools import wraps | |
| from typing import Any, Dict, Callable, Annotated, Literal, Union | |
| from haystack.components.agents import Agent | |
| from haystack.components.tools import ToolInvoker | |
| from haystack.dataclasses import ChatMessage, ToolCall | |
| from haystack.dataclasses import StreamingChunk | |
| from haystack.tools import tool | |
| from pydantic import BaseModel, ValidationError | |
| logger = logging.getLogger(__name__) | |
| class ConfirmationRequest(BaseModel): | |
| tool: str | |
| args: Dict[str, Any] | |
| message: str | |
| timestamp: float = None | |
| def __init__(self, **data): | |
| if 'timestamp' not in data: | |
| data['timestamp'] = time.time() | |
| super().__init__(**data) | |
| class ConfirmationError(Exception): | |
| """Custom exception for confirmation-related errors.""" | |
| pass | |
| def requires_confirmation_tool(fn: Callable = None, *, name: str = None, confirm_message: str = None, **tool_kwargs): | |
| """ | |
| Creates a tool that requires confirmation before execution. | |
| Args: | |
| fn: The function to wrap | |
| name: Override the tool name (defaults to function name) | |
| confirm_message: Custom confirmation message template (defaults to "Confirm running {name}?") | |
| """ | |
| def _decorate(f): | |
| public_name = name or f.__name__ | |
| private_name = f"{public_name}__exec" | |
| # Create the private executor that does the actual work | |
| @tool(name=private_name, **tool_kwargs) | |
| @wraps(f) | |
| def executor(**kwargs): | |
| try: | |
| return f(**kwargs) | |
| except Exception as e: | |
| logger.error(f"Error executing {private_name}: {e}") | |
| raise | |
| # Create the public tool that requests confirmation | |
| @tool(name=public_name, **tool_kwargs) | |
| @wraps(f) | |
| def requester(**kwargs): | |
| # Generate confirmation message | |
| if confirm_message: | |
| try: | |
| message = confirm_message.format(name=public_name, **kwargs) | |
| except KeyError as e: | |
| logger.warning(f"Invalid format key in confirm_message: {e}") | |
| message = f"Confirm running **{public_name}**?" | |
| else: | |
| message = f"Confirm running **{public_name}**?" | |
| return ConfirmationRequest( | |
| tool=private_name, | |
| args=kwargs, | |
| message=message, | |
| ).model_dump_json() | |
| # Attach metadata | |
| requester.confirmable = True | |
| requester.exec_tool = executor | |
| return requester | |
| return _decorate(fn) if fn else _decorate | |
| class ConfirmationAwareAgent: | |
| def __init__(self, *, chat_generator, tools, exit_conditions=None, | |
| confirmation_timeout: float = 300.0, max_pending: int = 10, **agent_kwargs): | |
| """ | |
| Initialize the confirmation-aware agent. | |
| Args: | |
| confirmation_timeout: Time in seconds before confirmations expire (default: 5 minutes) | |
| max_pending: Maximum number of pending confirmations (default: 10) | |
| """ | |
| exit_conditions = exit_conditions or ["text"] | |
| self.confirmation_timeout = confirmation_timeout | |
| self.max_pending = max_pending | |
| # Separate confirmable and regular tools | |
| self.confirmable_tools = [t for t in tools if getattr(t, "confirmable", False)] | |
| self.regular_tools = [t for t in tools if not getattr(t, "confirmable", False)] | |
| self.private_tools = [t.exec_tool for t in self.confirmable_tools if hasattr(t, "exec_tool")] | |
| # Public tools are confirmable + regular tools | |
| self.public_tools = self.confirmable_tools + self.regular_tools | |
| self.all_tools = self.public_tools + self.private_tools | |
| # Create agent with only public tools | |
| self.agent = Agent( | |
| chat_generator=chat_generator, | |
| tools=self.public_tools, | |
| exit_conditions=[t.name for t in self.confirmable_tools] + exit_conditions, | |
| **agent_kwargs, | |
| ) | |
| self.pending_confirmations: Dict[str, ConfirmationRequest] = {} | |
| self._lock = asyncio.Lock() | |
| def run(self, **kwargs) -> Dict[str, Any]: | |
| result = self.agent.run(**kwargs) | |
| return self._process_result(result) | |
| async def run_async(self, **kwargs) -> Dict[str, Any]: | |
| result = await self.agent.run_async(**kwargs) | |
| return self._process_result(result) | |
| def handle_confirmation(self, confirmation_id: str, confirmed: bool) -> ChatMessage: | |
| """Sync version of confirmation handling.""" | |
| return self._execute_confirmation_sync(confirmation_id, confirmed) | |
| async def handle_confirmation_async(self, confirmation_id: str, confirmed: bool) -> ChatMessage: | |
| """Async version of confirmation handling.""" | |
| return await self._execute_confirmation_async(confirmation_id, confirmed) | |
| def _cleanup_expired_confirmations(self): | |
| """Remove expired confirmations.""" | |
| current_time = time.time() | |
| expired_ids = [ | |
| cid for cid, req in self.pending_confirmations.items() | |
| if current_time - req.timestamp > self.confirmation_timeout | |
| ] | |
| for cid in expired_ids: | |
| del self.pending_confirmations[cid] | |
| logger.info(f"Expired confirmation {cid}") | |
| def _process_result(self, result: Dict[str, Any]) -> Dict[str, Any]: | |
| """Check if the result contains a confirmation request.""" | |
| messages = result.get("messages", []) | |
| if not messages: | |
| return result | |
| last_message = messages[-1] | |
| # Check if this is a tool result that might be a confirmation request | |
| if (last_message.is_from("tool") and | |
| last_message.tool_call_result and | |
| last_message.tool_call_result.result): | |
| try: | |
| payload = last_message.tool_call_result.result | |
| # More specific validation | |
| if isinstance(payload, str): | |
| try: | |
| confirmation_req = ConfirmationRequest.model_validate_json(payload) | |
| except ValidationError as e: | |
| logger.debug(f"Payload is not a confirmation request: {e}") | |
| return result | |
| else: | |
| try: | |
| confirmation_req = ConfirmationRequest.model_validate(payload) | |
| except ValidationError as e: | |
| logger.debug(f"Payload is not a confirmation request: {e}") | |
| return result | |
| # Clean up expired confirmations | |
| self._cleanup_expired_confirmations() | |
| # Check if we're at max pending | |
| if len(self.pending_confirmations) >= self.max_pending: | |
| logger.warning("Maximum pending confirmations reached") | |
| result["error"] = "Too many pending confirmations. Please resolve existing ones first." | |
| return result | |
| # Store the confirmation request | |
| confirmation_id = last_message.tool_call_result.origin.id | |
| self.pending_confirmations[confirmation_id] = confirmation_req | |
| # Add confirmation info to result | |
| result["confirmation_request"] = { | |
| "id": confirmation_id, | |
| "message": confirmation_req.message | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing confirmation request: {e}") | |
| result["error"] = "Error processing confirmation request" | |
| return result | |
| def _create_confirmation_sync(self, confirmation_id: str, confirmed: bool) -> Union[ChatMessage, ToolCall]: | |
| """Sync version of confirmation creation.""" | |
| confirmation_req = self.pending_confirmations.pop(confirmation_id, None) | |
| if not confirmation_req: | |
| raise ConfirmationError("Unknown or expired confirmation ID") | |
| # Check if confirmation has expired | |
| if time.time() - confirmation_req.timestamp > self.confirmation_timeout: | |
| raise ConfirmationError("Confirmation has expired") | |
| if not confirmed: | |
| return ChatMessage.from_assistant(text="❌ Operation cancelled.") | |
| # Create tool call for the private executor | |
| return ToolCall.from_dict({ | |
| "tool_name": confirmation_req.tool, | |
| "arguments": confirmation_req.args, | |
| "id": confirmation_id | |
| }) | |
| async def _create_confirmation(self, confirmation_id: str, confirmed: bool) -> Union[ChatMessage, ToolCall]: | |
| """Create confirmation response with async safety.""" | |
| async with self._lock: | |
| confirmation_req = self.pending_confirmations.pop(confirmation_id, None) | |
| if not confirmation_req: | |
| raise ConfirmationError("Unknown or expired confirmation ID") | |
| # Check if confirmation has expired | |
| if time.time() - confirmation_req.timestamp > self.confirmation_timeout: | |
| raise ConfirmationError("Confirmation has expired") | |
| if not confirmed: | |
| return ChatMessage.from_assistant(text="❌ Operation cancelled.") | |
| # Create tool call for the private executor | |
| return ToolCall.from_dict({ | |
| "tool_name": confirmation_req.tool, | |
| "arguments": confirmation_req.args, | |
| "id": confirmation_id | |
| }) | |
| def _execute_confirmation_sync(self, confirmation_id: str, confirmed: bool) -> ChatMessage: | |
| """Sync version of confirmation execution.""" | |
| try: | |
| res = self._create_confirmation_sync(confirmation_id, confirmed) | |
| if isinstance(res, ChatMessage): | |
| return res | |
| # Execute the tool call | |
| call_message = ChatMessage.from_assistant(tool_calls=[res]) | |
| try: | |
| result = ToolInvoker(tools=self.all_tools).run(messages=[call_message]) | |
| return self._extract_tool_result(result) | |
| except Exception as e: | |
| logger.error(f"Error executing tool: {e}") | |
| return ChatMessage.from_assistant(text=f"❌ Error executing tool: {str(e)}") | |
| except ConfirmationError as e: | |
| return ChatMessage.from_assistant(text=f"❌ {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Unexpected error in confirmation execution: {e}") | |
| return ChatMessage.from_assistant(text="❌ An unexpected error occurred.") | |
| async def _execute_confirmation_async(self, confirmation_id: str, confirmed: bool) -> ChatMessage: | |
| """Async version of confirmation execution.""" | |
| try: | |
| res = await self._create_confirmation(confirmation_id, confirmed) | |
| if isinstance(res, ChatMessage): | |
| return res | |
| # Execute the tool call | |
| call_message = ChatMessage.from_assistant(tool_calls=[res]) | |
| try: | |
| result = await ToolInvoker(tools=self.all_tools).run_async(messages=[call_message]) | |
| return self._extract_tool_result(result) | |
| except Exception as e: | |
| logger.error(f"Error executing tool: {e}") | |
| return ChatMessage.from_assistant(text=f"❌ Error executing tool: {str(e)}") | |
| except ConfirmationError as e: | |
| return ChatMessage.from_assistant(text=f"❌ {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Unexpected error in confirmation execution: {e}") | |
| return ChatMessage.from_assistant(text="❌ An unexpected error occurred.") | |
| @staticmethod | |
| def _extract_tool_result(result: Dict[str, Any]) -> ChatMessage: | |
| """Extract the tool result from ToolInvoker output.""" | |
| tool_messages = result.get("tool_messages", []) | |
| if not tool_messages: | |
| return ChatMessage.from_assistant(text="❌ No tool response received.") | |
| tool_result = tool_messages[0].tool_call_result | |
| return ChatMessage.from_tool( | |
| tool_result=tool_result.result, | |
| origin=tool_result.origin, | |
| error=tool_result.error | |
| ) | |
| def get_pending_confirmations(self) -> Dict[str, str]: | |
| """Get all pending confirmations with their messages.""" | |
| self._cleanup_expired_confirmations() | |
| return {cid: req.message for cid, req in self.pending_confirmations.items()} | |
| # Example usage with cleaner tool definitions | |
| def example_usage(): | |
| import chainlit as cl | |
| from haystack.components.generators.chat import OpenAIChatGenerator | |
| from haystack.utils import Secret | |
| @requires_confirmation_tool(confirm_message="Send support request to **{recipient}** about '{query}'?") | |
| def contact_support_tool( | |
| query: Annotated[str, "the email subject, must not be empty"], | |
| recipient: Annotated[Literal['fred', 'bob'], "the support team recipient, must not be empty"], | |
| ) -> str: | |
| """Contact the support team with a query.""" | |
| return f"✅ Support request sent to {recipient} team: {query}" | |
| @requires_confirmation_tool(confirm_message="Send email to {recipients} with subject '{subject}'?") | |
| def send_email_tool( | |
| recipients: Annotated[list[str], "the email recipients, must not be empty, if missing ask for it"], | |
| subject: Annotated[str, "the email subject, must not be empty, if missing ask for it"], | |
| body: Annotated[str, "the email body, must not be empty, if missing ask for it"], | |
| ) -> str: | |
| """Email the recipients.""" | |
| return f"✅ Email sent to {recipients} with subject '{subject}' and body '{body}'" | |
| chat_gen = OpenAIChatGenerator( | |
| model="gpt-4o-mini", | |
| api_key=Secret.from_env_var("OPENAI_API_KEY"), | |
| ) | |
| @cl.set_starters | |
| async def set_starters(): | |
| return [ | |
| cl.Starter( | |
| label="Send email to [email protected]", | |
| message="Send an email to [email protected] with subject 'Hello' and body 'How are you?'", | |
| icon="✉️", | |
| ), | |
| cl.Starter( | |
| label="Contact support", | |
| message="Contact billing support about my invoice question", | |
| icon="🎧", | |
| ), | |
| cl.Starter( | |
| label="Contact support", | |
| message="Contact billing support about my invoice question confirmed=true", | |
| icon="🎧", | |
| ), | |
| ] | |
| @cl.on_chat_start | |
| async def on_chat_start(): | |
| agent = ConfirmationAwareAgent( | |
| chat_generator=chat_gen, | |
| tools=[contact_support_tool, send_email_tool] | |
| ) | |
| cl.user_session.set("agent", agent) | |
| cl.user_session.set("message_history", []) | |
| @cl.on_message | |
| async def on_message(message: cl.Message): | |
| agent = cl.user_session.get("agent") | |
| query = message.content | |
| message_history = cl.user_session.get("message_history", []) | |
| msg = await cl.Message(content="Thinking...").send() | |
| if not agent: | |
| msg.content = "Error: Agent not initialized!" | |
| await msg.update() | |
| return | |
| async def my_streaming_callback(chunk: StreamingChunk): | |
| if chunk.content: | |
| if chunk.start: | |
| msg.content = "" | |
| msg.content = msg.content + chunk.content | |
| await msg.update() | |
| # Run the agent | |
| result = await agent.run_async( | |
| messages=message_history + [ChatMessage.from_user(query)], | |
| streaming_callback=my_streaming_callback, | |
| tool_invoker_kwargs={ | |
| "streaming_callback": my_streaming_callback, | |
| "enable_streaming_callback_passthrough": True | |
| } | |
| ) | |
| if messages := result.get("messages"): | |
| message_history = messages | |
| # Handle confirmation request | |
| if confirmation := result.get("confirmation_request"): | |
| await msg.remove() | |
| ask = await cl.AskActionMessage( | |
| content=confirmation["message"], | |
| actions=[ | |
| cl.Action( | |
| name="confirm", | |
| payload={"cid": confirmation["id"], "confirmed": True}, | |
| label="✅ Confirm" | |
| ), | |
| cl.Action( | |
| name="cancel", | |
| payload={"cid": confirmation["id"], "confirmed": False}, | |
| label="❌ Cancel" | |
| ), | |
| ], | |
| ).send() | |
| msg = await cl.Message(content="").send() | |
| if ask and (payload := ask.get("payload")): | |
| response = await agent.handle_confirmation_async( | |
| payload["cid"], | |
| payload["confirmed"] | |
| ) | |
| content = ( | |
| response.tool_call_result.result | |
| if response.tool_call_result | |
| else response.text | |
| ) | |
| message_history = message_history + [ChatMessage.from_assistant(text=content)] | |
| msg.content = content | |
| await msg.update() | |
| # Handle regular response | |
| elif messages := result.get("messages"): | |
| last_message = messages[-1] | |
| if last_message.is_from('tool') and last_message.tool_call_result: | |
| msg.content = last_message.tool_call_result.result | |
| await msg.update() | |
| else: | |
| msg.content = last_message.text or "No response" | |
| await msg.update() | |
| cl.user_session.set("message_history", message_history) | |
| if __name__ == "__main__": | |
| from chainlit.cli import run_chainlit | |
| example_usage() | |
| run_chainlit(__file__) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment