Skip to content

Instantly share code, notes, and snippets.

@mtnbarreto
Last active March 12, 2025 13:46
Show Gist options
  • Select an option

  • Save mtnbarreto/72800b64cd8e672802a15e39b7c88761 to your computer and use it in GitHub Desktop.

Select an option

Save mtnbarreto/72800b64cd8e672802a15e39b7c88761 to your computer and use it in GitHub Desktop.
type SwarmStepResult<TOOLS extends Record<string, CoreTool>> = StepResult<TOOLS> & { agent: Agent };
type SwarmChunk = Extract<
TextStreamPart<any>,
{
type: "text-delta" | "reasoning" | "tool-call" | "tool-call-streaming-start" | "tool-call-delta" | "tool-result";
}
> & { agent: Agent };
export function streamSwarm<CONTEXT = any>({
agent: activeAgent,
prompt,
context,
model,
maxSteps = 100,
toolChoice,
abortSignal,
debug = false,
onChunk,
onStepFinish,
onAgentFinish,
onFinish,
dataStream,
}: {
agent: Agent;
prompt: CoreMessage[] | string;
context?: CONTEXT;
model: LanguageModel;
maxSteps?: number;
toolChoice?: CoreToolChoice<{ [k: string]: CoreTool }>;
abortSignal?: AbortSignal;
debug?: boolean;
onChunk?: (event: { chunk: SwarmChunk }) => Promise<void> | void;
onStepFinish?: (event: SwarmStepResult<any>) => Promise<void> | void;
onAgentFinish?: (
event: Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] } & { agent: Agent }
) => Promise<void> | void;
onFinish?: (events: (Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] } & { agent: Agent })[]) => void;
dataStream?: DataStreamWriter;
}): AsyncIterableStream<TextStreamPart<any>> {
const initialMessages = typeof prompt === "string" ? [{ role: "user" as const, content: prompt }] : prompt;
let lastResult: StreamTextResult<any, any>;
const responseMessages: ResponseMessage[][] = [];
const allFinishEvents: (Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] } & { agent: Agent })[] = [];
// create a response stream
const response = new ReadableStream<TextStreamPart<any>>({
async start(controller) {
// invoke while max step is valid
outerLoop: while (responseMessages.length < maxSteps) {
lastResult = streamText({
model: activeAgent.model ?? model,
system: typeof activeAgent.system === "function" ? activeAgent.system(context) : activeAgent.system,
abortSignal,
tools: toolsFromAgent<CONTEXT>(activeAgent, context),
maxSteps: maxSteps,
toolChoice: activeAgent.toolChoice ?? toolChoice,
messages: [...initialMessages, ...responseMessages.flat()],
onChunk({ chunk }) {
return onChunk?.({ chunk: { ...chunk, agent: activeAgent } }) ?? Promise.resolve();
},
onStepFinish: (event: StepResult<any>) => {
// add step messages to the response messages, we will pass it to the next streamText call
responseMessages.push(event.response.messages);
return onStepFinish?.({ ...event, agent: activeAgent }) ?? Promise.resolve();
},
onFinish: (event: Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] }) => {
// adding messages to the response messages
const lastStep = event.steps.at(-1);
const previousActiveAgent = activeAgent;
if (lastStep?.finishReason === "tool-calls") {
// the generation stopped with an unhandled tool call
const { toolCalls, toolResults } = lastStep;
const unhandledHandoverCalls = getUnhandledHandoverCalls(toolCalls, toolResults, activeAgent);
// take the first handover call (other handover calls are ignored)
let handoverToolResult: ToolResultPart | undefined = undefined;
// process handover calls
if (unhandledHandoverCalls.length > 0) {
const handoverTool = activeAgent.tools?.[unhandledHandoverCalls[0].toolName]! as AgentHandoverTool<CONTEXT, any>;
const result = handoverTool.execute(unhandledHandoverCalls[0].args, {
context: context as any,
abortSignal,
});
activeAgent = result.agent;
context = result.context ?? context;
if (debug) {
console.log(`\x1b[36mHanding over to agent ${activeAgent.name}\x1b[0m`);
if (result.context != null) {
console.log(`\x1b[36mUpdated context: ${JSON.stringify(result.context, null, 2)}\x1b[0m`);
}
}
handoverToolResult = {
type: "tool-result",
toolCallId: unhandledHandoverCalls[0].toolCallId,
toolName: unhandledHandoverCalls[0].toolName,
result: `Handing over to agent ${activeAgent.name}`,
};
// update responseMessages to add tool result of the handover call
const toolMessage: CoreToolMessage | undefined =
responseMessages.at(-1)?.at(-1)?.role === "tool" ? (responseMessages.at(-1)?.at(-1) as CoreToolMessage) : undefined;
const assistantMessage: CoreAssistantMessage = responseMessages
.at(-1)
?.at(toolMessage === undefined ? -1 : -2) as CoreAssistantMessage;
// add handover tool result
toolMessage?.content.push(handoverToolResult) ??
responseMessages.at(-1)?.push({ role: "tool", content: [handoverToolResult], id: generateId() });
// clean out unused handover tool calls
if (typeof assistantMessage.content !== "string") {
const unusedHandoverToolCallIds = unhandledHandoverCalls.filter((_, index) => index > 0).map((call) => call.toolCallId);
assistantMessage.content = assistantMessage.content.filter((part) => {
return part.type === "tool-call" ? !unusedHandoverToolCallIds.includes(part.toolCallId) : true;
});
}
}
}
// just fix the return to call onFinish callback
// get all event.steps until last and create the last using responseMessages
const steps = event.steps.slice(0, -1);
const lastStepToFix = event.steps.at(-1)!;
const lastStepResult: StepResult<any> = {
...lastStepToFix,
response: { ...lastStepToFix.response, messages: responseMessages.at(-1)! },
};
// create new event with the fixed last step
const newEvent: Omit<StepResult<any>, "stepType" | "isContinued"> & { readonly steps: StepResult<any>[] } & { agent: Agent } = {
...event,
steps: [...steps, lastStepResult],
agent: previousActiveAgent,
};
allFinishEvents.push(newEvent);
const ret = onAgentFinish?.(newEvent) ?? Promise.resolve();
if (lastStep?.finishReason !== "tool-calls") {
onFinish?.(allFinishEvents);
}
return ret;
},
});
const fullStreamReader: ReadableStreamReader<TextStreamPart<any>> = lastResult.fullStream.getReader();
// store tool calls and results
let recordedToolCalls: Array<CoreToolCallUnion<any>> = [];
let recordedToolResults: Array<CoreToolResultUnion<any>> = [];
innerLoop: while (true) {
// read the full stream
const { done, value } = await fullStreamReader.read();
if (done) {
break innerLoop;
}
console.log("Chunk value => ", value);
if (value.type === "tool-call") {
controller.enqueue(value);
// save tool call
recordedToolCalls.push(value);
} else if (value.type === "tool-result") {
controller.enqueue(value);
// save tool result
recordedToolResults.push(value);
} else if (value.type === "step-finish") {
// process tool calls
if (value.finishReason !== "tool-calls") {
// add step finish to the response messages
controller.enqueue({ ...value });
} else {
//process tool calls
const handoverCalls = getUnhandledHandoverCalls(recordedToolCalls, recordedToolResults, activeAgent);
// take the first handover call (other handover calls are ignored)
let handoverToolResult: ToolResultPart | undefined = undefined;
if (handoverCalls.length > 0) {
const handoverToolResultTextStreamPart: TextStreamPart<any> = {
type: "tool-result",
toolCallId: handoverCalls[0].toolCallId,
toolName: handoverCalls[0].toolName,
result: `Handing over to agent ${handoverCalls[0].toolName}`,
args: handoverCalls[0].args,
};
// add tool response message to the response messages
controller.enqueue({ ...value, isContinued: true });
controller.enqueue(handoverToolResultTextStreamPart);
} else {
controller.enqueue({ ...value });
}
}
} else if (value.type === "finish" && value.finishReason !== "tool-calls") {
// we are done, add finish message to the response messages
controller.enqueue(value);
break outerLoop;
} else {
controller.enqueue(value);
}
} // end inner loop
} // end outer loop
controller.close();
},
});
return createAsyncIterableStream(response);
}
export type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
export function createAsyncIterableStream<T>(source: ReadableStream<T>): AsyncIterableStream<T> {
const stream = source.pipeThrough(new TransformStream<T, T>());
(stream as AsyncIterableStream<T>)[Symbol.asyncIterator] = () => {
const reader = stream.getReader();
return {
async next(): Promise<IteratorResult<T>> {
const { done, value } = await reader.read();
return done ? { done: true, value: undefined } : { done: false, value };
},
};
};
return stream as AsyncIterableStream<T>;
}
import { formatDataStreamPart } from "ai";
import { TextStreamPart } from "ai";
import { DataStreamString } from "@ai-sdk/ui-utils";
export const getStreamPartsTransformer = ({
getErrorMessage = () => "An error occurred",
sendUsage = true,
sendReasoning = true,
}: {
getErrorMessage?: (error: any) => string;
sendUsage?: boolean;
sendReasoning?: boolean;
}) =>
new TransformStream<TextStreamPart<any>, DataStreamString>({
transform: async (chunk, controller) => {
const chunkType = chunk.type;
switch (chunkType) {
case "text-delta": {
controller.enqueue(formatDataStreamPart("text", chunk.textDelta));
break;
}
case "reasoning": {
if (sendReasoning) {
controller.enqueue(formatDataStreamPart("reasoning", chunk.textDelta));
}
break;
}
case "tool-call-streaming-start": {
controller.enqueue(
formatDataStreamPart("tool_call_streaming_start", {
toolCallId: chunk.toolCallId,
toolName: chunk.toolName,
})
);
break;
}
case "tool-call-delta": {
controller.enqueue(
formatDataStreamPart("tool_call_delta", {
toolCallId: chunk.toolCallId,
argsTextDelta: chunk.argsTextDelta,
})
);
break;
}
case "tool-call": {
controller.enqueue(
formatDataStreamPart("tool_call", {
toolCallId: chunk.toolCallId,
toolName: chunk.toolName,
args: chunk.args,
})
);
break;
}
case "tool-result": {
controller.enqueue(
formatDataStreamPart("tool_result", {
toolCallId: chunk.toolCallId,
result: chunk.result,
})
);
break;
}
case "error": {
controller.enqueue(formatDataStreamPart("error", getErrorMessage(chunk.error)));
break;
}
case "step-start": {
controller.enqueue(
formatDataStreamPart("start_step", {
messageId: chunk.messageId,
})
);
break;
}
case "step-finish": {
controller.enqueue(
formatDataStreamPart("finish_step", {
finishReason: chunk.finishReason,
usage: sendUsage
? {
promptTokens: chunk.usage.promptTokens,
completionTokens: chunk.usage.completionTokens,
}
: undefined,
isContinued: chunk.isContinued,
})
);
break;
}
case "finish": {
controller.enqueue(
formatDataStreamPart("finish_message", {
finishReason: chunk.finishReason,
usage: sendUsage
? {
promptTokens: chunk.usage.promptTokens,
completionTokens: chunk.usage.completionTokens,
}
: undefined,
})
);
break;
}
default: {
const exhaustiveCheck: never = chunkType;
throw new Error(`Unknown chunk type: ${exhaustiveCheck}`);
}
}
},
});
export type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
export function createAsyncIterableStream<T>(source: ReadableStream<T>): AsyncIterableStream<T> {
const stream = source.pipeThrough(new TransformStream<T, T>());
(stream as AsyncIterableStream<T>)[Symbol.asyncIterator] = () => {
const reader = stream.getReader();
return {
async next(): Promise<IteratorResult<T>> {
const { done, value } = await reader.read();
return done ? { done: true, value: undefined } : { done: false, value };
},
};
};
return stream as AsyncIterableStream<T>;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment