Skip to content

Instantly share code, notes, and snippets.

@jctaoo
Created August 5, 2025 10:26
Show Gist options
  • Select an option

  • Save jctaoo/cfcfc4298456212781837275a3dabb2f to your computer and use it in GitHub Desktop.

Select an option

Save jctaoo/cfcfc4298456212781837275a3dabb2f to your computer and use it in GitHub Desktop.
vercel ai sdk showcase
import {
API_ROUTE_ERROR_NAME,
ChatExternalData,
ChatFileAttachmentSchema,
ChatRouteError,
ChatSuggestion,
modelConfigSchema,
PersistMessageContent,
PersistMessageErrorContent,
ProviderDefaultConfig,
} from "@/types";
import {
convertToCoreMessages,
CoreMessage,
CoreUserMessage,
createDataStream,
createDataStreamResponse,
FilePart,
ImagePart,
streamText,
TextPart,
tool,
UIMessage,
} from "ai";
import { db } from "@/server/db";
import { auth } from "@/server/auth";
import { ChatRole, LLMProvider, UserRole } from "@prisma/client";
import { NextResponse } from "next/server";
import zod from "zod";
import _ from "lodash";
import { retrieveMessages } from "@/server/api/routers/chat";
import { InputJsonValue } from "@prisma/client/runtime/library";
import { ChatMessage } from "@prisma/client";
import { createAzure } from "@ai-sdk/azure";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import { createOpenAI } from "@ai-sdk/openai";
import { createDeepSeek } from "@ai-sdk/deepseek";
import { DEFAULT_CHAT_MODEL, PREDEFINED_PROVIDER_ID } from "@/constants";
import { getProviderConfig, getUserProviderConfig } from "@/server/utils/llm";
import { tavily } from "@tavily/core";
import { env } from "@/env";
import { RAGService } from "@/server/services/rag_service";
import { SYSTEM_PROMPT } from "@/prompts";
import { getChatSuggestions } from "@/server/utils/chat_suggestion";
import { Session } from "next-auth";
import { type Attachment as VercelAttachment } from "ai";
import { getChatModelCapabilities, mergeSuccessiveMessages } from "./utils";
// Allow streaming responses up to 30 seconds
export const maxDuration = 30;
const BodyDataSchema = zod
.object({
modelConfig: modelConfigSchema.optional(),
})
.optional();
type BodyData = zod.infer<typeof BodyDataSchema>;
function createChatClient(provider: LLMProvider) {
switch (provider.providerId.toLowerCase()) {
case PREDEFINED_PROVIDER_ID.AZURE_OPENAI:
return createAzure({
resourceName: provider.endpoint,
apiKey: provider.apiKey,
apiVersion: "2025-01-01-preview",
});
case PREDEFINED_PROVIDER_ID.DEEP_SEEK:
return createDeepSeek({
baseURL: provider.endpoint,
apiKey: provider.apiKey,
});
case PREDEFINED_PROVIDER_ID.OPENAI:
return createOpenAI({
apiKey: provider.apiKey,
baseURL: provider.endpoint,
});
default:
return createOpenAICompatible({
apiKey: provider.apiKey,
baseURL: provider.endpoint,
name: provider.name,
});
}
}
async function retrieveChatProvider(session: Session) {
const role = session.user.role;
if (role === UserRole.SUPER_ADMIN) {
return await getProviderConfig(db, session.user.id);
}
// TODO: enhance here, see NOTE
const firstChannel = await db.channel.findFirst({
where: {
members: { some: { id: session.user.id } },
},
});
if (!firstChannel) {
throw new ChatRouteError("User is not in any channel");
}
return await getUserProviderConfig(db, session.user.id, firstChannel.id);
}
async function retrieveChatModelConfig(session: Session, bodyData: BodyData, provider: LLMProvider) {
const role = session.user.role;
if (role === UserRole.SUPER_ADMIN) {
const modelConfig = bodyData?.modelConfig;
const modelCode = modelConfig?.id ?? DEFAULT_CHAT_MODEL;
return {
modelCode,
modelConfig,
};
}
if (!provider.defaultModelConfig) {
throw new ChatRouteError("No default model config");
}
const { alias, ...modelConfig } = provider.defaultModelConfig as ProviderDefaultConfig;
const modelCode = alias ?? modelConfig.id;
if (!modelCode) {
throw new ChatRouteError("No model code");
}
return {
modelCode,
modelConfig,
};
}
export async function POST(req: Request) {
const session = await auth();
if (!session) {
return new NextResponse("Unauthorized", { status: 401 });
}
const { message: currentMessage, data, id: chatSessionId } = await req.json();
const bodyData: BodyData = BodyDataSchema.parse(data);
const provider = await retrieveChatProvider(session);
if (!provider) {
const error = new ChatRouteError("No default provider");
console.error("Error in chat route", error);
return error.toResponse();
}
console.log("[chat] using provider", { id: provider.id, name: provider.name });
// handle model config
const { modelCode, modelConfig } = await retrieveChatModelConfig(session, bodyData, provider);
const { supportToolCall, needMergeMessage, supportSuggestion } = getChatModelCapabilities(modelCode, provider);
console.log("[chat] using model", modelCode);
// make sure chatSessionId is valid
if (!chatSessionId) {
return new NextResponse("Missing chat session ID", { status: 400 });
}
const chatSession = await db.chatSession.findUnique({
where: {
id: chatSessionId,
},
});
if (!chatSession) {
return new NextResponse("Invalid chat session ID", { status: 400 });
}
// For LLM context, we don't need tool calls
const initialMessages = await retrieveMessages(db, chatSessionId, false);
// Save user message first
const userAskMessage = await db.chatMessage.create({
data: {
chatSessionId,
role: ChatRole.User,
contentBody: {
text: currentMessage.content,
parts: currentMessage.parts || [],
} as unknown as any, // Type assertion to handle JSON serialization
},
});
const chatClient = createChatClient(provider);
// 用于跟踪已处理的消息ID,避免重复处理
const processedMessageIds = new Set<string>();
// Initialize RAG service
const ragService = new RAGService(db, async (userId) => provider);
const requestMessagesRaw: CoreMessage[] = [
...convertToCoreMessages(
// TODO: convert to core messages patch, try custom convertToCoreMessages or better data structure.
[...initialMessages, currentMessage as UIMessage].map((i) => ({
...i,
experimental_attachments: i.parts
?.filter((p) => p.type === "source")
.map(
(i) =>
({
name: i.source.title,
contentType: i.source.providerMetadata?.info?.mime,
url: i.source.url,
}) as VercelAttachment,
),
})),
),
];
// merge successive user or assistant messages
const requestMessages = needMergeMessage ? mergeSuccessiveMessages(requestMessagesRaw) : requestMessagesRaw;
const textMessages: string[] = requestMessages
.map((m) => {
if (_.isString(m.content)) {
return m.content;
}
if (_.isArray(m.content)) {
const textPart = m.content.find((part) => part.type === "text");
return textPart?.text;
}
return null;
})
.filter((m) => _.isString(m));
// TODO: update other model config and test gemini model with tools?
return createDataStreamResponse({
execute: (dataStream) => {
const result = streamText({
model: chatClient(modelCode, {}),
messages: requestMessages,
system: SYSTEM_PROMPT,
tools: supportToolCall
? {
search: tool({
description: "Search for information from the web",
parameters: zod.object({
query: zod.string().describe("The search query"),
}),
execute: async ({ query }) => {
const client = tavily({ apiKey: env.TAVILY_API_KEY });
const response = await client.search(query, { k: 5 });
console.log("Search response", response);
return JSON.stringify(response);
},
}),
rag: tool({
description: "Search for relevant information from knowledge base",
parameters: zod.object({
query: zod.string().describe("The search query to find relevant knowledge base content"),
}),
execute: async ({ query }) => {
const chunks = await ragService.retrieveSimilarChunks(query, session.user.id);
// Format chunks into a user-friendly response
const formattedChunks = chunks.map((chunk) => ({
content: chunk.pageContent,
}));
console.log("RAG response", formattedChunks);
return JSON.stringify(formattedChunks);
},
}),
}
: undefined,
maxSteps: 10,
onStepFinish: async ({ response }) => {
// Process each message in the current step
let parentMessageId: string | null = null;
for (let i = 0; i < response.messages.length; i++) {
const message = response.messages[i];
if (!message || !message.id) continue;
// 检查消息是否已经处理过,避免重复处理
if (processedMessageIds.has(message.id)) {
continue;
}
// 标记消息为已处理
processedMessageIds.add(message.id);
if (message.role === "assistant") {
// Process assistant message
const content = message.content;
let text = "";
let parts: any[] = [];
// Extract text and tool calls from content
if (typeof content === "string") {
text = content;
} else if (_.isArray(content)) {
// Find text parts
const textPart = content.find((part) => part.type === "text");
text = textPart?.text ?? "";
// Store all parts for future reference
parts = content.map((i) => {
// TODO: redacted-reasoning in claude
if (i.type === "reasoning") {
// ReasoningUIPart in vercel ai
return {
type: "reasoning",
reasoning: i.text,
details: [{
type: "text",
text: i.text,
}],
};
}
return i;
}) as any[];
}
// Create assistant message in database
const assistantMessageData: {
chatSessionId: string;
role: ChatRole;
contentBody: InputJsonValue;
parentId?: string;
} = {
chatSessionId,
role: ChatRole.Assistant,
contentBody: {
text,
parts,
messageId: message.id,
} as InputJsonValue,
...(parentMessageId ? { parentId: parentMessageId } : {}),
};
const assistantMessage: ChatMessage = await db.chatMessage.create({
data: assistantMessageData,
});
if (text.length > 0) {
textMessages.push(text);
}
// Update parent ID for potential child messages
parentMessageId = assistantMessage.id;
} else if (message.role === "tool") {
// Process tool message (tool response)
const toolContent = message.content;
let toolResult: Record<string, any> | null = null;
// Extract tool result
if (_.isArray(toolContent)) {
const toolResultPart = toolContent.find((part) => part.type === "tool-result");
if (toolResultPart) {
toolResult = {
toolCallId: toolResultPart.toolCallId,
toolName: toolResultPart.toolName,
result: JSON.stringify(toolResultPart.result), // Convert to string to ensure it's serializable
};
}
}
// Create tool message in database
const toolMessageData: {
chatSessionId: string;
role: ChatRole;
contentBody: InputJsonValue;
parentId?: string;
} = {
chatSessionId,
role: ChatRole.Tool,
contentBody: {
// TODO: 应该使用 part 存储
toolResult,
messageId: message.id,
} as InputJsonValue,
...(parentMessageId ? { parentId: parentMessageId } : {}),
};
const toolMessage: ChatMessage = await db.chatMessage.create({
data: toolMessageData,
});
// Update parent ID for potential child messages
parentMessageId = toolMessage.id;
}
}
},
onFinish: async ({ usage }) => {
// Record token usage in the database
if (usage) {
try {
await db.tokenUsage.create({
data: {
chatSessionId,
promptTokens: usage.promptTokens,
completionTokens: usage.completionTokens,
totalTokens: usage.totalTokens,
model: modelCode,
userId: session.user.id,
chatMessageId: userAskMessage.id,
providerId: provider.id,
},
});
console.log(
`Recorded token usage: ${usage.totalTokens} tokens (${usage.promptTokens} prompt, ${usage.completionTokens} completion) for model ${modelCode}`,
);
} catch (error) {
console.error("Failed to record token usage:", error);
}
}
// send chat suggestions to client
if (supportSuggestion) {
console.log("[chat] sending chat suggestions to client");
const suggestions = await getChatSuggestions(provider, modelCode, textMessages);
dataStream.writeData({
type: "chat_suggestions",
data: suggestions,
} as ChatExternalData);
}
},
});
// Consume stream to ensure completion even if client disconnects
result.consumeStream();
result.mergeIntoDataStream(dataStream, { sendReasoning: true });
},
onError: (error) => {
if (_.isError(error)) {
console.error("Error in chat route", error);
return API_ROUTE_ERROR_NAME + ": " + error.message;
}
return "unknown_error";
},
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment