From b2b76a10a0ebb5dbed3a13015e2f07c94334c5c2 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Fri, 28 Jan 2022 11:57:26 +0800 Subject: [PATCH] fix rudp --- pkg/common/util/udp/listener.go | 9 ++++ pkg/connector/socks/v5/bind.go | 2 +- pkg/handler/relay/bind.go | 84 +++------------------------------ pkg/handler/socks/v5/udp_tun.go | 56 +++++++++++++--------- pkg/handler/ss/udp/handler.go | 4 +- pkg/service/service.go | 4 +- 6 files changed, 54 insertions(+), 105 deletions(-) diff --git a/pkg/common/util/udp/listener.go b/pkg/common/util/udp/listener.go index ce89d3e..d0748f4 100644 --- a/pkg/common/util/udp/listener.go +++ b/pkg/common/util/udp/listener.go @@ -18,6 +18,7 @@ type listener struct { connPool *ConnPool mux sync.Mutex closed chan struct{} + errChan chan error logger logger.Logger } @@ -30,6 +31,7 @@ func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dat readQueueSize: dataQueueSize, readBufferSize: dataBufferSize, closed: make(chan struct{}), + errChan: make(chan error, 1), logger: logger, } go ln.listenLoop() @@ -43,6 +45,11 @@ func (ln *listener) Accept() (conn net.Conn, err error) { return case <-ln.closed: return nil, net.ErrClosed + case err = <-ln.errChan: + if err == nil { + err = net.ErrClosed + } + return } } @@ -58,6 +65,8 @@ func (ln *listener) listenLoop() { n, raddr, err := ln.conn.ReadFrom(*b) if err != nil { + ln.errChan <- err + close(ln.errChan) return } diff --git a/pkg/connector/socks/v5/bind.go b/pkg/connector/socks/v5/bind.go index afe58ec..0a9a9db 100644 --- a/pkg/connector/socks/v5/bind.go +++ b/pkg/connector/socks/v5/bind.go @@ -21,7 +21,7 @@ func (c *socks5Connector) Bind(ctx context.Context, conn net.Conn, network, addr "network": network, "address": address, }) - log.Infof("bind: %s/%s", address, network) + log.Infof("bind on %s/%s", address, network) options := connector.BindOptions{} for _, opt := range opts { diff --git a/pkg/handler/relay/bind.go b/pkg/handler/relay/bind.go index 0f97e78..dd2f8b8 100644 --- a/pkg/handler/relay/bind.go +++ b/pkg/handler/relay/bind.go @@ -6,7 +6,6 @@ import ( "net" "time" - "github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/common/util/mux" "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/handler" @@ -113,13 +112,14 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr }) log.Debugf("bind on %s OK", pc.LocalAddr()) + relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). + WithBypass(h.options.Bypass). + WithLogger(log) + relay.SetBufferSize(h.md.udpBufferSize) + t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) - h.tunnelServerUDP( - socks.UDPTunServerConn(conn), - pc, - log, - ) + relay.Run() log.WithFields(map[string]interface{}{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) @@ -189,75 +189,3 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L }(rc) } } - -func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn, log logger.Logger) (err error) { - bufSize := h.md.udpBufferSize - errc := make(chan error, 2) - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, raddr, err := tunnel.ReadFrom(*b) - if err != nil { - return err - } - - if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) { - log.Warn("bypass: ", raddr) - return nil - } - - if _, err := c.WriteTo((*b)[:n], raddr); err != nil { - return err - } - - log.Debugf("%s >>> %s data: %d", - c.LocalAddr(), raddr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, raddr, err := c.ReadFrom(*b) - if err != nil { - return err - } - - if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) { - log.Warn("bypass: ", raddr) - return nil - } - - if _, err := tunnel.WriteTo((*b)[:n], raddr); err != nil { - return err - } - log.Debugf("%s <<< %s data: %d", - c.LocalAddr(), raddr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - return <-errc -} diff --git a/pkg/handler/socks/v5/udp_tun.go b/pkg/handler/socks/v5/udp_tun.go index 2fcb816..4dbbd31 100644 --- a/pkg/handler/socks/v5/udp_tun.go +++ b/pkg/handler/socks/v5/udp_tun.go @@ -16,35 +16,47 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network "cmd": "udp-tun", }) - if !h.md.enableUDP { - reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) - reply.Write(conn) - log.Debug(reply) - log.Error("UDP relay is diabled") - return + bindAddr, _ := net.ResolveUDPAddr(network, address) + if bindAddr == nil { + bindAddr = &net.UDPAddr{} } - // dummy bind - reply := gosocks5.NewReply(gosocks5.Succeeded, nil) + if bindAddr.Port == 0 { + // relay mode + if !h.md.enableUDP { + reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) + reply.Write(conn) + log.Debug(reply) + log.Error("UDP relay is diabled") + return + } + } else { + // BIND mode + if !h.md.enableBind { + reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) + reply.Write(conn) + log.Debug(reply) + log.Error("BIND is diabled") + return + } + } + + pc, err := net.ListenUDP(network, bindAddr) + if err != nil { + log.Error(err) + return + } + defer pc.Close() + + saddr := gosocks5.Addr{} + saddr.ParseFrom(pc.LocalAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) if err := reply.Write(conn); err != nil { log.Error(err) return } log.Debug(reply) - - // obtain a udp connection - c, err := h.router.Dial(ctx, "udp", "") // UDP association - if err != nil { - log.Error(err) - return - } - defer c.Close() - - pc, ok := c.(net.PacketConn) - if !ok { - log.Errorf("wrong connection type") - return - } + log.Debugf("bind on %s OK", pc.LocalAddr()) relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). WithBypass(h.options.Bypass). diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go index bed446d..2e6498d 100644 --- a/pkg/handler/ss/udp/handler.go +++ b/pkg/handler/ss/udp/handler.go @@ -109,10 +109,10 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { } t := time.Now() - log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) + log.Infof("%s <-> %s", conn.LocalAddr(), cc.LocalAddr()) h.relayPacket(pc, cc, log) log.WithFields(map[string]interface{}{"duration": time.Since(t)}). - Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) + Infof("%s >-< %s", conn.LocalAddr(), cc.LocalAddr()) } func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (err error) { diff --git a/pkg/service/service.go b/pkg/service/service.go index 76c72e0..3290e8a 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -50,11 +50,11 @@ func (s *Service) serve() error { if e != nil { if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { - tempDelay = 100 * time.Millisecond + tempDelay = 1 * time.Second } else { tempDelay *= 2 } - if max := 1 * time.Second; tempDelay > max { + if max := 5 * time.Second; tempDelay > max { tempDelay = max } s.logger.Warnf("accept: %v, retrying in %v", e, tempDelay)