Last active
March 12, 2025 13:46
-
-
Save mtnbarreto/72800b64cd8e672802a15e39b7c88761 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| type SwarmStepResult<TOOLS extends Record<string, CoreTool>> = StepResult<TOOLS> & { agent: Agent }; | |
| type SwarmChunk = Extract< | |
| TextStreamPart<any>, | |
| { | |
| type: "text-delta" | "reasoning" | "tool-call" | "tool-call-streaming-start" | "tool-call-delta" | "tool-result"; | |
| } | |
| > & { agent: Agent }; | |
| export function streamSwarm<CONTEXT = any>({ | |
| agent: activeAgent, | |
| prompt, | |
| context, | |
| model, | |
| maxSteps = 100, | |
| toolChoice, | |
| abortSignal, | |
| debug = false, | |
| onChunk, | |
| onStepFinish, | |
| onAgentFinish, | |
| onFinish, | |
| dataStream, | |
| }: { | |
| agent: Agent; | |
| prompt: CoreMessage[] | string; | |
| context?: CONTEXT; | |
| model: LanguageModel; | |
| maxSteps?: number; | |
| toolChoice?: CoreToolChoice<{ [k: string]: CoreTool }>; | |
| abortSignal?: AbortSignal; | |
| debug?: boolean; | |
| onChunk?: (event: { chunk: SwarmChunk }) => Promise<void> | void; | |
| onStepFinish?: (event: SwarmStepResult<any>) => Promise<void> | void; | |
| onAgentFinish?: ( | |
| event: Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] } & { agent: Agent } | |
| ) => Promise<void> | void; | |
| onFinish?: (events: (Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] } & { agent: Agent })[]) => void; | |
| dataStream?: DataStreamWriter; | |
| }): AsyncIterableStream<TextStreamPart<any>> { | |
| const initialMessages = typeof prompt === "string" ? [{ role: "user" as const, content: prompt }] : prompt; | |
| let lastResult: StreamTextResult<any, any>; | |
| const responseMessages: ResponseMessage[][] = []; | |
| const allFinishEvents: (Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] } & { agent: Agent })[] = []; | |
| // create a response stream | |
| const response = new ReadableStream<TextStreamPart<any>>({ | |
| async start(controller) { | |
| // invoke while max step is valid | |
| outerLoop: while (responseMessages.length < maxSteps) { | |
| lastResult = streamText({ | |
| model: activeAgent.model ?? model, | |
| system: typeof activeAgent.system === "function" ? activeAgent.system(context) : activeAgent.system, | |
| abortSignal, | |
| tools: toolsFromAgent<CONTEXT>(activeAgent, context), | |
| maxSteps: maxSteps, | |
| toolChoice: activeAgent.toolChoice ?? toolChoice, | |
| messages: [...initialMessages, ...responseMessages.flat()], | |
| onChunk({ chunk }) { | |
| return onChunk?.({ chunk: { ...chunk, agent: activeAgent } }) ?? Promise.resolve(); | |
| }, | |
| onStepFinish: (event: StepResult<any>) => { | |
| // add step messages to the response messages, we will pass it to the next streamText call | |
| responseMessages.push(event.response.messages); | |
| return onStepFinish?.({ ...event, agent: activeAgent }) ?? Promise.resolve(); | |
| }, | |
| onFinish: (event: Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] }) => { | |
| // adding messages to the response messages | |
| const lastStep = event.steps.at(-1); | |
| const previousActiveAgent = activeAgent; | |
| if (lastStep?.finishReason === "tool-calls") { | |
| // the generation stopped with an unhandled tool call | |
| const { toolCalls, toolResults } = lastStep; | |
| const unhandledHandoverCalls = getUnhandledHandoverCalls(toolCalls, toolResults, activeAgent); | |
| // take the first handover call (other handover calls are ignored) | |
| let handoverToolResult: ToolResultPart | undefined = undefined; | |
| // process handover calls | |
| if (unhandledHandoverCalls.length > 0) { | |
| const handoverTool = activeAgent.tools?.[unhandledHandoverCalls[0].toolName]! as AgentHandoverTool<CONTEXT, any>; | |
| const result = handoverTool.execute(unhandledHandoverCalls[0].args, { | |
| context: context as any, | |
| abortSignal, | |
| }); | |
| activeAgent = result.agent; | |
| context = result.context ?? context; | |
| if (debug) { | |
| console.log(`\x1b[36mHanding over to agent ${activeAgent.name}\x1b[0m`); | |
| if (result.context != null) { | |
| console.log(`\x1b[36mUpdated context: ${JSON.stringify(result.context, null, 2)}\x1b[0m`); | |
| } | |
| } | |
| handoverToolResult = { | |
| type: "tool-result", | |
| toolCallId: unhandledHandoverCalls[0].toolCallId, | |
| toolName: unhandledHandoverCalls[0].toolName, | |
| result: `Handing over to agent ${activeAgent.name}`, | |
| }; | |
| // update responseMessages to add tool result of the handover call | |
| const toolMessage: CoreToolMessage | undefined = | |
| responseMessages.at(-1)?.at(-1)?.role === "tool" ? (responseMessages.at(-1)?.at(-1) as CoreToolMessage) : undefined; | |
| const assistantMessage: CoreAssistantMessage = responseMessages | |
| .at(-1) | |
| ?.at(toolMessage === undefined ? -1 : -2) as CoreAssistantMessage; | |
| // add handover tool result | |
| toolMessage?.content.push(handoverToolResult) ?? | |
| responseMessages.at(-1)?.push({ role: "tool", content: [handoverToolResult], id: generateId() }); | |
| // clean out unused handover tool calls | |
| if (typeof assistantMessage.content !== "string") { | |
| const unusedHandoverToolCallIds = unhandledHandoverCalls.filter((_, index) => index > 0).map((call) => call.toolCallId); | |
| assistantMessage.content = assistantMessage.content.filter((part) => { | |
| return part.type === "tool-call" ? !unusedHandoverToolCallIds.includes(part.toolCallId) : true; | |
| }); | |
| } | |
| } | |
| } | |
| // just fix the return to call onFinish callback | |
| // get all event.steps until last and create the last using responseMessages | |
| const steps = event.steps.slice(0, -1); | |
| const lastStepToFix = event.steps.at(-1)!; | |
| const lastStepResult: StepResult<any> = { | |
| ...lastStepToFix, | |
| response: { ...lastStepToFix.response, messages: responseMessages.at(-1)! }, | |
| }; | |
| // create new event with the fixed last step | |
| const newEvent: Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] } & { agent: Agent } = { | |
| ...event, | |
| steps: [...steps, lastStepResult], | |
| agent: previousActiveAgent, | |
| }; | |
| allFinishEvents.push(newEvent); | |
| const ret = onAgentFinish?.(newEvent) ?? Promise.resolve(); | |
| if (lastStep?.finishReason !== "tool-calls") { | |
| onFinish?.(allFinishEvents); | |
| } | |
| return ret; | |
| }, | |
| }); | |
| const fullStreamReader: ReadableStreamReader<TextStreamPart<any>> = lastResult.fullStream.getReader(); | |
| // store tool calls and results | |
| let recordedToolCalls: Array<CoreToolCallUnion<any>> = []; | |
| let recordedToolResults: Array<CoreToolResultUnion<any>> = []; | |
| innerLoop: while (true) { | |
| // read the full stream | |
| const { done, value } = await fullStreamReader.read(); | |
| if (done) { | |
| break innerLoop; | |
| } | |
| console.log("Chunk value => ", value); | |
| if (value.type === "tool-call") { | |
| controller.enqueue(value); | |
| // save tool call | |
| recordedToolCalls.push(value); | |
| } else if (value.type === "tool-result") { | |
| controller.enqueue(value); | |
| // save tool result | |
| recordedToolResults.push(value); | |
| } else if (value.type === "step-finish") { | |
| // process tool calls | |
| if (value.finishReason !== "tool-calls") { | |
| // add step finish to the response messages | |
| controller.enqueue({ ...value }); | |
| } else { | |
| //process tool calls | |
| const handoverCalls = getUnhandledHandoverCalls(recordedToolCalls, recordedToolResults, activeAgent); | |
| // take the first handover call (other handover calls are ignored) | |
| let handoverToolResult: ToolResultPart | undefined = undefined; | |
| if (handoverCalls.length > 0) { | |
| const handoverToolResultTextStreamPart: TextStreamPart<any> = { | |
| type: "tool-result", | |
| toolCallId: handoverCalls[0].toolCallId, | |
| toolName: handoverCalls[0].toolName, | |
| result: `Handing over to agent ${handoverCalls[0].toolName}`, | |
| args: handoverCalls[0].args, | |
| }; | |
| // add tool response message to the response messages | |
| controller.enqueue({ ...value, isContinued: true }); | |
| controller.enqueue(handoverToolResultTextStreamPart); | |
| } else { | |
| controller.enqueue({ ...value }); | |
| } | |
| } | |
| } else if (value.type === "finish" && value.finishReason !== "tool-calls") { | |
| // we are done, add finish message to the response messages | |
| controller.enqueue(value); | |
| break outerLoop; | |
| } else { | |
| controller.enqueue(value); | |
| } | |
| } // end inner loop | |
| } // end outer loop | |
| controller.close(); | |
| }, | |
| }); | |
| return createAsyncIterableStream(response); | |
| } | |
| export type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>; | |
| export function createAsyncIterableStream<T>(source: ReadableStream<T>): AsyncIterableStream<T> { | |
| const stream = source.pipeThrough(new TransformStream<T, T>()); | |
| (stream as AsyncIterableStream<T>)[Symbol.asyncIterator] = () => { | |
| const reader = stream.getReader(); | |
| return { | |
| async next(): Promise<IteratorResult<T>> { | |
| const { done, value } = await reader.read(); | |
| return done ? { done: true, value: undefined } : { done: false, value }; | |
| }, | |
| }; | |
| }; | |
| return stream as AsyncIterableStream<T>; | |
| } | |
| import { formatDataStreamPart } from "ai"; | |
| import { TextStreamPart } from "ai"; | |
| import { DataStreamString } from "@ai-sdk/ui-utils"; | |
| export const getStreamPartsTransformer = ({ | |
| getErrorMessage = () => "An error occurred", | |
| sendUsage = true, | |
| sendReasoning = true, | |
| }: { | |
| getErrorMessage?: (error: any) => string; | |
| sendUsage?: boolean; | |
| sendReasoning?: boolean; | |
| }) => | |
| new TransformStream<TextStreamPart<any>, DataStreamString>({ | |
| transform: async (chunk, controller) => { | |
| const chunkType = chunk.type; | |
| switch (chunkType) { | |
| case "text-delta": { | |
| controller.enqueue(formatDataStreamPart("text", chunk.textDelta)); | |
| break; | |
| } | |
| case "reasoning": { | |
| if (sendReasoning) { | |
| controller.enqueue(formatDataStreamPart("reasoning", chunk.textDelta)); | |
| } | |
| break; | |
| } | |
| case "tool-call-streaming-start": { | |
| controller.enqueue( | |
| formatDataStreamPart("tool_call_streaming_start", { | |
| toolCallId: chunk.toolCallId, | |
| toolName: chunk.toolName, | |
| }) | |
| ); | |
| break; | |
| } | |
| case "tool-call-delta": { | |
| controller.enqueue( | |
| formatDataStreamPart("tool_call_delta", { | |
| toolCallId: chunk.toolCallId, | |
| argsTextDelta: chunk.argsTextDelta, | |
| }) | |
| ); | |
| break; | |
| } | |
| case "tool-call": { | |
| controller.enqueue( | |
| formatDataStreamPart("tool_call", { | |
| toolCallId: chunk.toolCallId, | |
| toolName: chunk.toolName, | |
| args: chunk.args, | |
| }) | |
| ); | |
| break; | |
| } | |
| case "tool-result": { | |
| controller.enqueue( | |
| formatDataStreamPart("tool_result", { | |
| toolCallId: chunk.toolCallId, | |
| result: chunk.result, | |
| }) | |
| ); | |
| break; | |
| } | |
| case "error": { | |
| controller.enqueue(formatDataStreamPart("error", getErrorMessage(chunk.error))); | |
| break; | |
| } | |
| case "step-start": { | |
| controller.enqueue( | |
| formatDataStreamPart("start_step", { | |
| messageId: chunk.messageId, | |
| }) | |
| ); | |
| break; | |
| } | |
| case "step-finish": { | |
| controller.enqueue( | |
| formatDataStreamPart("finish_step", { | |
| finishReason: chunk.finishReason, | |
| usage: sendUsage | |
| ? { | |
| promptTokens: chunk.usage.promptTokens, | |
| completionTokens: chunk.usage.completionTokens, | |
| } | |
| : undefined, | |
| isContinued: chunk.isContinued, | |
| }) | |
| ); | |
| break; | |
| } | |
| case "finish": { | |
| controller.enqueue( | |
| formatDataStreamPart("finish_message", { | |
| finishReason: chunk.finishReason, | |
| usage: sendUsage | |
| ? { | |
| promptTokens: chunk.usage.promptTokens, | |
| completionTokens: chunk.usage.completionTokens, | |
| } | |
| : undefined, | |
| }) | |
| ); | |
| break; | |
| } | |
| default: { | |
| const exhaustiveCheck: never = chunkType; | |
| throw new Error(`Unknown chunk type: ${exhaustiveCheck}`); | |
| } | |
| } | |
| }, | |
| }); | |
| export type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>; | |
| export function createAsyncIterableStream<T>(source: ReadableStream<T>): AsyncIterableStream<T> { | |
| const stream = source.pipeThrough(new TransformStream<T, T>()); | |
| (stream as AsyncIterableStream<T>)[Symbol.asyncIterator] = () => { | |
| const reader = stream.getReader(); | |
| return { | |
| async next(): Promise<IteratorResult<T>> { | |
| const { done, value } = await reader.read(); | |
| return done ? { done: true, value: undefined } : { done: false, value }; | |
| }, | |
| }; | |
| }; | |
| return stream as AsyncIterableStream<T>; | |
| } | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment