Skip to content

Instantly share code, notes, and snippets.

@theobjectivedad
Created April 17, 2025 12:35
Show Gist options
  • Select an option

  • Save theobjectivedad/9b9e47944bf25523a33ea302b09a962a to your computer and use it in GitHub Desktop.

Select an option

Save theobjectivedad/9b9e47944bf25523a33ea302b09a962a to your computer and use it in GitHub Desktop.
FastAgent Issue 88 Monkeypatch
from typing import List, Optional
from mcp.types import (
EmbeddedResource,
ImageContent,
TextContent,
)
from mcp_agent import PromptMessageMultipart, RequestParams, TextContent
from mcp_agent.agents.workflow.chain_agent import ChainAgent
from mcp_agent.core.prompt import Prompt
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
from mcp_agent.llm.providers.multipart_converter_openai import OpenAIConverter
async def _apply_prompt_provider_specific(
self: "OpenAIAugmentedLLM",
multipart_messages: List["PromptMessageMultipart"],
request_params: RequestParams | None = None,
) -> PromptMessageMultipart:
# TODO -- this is very similar to Anthropic (just the converter class changes).
# TODO -- potential refactor to base class, standardize Converter interface
# Check the last message role
last_message = multipart_messages[-1]
# Add all previous messages to history (or all messages if last is from assistant)
messages_to_add = (
multipart_messages[:-1]
if last_message.role == "user"
else multipart_messages
)
converted = []
for msg in messages_to_add:
converted.append(OpenAIConverter.convert_to_openai(msg))
self.history.extend(converted, is_prompt=True)
# Issue #88 - check if always_generate is set in request_params
# If so, always generate a response even if the last message is from
# the assistant.
if (
request_params is not None
and "always_generate" in request_params.__pydantic_extra__
):
always_generate = bool(
request_params.__pydantic_extra__["always_generate"]
)
else:
always_generate = False
# Issue #88 - check if always_generate is set in request_params
if always_generate or last_message.role == "user":
# For user messages: Generate response to the last one
self.logger.debug(
"Last message in prompt is from user, generating assistant response"
)
message_param = OpenAIConverter.convert_to_openai(last_message)
responses: List[
TextContent | ImageContent | EmbeddedResource
] = await self.generate_internal(
message_param,
request_params,
)
return Prompt.assistant(*responses)
else:
# For assistant messages: Return the last message content as text
self.logger.debug(
"Last message in prompt is from assistant, returning it directly"
)
return last_message
async def generate(
self: "ChainAgent",
multipart_messages: List[PromptMessageMultipart],
request_params: Optional[RequestParams] = None,
) -> PromptMessageMultipart:
"""
Chain the request through multiple agents in sequence.
Args:
multipart_messages: Initial messages to send to the first agent
request_params: Optional request parameters
Returns:
The response from the final agent in the chain
"""
# # Get the original user message (last message in the list)
user_message = multipart_messages[-1] if multipart_messages else None
if not self.cumulative:
response: PromptMessageMultipart = await self.agents[0].generate(
multipart_messages
)
# Process the rest of the agents in the chain
for agent in self.agents[1:]:
next_message = Prompt.user(*response.content)
response = await agent.generate(multipart_messages=[next_message])
return response
# Track all responses in the chain
all_responses: List[PromptMessageMultipart] = []
# Initialize list for storing formatted results
final_results: List[str] = []
# Add the original request with XML tag
request_text = (
f"<fastagent:request>{user_message.all_text()}</fastagent:request>"
)
final_results.append(request_text)
# Issue #88
if request_params is not None:
derived_request_params = request_params.model_copy()
derived_request_params["always_generate"] = True
else:
derived_request_params = RequestParams(always_generate=True)
# Process through each agent in sequence
for i, agent in enumerate(self.agents):
# In cumulative mode, include the original message and all previous responses
chain_messages = multipart_messages.copy()
chain_messages.extend(all_responses)
# Issue #88 - pass derived request params to the agent
current_response = await agent.generate(
chain_messages, derived_request_params
)
# Store the response
all_responses.append(current_response)
response_text = current_response.all_text()
attributed_response = f"<fastagent:response agent='{agent.name}'>{response_text}</fastagent:response>"
final_results.append(attributed_response)
if i < len(self.agents) - 1:
[Prompt.user(current_response.all_text())]
# For cumulative mode, return the properly formatted output with XML tags
response_text = "\n\n".join(final_results)
return PromptMessageMultipart(
role="assistant",
content=[TextContent(type="text", text=response_text)],
)
def apply_88() -> None:
"""
See: https://github.com/evalstate/fast-agent/issues/88
"""
# pylint: disable=protected-access
OpenAIAugmentedLLM._apply_prompt_provider_specific = (
_apply_prompt_provider_specific
)
ChainAgent.generate = generate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment