diff --git a/cmd/gost/gost.yml b/cmd/gost/gost.yml index 079a91d..ac6cfcd 100644 --- a/cmd/gost/gost.yml +++ b/cmd/gost/gost.yml @@ -45,7 +45,7 @@ services: # bypass: bypass01 - name: ssu url: "ss://chacha20:gost@:8000" - addr: ":8338" + addr: ":8388" handler: type: ssu metadata: @@ -54,7 +54,8 @@ services: readTimeout: 5s retry: 3 listener: - type: udp + type: tcp + # chain: chain-ssu - name: socks5+tcp url: "socks5://gost:gost@:1080" addr: ":1080" @@ -71,7 +72,7 @@ services: type: tcp metadata: keepAlive: 15s - # chain: chain-socks5 + chain: chain-socks5 # bypass: bypass01 - name: socks5+tcp url: "socks5://gost:gost@:1080" @@ -178,6 +179,20 @@ chains: dialer: type: tcp metadata: {} +- name: chain-ssu + hops: + - name: hop01 + nodes: + - name: node01 + addr: ":8339" + url: "http://gost:gost@:8081" + # bypass: bypass01 + connector: + type: ssu + metadata: {} + dialer: + type: udp + metadata: {} bypasses: - name: bypass01 diff --git a/cmd/gost/register.go b/cmd/gost/register.go index d9074c6..095e445 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -6,9 +6,11 @@ import ( _ "github.com/go-gost/gost/pkg/connector/socks/v4" _ "github.com/go-gost/gost/pkg/connector/socks/v5" _ "github.com/go-gost/gost/pkg/connector/ss" + _ "github.com/go-gost/gost/pkg/connector/ssu" // Register dialers _ "github.com/go-gost/gost/pkg/dialer/tcp" + _ "github.com/go-gost/gost/pkg/dialer/udp" // Register handlers _ "github.com/go-gost/gost/pkg/handler/http" diff --git a/go.mod b/go.mod index ba7e0c6..4ebcbbb 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/coreos/go-iptables v0.5.0 // indirect github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da github.com/go-gost/gosocks4 v0.0.1 - github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea + github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/gobwas/glob v0.2.3 github.com/golang/snappy v0.0.3 github.com/google/gopacket v1.1.19 // indirect diff --git a/go.sum b/go.sum index 03953c5..a92866f 100644 --- a/go.sum +++ b/go.sum @@ -125,6 +125,8 @@ github.com/go-gost/gosocks5 v0.3.1-0.20211108032632-bbfd2de9a32d h1:mjoFToMUWNN0 github.com/go-gost/gosocks5 v0.3.1-0.20211108032632-bbfd2de9a32d/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea h1:mrm6bMpdxBvInvBuDbUaAQWV60r/PaByLIG9fQJEEIc= github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= +github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= +github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= diff --git a/pkg/connector/http/connector.go b/pkg/connector/http/connector.go index ed36280..bfcb54b 100644 --- a/pkg/connector/http/connector.go +++ b/pkg/connector/http/connector.go @@ -43,6 +43,21 @@ func (c *httpConnector) Init(md md.Metadata) (err error) { } func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "local": conn.LocalAddr().String(), + "remote": conn.RemoteAddr().String(), + "network": network, + "address": address, + }) + + switch network { + case "tcp", "tcp4", "tcp6": + default: + err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network) + c.logger.Error(err) + return nil, err + } + req := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Host: address}, @@ -56,11 +71,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add } req.Header.Set("Proxy-Connection", "keep-alive") - c.logger = c.logger.WithFields(map[string]interface{}{ - "local": conn.LocalAddr().String(), - "remote": conn.RemoteAddr().String(), - "target": address, - }) c.logger.Infof("connect: ", address) if user := c.md.User; user != nil { diff --git a/pkg/connector/socks/v4/connector.go b/pkg/connector/socks/v4/connector.go index cd2d578..fa02303 100644 --- a/pkg/connector/socks/v4/connector.go +++ b/pkg/connector/socks/v4/connector.go @@ -42,10 +42,20 @@ func (c *socks4Connector) Init(md md.Metadata) (err error) { func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { c.logger = c.logger.WithFields(map[string]interface{}{ - "remote": conn.RemoteAddr().String(), - "local": conn.LocalAddr().String(), - "target": address, + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, }) + + switch network { + case "tcp", "tcp4", "tcp6": + default: + err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network) + c.logger.Error(err) + return nil, err + } + c.logger.Info("connect: ", address) var addr *gosocks4.Addr @@ -87,19 +97,14 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a c.logger.Error(err) return nil, err } - if c.logger.IsLevelEnabled(logger.DebugLevel) { - c.logger.Debug(req) - } + c.logger.Debug(req) reply, err := gosocks4.ReadReply(conn) if err != nil { c.logger.Error(err) return nil, err } - - if c.logger.IsLevelEnabled(logger.DebugLevel) { - c.logger.Debug(reply) - } + c.logger.Debug(reply) if reply.Code != gosocks4.Granted { return nil, fmt.Errorf("error: %d", reply.Code) diff --git a/pkg/connector/socks/v5/connector.go b/pkg/connector/socks/v5/connector.go index 492df20..b4d946b 100644 --- a/pkg/connector/socks/v5/connector.go +++ b/pkg/connector/socks/v5/connector.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "net" "net/url" "strings" @@ -79,6 +80,7 @@ func (c *socks5Connector) Handshake(ctx context.Context, conn net.Conn) (net.Con cc := gosocks5.ClientConn(conn, c.selector) if err := cc.Handleshake(); err != nil { + c.logger.Error(err) return nil, err } @@ -87,12 +89,22 @@ func (c *socks5Connector) Handshake(ctx context.Context, conn net.Conn) (net.Con func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { c.logger = c.logger.WithFields(map[string]interface{}{ - "target": address, + "network": network, + "address": address, }) + + switch network { + case "tcp", "tcp4", "tcp6": + default: + err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network) + c.logger.Error(err) + return nil, err + } + c.logger.Info("connect: ", address) - addr, err := gosocks5.NewAddr(address) - if err != nil { + addr := gosocks5.Addr{} + if err := addr.ParseFrom(address); err != nil { c.logger.Error(err) return nil, err } @@ -102,25 +114,19 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a defer conn.SetDeadline(time.Time{}) } - req := gosocks5.NewRequest(gosocks5.CmdConnect, addr) + req := gosocks5.NewRequest(gosocks5.CmdConnect, &addr) if err := req.Write(conn); err != nil { c.logger.Error(err) return nil, err } - - if c.logger.IsLevelEnabled(logger.DebugLevel) { - c.logger.Debug(req) - } + c.logger.Debug(req) reply, err := gosocks5.ReadReply(conn) if err != nil { c.logger.Error(err) return nil, err } - - if c.logger.IsLevelEnabled(logger.DebugLevel) { - c.logger.Debug(reply) - } + c.logger.Debug(reply) if reply.Rep != gosocks5.Succeeded { return nil, errors.New("service unavailable") diff --git a/pkg/connector/socks/v5/selector.go b/pkg/connector/socks/v5/selector.go index 2b9e29e..f3a300f 100644 --- a/pkg/connector/socks/v5/selector.go +++ b/pkg/connector/socks/v5/selector.go @@ -18,9 +18,7 @@ type clientSelector struct { } func (s *clientSelector) Methods() []uint8 { - if s.logger.IsLevelEnabled(logger.DebugLevel) { - s.logger.Debug("methods: ", s.methods) - } + s.logger.Debug("methods: ", s.methods) return s.methods } @@ -33,9 +31,7 @@ func (s *clientSelector) Select(methods ...uint8) (method uint8) { } func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { - if s.logger.IsLevelEnabled(logger.DebugLevel) { - s.logger.Debug("method selected: ", method) - } + s.logger.Debug("method selected: ", method) switch method { case socks.MethodTLS: @@ -57,18 +53,14 @@ func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro s.logger.Error(err) return nil, err } - if s.logger.IsLevelEnabled(logger.DebugLevel) { - s.logger.Debug(req) - } + s.logger.Debug(req) resp, err := gosocks5.ReadUserPassResponse(conn) if err != nil { s.logger.Error(err) return nil, err } - if s.logger.IsLevelEnabled(logger.DebugLevel) { - s.logger.Debug(resp) - } + s.logger.Debug(resp) if resp.Status != gosocks5.Succeeded { return nil, gosocks5.ErrAuthFailure diff --git a/pkg/connector/ss/connector.go b/pkg/connector/ss/connector.go index e5464d9..92dd8f0 100644 --- a/pkg/connector/ss/connector.go +++ b/pkg/connector/ss/connector.go @@ -2,6 +2,7 @@ package ss import ( "context" + "fmt" "net" "time" @@ -40,21 +41,30 @@ func (c *ssConnector) Init(md md.Metadata) (err error) { func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { c.logger = c.logger.WithFields(map[string]interface{}{ - "remote": conn.RemoteAddr().String(), - "local": conn.LocalAddr().String(), - "target": address, + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, }) + + switch network { + case "tcp", "tcp4", "tcp6": + default: + err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network) + c.logger.Error(err) + return nil, err + } c.logger.Infof("connect: ", address) - socksAddr, err := gosocks5.NewAddr(address) - if err != nil { - c.logger.Error("parse addr: ", err) + addr := gosocks5.Addr{} + if err := addr.ParseFrom(address); err != nil { + c.logger.Error(err) return nil, err } rawaddr := bufpool.Get(512) defer bufpool.Put(rawaddr) - n, err := socksAddr.Encode(rawaddr) + n, err := addr.Encode(rawaddr) if err != nil { c.logger.Error("encoding addr: ", err) return nil, err diff --git a/pkg/connector/ssu/connector.go b/pkg/connector/ssu/connector.go new file mode 100644 index 0000000..69dcf3e --- /dev/null +++ b/pkg/connector/ssu/connector.go @@ -0,0 +1,105 @@ +package ssu + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/internal/utils/socks" + "github.com/go-gost/gost/pkg/internal/utils/ss" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegiserConnector("ssu", NewConnector) +} + +type ssuConnector struct { + md metadata + logger logger.Logger +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := &connector.Options{} + for _, opt := range opts { + opt(options) + } + + return &ssuConnector{ + logger: options.Logger, + } +} + +func (c *ssuConnector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) +} + +func (c *ssuConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + + switch network { + case "udp", "udp4", "udp6": + default: + err := fmt.Errorf("network %s unsupported, should be udp, udp4 or udp6", network) + c.logger.Error(err) + return nil, err + } + + c.logger.Info("connect: ", address) + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + taddr, _ := net.ResolveUDPAddr(network, address) + if taddr == nil { + taddr = &net.UDPAddr{} + } + + pc, ok := conn.(net.PacketConn) + if ok { + if c.md.cipher != nil { + pc = c.md.cipher.PacketConn(pc) + } + + return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.bufferSize), nil + } + + return socks.UDPTunClientConn(conn, taddr), nil +} + +func (c *ssuConnector) parseMetadata(md md.Metadata) (err error) { + c.md.cipher, err = ss.ShadowCipher( + md.GetString(method), + md.GetString(password), + md.GetString(key), + ) + if err != nil { + return + } + + c.md.connectTimeout = md.GetDuration(connectTimeout) + c.md.bufferSize = md.GetInt(bufferSize) + if c.md.bufferSize > 0 { + if c.md.bufferSize < 512 { + c.md.bufferSize = 512 + } + if c.md.bufferSize > 65*1024 { + c.md.bufferSize = 65 * 1024 + } + } else { + c.md.bufferSize = 4096 + } + + return +} diff --git a/pkg/connector/ssu/metadata.go b/pkg/connector/ssu/metadata.go new file mode 100644 index 0000000..037f611 --- /dev/null +++ b/pkg/connector/ssu/metadata.go @@ -0,0 +1,21 @@ +package ssu + +import ( + "time" + + "github.com/shadowsocks/go-shadowsocks2/core" +) + +const ( + method = "method" + password = "password" + key = "key" + connectTimeout = "timeout" + bufferSize = "bufferSize" +) + +type metadata struct { + cipher core.Cipher + connectTimeout time.Duration + bufferSize int +} diff --git a/pkg/dialer/tcp/dialer.go b/pkg/dialer/tcp/dialer.go index c816431..dfa27e0 100644 --- a/pkg/dialer/tcp/dialer.go +++ b/pkg/dialer/tcp/dialer.go @@ -46,12 +46,10 @@ func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp if err != nil { d.logger.Error(err) } else { - if d.logger.IsLevelEnabled(logger.DebugLevel) { - d.logger.WithFields(map[string]interface{}{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debug("dial with dial func") - } + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debug("dial with dial func") } return conn, err } @@ -61,12 +59,10 @@ func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp if err != nil { d.logger.Error(err) } else { - if d.logger.IsLevelEnabled(logger.DebugLevel) { - d.logger.WithFields(map[string]interface{}{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debug("dial direct") - } + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debug("dial direct") } return conn, err } diff --git a/pkg/dialer/udp/conn.go b/pkg/dialer/udp/conn.go new file mode 100644 index 0000000..33e962d --- /dev/null +++ b/pkg/dialer/udp/conn.go @@ -0,0 +1,17 @@ +package udp + +import "net" + +type conn struct { + *net.UDPConn +} + +func (c *conn) WriteTo(b []byte, addr net.Addr) (int, error) { + return c.UDPConn.Write(b) +} + +func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.UDPConn.Read(b) + addr = c.RemoteAddr() + return +} diff --git a/pkg/dialer/udp/dialer.go b/pkg/dialer/udp/dialer.go new file mode 100644 index 0000000..2146f6e --- /dev/null +++ b/pkg/dialer/udp/dialer.go @@ -0,0 +1,54 @@ +package udp + +import ( + "context" + "net" + + "github.com/go-gost/gost/pkg/dialer" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterDialer("udp", NewDialer) +} + +type udpDialer struct { + md metadata + logger logger.Logger +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &udpDialer{ + logger: options.Logger, + } +} + +func (d *udpDialer) Init(md md.Metadata) (err error) { + return d.parseMetadata(md) +} + +func (d *udpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { + taddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + c, err := net.DialUDP("udp", nil, taddr) + if err != nil { + return nil, err + } + return &conn{ + UDPConn: c, + }, nil +} + +func (d *udpDialer) parseMetadata(md md.Metadata) (err error) { + return +} diff --git a/pkg/dialer/udp/metadata.go b/pkg/dialer/udp/metadata.go new file mode 100644 index 0000000..cc46071 --- /dev/null +++ b/pkg/dialer/udp/metadata.go @@ -0,0 +1,15 @@ +package udp + +import "time" + +const ( + dialTimeout = "dialTimeout" +) + +const ( + defaultDialTimeout = 5 * time.Second +) + +type metadata struct { + dialTimeout time.Duration +} diff --git a/pkg/handler/socks/v4/handler.go b/pkg/handler/socks/v4/handler.go index dc92d4c..73b26d9 100644 --- a/pkg/handler/socks/v4/handler.go +++ b/pkg/handler/socks/v4/handler.go @@ -70,19 +70,15 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { h.logger.Error(err) return } - conn.SetReadDeadline(time.Time{}) + h.logger.Debug(req) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(req) - } + conn.SetReadDeadline(time.Time{}) if h.md.authenticator != nil && !h.md.authenticator.Authenticate(string(req.Userid), "") { resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) return } @@ -107,9 +103,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g if h.bypass != nil && h.bypass.Contains(addr) { resp := gosocks4.NewReply(gosocks4.Rejected, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) h.logger.Info("bypass: ", addr) return } @@ -122,9 +116,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g if err != nil { resp := gosocks4.NewReply(gosocks4.Failed, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) return } @@ -135,9 +127,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g h.logger.Error(err) return } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) diff --git a/pkg/handler/socks/v5/bind.go b/pkg/handler/socks/v5/bind.go index a4085ac..4ae5003 100644 --- a/pkg/handler/socks/v5/bind.go +++ b/pkg/handler/socks/v5/bind.go @@ -7,7 +7,6 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" ) func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) { @@ -33,9 +32,7 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *goso if err != nil { resp := gosocks5.NewReply(gosocks5.Failure, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) return } defer cc.Close() @@ -45,9 +42,7 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *goso h.logger.Error(err) resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) return } @@ -65,32 +60,25 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, addr strin if err := reply.Write(conn); err != nil { h.logger.Error(err) } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply.String()) - } + h.logger.Debug(reply) return } - socksAddr, err := gosocks5.NewAddr(ln.Addr().String()) - if err != nil { + socksAddr := gosocks5.Addr{} + if err := socksAddr.ParseFrom(ln.Addr().String()); err != nil { h.logger.Warn(err) - socksAddr = &gosocks5.Addr{ - Type: gosocks5.AddrIPv4, - } } // Issue: may not reachable when host has multi-interface socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) socksAddr.Type = 0 - reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr) if err := reply.Write(conn); err != nil { h.logger.Error(err) ln.Close() return } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply.String()) - } + h.logger.Debug(reply) h.logger = h.logger.WithFields(map[string]interface{}{ "bind": socksAddr.String(), @@ -143,14 +131,13 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis } defer rc.Close() - raddr, _ := gosocks5.NewAddr(rc.RemoteAddr().String()) - reply := gosocks5.NewReply(gosocks5.Succeeded, raddr) + raddr := gosocks5.Addr{} + raddr.ParseFrom(rc.RemoteAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, &raddr) if err := reply.Write(pc2); err != nil { h.logger.Error(err) } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply.String()) - } + h.logger.Debug(reply) h.logger.Infof("peer accepted: %s", raddr.String()) start := time.Now() diff --git a/pkg/handler/socks/v5/connect.go b/pkg/handler/socks/v5/connect.go index a335afc..8752516 100644 --- a/pkg/handler/socks/v5/connect.go +++ b/pkg/handler/socks/v5/connect.go @@ -7,7 +7,6 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" ) func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr string) { @@ -20,9 +19,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s if h.bypass != nil && h.bypass.Contains(addr) { resp := gosocks5.NewReply(gosocks5.NotAllowed, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) h.logger.Info("bypass: ", addr) return } @@ -35,9 +32,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s if err != nil { resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) return } @@ -48,9 +43,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s h.logger.Error(err) return } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index f2d62b0..3015712 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -83,12 +83,9 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { h.logger.Error(err) return } + h.logger.Debug(req) conn.SetReadDeadline(time.Time{}) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(req) - } - switch req.Cmd { case gosocks5.CmdConnect: h.handleConnect(ctx, conn, req.Addr.String()) @@ -104,9 +101,7 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { h.logger.Errorf("unknown cmd: %d", req.Cmd) resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) return } } diff --git a/pkg/handler/socks/v5/mbind.go b/pkg/handler/socks/v5/mbind.go index db4ed6a..ca576c9 100644 --- a/pkg/handler/socks/v5/mbind.go +++ b/pkg/handler/socks/v5/mbind.go @@ -8,7 +8,6 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/internal/utils/mux" - "github.com/go-gost/gost/pkg/logger" ) func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) { @@ -34,9 +33,7 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *g if err != nil { resp := gosocks5.NewReply(gosocks5.Failure, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) return } defer cc.Close() @@ -46,9 +43,7 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *g h.logger.Error(err) resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(resp) - } + h.logger.Debug(resp) return } @@ -71,32 +66,26 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, addr st if err := reply.Write(conn); err != nil { h.logger.Error(err) } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply.String()) - } + h.logger.Debug(reply) return } - socksAddr, err := gosocks5.NewAddr(ln.Addr().String()) + socksAddr := gosocks5.Addr{} + socksAddr.ParseFrom(ln.Addr().String()) if err != nil { h.logger.Warn(err) - socksAddr = &gosocks5.Addr{ - Type: gosocks5.AddrIPv4, - } } // Issue: may not reachable when host has multi-interface socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) socksAddr.Type = 0 - reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr) if err := reply.Write(conn); err != nil { h.logger.Error(err) ln.Close() return } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply.String()) - } + h.logger.Debug(reply) h.logger = h.logger.WithFields(map[string]interface{}{ "bind": socksAddr.String(), diff --git a/pkg/handler/socks/v5/selector.go b/pkg/handler/socks/v5/selector.go index 3472539..e739e34 100644 --- a/pkg/handler/socks/v5/selector.go +++ b/pkg/handler/socks/v5/selector.go @@ -23,9 +23,7 @@ func (selector *serverSelector) Methods() []uint8 { } func (s *serverSelector) Select(methods ...uint8) (method uint8) { - if s.logger.IsLevelEnabled(logger.DebugLevel) { - s.logger.Debugf("%d %d %v", gosocks5.Ver5, len(methods), methods) - } + s.logger.Debugf("%d %d %v", gosocks5.Ver5, len(methods), methods) method = gosocks5.MethodNoAuth for _, m := range methods { if m == socks.MethodTLS && !s.noTLS { @@ -48,9 +46,7 @@ func (s *serverSelector) Select(methods ...uint8) (method uint8) { } func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { - if s.logger.IsLevelEnabled(logger.DebugLevel) { - s.logger.Debugf("%d %d", gosocks5.Ver5, method) - } + s.logger.Debugf("%d %d", gosocks5.Ver5, method) switch method { case socks.MethodTLS: conn = tls.Server(conn, s.TLSConfig) @@ -65,9 +61,7 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro s.logger.Error(err) return nil, err } - if s.logger.IsLevelEnabled(logger.DebugLevel) { - s.logger.Debug(req.String()) - } + s.logger.Debug(req) if s.Authenticator != nil && !s.Authenticator.Authenticate(req.Username, req.Password) { @@ -76,9 +70,8 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro s.logger.Error(err) return nil, err } - if s.logger.IsLevelEnabled(logger.DebugLevel) { - s.logger.Info(resp.String()) - } + s.logger.Info(resp) + return nil, gosocks5.ErrAuthFailure } @@ -87,9 +80,8 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro s.logger.Error(err) return nil, err } - if s.logger.IsLevelEnabled(logger.DebugLevel) { - s.logger.Debug(resp.String()) - } + s.logger.Debug(resp) + case gosocks5.MethodNoAcceptable: return nil, gosocks5.ErrBadMethod } diff --git a/pkg/handler/socks/v5/udp.go b/pkg/handler/socks/v5/udp.go index c67cf39..ec35246 100644 --- a/pkg/handler/socks/v5/udp.go +++ b/pkg/handler/socks/v5/udp.go @@ -1,7 +1,6 @@ package v5 import ( - "bytes" "context" "errors" "io" @@ -13,7 +12,6 @@ import ( "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/internal/bufpool" "github.com/go-gost/gost/pkg/internal/utils/socks" - "github.com/go-gost/gost/pkg/logger" ) func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosocks5.Request) { @@ -26,27 +24,21 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc h.logger.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) reply.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply) - } + h.logger.Debug(reply) return } defer relay.Close() - saddr, _ := gosocks5.NewAddr(relay.LocalAddr().String()) - if saddr == nil { - saddr = &gosocks5.Addr{} - } + saddr := gosocks5.Addr{} + saddr.ParseFrom(relay.LocalAddr().String()) saddr.Type = 0 saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's - reply := gosocks5.NewReply(gosocks5.Succeeded, saddr) + reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) if err := reply.Write(conn); err != nil { h.logger.Error(err) return } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply) - } + h.logger.Debug(reply) h.logger = h.logger.WithFields(map[string]interface{}{ "bind": saddr.String(), @@ -62,7 +54,10 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc } defer peer.Close() - go h.relayUDP(relay, peer) + go h.relayUDP( + socks.NewUDPConn(relay, h.md.udpBufferSize), + peer, + ) } else { tun, err := h.getUDPTun(ctx) if err != nil { @@ -71,15 +66,18 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc } defer tun.Close() - go h.tunnelClientUDP(relay, tun) + go h.tunnelClientUDP( + socks.NewUDPConn(relay, h.md.udpBufferSize), + socks.UDPTunClientConn(tun, nil), + ) } t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), saddr) + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), &saddr) io.Copy(ioutil.Discard, conn) h.logger. WithFields(map[string]interface{}{"duration": time.Since(t)}). - Infof("%s >-< %s", conn.RemoteAddr(), saddr) + Infof("%s >-< %s", conn.RemoteAddr(), &saddr) } func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error) { @@ -108,17 +106,13 @@ func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error if err = req.Write(conn); err != nil { return } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(req) - } + h.logger.Debug(req) reply, err := gosocks5.ReadReply(conn) if err != nil { return } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply) - } + h.logger.Debug(reply) if reply.Rep != gosocks5.Succeeded { err = errors.New("UDP associate failed") @@ -128,119 +122,72 @@ func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error return } -func (h *socks5Handler) tunnelClientUDP(c net.PacketConn, tunnel net.Conn) (err error) { +func (h *socks5Handler) tunnelClientUDP(c, tun net.PacketConn) (err error) { bufSize := h.md.udpBufferSize errc := make(chan error, 2) - var clientAddr net.Addr - go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - for { - n, laddr, err := c.ReadFrom(b) + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := c.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr) + return nil + } + + if _, err := tun.WriteTo(b[:n], raddr); err != nil { + return err + } + + h.logger.Debugf("%s >>> %s data: %d", + tun.LocalAddr(), raddr, n) + + return nil + }() + if err != nil { errc <- err return } - - if clientAddr == nil { - clientAddr = laddr - } - - var addr gosocks5.Addr - header := gosocks5.UDPHeader{ - Addr: &addr, - } - hlen, err := header.ReadFrom(bytes.NewReader(b[:n])) - if err != nil { - errc <- err - return - } - - raddr, err := net.ResolveUDPAddr("udp", addr.String()) - if err != nil { - continue // drop silently - } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - continue // bypass - } - - dgram := gosocks5.UDPDatagram{ - Header: &header, - Data: b[hlen:n], - } - dgram.Header.Rsv = uint16(len(dgram.Data)) - - if _, err := dgram.WriteTo(tunnel); err != nil { - errc <- err - return - } - - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s >>> %s: %v data: %d", - clientAddr, raddr, b[:hlen], len(dgram.Data)) - } } }() go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - const dataPos = 262 - for { - addr := gosocks5.Addr{} - header := gosocks5.UDPHeader{ - Addr: &addr, - } + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := tun.ReadFrom(b) + if err != nil { + return err + } + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr) + return nil + } + + if _, err := c.WriteTo(b[:n], raddr); err != nil { + return err + } + + h.logger.Debugf("%s <<< %s data: %d", + tun.LocalAddr(), raddr, n) + + return nil + }() - data := b[dataPos:] - dgram := gosocks5.UDPDatagram{ - Header: &header, - Data: data, - } - _, err := dgram.ReadFrom(tunnel) if err != nil { errc <- err return } - // NOTE: the dgram.Data may be reallocated if the provided buffer is too short, - // we drop it for simplicity. As this occurs, you should enlarge the buffer size. - if len(dgram.Data) > len(data) { - h.logger.Warnf("buffer too short, dropped") - continue - } - - // pipe from tunnel to relay - if clientAddr == nil { - h.logger.Warnf("ignore unexpected peer from %s", addr) - continue - } - - raddr := addr.String() - if h.bypass != nil && h.bypass.Contains(raddr) { - h.logger.Warn("bypass: ", raddr) - continue // bypass - } - - addrLen := addr.Length() - addr.Encode(b[dataPos-addrLen : dataPos]) - - hlen := addrLen + 3 - if _, err := c.WriteTo(b[dataPos-hlen:dataPos+len(dgram.Data)], clientAddr); err != nil { - errc <- err - return - } - - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s <<< %s: %v data: %d", - clientAddr, addr.String(), b[dataPos-hlen:dataPos], len(dgram.Data)) - } } }() @@ -251,91 +198,69 @@ func (h *socks5Handler) relayUDP(c, peer net.PacketConn) (err error) { bufSize := h.md.udpBufferSize errc := make(chan error, 2) - var clientAddr net.Addr - go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - for { - n, laddr, err := c.ReadFrom(b) + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := c.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr) + return nil + } + + if _, err := peer.WriteTo(b[:n], raddr); err != nil { + return err + } + + h.logger.Debugf("%s >>> %s data: %d", + peer.LocalAddr(), raddr, n) + + return nil + }() + if err != nil { errc <- err return } - if clientAddr == nil { - clientAddr = laddr - } - - var addr gosocks5.Addr - header := gosocks5.UDPHeader{ - Addr: &addr, - } - hlen, err := header.ReadFrom(bytes.NewReader(b[:n])) - if err != nil { - errc <- err - return - } - - raddr, err := net.ResolveUDPAddr("udp", addr.String()) - if err != nil { - continue // drop silently - } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - continue // bypass - } - - data := b[hlen:n] - if _, err := peer.WriteTo(data, raddr); err != nil { - errc <- err - return - } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s >>> %s: %v data: %d", - clientAddr, raddr, b[:hlen], len(data)) - } } }() go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - const dataPos = 262 - for { - n, raddr, err := peer.ReadFrom(b[dataPos:]) + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := peer.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr) + return nil + } + + if _, err := c.WriteTo(b[:n], raddr); err != nil { + return err + } + + h.logger.Debugf("%s <<< %s data: %d", + peer.LocalAddr(), raddr, n) + + return nil + }() + if err != nil { errc <- err return } - if clientAddr == nil { - continue - } - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - continue // bypass - } - - socksAddr, _ := gosocks5.NewAddr(raddr.String()) - if socksAddr == nil { - socksAddr = &gosocks5.Addr{} - } - addrLen := socksAddr.Length() - socksAddr.Encode(b[dataPos-addrLen : dataPos]) - - hlen := addrLen + 3 - if _, err := c.WriteTo(b[dataPos-hlen:dataPos+n], clientAddr); err != nil { - errc <- err - return - } - - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s <<< %s: %v data: %d", - clientAddr, raddr, b[dataPos-hlen:dataPos], n) - } } }() diff --git a/pkg/handler/socks/v5/udp_tun.go b/pkg/handler/socks/v5/udp_tun.go index 449ce96..dbd1891 100644 --- a/pkg/handler/socks/v5/udp_tun.go +++ b/pkg/handler/socks/v5/udp_tun.go @@ -8,7 +8,7 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/internal/bufpool" - "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/internal/utils/socks" ) func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *gosocks5.Request) { @@ -35,9 +35,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go h.logger.Error(err) return } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply) - } + h.logger.Debug(reply) h.logger = h.logger.WithFields(map[string]interface{}{ "bind": saddr.String(), @@ -45,7 +43,10 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), saddr) - h.tunnelServerUDP(conn, relay) + h.tunnelServerUDP( + socks.UDPTunServerConn(conn), + relay, + ) h.logger. WithFields(map[string]interface{}{ "duration": time.Since(t), @@ -64,9 +65,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go h.logger.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) reply.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply) - } + h.logger.Debug(reply) return } defer cc.Close() @@ -76,9 +75,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go h.logger.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) reply.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(reply) - } + h.logger.Debug(reply) } t := time.Now() @@ -91,97 +88,72 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go Infof("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) } -func (h *socks5Handler) tunnelServerUDP(tunnel net.Conn, c net.PacketConn) (err error) { +func (h *socks5Handler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) { bufSize := h.md.udpBufferSize errc := make(chan error, 2) go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - const dataPos = 262 - for { - addr := gosocks5.Addr{} - header := gosocks5.UDPHeader{ - Addr: &addr, - } + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := tunnel.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr) + return nil + } + + if _, err := c.WriteTo(b[:n], raddr); err != nil { + return err + } + + h.logger.Debugf("%s >>> %s data: %d", + c.LocalAddr(), raddr, n) + + return nil + }() - data := b[dataPos:] - dgram := gosocks5.UDPDatagram{ - Header: &header, - Data: data, - } - _, err := dgram.ReadFrom(tunnel) if err != nil { errc <- err return } - // NOTE: the dgram.Data may be reallocated if the provided buffer is too short, - // we drop it for simplicity. As this occurs, you should enlarge the buffer size. - if len(dgram.Data) > len(data) { - h.logger.Warnf("buffer too short, dropped") - continue - } - - raddr, err := net.ResolveUDPAddr("udp", addr.String()) - if err != nil { - continue // drop silently - } - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr.String()) - continue // bypass - } - - if _, err := c.WriteTo(dgram.Data, raddr); err != nil { - errc <- err - return - } - - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s >>> %s: %v data: %d", - tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data)) - } } }() go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - for { - n, raddr, err := c.ReadFrom(b) + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := c.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr) + return nil + } + + if _, err := tunnel.WriteTo(b[:n], raddr); err != nil { + return err + } + h.logger.Debugf("%s <<< %s data: %d", + c.LocalAddr(), raddr, n) + + return nil + }() + if err != nil { errc <- err return } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr.String()) - continue // bypass - } - - addr, _ := gosocks5.NewAddr(raddr.String()) - if addr == nil { - addr = &gosocks5.Addr{} - } - header := gosocks5.UDPHeader{ - Rsv: uint16(n), - Addr: addr, - } - dgram := gosocks5.UDPDatagram{ - Header: &header, - Data: b[:n], - } - - if _, err := dgram.WriteTo(tunnel); err != nil { - errc <- err - return - } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s <<< %s: %v data: %d", - tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data)) - } } }() diff --git a/pkg/handler/ssu/handler.go b/pkg/handler/ssu/handler.go index 93e8146..ee80c2b 100644 --- a/pkg/handler/ssu/handler.go +++ b/pkg/handler/ssu/handler.go @@ -1,16 +1,15 @@ package ssu import ( - "bytes" "context" "net" "time" - "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/internal/bufpool" + "github.com/go-gost/gost/pkg/internal/utils/socks" "github.com/go-gost/gost/pkg/internal/utils/ss" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -91,7 +90,10 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) - h.relayPacket(pc, cc) + h.relayPacket( + ss.UDPServerConn(pc, conn.RemoteAddr(), h.md.bufferSize), + cc, + ) h.logger. WithFields(map[string]interface{}{"duration": time.Since(t)}). Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) @@ -104,7 +106,7 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) - h.tunnelUDP(conn, cc) + h.tunnelUDP(socks.UDPTunServerConn(conn), cc) h.logger. WithFields(map[string]interface{}{"duration": time.Since(t)}). Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) @@ -112,47 +114,30 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { bufSize := h.md.bufferSize - errc := make(chan error, 2) - var clientAddr net.Addr go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - for { err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + n, addr, err := pc1.ReadFrom(b) if err != nil { return err } - if clientAddr == nil { - clientAddr = addr - } - rb := bytes.NewBuffer(b[:n]) - saddr := gosocks5.Addr{} - if _, err := saddr.ReadFrom(rb); err != nil { - return err - } - taddr, err := net.ResolveUDPAddr("udp", saddr.String()) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(taddr.String()) { - h.logger.Warn("bypass: ", taddr) + if h.bypass != nil && h.bypass.Contains(addr.String()) { + h.logger.Warn("bypass: ", addr) return nil } - if _, err = pc2.WriteTo(rb.Bytes(), taddr); err != nil { + if _, err = pc2.WriteTo(b[:n], addr); err != nil { return err } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s >>> %s: %v, data: %d", - addr, taddr, saddr.String(), rb.Len()) - } + h.logger.Debugf("%s >>> %s data: %d", + pc2.LocalAddr(), addr, n) return nil }() @@ -164,41 +149,27 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { }() go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - const dataPos = 259 - for { err := func() error { - n, raddr, err := pc2.ReadFrom(b[dataPos:]) + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := pc2.ReadFrom(b) if err != nil { return err } - if clientAddr == nil { - return nil - } if h.bypass != nil && h.bypass.Contains(raddr.String()) { h.logger.Warn("bypass: ", raddr) return nil } - socksAddr, _ := gosocks5.NewAddr(raddr.String()) - if socksAddr == nil { - socksAddr = &gosocks5.Addr{} - } - addrLen := socksAddr.Length() - socksAddr.Encode(b[dataPos-addrLen : dataPos]) - - if _, err = pc1.WriteTo(b[dataPos-addrLen:dataPos+n], clientAddr); err != nil { + if _, err = pc1.WriteTo(b[:n], raddr); err != nil { return err } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s <<< %s: %v data: %d", - clientAddr, raddr, b[dataPos-addrLen:dataPos], n) - } + h.logger.Debugf("%s <<< %s data: %d", + pc2.LocalAddr(), raddr, n) return nil }() @@ -212,7 +183,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { return <-errc } -func (h *ssuHandler) tunnelUDP(tunnel net.Conn, c net.PacketConn) (err error) { +func (h *ssuHandler) tunnelUDP(tunnel, c net.PacketConn) (err error) { bufSize := h.md.bufferSize errc := make(chan error, 2) @@ -220,49 +191,32 @@ func (h *ssuHandler) tunnelUDP(tunnel net.Conn, c net.PacketConn) (err error) { b := bufpool.Get(bufSize) defer bufpool.Put(b) - const dataPos = 262 - for { - addr := gosocks5.Addr{} - header := gosocks5.UDPHeader{ - Addr: &addr, - } + err := func() error { + n, addr, err := tunnel.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(addr.String()) { + h.logger.Warn("bypass: ", addr.String()) + return nil // bypass + } + + if _, err := c.WriteTo(b[:n], addr); err != nil { + return err + } + + h.logger.Debugf("%s >>> %s data: %d", + c.LocalAddr(), addr, n) + + return nil + }() - data := b[dataPos:] - dgram := gosocks5.UDPDatagram{ - Header: &header, - Data: data, - } - _, err := dgram.ReadFrom(tunnel) if err != nil { errc <- err return } - // NOTE: the dgram.Data may be reallocated if the provided buffer is too short, - // we drop it for simplicity. As this occurs, you should enlarge the buffer size. - if len(dgram.Data) > len(data) { - h.logger.Warnf("buffer too short, dropped") - continue - } - - raddr, err := net.ResolveUDPAddr("udp", addr.String()) - if err != nil { - continue // drop silently - } - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr.String()) - continue // bypass - } - - if _, err := c.WriteTo(dgram.Data, raddr); err != nil { - errc <- err - return - } - - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s >>> %s: %v data: %d", - tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data)) - } } }() @@ -271,38 +225,31 @@ func (h *ssuHandler) tunnelUDP(tunnel net.Conn, c net.PacketConn) (err error) { defer bufpool.Put(b) for { - n, raddr, err := c.ReadFrom(b) + err := func() error { + n, raddr, err := c.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr.String()) + return nil // bypass + } + + if _, err := tunnel.WriteTo(b[:n], raddr); err != nil { + return err + } + + h.logger.Debugf("%s <<< %s data: %d", + c.LocalAddr(), raddr, n) + + return nil + }() + if err != nil { errc <- err return } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr.String()) - continue // bypass - } - - addr, _ := gosocks5.NewAddr(raddr.String()) - if addr == nil { - addr = &gosocks5.Addr{} - } - header := gosocks5.UDPHeader{ - Rsv: uint16(n), - Addr: addr, - } - dgram := gosocks5.UDPDatagram{ - Header: &header, - Data: b[:n], - } - - if _, err := dgram.WriteTo(tunnel); err != nil { - errc <- err - return - } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debugf("%s <<< %s: %v data: %d", - tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data)) - } } }() diff --git a/pkg/internal/utils/socks/conn.go b/pkg/internal/utils/socks/conn.go new file mode 100644 index 0000000..fc0ee6c --- /dev/null +++ b/pkg/internal/utils/socks/conn.go @@ -0,0 +1,173 @@ +package socks + +import ( + "bytes" + "net" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/internal/bufpool" +) + +var ( + _ net.PacketConn = (*UDPTunConn)(nil) + _ net.Conn = (*UDPTunConn)(nil) + + _ net.PacketConn = (*UDPConn)(nil) + _ net.Conn = (*UDPConn)(nil) +) + +type UDPTunConn struct { + net.Conn + taddr net.Addr +} + +func UDPTunClientConn(c net.Conn, targetAddr net.Addr) *UDPTunConn { + return &UDPTunConn{ + Conn: c, + taddr: targetAddr, + } +} + +func UDPTunServerConn(c net.Conn) *UDPTunConn { + return &UDPTunConn{ + Conn: c, + } +} + +func (c *UDPTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + socksAddr := gosocks5.Addr{} + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + dgram := gosocks5.UDPDatagram{ + Header: &header, + Data: b, + } + _, err = dgram.ReadFrom(c.Conn) + if err != nil { + return + } + + n = len(dgram.Data) + if n > len(b) { + n = copy(b, dgram.Data) + } + addr, err = net.ResolveUDPAddr("udp", socksAddr.String()) + + return +} + +func (c *UDPTunConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *UDPTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + socksAddr := gosocks5.Addr{} + if err = socksAddr.ParseFrom(addr.String()); err != nil { + return + } + + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + dgram := gosocks5.UDPDatagram{ + Header: &header, + Data: b, + } + dgram.Header.Rsv = uint16(len(dgram.Data)) + _, err = dgram.WriteTo(c.Conn) + n = len(b) + + return +} + +func (c *UDPTunConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) +} + +var ( + DefaultBufferSize = 4096 +) + +type UDPConn struct { + net.PacketConn + raddr net.Addr + taddr net.Addr + bufferSize int +} + +func NewUDPConn(c net.PacketConn, bufferSize int) *UDPConn { + return &UDPConn{ + PacketConn: c, + bufferSize: bufferSize, + } +} + +// ReadFrom reads an UDP datagram. +// NOTE: for server side, +// the returned addr is the target address the client want to relay to. +func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + rbuf := bufpool.Get(c.bufferSize) + defer bufpool.Put(rbuf) + + n, c.raddr, err = c.PacketConn.ReadFrom(rbuf) + if err != nil { + return + } + + socksAddr := gosocks5.Addr{} + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + hlen, err := header.ReadFrom(bytes.NewReader(rbuf[:n])) + if err != nil { + return + } + n = copy(b, rbuf[hlen:n]) + + addr, err = net.ResolveUDPAddr("udp", socksAddr.String()) + return +} + +func (c *UDPConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + wbuf := bufpool.Get(c.bufferSize) + defer bufpool.Put(wbuf) + + socksAddr := gosocks5.Addr{} + if err = socksAddr.ParseFrom(addr.String()); err != nil { + return + } + + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + dgram := gosocks5.UDPDatagram{ + Header: &header, + Data: b, + } + + buf := bytes.NewBuffer(wbuf[:0]) + _, err = dgram.WriteTo(buf) + if err != nil { + return + } + + _, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr) + n = len(b) + + return +} + +func (c *UDPConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) +} + +func (c *UDPConn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/pkg/internal/utils/ss/conn.go b/pkg/internal/utils/ss/conn.go new file mode 100644 index 0000000..8eeeb6c --- /dev/null +++ b/pkg/internal/utils/ss/conn.go @@ -0,0 +1,96 @@ +package ss + +import ( + "bytes" + "net" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/internal/bufpool" +) + +var ( + DefaultBufferSize = 4096 +) + +var ( + _ net.PacketConn = (*UDPConn)(nil) + _ net.Conn = (*UDPConn)(nil) +) + +type UDPConn struct { + net.PacketConn + raddr net.Addr + taddr net.Addr + bufferSize int +} + +func UDPClientConn(c net.PacketConn, remoteAddr, targetAddr net.Addr, bufferSize int) *UDPConn { + return &UDPConn{ + PacketConn: c, + raddr: remoteAddr, + taddr: targetAddr, + bufferSize: bufferSize, + } +} + +func UDPServerConn(c net.PacketConn, remoteAddr net.Addr, bufferSize int) *UDPConn { + return &UDPConn{ + PacketConn: c, + raddr: remoteAddr, + bufferSize: bufferSize, + } +} + +func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + rbuf := bufpool.Get(c.bufferSize) + defer bufpool.Put(rbuf) + + n, _, err = c.PacketConn.ReadFrom(rbuf) + if err != nil { + return + } + + saddr := gosocks5.Addr{} + addrLen, err := saddr.ReadFrom(bytes.NewReader(rbuf[:n])) + if err != nil { + return + } + + n = copy(b, rbuf[addrLen:n]) + addr, err = net.ResolveUDPAddr("udp", saddr.String()) + + return +} + +func (c *UDPConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + wbuf := bufpool.Get(c.bufferSize) + defer bufpool.Put(wbuf) + + socksAddr := gosocks5.Addr{} + if err = socksAddr.ParseFrom(addr.String()); err != nil { + return + } + + addrLen, err := socksAddr.Encode(wbuf) + if err != nil { + return + } + + n = copy(wbuf[addrLen:], b) + _, err = c.PacketConn.WriteTo(wbuf[:addrLen+n], c.raddr) + + return +} + +func (c *UDPConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) +} + +func (c *UDPConn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/pkg/listener/udp/conn.go b/pkg/listener/udp/conn.go index 2e29d43..eba1e6d 100644 --- a/pkg/listener/udp/conn.go +++ b/pkg/listener/udp/conn.go @@ -6,109 +6,184 @@ import ( "sync" "sync/atomic" "time" + + "github.com/go-gost/gost/pkg/internal/bufpool" + "github.com/go-gost/gost/pkg/logger" ) -// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. -type serverConn struct { +// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. +type conn struct { net.PacketConn - raddr net.Addr + remoteAddr net.Addr rc chan []byte // data receive queue - fresh int32 + idle int32 closed chan struct{} closeMutex sync.Mutex - config *serverConnConfig } -type serverConnConfig struct { - ttl time.Duration - qsize int - onClose func() -} - -func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn { - if conn == nil || raddr == nil { - return nil - } - - if cfg == nil { - cfg = &serverConnConfig{} - } - c := &serverConn{ - PacketConn: conn, - raddr: raddr, - rc: make(chan []byte, cfg.qsize), +func newConn(c net.PacketConn, raddr net.Addr, queue int) *conn { + return &conn{ + PacketConn: c, + remoteAddr: raddr, + rc: make(chan []byte, queue), closed: make(chan struct{}), - config: cfg, } - go c.ttlWait() - return c } -func (c *serverConn) send(b []byte) error { +func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { select { - case c.rc <- b: - return nil - default: - return errors.New("queue is full") + case bb := <-c.rc: + n = copy(b, bb) + c.SetIdle(false) + bufpool.Put(bb) + + case <-c.closed: + err = net.ErrClosed + return } + + addr = c.remoteAddr + + return } -func (c *serverConn) Read(b []byte) (n int, err error) { +func (c *conn) Read(b []byte) (n int, err error) { n, _, err = c.ReadFrom(b) return } -func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - select { - case bb := <-c.rc: - n = copy(b, bb) - atomic.StoreInt32(&c.fresh, 1) - case <-c.closed: - err = errors.New("read from closed connection") - return - } - - addr = c.raddr - - return +func (c *conn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.remoteAddr) } -func (c *serverConn) Write(b []byte) (n int, err error) { - return c.WriteTo(b, c.raddr) -} - -func (c *serverConn) Close() error { +func (c *conn) Close() error { c.closeMutex.Lock() defer c.closeMutex.Unlock() select { case <-c.closed: - return errors.New("connection is closed") default: - if c.config.onClose != nil { - c.config.onClose() - } close(c.closed) } return nil } -func (c *serverConn) RemoteAddr() net.Addr { - return c.raddr +func (c *conn) RemoteAddr() net.Addr { + return c.remoteAddr } -func (c *serverConn) ttlWait() { - ticker := time.NewTicker(c.config.ttl) +func (c *conn) IsIdle() bool { + return atomic.LoadInt32(&c.idle) > 0 +} + +func (c *conn) SetIdle(idle bool) { + v := int32(0) + if idle { + v = 1 + } + atomic.StoreInt32(&c.idle, v) +} + +func (c *conn) Queue(b []byte) error { + select { + case c.rc <- b: + return nil + + case <-c.closed: + return net.ErrClosed + + default: + return errors.New("recv queue is full") + } +} + +type connPool struct { + m sync.Map + ttl time.Duration + closed chan struct{} + logger logger.Logger +} + +func newConnPool(ttl time.Duration) *connPool { + p := &connPool{ + ttl: ttl, + closed: make(chan struct{}), + } + go p.idleCheck() + return p +} + +func (p *connPool) WithLogger(logger logger.Logger) *connPool { + p.logger = logger + return p +} + +func (p *connPool) Get(key interface{}) (c *conn, ok bool) { + v, ok := p.m.Load(key) + if ok { + c, ok = v.(*conn) + } + return +} + +func (p *connPool) Set(key interface{}, c *conn) { + p.m.Store(key, c) +} + +func (p *connPool) Delete(key interface{}) { + p.m.Delete(key) +} + +func (p *connPool) Close() { + select { + case <-p.closed: + return + default: + } + + close(p.closed) + + p.m.Range(func(k, v interface{}) bool { + if c, ok := v.(*conn); ok && c != nil { + c.Close() + } + return true + }) +} + +func (p *connPool) idleCheck() { + ticker := time.NewTicker(p.ttl) defer ticker.Stop() for { select { case <-ticker.C: - if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) { - c.Close() - return + size := 0 + idles := 0 + p.m.Range(func(key, value interface{}) bool { + c, ok := value.(*conn) + if !ok || c == nil { + p.Delete(key) + return true + } + size++ + + if c.IsIdle() { + idles++ + p.Delete(key) + c.Close() + return true + } + + c.SetIdle(true) + + return true + }) + + if idles > 0 { + p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles) } - case <-c.closed: + case <-p.closed: return } } diff --git a/pkg/listener/udp/listener.go b/pkg/listener/udp/listener.go index f5995a1..e595672 100644 --- a/pkg/listener/udp/listener.go +++ b/pkg/listener/udp/listener.go @@ -2,9 +2,8 @@ package udp import ( "net" - "sync" - "sync/atomic" + "github.com/go-gost/gost/pkg/internal/bufpool" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -16,13 +15,14 @@ func init() { } type udpListener struct { - addr string - md metadata - conn net.PacketConn - connChan chan net.Conn - errChan chan error - connPool connPool - logger logger.Logger + addr string + md metadata + conn net.PacketConn + connChan chan net.Conn + errChan chan error + closeChan chan struct{} + connPool *connPool + logger logger.Logger } func NewListener(opts ...listener.Option) listener.Listener { @@ -31,8 +31,10 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &udpListener{ - addr: options.Addr, - logger: options.Logger, + addr: options.Addr, + errChan: make(chan error, 1), + closeChan: make(chan struct{}), + logger: options.Logger, } } @@ -46,15 +48,13 @@ func (l *udpListener) Init(md md.Metadata) (err error) { return } - var conn net.PacketConn - conn, err = net.ListenUDP("udp", laddr) + l.conn, err = net.ListenUDP("udp", laddr) if err != nil { return } - l.conn = conn l.connChan = make(chan net.Conn, l.md.connQueueSize) - l.errChan = make(chan error, 1) + l.connPool = newConnPool(l.md.ttl).WithLogger(l.logger) go l.listenLoop() @@ -74,12 +74,14 @@ func (l *udpListener) Accept() (conn net.Conn, err error) { } func (l *udpListener) Close() error { - err := l.conn.Close() - l.connPool.Range(func(k interface{}, v *serverConn) bool { - v.Close() - return true - }) - return err + select { + case <-l.closeChan: + return nil + default: + close(l.closeChan) + l.connPool.Close() + return l.conn.Close() + } } func (l *udpListener) Addr() net.Addr { @@ -88,43 +90,43 @@ func (l *udpListener) Addr() net.Addr { func (l *udpListener) listenLoop() { for { - b := make([]byte, l.md.readBufferSize) + b := bufpool.Get(l.md.readBufferSize) n, raddr, err := l.conn.ReadFrom(b) if err != nil { - l.logger.Error("accept:", err) l.errChan <- err close(l.errChan) return } - conn, ok := l.connPool.Get(raddr.String()) - if !ok { - conn = newServerConn(l.conn, raddr, - &serverConnConfig{ - ttl: l.md.ttl, - qsize: l.md.readQueueSize, - onClose: func() { - l.connPool.Delete(raddr.String()) - }, - }) - - select { - case l.connChan <- conn: - l.connPool.Set(raddr.String(), conn) - default: - conn.Close() - l.logger.Error("connection queue is full") - } + c := l.getConn(raddr) + if c == nil { + bufpool.Put(b) + continue } - if err := conn.send(b[:n]); err != nil { - l.logger.Warn("data discarded:", err) + if err := c.Queue(b[:n]); err != nil { + l.logger.Warn("data discarded: ", err) } - l.logger.Debug("recv", n) } } +func (l *udpListener) getConn(addr net.Addr) *conn { + c, ok := l.connPool.Get(addr.String()) + if !ok { + c = newConn(l.conn, addr, l.md.readQueueSize) + select { + case l.connChan <- c: + l.connPool.Set(addr.String(), c) + default: + c.Close() + l.logger.Warnf("connection queue is full, client %s discarded", addr.String()) + return nil + } + } + return c +} + func (l *udpListener) parseMetadata(md md.Metadata) (err error) { l.md.ttl = md.GetDuration(ttl) if l.md.ttl <= 0 { @@ -147,36 +149,3 @@ func (l *udpListener) parseMetadata(md md.Metadata) (err error) { return } - -type connPool struct { - size int64 - m sync.Map -} - -func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) { - v, ok := p.m.Load(key) - if ok { - conn, ok = v.(*serverConn) - } - return -} - -func (p *connPool) Set(key interface{}, conn *serverConn) { - p.m.Store(key, conn) - atomic.AddInt64(&p.size, 1) -} - -func (p *connPool) Delete(key interface{}) { - p.m.Delete(key) - atomic.AddInt64(&p.size, -1) -} - -func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) { - p.m.Range(func(k, v interface{}) bool { - return f(k, v.(*serverConn)) - }) -} - -func (p *connPool) Size() int64 { - return atomic.LoadInt64(&p.size) -} diff --git a/pkg/listener/udp/metadata.go b/pkg/listener/udp/metadata.go index 76bc478..ed575ab 100644 --- a/pkg/listener/udp/metadata.go +++ b/pkg/listener/udp/metadata.go @@ -4,7 +4,7 @@ import "time" const ( defaultTTL = 60 * time.Second - defaultReadBufferSize = 1024 + defaultReadBufferSize = 4096 defaultReadQueueSize = 128 defaultConnQueueSize = 128 )