Created
August 5, 2025 10:26
-
-
Save jctaoo/cfcfc4298456212781837275a3dabb2f to your computer and use it in GitHub Desktop.
vercel ai sdk showcase
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import { | |
| API_ROUTE_ERROR_NAME, | |
| ChatExternalData, | |
| ChatFileAttachmentSchema, | |
| ChatRouteError, | |
| ChatSuggestion, | |
| modelConfigSchema, | |
| PersistMessageContent, | |
| PersistMessageErrorContent, | |
| ProviderDefaultConfig, | |
| } from "@/types"; | |
| import { | |
| convertToCoreMessages, | |
| CoreMessage, | |
| CoreUserMessage, | |
| createDataStream, | |
| createDataStreamResponse, | |
| FilePart, | |
| ImagePart, | |
| streamText, | |
| TextPart, | |
| tool, | |
| UIMessage, | |
| } from "ai"; | |
| import { db } from "@/server/db"; | |
| import { auth } from "@/server/auth"; | |
| import { ChatRole, LLMProvider, UserRole } from "@prisma/client"; | |
| import { NextResponse } from "next/server"; | |
| import zod from "zod"; | |
| import _ from "lodash"; | |
| import { retrieveMessages } from "@/server/api/routers/chat"; | |
| import { InputJsonValue } from "@prisma/client/runtime/library"; | |
| import { ChatMessage } from "@prisma/client"; | |
| import { createAzure } from "@ai-sdk/azure"; | |
| import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; | |
| import { createOpenAI } from "@ai-sdk/openai"; | |
| import { createDeepSeek } from "@ai-sdk/deepseek"; | |
| import { DEFAULT_CHAT_MODEL, PREDEFINED_PROVIDER_ID } from "@/constants"; | |
| import { getProviderConfig, getUserProviderConfig } from "@/server/utils/llm"; | |
| import { tavily } from "@tavily/core"; | |
| import { env } from "@/env"; | |
| import { RAGService } from "@/server/services/rag_service"; | |
| import { SYSTEM_PROMPT } from "@/prompts"; | |
| import { getChatSuggestions } from "@/server/utils/chat_suggestion"; | |
| import { Session } from "next-auth"; | |
| import { type Attachment as VercelAttachment } from "ai"; | |
| import { getChatModelCapabilities, mergeSuccessiveMessages } from "./utils"; | |
| // Allow streaming responses up to 30 seconds | |
| export const maxDuration = 30; | |
| const BodyDataSchema = zod | |
| .object({ | |
| modelConfig: modelConfigSchema.optional(), | |
| }) | |
| .optional(); | |
| type BodyData = zod.infer<typeof BodyDataSchema>; | |
| function createChatClient(provider: LLMProvider) { | |
| switch (provider.providerId.toLowerCase()) { | |
| case PREDEFINED_PROVIDER_ID.AZURE_OPENAI: | |
| return createAzure({ | |
| resourceName: provider.endpoint, | |
| apiKey: provider.apiKey, | |
| apiVersion: "2025-01-01-preview", | |
| }); | |
| case PREDEFINED_PROVIDER_ID.DEEP_SEEK: | |
| return createDeepSeek({ | |
| baseURL: provider.endpoint, | |
| apiKey: provider.apiKey, | |
| }); | |
| case PREDEFINED_PROVIDER_ID.OPENAI: | |
| return createOpenAI({ | |
| apiKey: provider.apiKey, | |
| baseURL: provider.endpoint, | |
| }); | |
| default: | |
| return createOpenAICompatible({ | |
| apiKey: provider.apiKey, | |
| baseURL: provider.endpoint, | |
| name: provider.name, | |
| }); | |
| } | |
| } | |
| async function retrieveChatProvider(session: Session) { | |
| const role = session.user.role; | |
| if (role === UserRole.SUPER_ADMIN) { | |
| return await getProviderConfig(db, session.user.id); | |
| } | |
| // TODO: enhance here, see NOTE | |
| const firstChannel = await db.channel.findFirst({ | |
| where: { | |
| members: { some: { id: session.user.id } }, | |
| }, | |
| }); | |
| if (!firstChannel) { | |
| throw new ChatRouteError("User is not in any channel"); | |
| } | |
| return await getUserProviderConfig(db, session.user.id, firstChannel.id); | |
| } | |
| async function retrieveChatModelConfig(session: Session, bodyData: BodyData, provider: LLMProvider) { | |
| const role = session.user.role; | |
| if (role === UserRole.SUPER_ADMIN) { | |
| const modelConfig = bodyData?.modelConfig; | |
| const modelCode = modelConfig?.id ?? DEFAULT_CHAT_MODEL; | |
| return { | |
| modelCode, | |
| modelConfig, | |
| }; | |
| } | |
| if (!provider.defaultModelConfig) { | |
| throw new ChatRouteError("No default model config"); | |
| } | |
| const { alias, ...modelConfig } = provider.defaultModelConfig as ProviderDefaultConfig; | |
| const modelCode = alias ?? modelConfig.id; | |
| if (!modelCode) { | |
| throw new ChatRouteError("No model code"); | |
| } | |
| return { | |
| modelCode, | |
| modelConfig, | |
| }; | |
| } | |
| export async function POST(req: Request) { | |
| const session = await auth(); | |
| if (!session) { | |
| return new NextResponse("Unauthorized", { status: 401 }); | |
| } | |
| const { message: currentMessage, data, id: chatSessionId } = await req.json(); | |
| const bodyData: BodyData = BodyDataSchema.parse(data); | |
| const provider = await retrieveChatProvider(session); | |
| if (!provider) { | |
| const error = new ChatRouteError("No default provider"); | |
| console.error("Error in chat route", error); | |
| return error.toResponse(); | |
| } | |
| console.log("[chat] using provider", { id: provider.id, name: provider.name }); | |
| // handle model config | |
| const { modelCode, modelConfig } = await retrieveChatModelConfig(session, bodyData, provider); | |
| const { supportToolCall, needMergeMessage, supportSuggestion } = getChatModelCapabilities(modelCode, provider); | |
| console.log("[chat] using model", modelCode); | |
| // make sure chatSessionId is valid | |
| if (!chatSessionId) { | |
| return new NextResponse("Missing chat session ID", { status: 400 }); | |
| } | |
| const chatSession = await db.chatSession.findUnique({ | |
| where: { | |
| id: chatSessionId, | |
| }, | |
| }); | |
| if (!chatSession) { | |
| return new NextResponse("Invalid chat session ID", { status: 400 }); | |
| } | |
| // For LLM context, we don't need tool calls | |
| const initialMessages = await retrieveMessages(db, chatSessionId, false); | |
| // Save user message first | |
| const userAskMessage = await db.chatMessage.create({ | |
| data: { | |
| chatSessionId, | |
| role: ChatRole.User, | |
| contentBody: { | |
| text: currentMessage.content, | |
| parts: currentMessage.parts || [], | |
| } as unknown as any, // Type assertion to handle JSON serialization | |
| }, | |
| }); | |
| const chatClient = createChatClient(provider); | |
| // 用于跟踪已处理的消息ID,避免重复处理 | |
| const processedMessageIds = new Set<string>(); | |
| // Initialize RAG service | |
| const ragService = new RAGService(db, async (userId) => provider); | |
| const requestMessagesRaw: CoreMessage[] = [ | |
| ...convertToCoreMessages( | |
| // TODO: convert to core messages patch, try custom convertToCoreMessages or better data structure. | |
| [...initialMessages, currentMessage as UIMessage].map((i) => ({ | |
| ...i, | |
| experimental_attachments: i.parts | |
| ?.filter((p) => p.type === "source") | |
| .map( | |
| (i) => | |
| ({ | |
| name: i.source.title, | |
| contentType: i.source.providerMetadata?.info?.mime, | |
| url: i.source.url, | |
| }) as VercelAttachment, | |
| ), | |
| })), | |
| ), | |
| ]; | |
| // merge successive user or assistant messages | |
| const requestMessages = needMergeMessage ? mergeSuccessiveMessages(requestMessagesRaw) : requestMessagesRaw; | |
| const textMessages: string[] = requestMessages | |
| .map((m) => { | |
| if (_.isString(m.content)) { | |
| return m.content; | |
| } | |
| if (_.isArray(m.content)) { | |
| const textPart = m.content.find((part) => part.type === "text"); | |
| return textPart?.text; | |
| } | |
| return null; | |
| }) | |
| .filter((m) => _.isString(m)); | |
| // TODO: update other model config and test gemini model with tools? | |
| return createDataStreamResponse({ | |
| execute: (dataStream) => { | |
| const result = streamText({ | |
| model: chatClient(modelCode, {}), | |
| messages: requestMessages, | |
| system: SYSTEM_PROMPT, | |
| tools: supportToolCall | |
| ? { | |
| search: tool({ | |
| description: "Search for information from the web", | |
| parameters: zod.object({ | |
| query: zod.string().describe("The search query"), | |
| }), | |
| execute: async ({ query }) => { | |
| const client = tavily({ apiKey: env.TAVILY_API_KEY }); | |
| const response = await client.search(query, { k: 5 }); | |
| console.log("Search response", response); | |
| return JSON.stringify(response); | |
| }, | |
| }), | |
| rag: tool({ | |
| description: "Search for relevant information from knowledge base", | |
| parameters: zod.object({ | |
| query: zod.string().describe("The search query to find relevant knowledge base content"), | |
| }), | |
| execute: async ({ query }) => { | |
| const chunks = await ragService.retrieveSimilarChunks(query, session.user.id); | |
| // Format chunks into a user-friendly response | |
| const formattedChunks = chunks.map((chunk) => ({ | |
| content: chunk.pageContent, | |
| })); | |
| console.log("RAG response", formattedChunks); | |
| return JSON.stringify(formattedChunks); | |
| }, | |
| }), | |
| } | |
| : undefined, | |
| maxSteps: 10, | |
| onStepFinish: async ({ response }) => { | |
| // Process each message in the current step | |
| let parentMessageId: string | null = null; | |
| for (let i = 0; i < response.messages.length; i++) { | |
| const message = response.messages[i]; | |
| if (!message || !message.id) continue; | |
| // 检查消息是否已经处理过,避免重复处理 | |
| if (processedMessageIds.has(message.id)) { | |
| continue; | |
| } | |
| // 标记消息为已处理 | |
| processedMessageIds.add(message.id); | |
| if (message.role === "assistant") { | |
| // Process assistant message | |
| const content = message.content; | |
| let text = ""; | |
| let parts: any[] = []; | |
| // Extract text and tool calls from content | |
| if (typeof content === "string") { | |
| text = content; | |
| } else if (_.isArray(content)) { | |
| // Find text parts | |
| const textPart = content.find((part) => part.type === "text"); | |
| text = textPart?.text ?? ""; | |
| // Store all parts for future reference | |
| parts = content.map((i) => { | |
| // TODO: redacted-reasoning in claude | |
| if (i.type === "reasoning") { | |
| // ReasoningUIPart in vercel ai | |
| return { | |
| type: "reasoning", | |
| reasoning: i.text, | |
| details: [{ | |
| type: "text", | |
| text: i.text, | |
| }], | |
| }; | |
| } | |
| return i; | |
| }) as any[]; | |
| } | |
| // Create assistant message in database | |
| const assistantMessageData: { | |
| chatSessionId: string; | |
| role: ChatRole; | |
| contentBody: InputJsonValue; | |
| parentId?: string; | |
| } = { | |
| chatSessionId, | |
| role: ChatRole.Assistant, | |
| contentBody: { | |
| text, | |
| parts, | |
| messageId: message.id, | |
| } as InputJsonValue, | |
| ...(parentMessageId ? { parentId: parentMessageId } : {}), | |
| }; | |
| const assistantMessage: ChatMessage = await db.chatMessage.create({ | |
| data: assistantMessageData, | |
| }); | |
| if (text.length > 0) { | |
| textMessages.push(text); | |
| } | |
| // Update parent ID for potential child messages | |
| parentMessageId = assistantMessage.id; | |
| } else if (message.role === "tool") { | |
| // Process tool message (tool response) | |
| const toolContent = message.content; | |
| let toolResult: Record<string, any> | null = null; | |
| // Extract tool result | |
| if (_.isArray(toolContent)) { | |
| const toolResultPart = toolContent.find((part) => part.type === "tool-result"); | |
| if (toolResultPart) { | |
| toolResult = { | |
| toolCallId: toolResultPart.toolCallId, | |
| toolName: toolResultPart.toolName, | |
| result: JSON.stringify(toolResultPart.result), // Convert to string to ensure it's serializable | |
| }; | |
| } | |
| } | |
| // Create tool message in database | |
| const toolMessageData: { | |
| chatSessionId: string; | |
| role: ChatRole; | |
| contentBody: InputJsonValue; | |
| parentId?: string; | |
| } = { | |
| chatSessionId, | |
| role: ChatRole.Tool, | |
| contentBody: { | |
| // TODO: 应该使用 part 存储 | |
| toolResult, | |
| messageId: message.id, | |
| } as InputJsonValue, | |
| ...(parentMessageId ? { parentId: parentMessageId } : {}), | |
| }; | |
| const toolMessage: ChatMessage = await db.chatMessage.create({ | |
| data: toolMessageData, | |
| }); | |
| // Update parent ID for potential child messages | |
| parentMessageId = toolMessage.id; | |
| } | |
| } | |
| }, | |
| onFinish: async ({ usage }) => { | |
| // Record token usage in the database | |
| if (usage) { | |
| try { | |
| await db.tokenUsage.create({ | |
| data: { | |
| chatSessionId, | |
| promptTokens: usage.promptTokens, | |
| completionTokens: usage.completionTokens, | |
| totalTokens: usage.totalTokens, | |
| model: modelCode, | |
| userId: session.user.id, | |
| chatMessageId: userAskMessage.id, | |
| providerId: provider.id, | |
| }, | |
| }); | |
| console.log( | |
| `Recorded token usage: ${usage.totalTokens} tokens (${usage.promptTokens} prompt, ${usage.completionTokens} completion) for model ${modelCode}`, | |
| ); | |
| } catch (error) { | |
| console.error("Failed to record token usage:", error); | |
| } | |
| } | |
| // send chat suggestions to client | |
| if (supportSuggestion) { | |
| console.log("[chat] sending chat suggestions to client"); | |
| const suggestions = await getChatSuggestions(provider, modelCode, textMessages); | |
| dataStream.writeData({ | |
| type: "chat_suggestions", | |
| data: suggestions, | |
| } as ChatExternalData); | |
| } | |
| }, | |
| }); | |
| // Consume stream to ensure completion even if client disconnects | |
| result.consumeStream(); | |
| result.mergeIntoDataStream(dataStream, { sendReasoning: true }); | |
| }, | |
| onError: (error) => { | |
| if (_.isError(error)) { | |
| console.error("Error in chat route", error); | |
| return API_ROUTE_ERROR_NAME + ": " + error.message; | |
| } | |
| return "unknown_error"; | |
| }, | |
| }); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment