Created
January 26, 2025 16:41
-
-
Save prnthh/cd5db3b7507f399cc60e441eff3e0572 to your computer and use it in GitHub Desktop.
RemoteThink - Remote Function Calling using Llama Tool Use
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| const http = require('http'); | |
| const bonjour = require('bonjour')(); | |
| const crypto = require('crypto'); | |
| const os = require('os'); | |
| // Service Discovery and HTTP Server implementation | |
| class ServiceDiscovery { | |
| constructor() { | |
| this.instances = new Map(); | |
| this.server = null; | |
| this.service = null; | |
| // Generate unique instance ID | |
| this.instanceId = crypto.randomBytes(4).toString('hex'); | |
| this.hostname = os.hostname(); | |
| this.initializeServer(); | |
| this.setupDiscovery(); | |
| } | |
| initializeServer() { | |
| return new Promise((resolve) => { | |
| this.server = http.createServer(this.handleRequest.bind(this)); | |
| // Enable dual-stack support | |
| this.server.listen({ | |
| host: '::', | |
| port: 0, | |
| ipv6Only: false // Listen on both IPv4 and IPv6 | |
| }, () => { | |
| const port = this.server.address().port; | |
| this.advertiseService(port); | |
| console.log(`[Service Discovery] Instance ${this.instanceId} running on port ${port}`); | |
| resolve(); | |
| }); | |
| }); | |
| } | |
| getIPAddress() { | |
| const interfaces = os.networkInterfaces(); | |
| const addresses = []; | |
| for (const name of Object.keys(interfaces)) { | |
| for (const net of interfaces[name]) { | |
| if (!net.internal) { | |
| addresses.push({ | |
| family: net.family, | |
| address: net.family === 'IPv6' ? `[${net.address}]` : net.address | |
| }); | |
| } | |
| } | |
| } | |
| return addresses.length > 0 ? addresses : [{ family: 'IPv4', address: '127.0.0.1' }]; | |
| } | |
| advertiseService(port) { | |
| this.service = bonjour.publish({ | |
| name: `RemoteThink-${this.hostname}-${this.instanceId}`, | |
| type: 'remoteThink', | |
| port: port, | |
| txt: { | |
| instanceId: this.instanceId, | |
| hostname: this.hostname, | |
| pid: process.pid, | |
| port: port.toString() | |
| } | |
| }); | |
| // Handle name conflicts | |
| this.service.on('error', (err) => { | |
| if (err.message.includes('Conflict')) { | |
| console.log('Detected name conflict, regenerating...'); | |
| this.instanceId = crypto.randomBytes(4).toString('hex'); | |
| this.advertiseService(port); | |
| } | |
| }); | |
| } | |
| setupDiscovery() { | |
| this.browser = bonjour.find({ type: 'remoteThink' }); | |
| this.browser.on('up', (service) => { | |
| if (service.txt.instanceId !== this.instanceId) { | |
| const key = this.getInstanceKey(service); | |
| console.log(`[Service Discovery] Discovered ${service.name} at ${key}`); | |
| this.instances.set(key, service); | |
| } | |
| }); | |
| } | |
| getInstanceKey(service) { | |
| return `${service.host}:${service.port}`; | |
| } | |
| handleRequest(req, res) { | |
| if (req.method === 'POST' && req.url === '/handle-remote') { | |
| let body = ''; | |
| req.on('data', (chunk) => body += chunk); | |
| req.on('end', async () => { | |
| try { | |
| const data = JSON.parse(body); | |
| let toolCall = data.toolCall; | |
| // Handle string or object format | |
| // todo: use makeRequest to make a local llm request here. | |
| console.log("[Remote Tool Use] Received tool call:", toolCall); | |
| const result = { message: "Success! Received tool call: " + toolCall }; | |
| this.sendResponse(res, 200, result); | |
| } catch (error) { | |
| console.error('Request failed:', error); | |
| this.sendResponse(res, 500, { | |
| error: error.message, | |
| stack: process.env.NODE_ENV === 'development' ? error.stack : undefined | |
| }); | |
| } | |
| }); | |
| } else { | |
| res.writeHead(404).end(); | |
| } | |
| } | |
| sendResponse(res, status, data) { | |
| res.writeHead(status, { 'Content-Type': 'application/json' }); | |
| res.end(JSON.stringify(data)); | |
| } | |
| getAvailablePort() { | |
| return new Promise((resolve) => { | |
| const server = http.createServer(); | |
| server.listen(0, () => { | |
| const port = server.address().port; | |
| server.close(() => resolve(port)); | |
| }); | |
| }); | |
| } | |
| getAvailableInstances() { | |
| return Array.from(this.instances.values()); | |
| } | |
| } | |
| // Library for handling AI tool calls | |
| const createToolLibrary = () => { | |
| const serviceDiscovery = new ServiceDiscovery(); | |
| // Set up HTTP server for remote requests | |
| // Recursive tool execution handler | |
| const executeToolChain = async (initialToolCall) => { | |
| const results = []; | |
| let currentToolCall = initialToolCall; | |
| while (currentToolCall) { | |
| // Parse tool call if it's a string | |
| if (typeof currentToolCall === 'string') { | |
| try { | |
| currentToolCall = JSON.parse(currentToolCall); | |
| } catch (error) { | |
| throw new Error(`Invalid tool call string: ${error.message}`); | |
| } | |
| } | |
| // Validate tool call structure | |
| if (!currentToolCall.name || !currentToolCall.arguments) { | |
| throw new Error('Invalid tool call format'); | |
| } | |
| // Find the appropriate tool | |
| const tool = TOOLS_LIST.find(t => t.function.name === currentToolCall.name); | |
| if (!tool) { | |
| throw new Error(`Tool not found: ${currentToolCall.name}`); | |
| } | |
| // Process arguments with type coercion | |
| const processedArgs = {}; | |
| const paramSchema = tool.function.parameters; | |
| for (const [key, prop] of Object.entries(paramSchema.properties)) { | |
| const rawValue = currentToolCall.arguments[key]; | |
| // Handle nested string values | |
| const value = typeof rawValue === 'string' | |
| ? tryParseJson(rawValue) || rawValue | |
| : rawValue; | |
| // Coerce types based on schema | |
| switch (prop.type) { | |
| case 'number': | |
| processedArgs[key] = Number(value); | |
| if (isNaN(processedArgs[key])) { | |
| throw new Error(`Invalid number value for ${key}`); | |
| } | |
| break; | |
| case 'integer': | |
| processedArgs[key] = parseInt(value, 10); | |
| if (isNaN(processedArgs[key])) { | |
| throw new Error(`Invalid integer value for ${key}`); | |
| } | |
| break; | |
| case 'boolean': | |
| processedArgs[key] = Boolean(value); | |
| break; | |
| default: | |
| processedArgs[key] = value; | |
| } | |
| // Validate required parameters | |
| if (paramSchema.required?.includes(key) && processedArgs[key] === undefined) { | |
| throw new Error(`Missing required parameter: ${key}`); | |
| } | |
| } | |
| // Execute the tool | |
| const result = await tool.solver(processedArgs); | |
| results.push(result); | |
| // Check for nested tool calls in the result | |
| currentToolCall = result.toolCall ? result.toolCall : null; | |
| } | |
| return results; | |
| }; | |
| // Helper function to safely parse JSON | |
| const tryParseJson = (str) => { | |
| try { | |
| return JSON.parse(str); | |
| } catch (error) { | |
| return null; | |
| } | |
| }; | |
| // Modified HTTP server with recursive handling | |
| const server = http.createServer(async (req, res) => { | |
| if (req.method === 'POST' && req.url === '/handle-remote') { | |
| let body = ''; | |
| req.on('data', (chunk) => body += chunk); | |
| req.on('end', async () => { | |
| try { | |
| const requestData = JSON.parse(body); | |
| const toolResults = await executeToolChain(requestData.toolCall); | |
| res.writeHead(200, { 'Content-Type': 'application/json' }); | |
| res.end(JSON.stringify({ | |
| results: toolResults, | |
| finalResult: toolResults[toolResults.length - 1] | |
| })); | |
| } catch (error) { | |
| res.writeHead(500, { 'Content-Type': 'application/json' }); | |
| res.end(JSON.stringify({ | |
| error: error.message, | |
| stack: process.env.NODE_ENV === 'development' ? error.stack : undefined | |
| })); | |
| } | |
| }); | |
| } else { | |
| res.writeHead(404); | |
| res.end(); | |
| } | |
| }); | |
| const TOOLS_LIST = [ | |
| { | |
| type: "function", | |
| function: { | |
| name: "get_stock_fundamentals", | |
| parameters: { | |
| type: "object", | |
| properties: { | |
| symbol: { type: "string" }, | |
| }, | |
| required: ["symbol"], | |
| }, | |
| description: `get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol. | |
| Args: | |
| symbol (str): The stock symbol. | |
| Returns: | |
| dict: A dictionary containing fundamental data. | |
| Keys: | |
| - 'symbol': The stock symbol. | |
| - 'company_name': The long name of the company. | |
| - 'sector': The sector to which the company belongs. | |
| - 'industry': The industry to which the company belongs. | |
| - 'market_cap': The market capitalization of the company. | |
| - 'pe_ratio': The forward price-to-earnings ratio. | |
| - 'pb_ratio': The price-to-book ratio. | |
| - 'dividend_yield': The dividend yield. | |
| - 'eps': The trailing earnings per share. | |
| - 'beta': The beta value of the stock. | |
| - '52_week_high': The 52-week high price of the stock. | |
| - '52_week_low': The 52-week low price of the stock. | |
| `, | |
| }, | |
| solver: ({symbol}) => ({ | |
| symbol, | |
| company_name: `Dummy Company for ${symbol}`, | |
| sector: "Technology", | |
| industry: "Software", | |
| market_cap: 1000000000, | |
| pe_ratio: 20.5, | |
| pb_ratio: 2.3, | |
| dividend_yield: 0.02, | |
| eps: 5.67, | |
| beta: 1.1, | |
| "52_week_high": 100.0, | |
| "52_week_low": 80.0, | |
| }), | |
| }, | |
| { | |
| type: "function", | |
| function: { | |
| name: "calculator", | |
| parameters: { | |
| type: "object", | |
| properties: { | |
| operand1: { type: "number" }, | |
| operand2: { type: "number" }, | |
| operator: { | |
| type: "string", | |
| enum: ["+", "-", "*", "/"] | |
| }, | |
| }, | |
| required: ["operand1", "operand2", "operator"], | |
| }, | |
| description: `calculator(operand1: number, operand2: number, operator: str) -> dict - Performs basic arithmetic operations. | |
| Args: | |
| operand1 (number): The first operand. | |
| operand2 (number): The second operand. | |
| operator (str): The arithmetic operator (+, -, *, /). | |
| Returns: | |
| dict: A dictionary containing either the result or an error message. | |
| Keys: | |
| - 'result': The numerical result (present if operation is valid). | |
| - 'error': Error message (present if invalid operator or division by zero). | |
| `, | |
| }, | |
| solver: ({ operand1, operand2, operator }) => { | |
| console.log("Calculator args:", operand1, operand2, operator); | |
| switch (operator) { | |
| case "+": | |
| return { result: operand1 + operand2 }; | |
| case "-": | |
| return { result: operand1 - operand2 }; | |
| case "*": | |
| return { result: operand1 * operand2 }; | |
| case "/": | |
| if (operand2 === 0) { | |
| return { error: "Division by zero is not allowed." }; | |
| } | |
| return { result: operand1 / operand2 }; | |
| default: | |
| return { error: "Invalid operator. Supported operators: +, -, *, /" }; | |
| } | |
| }, | |
| }, | |
| { | |
| type: "function", | |
| function: { | |
| name: "remoteThink", | |
| parameters: { | |
| type: "object", | |
| properties: { | |
| subthought: { | |
| type: "string", | |
| description: "Raw string of the subthought to execute" | |
| } | |
| }, | |
| required: ["subthought"] | |
| }, | |
| description: `remoteThink(subthought: str) -> dict - Execute a subthought (as a raw string) on a remote instance.` | |
| }, | |
| solver: async ({ subthought }) => { | |
| const instances = serviceDiscovery.getAvailableInstances(); | |
| if (instances.length === 0) return { error: "No instances available" }; | |
| const instance = instances[Math.floor(Math.random() * instances.length)]; | |
| const formattedHost = instance.addresses[0].includes(':') | |
| ? `[${instance.addresses[0]}]` | |
| : instance.addresses[0]; | |
| try { | |
| const response = await fetch(`http://${formattedHost}:${instance.port}/handle-remote`, { | |
| method: 'POST', | |
| headers: { 'Content-Type': 'application/json' }, | |
| body: JSON.stringify({ toolCall: subthought }), | |
| timeout: 5000 | |
| }); | |
| if (!response.ok) { | |
| return { error: `Remote error: ${response.status} ${await response.text()}` }; | |
| } | |
| const result = await response.json(); | |
| return result.finalResult || result; | |
| } catch (error) { | |
| return { error: `Network error: ${error.message}` }; | |
| } | |
| } | |
| } | |
| ]; | |
| const handleToolCalls = async (responseContent) => { | |
| const toolCalls = []; | |
| const regex = /<tool_call>\n(.*?)\n<\/tool_call>/gs; | |
| let match; | |
| while ((match = regex.exec(responseContent)) !== null) { | |
| try { | |
| const toolCall = JSON.parse(match[1]); | |
| toolCalls.push(toolCall); | |
| } catch (error) { | |
| console.error("Failed to parse tool call:", match[1]); | |
| } | |
| } | |
| const results = []; | |
| for (const toolCall of toolCalls) { | |
| const { name, arguments: args } = toolCall; | |
| const tool = TOOLS_LIST.find((tool) => tool.function.name === name); | |
| if (!tool) { | |
| results.push({ | |
| name: "error", | |
| result: { error: `No solver found for tool: ${name}` } | |
| }); | |
| continue; | |
| } | |
| try { | |
| const result = await tool.solver(args); | |
| results.push({ name, result }); | |
| } catch (error) { | |
| console.error(`Error executing tool ${name}:`, error); | |
| results.push({ | |
| name, | |
| result: { error: `Execution failed for tool: ${name}` } | |
| }); | |
| } | |
| } | |
| return results; | |
| }; | |
| const makeRequest = async (payload, url = "http://localhost:1234/v1/chat/completions") => { | |
| // const makeRequest = async (payload, url = "http://127.0.0.1:11434/api/chat") => { | |
| const headers = { | |
| "Content-Type": "application/json", | |
| }; | |
| try { | |
| const response = await fetch(url, { | |
| method: "POST", | |
| headers, | |
| body: JSON.stringify(payload), | |
| }); | |
| if (!response.ok) { | |
| throw new Error(`HTTP error! Status: ${response.status}`); | |
| } | |
| const jsonResponse = await response.json(); | |
| return jsonResponse; | |
| } catch (error) { | |
| console.error("Request failed:", error); | |
| return null; | |
| } | |
| }; | |
| return { TOOLS_LIST, handleToolCalls, makeRequest, }; | |
| }; | |
| (async () => { | |
| const { TOOLS_LIST, handleToolCalls, makeRequest } = createToolLibrary(); | |
| // Initialize messages with system and user messages | |
| const messages = [ | |
| { | |
| role: "system", | |
| content: `You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. | |
| You may call one or more functions to assist with the user query. Don\'t make assumptions about what values to plug into functions. | |
| Here are the available tools: <tools> ${TOOLS_LIST.map((tool) => JSON.stringify(tool)).join("\n\n")} </tools> | |
| Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} | |
| For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows: | |
| <tool_call> | |
| {"arguments": <args-dict>, "name": <function-name>} | |
| </tool_call>` | |
| }, | |
| { | |
| role: "user", | |
| content: "do 43 times 12 and then subtract 91.", | |
| }, | |
| ]; | |
| let finalAnswer = null; | |
| let iterationCount = 0; | |
| const maxIterations = 5; // Safety guard against infinite loops | |
| while (iterationCount < maxIterations) { | |
| iterationCount++; | |
| const payload = { | |
| // model: "llama3.2", | |
| stream: false, | |
| messages: messages, | |
| options: { | |
| temperature: 0, | |
| seed: 1337, | |
| } | |
| }; | |
| const response = await makeRequest(payload); | |
| if (!response) break; | |
| const responseContent = response.message?.content || response.choices[0].message.content; | |
| console.log(`[Iteration ${iterationCount}] LLM Response:`, responseContent); | |
| // Process tool calls if any exist in the response | |
| const toolResponses = await handleToolCalls(responseContent); | |
| if (toolResponses.length > 0) { | |
| console.log(`[Iteration ${iterationCount}] Tool Responses:`, toolResponses); | |
| // Add assistant's message with tool calls to context | |
| messages.push({ | |
| role: "assistant", | |
| content: responseContent | |
| }); | |
| // Add tool responses to context | |
| toolResponses.forEach(({ name, result }) => { | |
| messages.push({ | |
| role: "tool", | |
| name, | |
| content: JSON.stringify(result) | |
| }); | |
| }); | |
| } else { | |
| // No more tool calls - final answer | |
| finalAnswer = responseContent; | |
| break; | |
| } | |
| } | |
| if (finalAnswer) { | |
| console.log("[Final Answer]", finalAnswer); | |
| } else { | |
| console.log("Reached maximum iterations or encountered an error"); | |
| } | |
| })(); |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
works for me using ollama to serve Llama3.2 3B and Llama3.3 70B.