|
package main |
|
|
|
import ( |
|
"context" |
|
"encoding/json" |
|
"fmt" |
|
"io" |
|
"log" |
|
"os" |
|
"os/exec" |
|
"sync" |
|
"sync/atomic" |
|
"time" |
|
"unsafe" |
|
|
|
"github.com/creack/pty" |
|
"github.com/spf13/cobra" |
|
"golang.org/x/sys/unix" |
|
"golang.org/x/term" |
|
) |
|
|
|
const ( |
|
// IOCTL command to get local CID |
|
IOCTL_VM_SOCKETS_GET_LOCAL_CID = 0x7b9 |
|
defaultPort = 9999 |
|
|
|
// Channel IDs for multiplexing |
|
channelStdin = 0 |
|
channelStdout = 1 |
|
channelStderr = 2 |
|
channelControl = 3 |
|
|
|
// Frame header size: 1 byte channel + 4 bytes length |
|
frameHeaderSize = 5 |
|
maxFrameSize = 256 * 1024 // Increased from 32KB to 256KB for better throughput |
|
) |
|
|
|
// ClientRequest is sent from client to server to specify execution mode |
|
type ClientRequest struct { |
|
UsePTY bool `json:"use_pty"` |
|
Command []string `json:"command"` // empty means interactive shell |
|
} |
|
|
|
// ServerResponse is sent from server to client with the exit code |
|
type ServerResponse struct { |
|
ExitCode int `json:"exit_code"` |
|
} |
|
|
|
var ( |
|
port int |
|
single bool |
|
forcePTY bool |
|
connID int32 |
|
) |
|
|
|
// Frame represents a multiplexed data frame |
|
type Frame struct { |
|
Channel byte |
|
Data []byte |
|
} |
|
|
|
// writeFrame writes a frame to the connection using vectored I/O for efficiency |
|
func writeFrame(w io.Writer, channel byte, data []byte) error { |
|
if len(data) > maxFrameSize { |
|
return fmt.Errorf("frame data too large: %d > %d", len(data), maxFrameSize) |
|
} |
|
|
|
header := make([]byte, frameHeaderSize) |
|
header[0] = channel |
|
// Write length in big-endian (network byte order) |
|
header[1] = byte(len(data) >> 24) |
|
header[2] = byte(len(data) >> 16) |
|
header[3] = byte(len(data) >> 8) |
|
header[4] = byte(len(data)) |
|
|
|
// Try to use writev for efficient writing (single syscall) |
|
if file, ok := w.(*os.File); ok && len(data) > 0 { |
|
// Use writev to write header and data in one syscall |
|
buffers := [][]byte{header, data} |
|
_, err := unix.Writev(int(file.Fd()), buffers) |
|
return err |
|
} |
|
|
|
// Fallback for non-file writers |
|
if _, err := w.Write(header); err != nil { |
|
return err |
|
} |
|
if len(data) > 0 { |
|
if _, err := w.Write(data); err != nil { |
|
return err |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
// readFrame reads a frame from the connection |
|
func readFrame(r io.Reader) (*Frame, error) { |
|
header := make([]byte, frameHeaderSize) |
|
if _, err := io.ReadFull(r, header); err != nil { |
|
return nil, err |
|
} |
|
|
|
channel := header[0] |
|
length := int(header[1])<<24 | int(header[2])<<16 | int(header[3])<<8 | int(header[4]) |
|
|
|
if length > maxFrameSize { |
|
return nil, fmt.Errorf("frame too large: %d > %d", length, maxFrameSize) |
|
} |
|
|
|
data := make([]byte, length) |
|
if length > 0 { |
|
if _, err := io.ReadFull(r, data); err != nil { |
|
return nil, err |
|
} |
|
} |
|
|
|
return &Frame{Channel: channel, Data: data}, nil |
|
} |
|
|
|
// getLocalCID retrieves the local CID by performing an ioctl on /dev/vsock |
|
func getLocalCID() (uint32, error) { |
|
f, err := os.Open("/dev/vsock") |
|
if err != nil { |
|
return 0, fmt.Errorf("failed to open /dev/vsock: %w", err) |
|
} |
|
defer f.Close() |
|
|
|
var cid uint32 |
|
_, _, errno := unix.Syscall( |
|
unix.SYS_IOCTL, |
|
f.Fd(), |
|
IOCTL_VM_SOCKETS_GET_LOCAL_CID, |
|
uintptr(unsafe.Pointer(&cid)), |
|
) |
|
if errno != 0 { |
|
return 0, fmt.Errorf("ioctl failed: %v", errno) |
|
} |
|
|
|
return cid, nil |
|
} |
|
|
|
// runServer starts the vsock server |
|
func runServer() error { |
|
cid, err := getLocalCID() |
|
if err != nil { |
|
return fmt.Errorf("failed to get local CID: %w", err) |
|
} |
|
|
|
log.Printf("Local CID: %d", cid) |
|
|
|
// Create a vsock socket |
|
fd, err := unix.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0) |
|
if err != nil { |
|
return fmt.Errorf("failed to create socket: %w", err) |
|
} |
|
defer unix.Close(fd) |
|
|
|
// Bind to the vsock address |
|
sockaddr := &unix.SockaddrVM{ |
|
CID: unix.VMADDR_CID_ANY, |
|
Port: uint32(port), |
|
} |
|
|
|
if err := unix.Bind(fd, sockaddr); err != nil { |
|
return fmt.Errorf("failed to bind: %w", err) |
|
} |
|
|
|
// Listen for connections |
|
if err := unix.Listen(fd, 128); err != nil { |
|
return fmt.Errorf("failed to listen: %w", err) |
|
} |
|
|
|
log.Printf("Listening on vsock(%d:%d)", cid, port) |
|
|
|
// Accept connections in a loop |
|
for { |
|
clientFd, _, err := unix.Accept(fd) |
|
if err != nil { |
|
log.Printf("Accept error: %v", err) |
|
continue |
|
} |
|
|
|
if single { |
|
// Handle client and exit after it's done |
|
handleClient(clientFd) |
|
return nil |
|
} else { |
|
// Handle each client in a goroutine |
|
go handleClient(clientFd) |
|
} |
|
} |
|
} |
|
|
|
// handleClient handles an individual client connection |
|
func handleClient(clientFd int) { |
|
id := atomic.AddInt32(&connID, 1) |
|
connName := fmt.Sprintf("conn-%d", id) |
|
|
|
log.Printf("[%s] Client connected", connName) |
|
defer log.Printf("[%s] Client disconnected", connName) |
|
|
|
// Create a file from the client fd for easier I/O |
|
// Note: closing clientFile will also close the FD |
|
clientFile := os.NewFile(uintptr(clientFd), "vsock-client") |
|
defer clientFile.Close() |
|
|
|
// Read the client request |
|
var req ClientRequest |
|
decoder := json.NewDecoder(clientFile) |
|
if err := decoder.Decode(&req); err != nil { |
|
log.Printf("[%s] Failed to read client request: %v", connName, err) |
|
return |
|
} |
|
|
|
// Determine what to execute |
|
var cmdArgs []string |
|
if len(req.Command) == 0 { |
|
// No command specified, use default shell |
|
shell := os.Getenv("SHELL") |
|
if shell == "" { |
|
shell = "/bin/sh" |
|
} |
|
cmdArgs = []string{shell} |
|
} else { |
|
cmdArgs = req.Command |
|
} |
|
|
|
cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) |
|
cmd.Env = os.Environ() |
|
|
|
var exitCode int |
|
|
|
if req.UsePTY { |
|
// PTY mode |
|
ptmx, err := pty.Start(cmd) |
|
if err != nil { |
|
log.Printf("[%s] Failed to start command with PTY: %v", connName, err) |
|
exitCode = 255 |
|
// Send error exit code |
|
resp := ServerResponse{ExitCode: exitCode} |
|
json.NewEncoder(clientFile).Encode(&resp) |
|
return |
|
} |
|
defer ptmx.Close() |
|
|
|
// Copy data bidirectionally |
|
var wg sync.WaitGroup |
|
wg.Add(2) |
|
|
|
// Client -> PTY |
|
go func() { |
|
defer wg.Done() |
|
io.Copy(ptmx, clientFile) |
|
}() |
|
|
|
// PTY -> Client |
|
go func() { |
|
defer wg.Done() |
|
io.Copy(clientFile, ptmx) |
|
}() |
|
|
|
// Wait for the command to exit |
|
if err := cmd.Wait(); err != nil { |
|
log.Printf("[%s] Command error: %v", connName, err) |
|
if exitErr, ok := err.(*exec.ExitError); ok { |
|
exitCode = exitErr.ExitCode() |
|
} else { |
|
exitCode = 255 |
|
} |
|
} else { |
|
exitCode = 0 |
|
} |
|
|
|
// Close the PTY to stop writing to client |
|
ptmx.Close() |
|
|
|
// Shutdown read side to unblock the Client->PTY goroutine |
|
unix.Shutdown(int(clientFile.Fd()), unix.SHUT_RD) |
|
|
|
// Wait for goroutines |
|
wg.Wait() |
|
|
|
// Send exit code to client |
|
resp := ServerResponse{ExitCode: exitCode} |
|
json.NewEncoder(clientFile).Encode(&resp) |
|
} else { |
|
// Non-PTY mode: use pipes for stdin/stdout/stderr and multiplex them |
|
|
|
// Create pipes for stdin, stdout, stderr |
|
stdinPipe, err := cmd.StdinPipe() |
|
if err != nil { |
|
log.Printf("[%s] Failed to create stdin pipe: %v", connName, err) |
|
exitCode = 255 |
|
respData, _ := json.Marshal(ServerResponse{ExitCode: exitCode}) |
|
writeFrame(clientFile, channelControl, respData) |
|
return |
|
} |
|
|
|
stdoutPipe, err := cmd.StdoutPipe() |
|
if err != nil { |
|
log.Printf("[%s] Failed to create stdout pipe: %v", connName, err) |
|
exitCode = 255 |
|
respData, _ := json.Marshal(ServerResponse{ExitCode: exitCode}) |
|
writeFrame(clientFile, channelControl, respData) |
|
return |
|
} |
|
|
|
stderrPipe, err := cmd.StderrPipe() |
|
if err != nil { |
|
log.Printf("[%s] Failed to create stderr pipe: %v", connName, err) |
|
exitCode = 255 |
|
respData, _ := json.Marshal(ServerResponse{ExitCode: exitCode}) |
|
writeFrame(clientFile, channelControl, respData) |
|
return |
|
} |
|
|
|
if err := cmd.Start(); err != nil { |
|
log.Printf("[%s] Failed to start command: %v", connName, err) |
|
exitCode = 255 |
|
respData, _ := json.Marshal(ServerResponse{ExitCode: exitCode}) |
|
writeFrame(clientFile, channelControl, respData) |
|
return |
|
} |
|
|
|
var wg sync.WaitGroup |
|
var stdinWg sync.WaitGroup |
|
|
|
// Track stdout/stderr completion separately |
|
wg.Add(2) |
|
stdinWg.Add(1) |
|
|
|
// Client stdin -> Command stdin |
|
go func() { |
|
defer stdinWg.Done() |
|
defer stdinPipe.Close() |
|
for { |
|
frame, err := readFrame(clientFile) |
|
if err != nil { |
|
return |
|
} |
|
if frame.Channel == channelStdin { |
|
if len(frame.Data) > 0 { |
|
stdinPipe.Write(frame.Data) |
|
} |
|
} |
|
} |
|
}() |
|
|
|
// Command stdout -> Client |
|
go func() { |
|
defer wg.Done() |
|
buf := make([]byte, maxFrameSize) |
|
for { |
|
n, err := stdoutPipe.Read(buf) |
|
if n > 0 { |
|
writeFrame(clientFile, channelStdout, buf[:n]) |
|
} |
|
if err != nil { |
|
return |
|
} |
|
} |
|
}() |
|
|
|
// Command stderr -> Client |
|
go func() { |
|
defer wg.Done() |
|
buf := make([]byte, maxFrameSize) |
|
for { |
|
n, err := stderrPipe.Read(buf) |
|
if n > 0 { |
|
writeFrame(clientFile, channelStderr, buf[:n]) |
|
} |
|
if err != nil { |
|
return |
|
} |
|
} |
|
}() |
|
|
|
// Wait for stdout/stderr to be fully read (they'll get EOF when command exits) |
|
wg.Wait() |
|
|
|
// Now it's safe to call Wait - all pipe data has been read |
|
if err := cmd.Wait(); err != nil { |
|
log.Printf("[%s] Command error: %v", connName, err) |
|
if exitErr, ok := err.(*exec.ExitError); ok { |
|
exitCode = exitErr.ExitCode() |
|
} else { |
|
exitCode = 255 |
|
} |
|
} else { |
|
exitCode = 0 |
|
} |
|
|
|
// Shutdown read side to unblock stdin reader goroutine |
|
unix.Shutdown(int(clientFile.Fd()), unix.SHUT_RD) |
|
|
|
// Wait for stdin goroutine to finish |
|
stdinWg.Wait() |
|
|
|
// Send exit code on control channel |
|
resp := ServerResponse{ExitCode: exitCode} |
|
respData, _ := json.Marshal(resp) |
|
writeFrame(clientFile, channelControl, respData) |
|
} |
|
} |
|
|
|
// runClient connects to a vsock server |
|
func runClient(cid uint32, command []string) error { |
|
// Create a vsock socket |
|
fd, err := unix.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0) |
|
if err != nil { |
|
return fmt.Errorf("failed to create socket: %w", err) |
|
} |
|
defer unix.Close(fd) |
|
|
|
// Connect to the server |
|
sockaddr := &unix.SockaddrVM{ |
|
CID: cid, |
|
Port: uint32(port), |
|
} |
|
|
|
if err := unix.Connect(fd, sockaddr); err != nil { |
|
return fmt.Errorf("failed to connect: %w", err) |
|
} |
|
|
|
// Create a file from the socket for easier I/O |
|
conn := os.NewFile(uintptr(fd), "vsock-conn") |
|
|
|
// Determine if we need PTY |
|
// PTY is needed if: no command provided OR -t flag is set |
|
usePTY := len(command) == 0 || forcePTY |
|
|
|
// Send the client request |
|
req := ClientRequest{ |
|
UsePTY: usePTY, |
|
Command: command, |
|
} |
|
|
|
encoder := json.NewEncoder(conn) |
|
if err := encoder.Encode(&req); err != nil { |
|
return fmt.Errorf("failed to send request: %w", err) |
|
} |
|
|
|
if usePTY { |
|
// PTY mode: set terminal to raw mode if stdin is a terminal |
|
var oldState *term.State |
|
if term.IsTerminal(int(os.Stdin.Fd())) { |
|
oldState, err = term.MakeRaw(int(os.Stdin.Fd())) |
|
if err != nil { |
|
return fmt.Errorf("failed to set raw mode: %w", err) |
|
} |
|
defer term.Restore(int(os.Stdin.Fd()), oldState) |
|
} |
|
|
|
// Channel to receive exit code |
|
exitCodeChan := make(chan int, 1) |
|
|
|
// Stdin -> Server |
|
go func() { |
|
io.Copy(conn, os.Stdin) |
|
}() |
|
|
|
// Server -> Stdout, then read exit code JSON |
|
go func() { |
|
// Copy all PTY output to stdout until the JSON starts |
|
// We need to detect the JSON exit code message |
|
buf := make([]byte, 32*1024) |
|
var jsonBuf []byte |
|
|
|
for { |
|
n, err := conn.Read(buf) |
|
if n > 0 { |
|
// Check if this contains the start of our JSON message |
|
data := buf[:n] |
|
|
|
// Look for {"exit_code": pattern |
|
jsonStart := -1 |
|
for i := 0; i < len(data); i++ { |
|
if i <= len(data)-13 && string(data[i:i+13]) == `{"exit_code":` { |
|
jsonStart = i |
|
break |
|
} |
|
} |
|
|
|
if jsonStart >= 0 { |
|
// Write everything before the JSON |
|
if jsonStart > 0 { |
|
os.Stdout.Write(data[:jsonStart]) |
|
} |
|
// Start collecting JSON |
|
jsonBuf = append(jsonBuf, data[jsonStart:]...) |
|
break |
|
} else { |
|
// No JSON yet, write all output |
|
os.Stdout.Write(data) |
|
} |
|
} |
|
if err != nil { |
|
exitCodeChan <- 255 // Error reading |
|
return |
|
} |
|
} |
|
|
|
// Continue reading to get complete JSON |
|
for { |
|
n, err := conn.Read(buf) |
|
if n > 0 { |
|
jsonBuf = append(jsonBuf, buf[:n]...) |
|
} |
|
if err != nil { |
|
break |
|
} |
|
} |
|
|
|
// Parse the JSON exit code |
|
var resp ServerResponse |
|
if err := json.Unmarshal(jsonBuf, &resp); err != nil { |
|
exitCodeChan <- 255 |
|
} else { |
|
exitCodeChan <- resp.ExitCode |
|
} |
|
}() |
|
|
|
// Wait for exit code |
|
exitCode := <-exitCodeChan |
|
|
|
// Restore terminal before exiting |
|
if oldState != nil { |
|
term.Restore(int(os.Stdin.Fd()), oldState) |
|
} |
|
|
|
os.Exit(exitCode) |
|
} else { |
|
// Non-PTY mode: use framing protocol for stdin/stdout/stderr |
|
|
|
ctx, cancel := context.WithCancel(context.Background()) |
|
defer cancel() |
|
|
|
exitCodeChan := make(chan int, 1) |
|
var wg sync.WaitGroup |
|
wg.Add(2) |
|
|
|
// Stdin -> Server (on stdin channel) |
|
go func() { |
|
defer wg.Done() |
|
buf := make([]byte, maxFrameSize) |
|
for { |
|
select { |
|
case <-ctx.Done(): |
|
return |
|
default: |
|
} |
|
|
|
n, err := os.Stdin.Read(buf) |
|
if n > 0 { |
|
if err := writeFrame(conn, channelStdin, buf[:n]); err != nil { |
|
return |
|
} |
|
} |
|
if err != nil { |
|
return |
|
} |
|
} |
|
}() |
|
|
|
// Server -> Stdout/Stderr |
|
go func() { |
|
defer wg.Done() |
|
for { |
|
frame, err := readFrame(conn) |
|
if err != nil { |
|
exitCodeChan <- 255 |
|
return |
|
} |
|
|
|
switch frame.Channel { |
|
case channelStdout: |
|
os.Stdout.Write(frame.Data) |
|
case channelStderr: |
|
os.Stderr.Write(frame.Data) |
|
case channelControl: |
|
// Parse exit code |
|
var resp ServerResponse |
|
if err := json.Unmarshal(frame.Data, &resp); err != nil { |
|
exitCodeChan <- 255 |
|
} else { |
|
exitCodeChan <- resp.ExitCode |
|
} |
|
// Cancel context to stop stdin reader |
|
cancel() |
|
return |
|
} |
|
} |
|
}() |
|
|
|
// Wait for exit code |
|
exitCode := <-exitCodeChan |
|
|
|
// Give goroutines a moment to finish, but don't wait forever |
|
done := make(chan struct{}) |
|
go func() { |
|
wg.Wait() |
|
close(done) |
|
}() |
|
|
|
select { |
|
case <-done: |
|
case <-time.After(100 * time.Millisecond): |
|
// Force exit if goroutines don't finish quickly |
|
} |
|
|
|
os.Exit(exitCode) |
|
} |
|
|
|
return nil // unreachable |
|
} |
|
|
|
func main() { |
|
var rootCmd = &cobra.Command{ |
|
Use: "vsock-shell", |
|
Short: "A tool for executing commands over VM sockets", |
|
Long: `vsock-shell provides SSH-like functionality for virtual machine communication using VSOCK sockets.`, |
|
} |
|
|
|
var serverCmd = &cobra.Command{ |
|
Use: "serve", |
|
Short: "Run in server mode", |
|
Long: `Start a vsock-shell server that listens for client connections and executes commands.`, |
|
RunE: func(cmd *cobra.Command, args []string) error { |
|
return runServer() |
|
}, |
|
} |
|
|
|
serverCmd.Flags().IntVarP(&port, "port", "p", defaultPort, "Port to listen on") |
|
serverCmd.Flags().BoolVar(&single, "single", false, "Exit after handling one client connection") |
|
serverCmd.Flags().BoolVarP(&single, "one", "1", false, "Exit after handling one client connection (shorthand)") |
|
|
|
var clientCmd = &cobra.Command{ |
|
Use: "exec CID [command...]", |
|
Aliases: []string{"x"}, |
|
Short: "Run in client mode", |
|
Long: `Connect to a vsock-shell server and execute commands or open an interactive shell.`, |
|
Args: cobra.MinimumNArgs(1), |
|
RunE: func(cmd *cobra.Command, args []string) error { |
|
var cid uint32 |
|
if _, err := fmt.Sscanf(args[0], "%d", &cid); err != nil { |
|
return fmt.Errorf("invalid CID: %v", err) |
|
} |
|
|
|
command := args[1:] |
|
return runClient(cid, command) |
|
}, |
|
} |
|
|
|
clientCmd.Flags().IntVarP(&port, "port", "p", defaultPort, "Port to connect to") |
|
clientCmd.Flags().BoolVarP(&forcePTY, "tty", "t", false, "Force PTY allocation") |
|
|
|
rootCmd.AddCommand(serverCmd) |
|
rootCmd.AddCommand(clientCmd) |
|
|
|
if err := rootCmd.Execute(); err != nil { |
|
os.Exit(1) |
|
} |
|
} |