Skip to content

Instantly share code, notes, and snippets.

@sunary
Created March 31, 2025 07:58
Show Gist options
  • Select an option

  • Save sunary/470cd116537661ebb6739774381e6523 to your computer and use it in GitHub Desktop.

Select an option

Save sunary/470cd116537661ebb6739774381e6523 to your computer and use it in GitHub Desktop.
golang-MCP-gemini
package main
import (
"context"
"flag"
"fmt"
"log"
"github.com/google/generative-ai-go/genai"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"google.golang.org/api/option"
)
type ToolName string
const (
CODE_SUMMARY ToolName = "summary"
CODE_REVIEW ToolName = "review"
)
type CustomServer struct {
server *server.MCPServer
geminiModel *genai.GenerativeModel
}
func NewMCPServer() CustomServer {
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(""))
if err != nil {
log.Fatal(err)
}
model := client.GenerativeModel("gemini-1.5-flash")
hooks := &server.Hooks{}
hooks.AddBeforeAny(func(id any, method mcp.MCPMethod, message any) {
fmt.Printf("beforeAny: %s, %v, %v\n", method, id, message)
})
hooks.AddOnSuccess(func(id any, method mcp.MCPMethod, message any, result any) {
fmt.Printf("onSuccess: %s, %v, %v, %v\n", method, id, message, result)
})
hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
fmt.Printf("onError: %s, %v, %v, %v\n", method, id, message, err)
})
hooks.AddBeforeInitialize(func(id any, message *mcp.InitializeRequest) {
fmt.Printf("beforeInitialize: %v, %v\n", id, message)
})
hooks.AddAfterInitialize(func(id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
fmt.Printf("afterInitialize: %v, %v, %v\n", id, message, result)
})
hooks.AddAfterCallTool(func(id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
fmt.Printf("afterCallTool: %v, %v, %v\n", id, message, result)
})
hooks.AddBeforeCallTool(func(id any, message *mcp.CallToolRequest) {
fmt.Printf("beforeCallTool: %v, %v\n", id, message)
})
mcpServer := CustomServer{
server: server.NewMCPServer(
"simple-servers",
"1.0.0",
server.WithResourceCapabilities(true, true),
server.WithPromptCapabilities(true),
server.WithLogging(),
server.WithHooks(hooks),
),
geminiModel: model,
}
mcpServer.server.AddTool(mcp.NewTool(string(CODE_SUMMARY)), mcpServer.handleCodeSummaryTool)
mcpServer.server.AddTool(mcp.NewTool(string(CODE_REVIEW)), mcpServer.handleCodeSummaryTool)
mcpServer.server.AddNotificationHandler("notification", handleNotification)
return mcpServer
}
func (s *CustomServer) handleCodeSummaryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return s.handleCommand(ctx, "summary this code changes:", request)
}
func (s *CustomServer) handleCodeReviewTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return s.handleCommand(ctx, "review this code changes:", request)
}
func (s *CustomServer) handleCommand(ctx context.Context, command string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
cs := s.geminiModel.StartChat()
fileChanges := []genai.Part{genai.Text(command)}
if v, ok := request.Params.Arguments["changes"]; ok {
for _, change := range v.([]interface{}) {
fileChanges = append(fileChanges, genai.Text(change.(string)))
}
}
resp, err := cs.SendMessage(ctx, fileChanges...)
if err != nil {
return nil, fmt.Errorf("failed to send message: %w", err)
}
conts := make([]mcp.Content, len(resp.Candidates))
for i, c := range resp.Candidates {
conts[i] = mcp.TextContent{
Type: "text",
Text: string(c.Content.Parts[i].(genai.Text)),
}
}
return &mcp.CallToolResult{
Content: conts,
}, nil
}
func handleNotification(
ctx context.Context,
notification mcp.JSONRPCNotification,
) {
log.Printf("Received notification: %s", notification.Method)
}
func main() {
var transport string
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)")
flag.StringVar(&transport, "transport", "stdio", "Transport type (stdio or sse)")
flag.Parse()
mcpServer := NewMCPServer()
// Only check for "sse" since stdio is the default
if transport == "sse" {
sseServer := server.NewSSEServer(mcpServer.server, server.WithBaseURL("http://localhost:8080"))
log.Printf("SSE server listening on :8080")
if err := sseServer.Start(":8080"); err != nil {
log.Fatalf("Server error: %v", err)
}
} else {
if err := server.ServeStdio(mcpServer.server); err != nil {
log.Fatalf("Server error: %v", err)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment