Skip to content

Instantly share code, notes, and snippets.

@hydrz
Created January 23, 2026 12:38
Show Gist options
  • Select an option

  • Save hydrz/4cb3bb1c1d5379bd2857864625d37034 to your computer and use it in GitHub Desktop.

Select an option

Save hydrz/4cb3bb1c1d5379bd2857864625d37034 to your computer and use it in GitHub Desktop.
package socks5
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"time"
)
// Authentication METHODs described in RFC 1928, section 3.
const (
noAuthRequired byte = 0x00
passwordAuth byte = 0x02
noAcceptableAuth byte = 0xff
)
// passwordAuthVersion is the auth version byte described in RFC 1929.
const passwordAuthVersion = 1
// socks5Version is the byte that represents the SOCKS version
// in requests.
const socks5Version byte = 0x05
// commandType are the bytes sent in SOCKS5 packets
// that represent the kind of connection the client needs.
type commandType byte
// The set of valid SOCKS5 commands as described in RFC 1928.
const (
connect commandType = 0x01
bind commandType = 0x02
udpAssociate commandType = 0x03
)
func (cmd commandType) String() string {
switch cmd {
case connect:
return "socks connect"
case bind:
return "socks bind"
case udpAssociate:
return "socks udp"
default:
return "socks " + strconv.Itoa(int(cmd))
}
}
// addrType are the bytes sent in SOCKS5 packets
// that represent particular address types.
type addrType byte
// The set of valid SOCKS5 address types as defined in RFC 1928.
const (
ipv4 addrType = 0x01
domainName addrType = 0x03
ipv6 addrType = 0x04
)
// replyCode are the bytes sent in SOCKS5 packets
// that represent replies from the server to a client
// request.
type replyCode byte
// The set of valid SOCKS5 reply types as per the RFC 1928.
const (
success replyCode = 0x00
generalFailure replyCode = 0x01
connectionNotAllowed replyCode = 0x02
networkUnreachable replyCode = 0x03
hostUnreachable replyCode = 0x04
connectionRefused replyCode = 0x05
ttlExpired replyCode = 0x06
commandNotSupported replyCode = 0x07
addrTypeNotSupported replyCode = 0x08
)
func (code replyCode) String() string {
switch code {
case success:
return "succeeded"
case generalFailure:
return "general SOCKS server failure"
case connectionNotAllowed:
return "connection not allowed by ruleset"
case networkUnreachable:
return "network unreachable"
case hostUnreachable:
return "host unreachable"
case connectionRefused:
return "connection refused"
case ttlExpired:
return "TTL expired"
case commandNotSupported:
return "command not supported"
case addrTypeNotSupported:
return "address type not supported"
default:
return "unknown code: " + strconv.Itoa(int(code))
}
}
type socksAddr struct {
addrType addrType
addr string
port uint16
}
var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
func parseSocksAddr(r io.Reader) (addr socksAddr, err error) {
var addrTypeData [1]byte
_, err = io.ReadFull(r, addrTypeData[:])
if err != nil {
return socksAddr{}, fmt.Errorf("could not read address type")
}
dstAddrType := addrType(addrTypeData[0])
var destination string
switch dstAddrType {
case ipv4:
var ip [4]byte
_, err = io.ReadFull(r, ip[:])
if err != nil {
return socksAddr{}, fmt.Errorf("could not read IPv4 address")
}
destination = net.IP(ip[:]).String()
case domainName:
var dstSizeByte [1]byte
_, err = io.ReadFull(r, dstSizeByte[:])
if err != nil {
return socksAddr{}, fmt.Errorf("could not read domain name size")
}
dstSize := int(dstSizeByte[0])
domainName := make([]byte, dstSize)
_, err = io.ReadFull(r, domainName)
if err != nil {
return socksAddr{}, fmt.Errorf("could not read domain name")
}
destination = string(domainName)
case ipv6:
var ip [16]byte
_, err = io.ReadFull(r, ip[:])
if err != nil {
return socksAddr{}, fmt.Errorf("could not read IPv6 address")
}
destination = net.IP(ip[:]).String()
default:
return socksAddr{}, fmt.Errorf("unsupported address type")
}
var portBytes [2]byte
_, err = io.ReadFull(r, portBytes[:])
if err != nil {
return socksAddr{}, fmt.Errorf("could not read port")
}
port := binary.BigEndian.Uint16(portBytes[:])
return socksAddr{
addrType: dstAddrType,
addr: destination,
port: port,
}, nil
}
func (s socksAddr) marshal() ([]byte, error) {
var addr []byte
switch s.addrType {
case ipv4:
addr = net.ParseIP(s.addr).To4()
if addr == nil {
return nil, fmt.Errorf("invalid IPv4 address for binding")
}
case domainName:
if len(s.addr) > 255 {
return nil, fmt.Errorf("invalid domain name for binding")
}
addr = make([]byte, 0, len(s.addr)+1)
addr = append(addr, byte(len(s.addr)))
addr = append(addr, []byte(s.addr)...)
case ipv6:
addr = net.ParseIP(s.addr).To16()
if addr == nil {
return nil, fmt.Errorf("invalid IPv6 address for binding")
}
default:
return nil, fmt.Errorf("unsupported address type")
}
pkt := []byte{byte(s.addrType)}
pkt = append(pkt, addr...)
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
return pkt, nil
}
func (s socksAddr) hostPort() string {
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
}
func (a *socksAddr) Network() string { return "socks" }
func (s socksAddr) String() string {
return s.hostPort()
}
func splitHostPort(hostport string) (host string, port uint16, err error) {
host, portStr, err := net.SplitHostPort(hostport)
if err != nil {
return "", 0, err
}
portInt, err := strconv.Atoi(portStr)
if err != nil {
return "", 0, err
}
if portInt < 0 || portInt > 65535 {
return "", 0, fmt.Errorf("invalid port number %d", portInt)
}
return host, uint16(portInt), nil
}
// UDP conn default buffer size and read timeout.
const (
bufferSize = 8 * 1024
readTimeout = 5 * time.Second
)
type Handler struct {
Logf func(format string, args ...any)
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
Authenticator func(username, password string) bool
request *request
udpClientAddr net.Addr
udpTargetConns map[socksAddr]net.Conn
}
func (h *Handler) dial(ctx context.Context, network, addr string) (net.Conn, error) {
dial := h.Dialer
if dial == nil {
dialer := &net.Dialer{}
dial = dialer.DialContext
}
return dial(ctx, network, addr)
}
func (h *Handler) logf(format string, args ...any) {
logf := h.Logf
if logf == nil {
logf = log.Printf
}
logf(format, args...)
}
func (h *Handler) HandleConn(ctx context.Context, conn net.Conn) error {
needAuth := h.Authenticator != nil
authMethod := noAuthRequired
if needAuth {
authMethod = passwordAuth
}
err := parseClientGreeting(conn, authMethod)
if err != nil {
conn.Write([]byte{socks5Version, noAcceptableAuth})
return err
}
conn.Write([]byte{socks5Version, authMethod})
if !needAuth {
return h.handleRequest(ctx, conn)
}
user, pwd, err := parseClientAuth(conn)
if err != nil || !h.Authenticator(user, pwd) {
conn.Write([]byte{1, 1}) // auth error
return err
}
conn.Write([]byte{1, 0}) // auth success
return h.handleRequest(ctx, conn)
}
func (h *Handler) handleRequest(ctx context.Context, conn net.Conn) error {
req, err := parseClientRequest(conn)
if err != nil {
res := errorResponse(generalFailure)
buf, _ := res.marshal()
conn.Write(buf)
return err
}
h.request = req
switch req.command {
case connect:
return h.handleTCP(ctx, conn)
case udpAssociate:
return h.handleUDP(ctx, conn)
default:
res := errorResponse(commandNotSupported)
buf, _ := res.marshal()
conn.Write(buf)
return fmt.Errorf("unsupported command %v", req.command)
}
}
func (h *Handler) handleTCP(ctx context.Context, conn net.Conn) error {
dialCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
srv, err := h.dial(
dialCtx,
"tcp",
h.request.destination.hostPort(),
)
if err != nil {
res := errorResponse(generalFailure)
buf, _ := res.marshal()
conn.Write(buf)
return err
}
defer srv.Close()
localAddr := srv.LocalAddr().String()
serverAddr, serverPort, err := splitHostPort(localAddr)
if err != nil {
return err
}
res := &response{
reply: success,
bindAddr: socksAddr{
addrType: getAddrType(serverAddr),
addr: serverAddr,
port: serverPort,
},
}
buf, err := res.marshal()
if err != nil {
res = errorResponse(generalFailure)
buf, _ = res.marshal()
}
conn.Write(buf)
errc := make(chan error, 2)
go func() {
_, err := io.Copy(conn, srv)
if err != nil {
err = fmt.Errorf("from backend to client: %w", err)
}
errc <- err
}()
go func() {
_, err := io.Copy(srv, conn)
if err != nil {
err = fmt.Errorf("from client to backend: %w", err)
}
errc <- err
}()
return <-errc
}
func (h *Handler) handleUDP(ctx context.Context, conn net.Conn) error {
// The DST.ADDR and DST.PORT fields contain the address and port that
// the client expects to use to send UDP datagrams on for the
// association. The server MAY use this information to limit access
// to the association.
// @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928.
//
// We do NOT limit the access from the client currently in this implementation.
_ = h.request.destination
addr := conn.LocalAddr()
host, _, err := net.SplitHostPort(addr.String())
if err != nil {
return err
}
clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0"))
if err != nil {
res := errorResponse(generalFailure)
buf, _ := res.marshal()
conn.Write(buf)
return err
}
defer clientUDPConn.Close()
bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
if err != nil {
return err
}
res := &response{
reply: success,
bindAddr: socksAddr{
addrType: getAddrType(bindAddr),
addr: bindAddr,
port: bindPort,
},
}
buf, err := res.marshal()
if err != nil {
res = errorResponse(generalFailure)
buf, _ = res.marshal()
}
conn.Write(buf)
return h.transferUDP(conn, clientUDPConn)
}
func (h *Handler) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// client -> target
go func() {
defer cancel()
h.udpTargetConns = make(map[socksAddr]net.Conn)
// close all target udp connections when the client connection is closed
defer func() {
for _, conn := range h.udpTargetConns {
_ = conn.Close()
}
}()
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := h.handleUDPRequest(ctx, clientConn, buf)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) {
return
}
h.logf("udp transfer: handle udp request fail: %v", err)
}
}
}
}()
// A UDP association terminates when the TCP connection that the UDP
// ASSOCIATE request arrived on terminates. RFC1928
_, err := io.Copy(io.Discard, associatedTCP)
if err != nil {
err = fmt.Errorf("udp associated tcp conn: %w", err)
}
return err
}
func (h *Handler) getOrDialTargetConn(
ctx context.Context,
clientConn net.PacketConn,
targetAddr socksAddr,
) (net.Conn, error) {
conn, exist := h.udpTargetConns[targetAddr]
if exist {
return conn, nil
}
conn, err := h.dial(ctx, "udp", targetAddr.hostPort())
if err != nil {
return nil, err
}
h.udpTargetConns[targetAddr] = conn
// target -> client
go func() {
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := h.handleUDPResponse(clientConn, targetAddr, conn, buf)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return
}
h.logf("udp transfer: handle udp response fail: %v", err)
}
}
}
}()
return conn, nil
}
func (h *Handler) handleUDPRequest(
ctx context.Context,
clientConn net.PacketConn,
buf []byte,
) error {
// add a deadline for the read to avoid blocking forever
_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
n, addr, err := clientConn.ReadFrom(buf)
if err != nil {
return fmt.Errorf("read from client: %w", err)
}
h.udpClientAddr = addr
req, data, err := parseUDPRequest(buf[:n])
if err != nil {
return fmt.Errorf("parse udp request: %w", err)
}
targetConn, err := h.getOrDialTargetConn(ctx, clientConn, req.addr)
if err != nil {
return fmt.Errorf("dial target %s fail: %w", req.addr, err)
}
nn, err := targetConn.Write(data)
if err != nil {
return fmt.Errorf("write to target %s fail: %w", req.addr, err)
}
if nn != len(data) {
return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite)
}
return nil
}
func (h *Handler) handleUDPResponse(
clientConn net.PacketConn,
targetAddr socksAddr,
targetConn net.Conn,
buf []byte,
) error {
// add a deadline for the read to avoid blocking forever
_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
n, err := targetConn.Read(buf)
if err != nil {
return fmt.Errorf("read from target: %w", err)
}
hdr := udpRequest{addr: targetAddr}
pkt, err := hdr.marshal()
if err != nil {
return fmt.Errorf("marshal udp request: %w", err)
}
data := append(pkt, buf[:n]...)
// use addr from client to send back
nn, err := clientConn.WriteTo(data, h.udpClientAddr)
if err != nil {
return fmt.Errorf("write to client: %w", err)
}
if nn != len(data) {
return fmt.Errorf("write to client: %w", io.ErrShortWrite)
}
return nil
}
func isTimeout(err error) bool {
var terr interface{ Timeout() bool }
return errors.As(err, &terr) && terr.Timeout()
}
// parseClientGreeting parses a request initiation packet.
func parseClientGreeting(r io.Reader, authMethod byte) error {
var hdr [2]byte
_, err := io.ReadFull(r, hdr[:])
if err != nil {
return fmt.Errorf("could not read packet header")
}
if hdr[0] != socks5Version {
return fmt.Errorf("incompatible SOCKS version")
}
count := int(hdr[1])
methods := make([]byte, count)
_, err = io.ReadFull(r, methods)
if err != nil {
return fmt.Errorf("could not read methods")
}
for _, m := range methods {
if m == authMethod {
return nil
}
}
return fmt.Errorf("no acceptable auth methods")
}
func parseClientAuth(r io.Reader) (usr, pwd string, err error) {
var hdr [2]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return "", "", fmt.Errorf("could not read auth packet header")
}
if hdr[0] != passwordAuthVersion {
return "", "", fmt.Errorf("bad SOCKS auth version")
}
usrLen := int(hdr[1])
usrBytes := make([]byte, usrLen)
if _, err := io.ReadFull(r, usrBytes); err != nil {
return "", "", fmt.Errorf("could not read auth packet username")
}
var hdrPwd [1]byte
if _, err := io.ReadFull(r, hdrPwd[:]); err != nil {
return "", "", fmt.Errorf("could not read auth packet password length")
}
pwdLen := int(hdrPwd[0])
pwdBytes := make([]byte, pwdLen)
if _, err := io.ReadFull(r, pwdBytes); err != nil {
return "", "", fmt.Errorf("could not read auth packet password")
}
return string(usrBytes), string(pwdBytes), nil
}
func getAddrType(addr string) addrType {
if ip := net.ParseIP(addr); ip != nil {
if ip.To4() != nil {
return ipv4
}
return ipv6
}
return domainName
}
// request represents data contained within a SOCKS5
// connection request packet.
type request struct {
command commandType
destination socksAddr
}
// parseClientRequest converts raw packet bytes into a
// SOCKS5Request struct.
func parseClientRequest(r io.Reader) (*request, error) {
var hdr [3]byte
_, err := io.ReadFull(r, hdr[:])
if err != nil {
return nil, fmt.Errorf("could not read packet header")
}
cmd := hdr[1]
destination, err := parseSocksAddr(r)
return &request{
command: commandType(cmd),
destination: destination,
}, err
}
// response contains the contents of
// a response packet sent from the proxy
// to the client.
type response struct {
reply replyCode
bindAddr socksAddr
}
func errorResponse(code replyCode) *response {
return &response{reply: code, bindAddr: zeroSocksAddr}
}
// marshal converts a SOCKS5Response struct into
// a packet. If res.reply == Success, it may throw an error on
// receiving an invalid bind address. Otherwise, it will not throw.
func (res *response) marshal() ([]byte, error) {
pkt := make([]byte, 3)
pkt[0] = socks5Version
pkt[1] = byte(res.reply)
pkt[2] = 0 // null reserved byte
addrPkt, err := res.bindAddr.marshal()
if err != nil {
return nil, err
}
return append(pkt, addrPkt...), nil
}
type udpRequest struct {
frag byte
addr socksAddr
}
// +----+------+------+----------+----------+----------+
// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
// +----+------+------+----------+----------+----------+
// | 2 | 1 | 1 | Variable | 2 | Variable |
// +----+------+------+----------+----------+----------+
func parseUDPRequest(data []byte) (*udpRequest, []byte, error) {
if len(data) < 4 {
return nil, nil, fmt.Errorf("invalid packet length")
}
// reserved bytes
if !(data[0] == 0 && data[1] == 0) {
return nil, nil, fmt.Errorf("invalid udp request header")
}
frag := data[2]
reader := bytes.NewReader(data[3:])
addr, err := parseSocksAddr(reader)
bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length
body := data[len(data)-bodyLen:]
return &udpRequest{
frag: frag,
addr: addr,
}, body, err
}
func (u *udpRequest) marshal() ([]byte, error) {
pkt := make([]byte, 3)
pkt[0] = 0
pkt[1] = 0
pkt[2] = u.frag
addrPkt, err := u.addr.marshal()
if err != nil {
return nil, err
}
return append(pkt, addrPkt...), nil
}
var (
noDeadline = time.Time{}
aLongTimeAgo = time.Unix(1, 0)
)
// A Dialer holds SOCKS-specific options.
type Dialer struct {
ProxyNetwork string // network between a proxy server and a client
ProxyAddress string // proxy server address
// ProxyDial specifies the optional dial function for
// establishing the transport connection.
ProxyDial func(context.Context, string, string) (net.Conn, error)
Username string
Password string
}
// Conn wraps a net.Conn and implements both net.Conn and net.PacketConn interfaces.
// For TCP connections, use Read/Write methods.
// For UDP connections, use ReadFrom/WriteTo methods.
type Conn struct {
net.Conn // TCP connection (for CONNECT) or TCP control connection (for UDP ASSOCIATE)
udpConn *net.UDPConn // UDP connection (only for UDP ASSOCIATE)
relayAddr *net.UDPAddr // UDP relay server address (only for UDP ASSOCIATE)
boundAddr socksAddr // Bound address from server response
isUDP bool // Whether this is a UDP connection
}
// ReadFrom implements net.PacketConn interface for UDP connections
func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
if !c.isUDP {
return 0, nil, errors.New("ReadFrom only available for UDP connections")
}
if c.udpConn == nil {
return 0, nil, errors.New("UDP connection not initialized")
}
buf := make([]byte, 65535)
n, _, err = c.udpConn.ReadFromUDP(buf)
if err != nil {
return 0, nil, err
}
// Parse UDP request header (RFC 1928 section 7)
if n < 10 {
return 0, nil, fmt.Errorf("packet too short")
}
if buf[0] != 0x00 || buf[1] != 0x00 {
return 0, nil, fmt.Errorf("invalid reserved fields")
}
if buf[2] != 0x00 {
return 0, nil, fmt.Errorf("fragmentation not supported")
}
// Parse address
reader := &byteReader{buf: buf[3:n], pos: 0}
socksAddr, err := parseSocksAddr(reader)
if err != nil {
return 0, nil, err
}
// Calculate header length
headerLen := 3 + 1 // RSV + FRAG + ATYP
switch socksAddr.addrType {
case ipv4:
headerLen += 4
case ipv6:
headerLen += 16
case domainName:
headerLen += 1 + len(socksAddr.addr)
}
headerLen += 2 // PORT
// Copy data after header
dataLen := n - headerLen
if dataLen > len(b) {
return 0, nil, fmt.Errorf("buffer too small")
}
copy(b, buf[headerLen:n])
return dataLen, &socksAddr, nil
}
// WriteTo implements net.PacketConn interface for UDP connections
func (c *Conn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
if !c.isUDP {
return 0, errors.New("WriteTo only available for UDP connections")
}
if c.udpConn == nil {
return 0, errors.New("UDP connection not initialized")
}
var sa socksAddr
// Parse target address
switch a := addr.(type) {
case *socksAddr:
sa = *a
case *net.UDPAddr:
host, port, err := splitHostPort(a.String())
if err != nil {
return 0, err
}
if ip := net.ParseIP(host); ip != nil {
if ip.To4() != nil {
sa = socksAddr{addrType: ipv4, addr: ip.String(), port: port}
} else {
sa = socksAddr{addrType: ipv6, addr: ip.String(), port: port}
}
} else {
sa = socksAddr{addrType: domainName, addr: host, port: port}
}
default:
host, port, err := splitHostPort(addr.String())
if err != nil {
return 0, err
}
if ip := net.ParseIP(host); ip != nil {
if ip.To4() != nil {
sa = socksAddr{addrType: ipv4, addr: ip.String(), port: port}
} else {
sa = socksAddr{addrType: ipv6, addr: ip.String(), port: port}
}
} else {
sa = socksAddr{addrType: domainName, addr: host, port: port}
}
}
// Build UDP request header (RFC 1928 section 7)
header := []byte{0x00, 0x00, 0x00} // RSV + FRAG
addrBytes, err := sa.marshal()
if err != nil {
return 0, err
}
header = append(header, addrBytes...)
// Combine header and data
packet := append(header, b...)
// Send to relay server
_, err = c.udpConn.WriteToUDP(packet, c.relayAddr)
if err != nil {
return 0, err
}
return len(b), nil
}
// LocalAddr returns the local network address for UDP, or the TCP local address
func (c *Conn) LocalAddr() net.Addr {
if c.isUDP && c.udpConn != nil {
return c.udpConn.LocalAddr()
}
return c.Conn.LocalAddr()
}
// RemoteAddr returns the remote network address
func (c *Conn) RemoteAddr() net.Addr {
return &c.boundAddr
}
// Close closes the connection
func (c *Conn) Close() error {
var err1, err2 error
if c.udpConn != nil {
err1 = c.udpConn.Close()
}
if c.Conn != nil {
err2 = c.Conn.Close()
}
if err1 != nil {
return err1
}
return err2
}
// IsUDP returns true if this is a UDP connection
func (c *Conn) IsUDP() bool {
return c.isUDP
}
// byteReader implements io.Reader for byte slices
type byteReader struct {
buf []byte
pos int
}
func (r *byteReader) Read(p []byte) (n int, err error) {
if r.pos >= len(r.buf) {
return 0, io.EOF
}
n = copy(p, r.buf[r.pos:])
r.pos += n
return n, nil
}
// DialContext connects to the provided address on the provided network.
//
// For TCP networks ("tcp", "tcp4", "tcp6"), it returns a *Conn that can be used
// with Read/Write methods like a normal net.Conn.
//
// For UDP networks ("udp", "udp4", "udp6"), it returns a *Conn that should be used
// with ReadFrom/WriteTo methods like a net.PacketConn.
//
// The returned error value may be a net.OpError. When the Op field of
// net.OpError contains "socks", the Source field contains a proxy
// server address and the Addr field contains a command target address.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (_ *Conn, ctxErr error) {
var cmd commandType
switch network {
case "tcp", "tcp4", "tcp6":
cmd = connect
case "udp", "udp4", "udp6":
cmd = udpAssociate
default:
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: "dial", Net: network, Source: proxy, Addr: dst, Err: errors.New("network not implemented")}
}
if ctx == nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
}
dialer := d.ProxyDial
if dialer == nil {
dialer = (&net.Dialer{}).DialContext
}
c, err := dialer(ctx, d.ProxyNetwork, d.ProxyAddress)
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
c.SetDeadline(deadline)
defer c.SetDeadline(noDeadline)
}
if ctx != context.Background() {
errCh := make(chan error, 1)
done := make(chan struct{})
defer func() {
close(done)
if ctxErr == nil {
ctxErr = <-errCh
}
}()
go func() {
select {
case <-ctx.Done():
c.SetDeadline(aLongTimeAgo)
errCh <- ctx.Err()
case <-done:
errCh <- nil
}
}()
}
// Perform SOCKS5 handshake
b := make([]byte, 0, 512)
if err := d.greet(ctx, c, b); err != nil {
c.Close()
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
// Send request
_, dstAddr, err := d.pathAddrs(address)
if err != nil {
c.Close()
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if cmd == udpAssociate {
return d.udpAssociate(ctx, c, network, dstAddr)
}
// TCP CONNECT
return d.connect(ctx, c, network, dstAddr)
}
// connect handles TCP CONNECT command
func (d *Dialer) connect(_ context.Context, c net.Conn, network string, dstAddr *socksAddr) (*Conn, error) {
b := make([]byte, 0, 512)
b = append(b, socks5Version, byte(connect), 0x00) // VER + CMD + RSV
addrBytes, err := dstAddr.marshal()
if err != nil {
c.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
b = append(b, addrBytes...)
if _, err := c.Write(b); err != nil {
c.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
// Read response
if _, err := io.ReadFull(c, b[:4]); err != nil {
c.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
if b[0] != socks5Version {
c.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: errors.New("invalid SOCKS version")}
}
if reply := replyCode(b[1]); reply != success {
c.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: errors.New(reply.String())}
}
// Parse bound address
boundAddr, err := parseSocksAddr(c)
if err != nil {
c.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: connect.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
return &Conn{
Conn: c,
boundAddr: boundAddr,
isUDP: false,
}, nil
}
// udpAssociate handles UDP ASSOCIATE command
func (d *Dialer) udpAssociate(_ context.Context, tcpConn net.Conn, network string, dstAddr *socksAddr) (*Conn, error) {
b := make([]byte, 0, 512)
b = append(b, socks5Version, byte(udpAssociate), 0x00) // VER + CMD + RSV
// Send 0.0.0.0:0 as client UDP address
clientAddr := zeroSocksAddr
addrBytes, err := clientAddr.marshal()
if err != nil {
tcpConn.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
b = append(b, addrBytes...)
if _, err := tcpConn.Write(b); err != nil {
tcpConn.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
// Read response
if _, err := io.ReadFull(tcpConn, b[:4]); err != nil {
tcpConn.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
if b[0] != socks5Version {
tcpConn.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: errors.New("invalid SOCKS version")}
}
if reply := replyCode(b[1]); reply != success {
tcpConn.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: errors.New(reply.String())}
}
// Parse UDP relay address
relayAddr, err := parseSocksAddr(tcpConn)
if err != nil {
tcpConn.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
// Create UDP connection using ListenUDP
udpAddr, err := net.ResolveUDPAddr(network, "")
if err != nil {
tcpConn.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
udpConn, err := net.ListenUDP(network, udpAddr)
if err != nil {
tcpConn.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
// Parse relay address to *net.UDPAddr
udpRelayAddr, err := net.ResolveUDPAddr("udp", relayAddr.hostPort())
if err != nil {
udpConn.Close()
tcpConn.Close()
proxy, _, _ := d.pathAddrs("")
return nil, &net.OpError{Op: udpAssociate.String(), Net: network, Source: proxy, Addr: dstAddr, Err: err}
}
return &Conn{
Conn: tcpConn,
udpConn: udpConn,
relayAddr: udpRelayAddr,
boundAddr: relayAddr,
isUDP: true,
}, nil
}
func (d *Dialer) greet(ctx context.Context, c net.Conn, b []byte) error {
b = append(b[:0], socks5Version)
// authentication
if d.Username != "" || d.Password != "" {
b = append(b, 2, noAuthRequired, passwordAuth)
} else {
b = append(b, 1, noAuthRequired)
}
if _, err := c.Write(b); err != nil {
return err
}
if _, err := io.ReadFull(c, b[:2]); err != nil {
return err
}
if b[0] != socks5Version {
return errors.New("invalid SOCKS version")
}
switch b[1] {
case noAuthRequired:
// no authentication
case passwordAuth:
if err := d.authenticate(ctx, c, b); err != nil {
return err
}
default:
return errors.New("unsupported authentication method")
}
return nil
}
func (d *Dialer) authenticate(_ context.Context, rw io.ReadWriter, b []byte) error {
if len(d.Username) == 0 || len(d.Username) > 255 || len(d.Password) > 255 {
return errors.New("invalid username/password")
}
b = b[:0]
b = append(b, passwordAuthVersion)
b = append(b, byte(len(d.Username)))
b = append(b, d.Username...)
b = append(b, byte(len(d.Password)))
b = append(b, d.Password...)
if _, err := rw.Write(b); err != nil {
return err
}
if _, err := io.ReadFull(rw, b[:2]); err != nil {
return err
}
if b[0] != passwordAuthVersion {
return errors.New("invalid username/password version")
}
if b[1] != byte(success) {
return errors.New("username/password authentication failed")
}
return nil
}
func (d *Dialer) pathAddrs(address string) (proxy, dst *socksAddr, err error) {
for i, s := range []string{d.ProxyAddress, address} {
host, port, err := splitHostPort(s)
if err != nil {
return nil, nil, err
}
a := &socksAddr{port: port}
if ip := net.ParseIP(host); ip != nil {
if ip.To4() != nil {
a.addrType = ipv4
a.addr = ip.String()
} else {
a.addrType = ipv6
a.addr = ip.String()
}
} else {
a.addrType = domainName
a.addr = host
}
if i == 0 {
proxy = a
} else {
dst = a
}
}
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment