Created
March 31, 2025 07:58
-
-
Save sunary/470cd116537661ebb6739774381e6523 to your computer and use it in GitHub Desktop.
golang-MCP-gemini
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package main | |
| import ( | |
| "context" | |
| "flag" | |
| "fmt" | |
| "log" | |
| "github.com/google/generative-ai-go/genai" | |
| "github.com/mark3labs/mcp-go/mcp" | |
| "github.com/mark3labs/mcp-go/server" | |
| "google.golang.org/api/option" | |
| ) | |
| type ToolName string | |
| const ( | |
| CODE_SUMMARY ToolName = "summary" | |
| CODE_REVIEW ToolName = "review" | |
| ) | |
| type CustomServer struct { | |
| server *server.MCPServer | |
| geminiModel *genai.GenerativeModel | |
| } | |
| func NewMCPServer() CustomServer { | |
| ctx := context.Background() | |
| client, err := genai.NewClient(ctx, option.WithAPIKey("")) | |
| if err != nil { | |
| log.Fatal(err) | |
| } | |
| model := client.GenerativeModel("gemini-1.5-flash") | |
| hooks := &server.Hooks{} | |
| hooks.AddBeforeAny(func(id any, method mcp.MCPMethod, message any) { | |
| fmt.Printf("beforeAny: %s, %v, %v\n", method, id, message) | |
| }) | |
| hooks.AddOnSuccess(func(id any, method mcp.MCPMethod, message any, result any) { | |
| fmt.Printf("onSuccess: %s, %v, %v, %v\n", method, id, message, result) | |
| }) | |
| hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) { | |
| fmt.Printf("onError: %s, %v, %v, %v\n", method, id, message, err) | |
| }) | |
| hooks.AddBeforeInitialize(func(id any, message *mcp.InitializeRequest) { | |
| fmt.Printf("beforeInitialize: %v, %v\n", id, message) | |
| }) | |
| hooks.AddAfterInitialize(func(id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { | |
| fmt.Printf("afterInitialize: %v, %v, %v\n", id, message, result) | |
| }) | |
| hooks.AddAfterCallTool(func(id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { | |
| fmt.Printf("afterCallTool: %v, %v, %v\n", id, message, result) | |
| }) | |
| hooks.AddBeforeCallTool(func(id any, message *mcp.CallToolRequest) { | |
| fmt.Printf("beforeCallTool: %v, %v\n", id, message) | |
| }) | |
| mcpServer := CustomServer{ | |
| server: server.NewMCPServer( | |
| "simple-servers", | |
| "1.0.0", | |
| server.WithResourceCapabilities(true, true), | |
| server.WithPromptCapabilities(true), | |
| server.WithLogging(), | |
| server.WithHooks(hooks), | |
| ), | |
| geminiModel: model, | |
| } | |
| mcpServer.server.AddTool(mcp.NewTool(string(CODE_SUMMARY)), mcpServer.handleCodeSummaryTool) | |
| mcpServer.server.AddTool(mcp.NewTool(string(CODE_REVIEW)), mcpServer.handleCodeSummaryTool) | |
| mcpServer.server.AddNotificationHandler("notification", handleNotification) | |
| return mcpServer | |
| } | |
| func (s *CustomServer) handleCodeSummaryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { | |
| return s.handleCommand(ctx, "summary this code changes:", request) | |
| } | |
| func (s *CustomServer) handleCodeReviewTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { | |
| return s.handleCommand(ctx, "review this code changes:", request) | |
| } | |
| func (s *CustomServer) handleCommand(ctx context.Context, command string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { | |
| cs := s.geminiModel.StartChat() | |
| fileChanges := []genai.Part{genai.Text(command)} | |
| if v, ok := request.Params.Arguments["changes"]; ok { | |
| for _, change := range v.([]interface{}) { | |
| fileChanges = append(fileChanges, genai.Text(change.(string))) | |
| } | |
| } | |
| resp, err := cs.SendMessage(ctx, fileChanges...) | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to send message: %w", err) | |
| } | |
| conts := make([]mcp.Content, len(resp.Candidates)) | |
| for i, c := range resp.Candidates { | |
| conts[i] = mcp.TextContent{ | |
| Type: "text", | |
| Text: string(c.Content.Parts[i].(genai.Text)), | |
| } | |
| } | |
| return &mcp.CallToolResult{ | |
| Content: conts, | |
| }, nil | |
| } | |
| func handleNotification( | |
| ctx context.Context, | |
| notification mcp.JSONRPCNotification, | |
| ) { | |
| log.Printf("Received notification: %s", notification.Method) | |
| } | |
| func main() { | |
| var transport string | |
| flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)") | |
| flag.StringVar(&transport, "transport", "stdio", "Transport type (stdio or sse)") | |
| flag.Parse() | |
| mcpServer := NewMCPServer() | |
| // Only check for "sse" since stdio is the default | |
| if transport == "sse" { | |
| sseServer := server.NewSSEServer(mcpServer.server, server.WithBaseURL("http://localhost:8080")) | |
| log.Printf("SSE server listening on :8080") | |
| if err := sseServer.Start(":8080"); err != nil { | |
| log.Fatalf("Server error: %v", err) | |
| } | |
| } else { | |
| if err := server.ServeStdio(mcpServer.server); err != nil { | |
| log.Fatalf("Server error: %v", err) | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment