Created
January 23, 2026 12:38
-
-
Save hydrz/4cb3bb1c1d5379bd2857864625d37034 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package socks5 | |
| import ( | |
| "bytes" | |
| "context" | |
| "encoding/binary" | |
| "errors" | |
| "fmt" | |
| "io" | |
| "log" | |
| "net" | |
| "strconv" | |
| "time" | |
| ) | |
| // Authentication METHODs described in RFC 1928, section 3. | |
| const ( | |
| noAuthRequired byte = 0x00 | |
| passwordAuth byte = 0x02 | |
| noAcceptableAuth byte = 0xff | |
| ) | |
| // passwordAuthVersion is the auth version byte described in RFC 1929. | |
| const passwordAuthVersion = 1 | |
| // socks5Version is the byte that represents the SOCKS version | |
| // in requests. | |
| const socks5Version byte = 0x05 | |
| // commandType are the bytes sent in SOCKS5 packets | |
| // that represent the kind of connection the client needs. | |
| type commandType byte | |
| // The set of valid SOCKS5 commands as described in RFC 1928. | |
| const ( | |
| connect commandType = 0x01 | |
| bind commandType = 0x02 | |
| udpAssociate commandType = 0x03 | |
| ) | |
| func (cmd commandType) String() string { | |
| switch cmd { | |
| case connect: | |
| return "socks connect" | |
| case bind: | |
| return "socks bind" | |
| case udpAssociate: | |
| return "socks udp" | |
| default: | |
| return "socks " + strconv.Itoa(int(cmd)) | |
| } | |
| } | |
| // addrType are the bytes sent in SOCKS5 packets | |
| // that represent particular address types. | |
| type addrType byte | |
| // The set of valid SOCKS5 address types as defined in RFC 1928. | |
| const ( | |
| ipv4 addrType = 0x01 | |
| domainName addrType = 0x03 | |
| ipv6 addrType = 0x04 | |
| ) | |
| // replyCode are the bytes sent in SOCKS5 packets | |
| // that represent replies from the server to a client | |
| // request. | |
| type replyCode byte | |
| // The set of valid SOCKS5 reply types as per the RFC 1928. | |
| const ( | |
| success replyCode = 0x00 | |
| generalFailure replyCode = 0x01 | |
| connectionNotAllowed replyCode = 0x02 | |
| networkUnreachable replyCode = 0x03 | |
| hostUnreachable replyCode = 0x04 | |
| connectionRefused replyCode = 0x05 | |
| ttlExpired replyCode = 0x06 | |
| commandNotSupported replyCode = 0x07 | |
| addrTypeNotSupported replyCode = 0x08 | |
| ) | |
| func (code replyCode) String() string { | |
| switch code { | |
| case success: | |
| return "succeeded" | |
| case generalFailure: | |
| return "general SOCKS server failure" | |
| case connectionNotAllowed: | |
| return "connection not allowed by ruleset" | |
| case networkUnreachable: | |
| return "network unreachable" | |
| case hostUnreachable: | |
| return "host unreachable" | |
| case connectionRefused: | |
| return "connection refused" | |
| case ttlExpired: | |
| return "TTL expired" | |
| case commandNotSupported: | |
| return "command not supported" | |
| case addrTypeNotSupported: | |
| return "address type not supported" | |
| default: | |
| return "unknown code: " + strconv.Itoa(int(code)) | |
| } | |
| } | |
| type socksAddr struct { | |
| addrType addrType | |
| addr string | |
| port uint16 | |
| } | |
| var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0} | |
| func parseSocksAddr(r io.Reader) (addr socksAddr, err error) { | |
| var addrTypeData [1]byte | |
| _, err = io.ReadFull(r, addrTypeData[:]) | |
| if err != nil { | |
| return socksAddr{}, fmt.Errorf("could not read address type") | |
| } | |
| dstAddrType := addrType(addrTypeData[0]) | |
| var destination string | |
| switch dstAddrType { | |
| case ipv4: | |
| var ip [4]byte | |
| _, err = io.ReadFull(r, ip[:]) | |
| if err != nil { | |
| return socksAddr{}, fmt.Errorf("could not read IPv4 address") | |
| } | |
| destination = net.IP(ip[:]).String() | |
| case domainName: | |
| var dstSizeByte [1]byte | |
| _, err = io.ReadFull(r, dstSizeByte[:]) | |
| if err != nil { | |
| return socksAddr{}, fmt.Errorf("could not read domain name size") | |
| } | |
| dstSize := int(dstSizeByte[0]) | |
| domainName := make([]byte, dstSize) | |
| _, err = io.ReadFull(r, domainName) | |
| if err != nil { | |
| return socksAddr{}, fmt.Errorf("could not read domain name") | |
| } | |
| destination = string(domainName) | |
| case ipv6: | |
| var ip [16]byte | |
| _, err = io.ReadFull(r, ip[:]) | |
| if err != nil { | |
| return socksAddr{}, fmt.Errorf("could not read IPv6 address") | |
| } | |
| destination = net.IP(ip[:]).String() | |
| default: | |
| return socksAddr{}, fmt.Errorf("unsupported address type") | |
| } | |
| var portBytes [2]byte | |
| _, err = io.ReadFull(r, portBytes[:]) | |
| if err != nil { | |
| return socksAddr{}, fmt.Errorf("could not read port") | |
| } | |
| port := binary.BigEndian.Uint16(portBytes[:]) | |
| return socksAddr{ | |
| addrType: dstAddrType, | |
| addr: destination, | |
| port: port, | |
| }, nil | |
| } | |
| func (s socksAddr) marshal() ([]byte, error) { | |
| var addr []byte | |
| switch s.addrType { | |
| case ipv4: | |
| addr = net.ParseIP(s.addr).To4() | |
| if addr == nil { | |
| return nil, fmt.Errorf("invalid IPv4 address for binding") | |
| } | |
| case domainName: | |
| if len(s.addr) > 255 { | |
| return nil, fmt.Errorf("invalid domain name for binding") | |
| } | |
| addr = make([]byte, 0, len(s.addr)+1) | |
| addr = append(addr, byte(len(s.addr))) | |
| addr = append(addr, []byte(s.addr)...) | |
| case ipv6: | |
| addr = net.ParseIP(s.addr).To16() | |
| if addr == nil { | |
| return nil, fmt.Errorf("invalid IPv6 address for binding") | |
| } | |
| default: | |
| return nil, fmt.Errorf("unsupported address type") | |
| } | |
| pkt := []byte{byte(s.addrType)} | |
| pkt = append(pkt, addr...) | |
| pkt = binary.BigEndian.AppendUint16(pkt, s.port) | |
| return pkt, nil | |
| } | |
| func (s socksAddr) hostPort() string { | |
| return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port))) | |
| } | |
| func (a *socksAddr) Network() string { return "socks" } | |
| func (s socksAddr) String() string { | |
| return s.hostPort() | |
| } | |
| func splitHostPort(hostport string) (host string, port uint16, err error) { | |
| host, portStr, err := net.SplitHostPort(hostport) | |
| if err != nil { | |
| return "", 0, err | |
| } | |
| portInt, err := strconv.Atoi(portStr) | |
| if err != nil { | |
| return "", 0, err | |
| } | |
| if portInt < 0 || portInt > 65535 { | |
| return "", 0, fmt.Errorf("invalid port number %d", portInt) | |
| } | |
| return host, uint16(portInt), nil | |
| } | |
| // UDP conn default buffer size and read timeout. | |
| const ( | |
| bufferSize = 8 * 1024 | |
| readTimeout = 5 * time.Second | |
| ) | |
| type Handler struct { | |
| Logf func(format string, args ...any) | |
| Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |
| Authenticator func(username, password string) bool | |
| request *request | |
| udpClientAddr net.Addr | |
| udpTargetConns map[socksAddr]net.Conn | |
| } | |
| func (h *Handler) dial(ctx context.Context, network, addr string) (net.Conn, error) { | |
| dial := h.Dialer | |
| if dial == nil { | |
| dialer := &net.Dialer{} | |
| dial = dialer.DialContext | |
| } | |
| return dial(ctx, network, addr) | |
| } | |
| func (h *Handler) logf(format string, args ...any) { | |
| logf := h.Logf | |
| if logf == nil { | |
| logf = log.Printf | |
| } | |
| logf(format, args...) | |
| } | |
| func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error { | |
| needAuth := h.Authenticator != nil | |
| authMethod := noAuthRequired | |
| if needAuth { | |
| authMethod = passwordAuth | |
| } | |
| err := parseClientGreeting(conn, authMethod) | |
| if err != nil { | |
| conn.Write([]byte{socks5Version, noAcceptableAuth}) | |
| return err | |
| } | |
| conn.Write([]byte{socks5Version, authMethod}) | |
| if !needAuth { | |
| return h.handleRequest(ctx, conn) | |
| } | |
| user, pwd, err := parseClientAuth(conn) | |
| if err != nil || !h.Authenticator(user, pwd) { | |
| conn.Write([]byte{1, 1}) // auth error | |
| return err | |
| } | |
| conn.Write([]byte{1, 0}) // auth success | |
| return h.handleRequest(ctx, conn) | |
| } | |
| func (h *Handler) handleRequest(ctx context.Context, conn net.Conn) error { | |
| req, err := parseClientRequest(conn) | |
| if err != nil { | |
| res := errorResponse(generalFailure) | |
| buf, _ := res.marshal() | |
| conn.Write(buf) | |
| return err | |
| } | |
| h.request = req | |
| switch req.command { | |
| case connect: | |
| return h.handleTCP(ctx, conn) | |
| case udpAssociate: | |
| return h.handleUDP(ctx, conn) | |
| default: | |
| res := errorResponse(commandNotSupported) | |
| buf, _ := res.marshal() | |
| conn.Write(buf) | |
| return fmt.Errorf("unsupported command %v", req.command) | |
| } | |
| } | |
| func (h *Handler) handleTCP(ctx context.Context, conn net.Conn) error { | |
| dialCtx, cancel := context.WithTimeout(ctx, 5*time.Second) | |
| defer cancel() | |
| srv, err := h.dial( | |
| dialCtx, | |
| "tcp", | |
| h.request.destination.hostPort(), | |
| ) | |
| if err != nil { | |
| res := errorResponse(generalFailure) | |
| buf, _ := res.marshal() | |
| conn.Write(buf) | |
| return err | |
| } | |
| defer srv.Close() | |
| localAddr := srv.LocalAddr().String() | |
| serverAddr, serverPort, err := splitHostPort(localAddr) | |
| if err != nil { | |
| return err | |
| } | |
| res := &response{ | |
| reply: success, | |
| bindAddr: socksAddr{ | |
| addrType: getAddrType(serverAddr), | |
| addr: serverAddr, | |
| port: serverPort, | |
| }, | |
| } | |
| buf, err := res.marshal() | |
| if err != nil { | |
| res = errorResponse(generalFailure) | |
| buf, _ = res.marshal() | |
| } | |
| conn.Write(buf) | |
| errc := make(chan error, 2) | |
| go func() { | |
| _, err := io.Copy(conn, srv) | |
| if err != nil { | |
| err = fmt.Errorf("from backend to client: %w", err) | |
| } | |
| errc <- err | |
| }() | |
| go func() { | |
| _, err := io.Copy(srv, conn) | |
| if err != nil { | |
| err = fmt.Errorf("from client to backend: %w", err) | |
| } | |
| errc <- err | |
| }() | |
| return <-errc | |
| } | |
| func (h *Handler) handleUDP(ctx context.Context, conn net.Conn) error { | |
| // The DST.ADDR and DST.PORT fields contain the address and port that | |
| // the client expects to use to send UDP datagrams on for the | |
| // association. The server MAY use this information to limit access | |
| // to the association. | |
| // @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928. | |
| // | |
| // We do NOT limit the access from the client currently in this implementation. | |
| _ = h.request.destination | |
| addr := conn.LocalAddr() | |
| host, _, err := net.SplitHostPort(addr.String()) | |
| if err != nil { | |
| return err | |
| } | |
| clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0")) | |
| if err != nil { | |
| res := errorResponse(generalFailure) | |
| buf, _ := res.marshal() | |
| conn.Write(buf) | |
| return err | |
| } | |
| defer clientUDPConn.Close() | |
| bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String()) | |
| if err != nil { | |
| return err | |
| } | |
| res := &response{ | |
| reply: success, | |
| bindAddr: socksAddr{ | |
| addrType: getAddrType(bindAddr), | |
| addr: bindAddr, | |
| port: bindPort, | |
| }, | |
| } | |
| buf, err := res.marshal() | |
| if err != nil { | |
| res = errorResponse(generalFailure) | |
| buf, _ = res.marshal() | |
| } | |
| conn.Write(buf) | |
| return h.transferUDP(conn, clientUDPConn) | |
| } | |
| func (h *Handler) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error { | |
| ctx, cancel := context.WithCancel(context.Background()) | |
| defer cancel() | |
| // client -> target | |
| go func() { | |
| defer cancel() | |
| h.udpTargetConns = make(map[socksAddr]net.Conn) | |
| // close all target udp connections when the client connection is closed | |
| defer func() { | |
| for _, conn := range h.udpTargetConns { | |
| _ = conn.Close() | |
| } | |
| }() | |
| buf := make([]byte, bufferSize) | |
| for { | |
| select { | |
| case <-ctx.Done(): | |
| return | |
| default: | |
| err := h.handleUDPRequest(ctx, clientConn, buf) | |
| if err != nil { | |
| if isTimeout(err) { | |
| continue | |
| } | |
| if errors.Is(err, net.ErrClosed) { | |
| return | |
| } | |
| h.logf("udp transfer: handle udp request fail: %v", err) | |
| } | |
| } | |
| } | |
| }() | |
| // A UDP association terminates when the TCP connection that the UDP | |
| // ASSOCIATE request arrived on terminates. RFC1928 | |
| _, err := io.Copy(io.Discard, associatedTCP) | |
| if err != nil { | |
| err = fmt.Errorf("udp associated tcp conn: %w", err) | |
| } | |
| return err | |
| } | |
| func (h *Handler) getOrDialTargetConn( | |
| ctx context.Context, | |
| clientConn net.PacketConn, | |
| targetAddr socksAddr, | |
| ) (net.Conn, error) { | |
| conn, exist := h.udpTargetConns[targetAddr] | |
| if exist { | |
| return conn, nil | |
| } | |
| conn, err := h.dial(ctx, "udp", targetAddr.hostPort()) | |
| if err != nil { | |
| return nil, err | |
| } | |
| h.udpTargetConns[targetAddr] = conn | |
| // target -> client | |
| go func() { | |
| buf := make([]byte, bufferSize) | |
| for { | |
| select { | |
| case <-ctx.Done(): | |
| return | |
| default: | |
| err := h.handleUDPResponse(clientConn, targetAddr, conn, buf) | |
| if err != nil { | |
| if isTimeout(err) { | |
| continue | |
| } | |
| if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { | |
| return | |
| } | |
| h.logf("udp transfer: handle udp response fail: %v", err) | |
| } | |
| } | |
| } | |
| }() | |
| return conn, nil | |
| } | |
| func (h *Handler) handleUDPRequest( | |
| ctx context.Context, | |
| clientConn net.PacketConn, | |
| buf []byte, | |
| ) error { | |
| // add a deadline for the read to avoid blocking forever | |
| _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout)) | |
| n, addr, err := clientConn.ReadFrom(buf) | |
| if err != nil { | |
| return fmt.Errorf("read from client: %w", err) | |
| } | |
| h.udpClientAddr = addr | |
| req, data, err := parseUDPRequest(buf[:n]) | |
| if err != nil { | |
| return fmt.Errorf("parse udp request: %w", err) | |
| } | |
| targetConn, err := h.getOrDialTargetConn(ctx, clientConn, req.addr) | |
| if err != nil { | |
| return fmt.Errorf("dial target %s fail: %w", req.addr, err) | |
| } | |
| nn, err := targetConn.Write(data) | |
| if err != nil { | |
| return fmt.Errorf("write to target %s fail: %w", req.addr, err) | |
| } | |
| if nn != len(data) { | |
| return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite) | |
| } | |
| return nil | |
| } | |
| func (h *Handler) handleUDPResponse( | |
| clientConn net.PacketConn, | |
| targetAddr socksAddr, | |
| targetConn net.Conn, | |
| buf []byte, | |
| ) error { | |
| // add a deadline for the read to avoid blocking forever | |
| _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout)) | |
| n, err := targetConn.Read(buf) | |
| if err != nil { | |
| return fmt.Errorf("read from target: %w", err) | |
| } | |
| hdr := udpRequest{addr: targetAddr} | |
| pkt, err := hdr.marshal() | |
| if err != nil { | |
| return fmt.Errorf("marshal udp request: %w", err) | |
| } | |
| data := append(pkt, buf[:n]...) | |
| // use addr from client to send back | |
| nn, err := clientConn.WriteTo(data, h.udpClientAddr) | |
| if err != nil { | |
| return fmt.Errorf("write to client: %w", err) | |
| } | |
| if nn != len(data) { | |
| return fmt.Errorf("write to client: %w", io.ErrShortWrite) | |
| } | |
| return nil | |
| } | |
| func isTimeout(err error) bool { | |
| var terr interface{ Timeout() bool } | |
| return errors.As(err, &terr) && terr.Timeout() | |
| } | |
| // parseClientGreeting parses a request initiation packet. | |
| func parseClientGreeting(r io.Reader, authMethod byte) error { | |
| var hdr [2]byte | |
| _, err := io.ReadFull(r, hdr[:]) | |
| if err != nil { | |
| return fmt.Errorf("could not read packet header") | |
| } | |
| if hdr[0] != socks5Version { | |
| return fmt.Errorf("incompatible SOCKS version") | |
| } | |
| count := int(hdr[1]) | |
| methods := make([]byte, count) | |
| _, err = io.ReadFull(r, methods) | |
| if err != nil { | |
| return fmt.Errorf("could not read methods") | |
| } | |
| for _, m := range methods { | |
| if m == authMethod { | |
| return nil | |
| } | |
| } | |
| return fmt.Errorf("no acceptable auth methods") | |
| } | |
| func parseClientAuth(r io.Reader) (usr, pwd string, err error) { | |
| var hdr [2]byte | |
| if _, err := io.ReadFull(r, hdr[:]); err != nil { | |
| return "", "", fmt.Errorf("could not read auth packet header") | |
| } | |
| if hdr[0] != passwordAuthVersion { | |
| return "", "", fmt.Errorf("bad SOCKS auth version") | |
| } | |
| usrLen := int(hdr[1]) | |
| usrBytes := make([]byte, usrLen) | |
| if _, err := io.ReadFull(r, usrBytes); err != nil { | |
| return "", "", fmt.Errorf("could not read auth packet username") | |
| } | |
| var hdrPwd [1]byte | |
| if _, err := io.ReadFull(r, hdrPwd[:]); err != nil { | |
| return "", "", fmt.Errorf("could not read auth packet password length") | |
| } | |
| pwdLen := int(hdrPwd[0]) | |
| pwdBytes := make([]byte, pwdLen) | |
| if _, err := io.ReadFull(r, pwdBytes); err != nil { | |
| return "", "", fmt.Errorf("could not read auth packet password") | |
| } | |
| return string(usrBytes), string(pwdBytes), nil | |
| } | |
| func getAddrType(addr string) addrType { | |
| if ip := net.ParseIP(addr); ip != nil { | |
| if ip.To4() != nil { | |
| return ipv4 | |
| } | |
| return ipv6 | |
| } | |
| return domainName | |
| } | |
| // request represents data contained within a SOCKS5 | |
| // connection request packet. | |
| type request struct { | |
| command commandType | |
| destination socksAddr | |
| } | |
| // parseClientRequest converts raw packet bytes into a | |
| // SOCKS5Request struct. | |
| func parseClientRequest(r io.Reader) (*request, error) { | |
| var hdr [3]byte | |
| _, err := io.ReadFull(r, hdr[:]) | |
| if err != nil { | |
| return nil, fmt.Errorf("could not read packet header") | |
| } | |
| cmd := hdr[1] | |
| destination, err := parseSocksAddr(r) | |
| return &request{ | |
| command: commandType(cmd), | |
| destination: destination, | |
| }, err | |
| } | |
| // response contains the contents of | |
| // a response packet sent from the proxy | |
| // to the client. | |
| type response struct { | |
| reply replyCode | |
| bindAddr socksAddr | |
| } | |
| func errorResponse(code replyCode) *response { | |
| return &response{reply: code, bindAddr: zeroSocksAddr} | |
| } | |
| // marshal converts a SOCKS5Response struct into | |
| // a packet. If res.reply == Success, it may throw an error on | |
| // receiving an invalid bind address. Otherwise, it will not throw. | |
| func (res *response) marshal() ([]byte, error) { | |
| pkt := make([]byte, 3) | |
| pkt[0] = socks5Version | |
| pkt[1] = byte(res.reply) | |
| pkt[2] = 0 // null reserved byte | |
| addrPkt, err := res.bindAddr.marshal() | |
| if err != nil { | |
| return nil, err | |
| } | |
| return append(pkt, addrPkt...), nil | |
| } | |
| type udpRequest struct { | |
| frag byte | |
| addr socksAddr | |
| } | |
| // +----+------+------+----------+----------+----------+ | |
| // |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | | |
| // +----+------+------+----------+----------+----------+ | |
| // | 2 | 1 | 1 | Variable | 2 | Variable | | |
| // +----+------+------+----------+----------+----------+ | |
| func parseUDPRequest(data []byte) (*udpRequest, []byte, error) { | |
| if len(data) < 4 { | |
| return nil, nil, fmt.Errorf("invalid packet length") | |
| } | |
| // reserved bytes | |
| if !(data[0] == 0 && data[1] == 0) { | |
| return nil, nil, fmt.Errorf("invalid udp request header") | |
| } | |
| frag := data[2] | |
| reader := bytes.NewReader(data[3:]) | |
| addr, err := parseSocksAddr(reader) | |
| bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length | |
| body := data[len(data)-bodyLen:] | |
| return &udpRequest{ | |
| frag: frag, | |
| addr: addr, | |
| }, body, err | |
| } | |
| func (u *udpRequest) marshal() ([]byte, error) { | |
| pkt := make([]byte, 3) | |
| pkt[0] = 0 | |
| pkt[1] = 0 | |
| pkt[2] = u.frag | |
| addrPkt, err := u.addr.marshal() | |
| if err != nil { | |
| return nil, err | |
| } | |
| return append(pkt, addrPkt...), nil | |
| } | |
| var ( | |
| noDeadline = time.Time{} | |
| aLongTimeAgo = time.Unix(1, 0) | |
| ) | |
| // A Dialer holds SOCKS-specific options. | |
| type Dialer struct { | |
| ProxyNetwork string // network between a proxy server and a client | |
| ProxyAddress string // proxy server address | |
| // ProxyDial specifies the optional dial function for | |
| // establishing the transport connection. | |
| ProxyDial func(context.Context, string, string) (net.Conn, error) | |
| Username string | |
| Password string | |
| } | |
| // Conn wraps a net.Conn and implements both net.Conn and net.PacketConn interfaces. | |
| // For TCP connections, use Read/Write methods. | |
| // For UDP connections, use ReadFrom/WriteTo methods. | |
| type Conn struct { | |
| net.Conn // TCP connection (for CONNECT) or TCP control connection (for UDP ASSOCIATE) | |
| udpConn *net.UDPConn // UDP connection (only for UDP ASSOCIATE) | |
| relayAddr *net.UDPAddr // UDP relay server address (only for UDP ASSOCIATE) | |
| boundAddr socksAddr // Bound address from server response | |
| isUDP bool // Whether this is a UDP connection | |
| } | |
| // ReadFrom implements net.PacketConn interface for UDP connections | |
| func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { | |
| if !c.isUDP { | |
| return 0, nil, errors.New("ReadFrom only available for UDP connections") | |
| } | |
| if c.udpConn == nil { | |
| return 0, nil, errors.New("UDP connection not initialized") | |
| } | |
| buf := make([]byte, 65535) | |
| n, _, err = c.udpConn.ReadFromUDP(buf) | |
| if err != nil { | |
| return 0, nil, err | |
| } | |
| // Parse UDP request header (RFC 1928 section 7) | |
| if n < 10 { | |
| return 0, nil, fmt.Errorf("packet too short") | |
| } | |
| if buf[0] != 0x00 || buf[1] != 0x00 { | |
| return 0, nil, fmt.Errorf("invalid reserved fields") | |
| } | |
| if buf[2] != 0x00 { | |
| return 0, nil, fmt.Errorf("fragmentation not supported") | |
| } | |
| // Parse address | |
| reader := &byteReader{buf: buf[3:n], pos: 0} | |
| socksAddr, err := parseSocksAddr(reader) | |
| if err != nil { | |
| return 0, nil, err | |
| } | |
| // Calculate header length | |
| headerLen := 3 + 1 // RSV + FRAG + ATYP | |
| switch socksAddr.addrType { | |
| case ipv4: | |
| headerLen += 4 | |
| case ipv6: | |
| headerLen += 16 | |
| case domainName: | |
| headerLen += 1 + len(socksAddr.addr) | |
| } | |
| headerLen += 2 // PORT | |
| // Copy data after header | |
| dataLen := n - headerLen | |
| if dataLen > len(b) { | |
| return 0, nil, fmt.Errorf("buffer too small") | |
| } | |
| copy(b, buf[headerLen:n]) | |
| return dataLen, &socksAddr, nil | |
| } | |
| // WriteTo implements net.PacketConn interface for UDP connections | |
| func (c *Conn) WriteTo(b []byte, addr net.Addr) (n int, err error) { | |
| if !c.isUDP { | |
| return 0, errors.New("WriteTo only available for UDP connections") | |
| } | |
| if c.udpConn == nil { | |
| return 0, errors.New("UDP connection not initialized") | |
| } | |
| var sa socksAddr | |
| // Parse target address | |
| switch a := addr.(type) { | |
| case *socksAddr: | |
| sa = *a | |
| case *net.UDPAddr: | |
| host, port, err := splitHostPort(a.String()) | |
| if err != nil { | |
| return 0, err | |
| } | |
| if ip := net.ParseIP(host); ip != nil { | |
| if ip.To4() != nil { | |
| sa = socksAddr{addrType: ipv4, addr: ip.String(), port: port} | |
| } else { | |
| sa = socksAddr{addrType: ipv6, addr: ip.String(), port: port} | |
| } | |
| } else { | |
| sa = socksAddr{addrType: domainName, addr: host, port: port} | |
| } | |
| default: | |
| host, port, err := splitHostPort(addr.String()) | |
| if err != nil { | |
| return 0, err | |
| } | |
| if ip := net.ParseIP(host); ip != nil { | |
| if ip.To4() != nil { | |
| sa = socksAddr{addrType: ipv4, addr: ip.String(), port: port} | |
| } else { | |
| sa = socksAddr{addrType: ipv6, addr: ip.String(), port: port} | |
| } | |
| } else { | |
| sa = socksAddr{addrType: domainName, addr: host, port: port} | |
| } | |
| } | |
| // Build UDP request header (RFC 1928 section 7) | |
| header := []byte{0x00, 0x00, 0x00} // RSV + FRAG | |
| addrBytes, err := sa.marshal() | |
| if err != nil { | |
| return 0, err | |
| } | |
| header = append(header, addrBytes...) | |
| // Combine header and data | |
| packet := append(header, b...) | |
| // Send to relay server | |
| _, err = c.udpConn.WriteToUDP(packet, c.relayAddr) | |
| if err != nil { | |
| return 0, err | |
| } | |
| return len(b), nil | |
| } | |
| // LocalAddr returns the local network address for UDP, or the TCP local address | |
| func (c *Conn) LocalAddr() net.Addr { | |
| if c.isUDP && c.udpConn != nil { | |
| return c.udpConn.LocalAddr() | |
| } | |
| return c.Conn.LocalAddr() | |
| } | |
| // RemoteAddr returns the remote network address | |
| func (c *Conn) RemoteAddr() net.Addr { | |
| return &c.boundAddr | |
| } | |
| // Close closes the connection | |
| func (c *Conn) Close() error { | |
| var err1, err2 error | |
| if c.udpConn != nil { | |
| err1 = c.udpConn.Close() | |
| } | |
| if c.Conn != nil { | |
| err2 = c.Conn.Close() | |
| } | |
| if err1 != nil { | |
| return err1 | |
| } | |
| return err2 | |
| } | |
| // IsUDP returns true if this is a UDP connection | |
| func (c *Conn) IsUDP() bool { | |
| return c.isUDP | |
| } | |
| // byteReader implements io.Reader for byte slices | |
| type byteReader struct { | |
| buf []byte | |
| pos int | |
| } | |
| func (r *byteReader) Read(p []byte) (n int, err error) { | |
| if r.pos >= len(r.buf) { | |
| return 0, io.EOF | |
| } | |
| n = copy(p, r.buf[r.pos:]) | |
| r.pos += n | |
| return n, nil | |
| } | |
| // DialContext connects to the provided address on the provided network. | |
| // | |
| // For TCP networks ("tcp", "tcp4", "tcp6"), it returns a *Conn that can be used | |
| // with Read/Write methods like a normal net.Conn. | |
| // | |
| // For UDP networks ("udp", "udp4", "udp6"), it returns a *Conn that should be used | |
| // with ReadFrom/WriteTo methods like a net.PacketConn. | |
| // | |
| // The returned error value may be a net.OpError. When the Op field of | |
| // net.OpError contains "socks", the Source field contains a proxy | |
| // server address and the Addr field contains a command target address. | |
| func (d *Dialer) DialContext(ctx context.Context, network, address string) (_ *Conn, ctxErr error) { | |
| var cmd commandType | |
| switch network { | |
| case "tcp", "tcp4", "tcp6": | |
| cmd = connect | |
| case "udp", "udp4", "udp6": | |
| cmd = udpAssociate | |
| default: | |
| proxy, dst, _ := d.pathAddrs(address) | |
| return nil, &net.OpError{Op: "dial", Net: network, Source: proxy, Addr: dst, Err: errors.New("network not implemented")} | |
| } | |
| if ctx == nil { | |
| proxy, dst, _ := d.pathAddrs(address) | |
| return nil, &net.OpError{Op: cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} | |
| } | |
| dialer := d.ProxyDial | |
| if dialer == nil { | |
| dialer = (&net.Dialer{}).DialContext | |
| } | |
| c, err := dialer(ctx, d.ProxyNetwork, d.ProxyAddress) | |
| if err != nil { | |
| proxy, dst, _ := d.pathAddrs(address) | |
| return nil, &net.OpError{Op: cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} | |
| } | |
| if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { | |
| c.SetDeadline(deadline) | |
| defer c.SetDeadline(noDeadline) | |
| } | |
| if ctx != context.Background() { | |
| errCh := make(chan error, 1) | |
| done := make(chan struct{}) | |
| defer func() { | |
| close(done) | |
| if ctxErr == nil { | |
| ctxErr = <-errCh | |
| } | |
| }() | |
| go func() { | |
| select { | |
| case <-ctx.Done(): | |
| c.SetDeadline(aLongTimeAgo) | |
| errCh <- ctx.Err() | |
| case <-done: | |
| errCh <- nil | |
| } | |
| }() | |
| } | |
| // Perform SOCKS5 handshake | |
| b := make([]byte, 0, 512) | |
| if err := d.greet(ctx, c, b); err != nil { | |
| c.Close() | |
| proxy, dst, _ := d.pathAddrs(address) | |
| return nil, &net.OpError{Op: cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} | |
| } | |
| // Send request | |
| _, dstAddr, err := d.pathAddrs(address) | |
| if err != nil { | |
| c.Close() | |
| proxy, dst, _ := d.pathAddrs(address) | |
| return nil, &net.OpError{Op: cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} | |
| } | |
| if cmd == udpAssociate { | |
| return d.udpAssociate(ctx, c, network, dstAddr) | |
| } | |
| // TCP CONNECT | |
| return d.connect(ctx, c, network, dstAddr) | |
| } | |
| // connect handles TCP CONNECT command | |
| func (d *Dialer) connect(_ context.Context, c net.Conn, network string, dstAddr *socksAddr) (*Conn, error) { | |
| b := make([]byte, 0, 512) | |
| b = append(b, socks5Version, byte(connect), 0x00) // VER + CMD + RSV | |
| addrBytes, err := dstAddr.marshal() | |
| if err != nil { | |
| c.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| b = append(b, addrBytes...) | |
| if _, err := c.Write(b); err != nil { | |
| c.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| // Read response | |
| if _, err := io.ReadFull(c, b[:4]); err != nil { | |
| c.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| if b[0] != socks5Version { | |
| c.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: errors.New("invalid SOCKS version")} | |
| } | |
| if reply := replyCode(b[1]); reply != success { | |
| c.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: errors.New(reply.String())} | |
| } | |
| // Parse bound address | |
| boundAddr, err := parseSocksAddr(c) | |
| if err != nil { | |
| c.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| return &Conn{ | |
| Conn: c, | |
| boundAddr: boundAddr, | |
| isUDP: false, | |
| }, nil | |
| } | |
| // udpAssociate handles UDP ASSOCIATE command | |
| func (d *Dialer) udpAssociate(_ context.Context, tcpConn net.Conn, network string, dstAddr *socksAddr) (*Conn, error) { | |
| b := make([]byte, 0, 512) | |
| b = append(b, socks5Version, byte(udpAssociate), 0x00) // VER + CMD + RSV | |
| // Send 0.0.0.0:0 as client UDP address | |
| clientAddr := zeroSocksAddr | |
| addrBytes, err := clientAddr.marshal() | |
| if err != nil { | |
| tcpConn.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| b = append(b, addrBytes...) | |
| if _, err := tcpConn.Write(b); err != nil { | |
| tcpConn.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| // Read response | |
| if _, err := io.ReadFull(tcpConn, b[:4]); err != nil { | |
| tcpConn.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| if b[0] != socks5Version { | |
| tcpConn.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: errors.New("invalid SOCKS version")} | |
| } | |
| if reply := replyCode(b[1]); reply != success { | |
| tcpConn.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: errors.New(reply.String())} | |
| } | |
| // Parse UDP relay address | |
| relayAddr, err := parseSocksAddr(tcpConn) | |
| if err != nil { | |
| tcpConn.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| // Create UDP connection using ListenUDP | |
| udpAddr, err := net.ResolveUDPAddr(network, "") | |
| if err != nil { | |
| tcpConn.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| udpConn, err := net.ListenUDP(network, udpAddr) | |
| if err != nil { | |
| tcpConn.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| // Parse relay address to *net.UDPAddr | |
| udpRelayAddr, err := net.ResolveUDPAddr("udp", relayAddr.hostPort()) | |
| if err != nil { | |
| udpConn.Close() | |
| tcpConn.Close() | |
| proxy, _, _ := d.pathAddrs("") | |
| return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err} | |
| } | |
| return &Conn{ | |
| Conn: tcpConn, | |
| udpConn: udpConn, | |
| relayAddr: udpRelayAddr, | |
| boundAddr: relayAddr, | |
| isUDP: true, | |
| }, nil | |
| } | |
| func (d *Dialer) greet(ctx context.Context, c net.Conn, b []byte) error { | |
| b = append(b[:0], socks5Version) | |
| // authentication | |
| if d.Username != "" || d.Password != "" { | |
| b = append(b, 2, noAuthRequired, passwordAuth) | |
| } else { | |
| b = append(b, 1, noAuthRequired) | |
| } | |
| if _, err := c.Write(b); err != nil { | |
| return err | |
| } | |
| if _, err := io.ReadFull(c, b[:2]); err != nil { | |
| return err | |
| } | |
| if b[0] != socks5Version { | |
| return errors.New("invalid SOCKS version") | |
| } | |
| switch b[1] { | |
| case noAuthRequired: | |
| // no authentication | |
| case passwordAuth: | |
| if err := d.authenticate(ctx, c, b); err != nil { | |
| return err | |
| } | |
| default: | |
| return errors.New("unsupported authentication method") | |
| } | |
| return nil | |
| } | |
| func (d *Dialer) authenticate(_ context.Context, rw io.ReadWriter, b []byte) error { | |
| if len(d.Username) == 0 || len(d.Username) > 255 || len(d.Password) > 255 { | |
| return errors.New("invalid username/password") | |
| } | |
| b = b[:0] | |
| b = append(b, passwordAuthVersion) | |
| b = append(b, byte(len(d.Username))) | |
| b = append(b, d.Username...) | |
| b = append(b, byte(len(d.Password))) | |
| b = append(b, d.Password...) | |
| if _, err := rw.Write(b); err != nil { | |
| return err | |
| } | |
| if _, err := io.ReadFull(rw, b[:2]); err != nil { | |
| return err | |
| } | |
| if b[0] != passwordAuthVersion { | |
| return errors.New("invalid username/password version") | |
| } | |
| if b[1] != byte(success) { | |
| return errors.New("username/password authentication failed") | |
| } | |
| return nil | |
| } | |
| func (d *Dialer) pathAddrs(address string) (proxy, dst *socksAddr, err error) { | |
| for i, s := range []string{d.ProxyAddress, address} { | |
| host, port, err := splitHostPort(s) | |
| if err != nil { | |
| return nil, nil, err | |
| } | |
| a := &socksAddr{port: port} | |
| if ip := net.ParseIP(host); ip != nil { | |
| if ip.To4() != nil { | |
| a.addrType = ipv4 | |
| a.addr = ip.String() | |
| } else { | |
| a.addrType = ipv6 | |
| a.addr = ip.String() | |
| } | |
| } else { | |
| a.addrType = domainName | |
| a.addr = host | |
| } | |
| if i == 0 { | |
| proxy = a | |
| } else { | |
| dst = a | |
| } | |
| } | |
| return | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment