From be374b6488a4cb95d20a33b79023651743e76d87 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 8 Mar 2022 22:34:39 +0800 Subject: [PATCH] add icmp tunnel --- cmd/gost/register.go | 2 + pkg/common/util/icmp/conn.go | 213 ++++++++++++++++++++++++++++++++ pkg/dialer/icmp/conn.go | 42 +++++++ pkg/dialer/icmp/dialer.go | 130 +++++++++++++++++++ pkg/dialer/icmp/metadata.go | 29 +++++ pkg/dialer/kcp/conn.go | 1 - pkg/dialer/kcp/dialer.go | 114 +++++++---------- pkg/dialer/quic/conn.go | 1 - pkg/dialer/quic/dialer.go | 100 +++++---------- pkg/dialer/ws/dialer.go | 12 +- pkg/dialer/ws/mux/dialer.go | 25 ++-- pkg/handler/tun/handler.go | 14 ++- pkg/handler/tun/metadata.go | 2 +- pkg/internal/util/pht/server.go | 1 + pkg/internal/util/quic/conn.go | 19 +-- pkg/internal/util/ws/ws.go | 25 +++- pkg/listener/icmp/conn.go | 21 ++++ pkg/listener/icmp/listener.go | 145 ++++++++++++++++++++++ pkg/listener/icmp/metadata.go | 41 ++++++ pkg/listener/kcp/listener.go | 1 - pkg/listener/quic/listener.go | 17 +-- 21 files changed, 769 insertions(+), 186 deletions(-) create mode 100644 pkg/common/util/icmp/conn.go create mode 100644 pkg/dialer/icmp/conn.go create mode 100644 pkg/dialer/icmp/dialer.go create mode 100644 pkg/dialer/icmp/metadata.go create mode 100644 pkg/listener/icmp/conn.go create mode 100644 pkg/listener/icmp/listener.go create mode 100644 pkg/listener/icmp/metadata.go diff --git a/cmd/gost/register.go b/cmd/gost/register.go index ed353ad..f4edbe8 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -19,6 +19,7 @@ import ( _ "github.com/go-gost/gost/pkg/dialer/http2" _ "github.com/go-gost/gost/pkg/dialer/http2/h2" _ "github.com/go-gost/gost/pkg/dialer/http3" + _ "github.com/go-gost/gost/pkg/dialer/icmp" _ "github.com/go-gost/gost/pkg/dialer/kcp" _ "github.com/go-gost/gost/pkg/dialer/obfs/http" _ "github.com/go-gost/gost/pkg/dialer/obfs/tls" @@ -58,6 +59,7 @@ import ( _ "github.com/go-gost/gost/pkg/listener/http2" _ "github.com/go-gost/gost/pkg/listener/http2/h2" _ "github.com/go-gost/gost/pkg/listener/http3" + _ "github.com/go-gost/gost/pkg/listener/icmp" _ "github.com/go-gost/gost/pkg/listener/kcp" _ "github.com/go-gost/gost/pkg/listener/obfs/http" _ "github.com/go-gost/gost/pkg/listener/obfs/tls" diff --git a/pkg/common/util/icmp/conn.go b/pkg/common/util/icmp/conn.go new file mode 100644 index 0000000..0b1167d --- /dev/null +++ b/pkg/common/util/icmp/conn.go @@ -0,0 +1,213 @@ +package icmp + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "net" + "sync/atomic" + + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/logger" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" +) + +const ( + readBufferSize = 1500 + writeBufferSize = 1500 + magicNumber = 0x474F5354 +) + +var ( + ErrInvalidPacket = errors.New("icmp: invalid packet") + ErrInvalidType = errors.New("icmp: invalid type") +) + +type clientConn struct { + net.PacketConn + id int + seq uint32 +} + +func ClientConn(conn net.PacketConn, id int) net.PacketConn { + return &clientConn{ + PacketConn: conn, + id: id, + } +} + +func (c *clientConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + buf := bufpool.Get(readBufferSize) + defer bufpool.Put(buf) + + for { + n, addr, err = c.PacketConn.ReadFrom(*buf) + if err != nil { + return + } + + m, err := icmp.ParseMessage(1, (*buf)[:n]) + if err != nil { + logger.Default().Error("icmp: parse message %v", err) + return 0, addr, err + } + echo, ok := m.Body.(*icmp.Echo) + if !ok || m.Type != ipv4.ICMPTypeEchoReply { + logger.Default().Warnf("icmp: invalid type %s (discarded)", m.Type) + continue // discard + } + + if echo.ID != c.id { + logger.Default().Warnf("icmp: id mismatch got %d, should be %d (discarded)", echo.ID, c.id) + continue + } + + if len(echo.Data) < 4 || + binary.BigEndian.Uint32(echo.Data[:4]) != magicNumber { + logger.Default().Warn("icmp: invalid message (discarded)") + continue + } + n = copy(b, echo.Data[4:]) + break + } + + if v, ok := addr.(*net.IPAddr); ok { + addr = &net.UDPAddr{ + IP: v.IP, + Port: c.id, + } + } + // logger.Default().Infof("icmp: read from: %v %d", addr, n) + + return +} + +func (c *clientConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + // logger.Default().Infof("icmp: write to: %v %d", addr, len(b)) + switch v := addr.(type) { + case *net.UDPAddr: + addr = &net.IPAddr{IP: v.IP} + } + + buf := bufpool.Get(writeBufferSize) + defer bufpool.Put(buf) + + binary.BigEndian.PutUint32((*buf)[:4], magicNumber) + copy((*buf)[4:], b) + + echo := icmp.Echo{ + ID: c.id, + Seq: int(atomic.AddUint32(&c.seq, 1)), + Data: (*buf)[:len(b)+4], + } + m := icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &echo, + } + wb, err := m.Marshal(nil) + if err != nil { + return 0, err + } + _, err = c.PacketConn.WriteTo(wb, addr) + n = len(b) + return +} + +type serverConn struct { + net.PacketConn + seqs [65535]uint32 +} + +func ServerConn(conn net.PacketConn) net.PacketConn { + return &serverConn{ + PacketConn: conn, + } +} + +func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + buf := bufpool.Get(readBufferSize) + defer bufpool.Put(buf) + + for { + n, addr, err = c.PacketConn.ReadFrom(*buf) + if err != nil { + return + } + + m, err := icmp.ParseMessage(1, (*buf)[:n]) + if err != nil { + logger.Default().Error("icmp: parse message %v", err) + return 0, addr, err + } + + echo, ok := m.Body.(*icmp.Echo) + if !ok || m.Type != ipv4.ICMPTypeEcho || echo.ID <= 0 { + logger.Default().Warnf("icmp: invalid type %s (discarded)", m.Type) + continue + } + + atomic.StoreUint32(&c.seqs[uint16(echo.ID-1)], uint32(echo.Seq)) + + if len(echo.Data) < 4 || + binary.BigEndian.Uint32(echo.Data[:4]) != magicNumber { + logger.Default().Warn("icmp: invalid message (discarded)") + continue + } + + n = copy(b, echo.Data[4:]) + + if v, ok := addr.(*net.IPAddr); ok { + addr = &net.UDPAddr{ + IP: v.IP, + Port: echo.ID, + } + } + break + } + + // logger.Default().Infof("icmp: read from: %v %d", addr, n) + + return +} + +func (c *serverConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + // logger.Default().Infof("icmp: write to: %v %d", addr, len(b)) + var id int + switch v := addr.(type) { + case *net.UDPAddr: + addr = &net.IPAddr{IP: v.IP} + id = v.Port + } + + if id <= 0 || id > math.MaxUint16 { + err = fmt.Errorf("icmp: invalid message id %v", addr) + return + } + + buf := bufpool.Get(writeBufferSize) + defer bufpool.Put(buf) + + binary.BigEndian.PutUint32((*buf)[:4], magicNumber) + copy((*buf)[4:], b) + + echo := icmp.Echo{ + ID: id, + Seq: int(atomic.LoadUint32(&c.seqs[id-1])), + Data: (*buf)[:len(b)+4], + } + m := icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &echo, + } + wb, err := m.Marshal(nil) + if err != nil { + return 0, err + } + _, err = c.PacketConn.WriteTo(wb, addr) + n = len(b) + return +} diff --git a/pkg/dialer/icmp/conn.go b/pkg/dialer/icmp/conn.go new file mode 100644 index 0000000..13c05c5 --- /dev/null +++ b/pkg/dialer/icmp/conn.go @@ -0,0 +1,42 @@ +package quic + +import ( + "context" + "net" + + "github.com/lucas-clemente/quic-go" +) + +type quicSession struct { + session quic.Session +} + +func (session *quicSession) GetConn() (*quicConn, error) { + stream, err := session.session.OpenStreamSync(context.Background()) + if err != nil { + return nil, err + } + return &quicConn{ + Stream: stream, + laddr: session.session.LocalAddr(), + raddr: session.session.RemoteAddr(), + }, nil +} + +func (session *quicSession) Close() error { + return session.session.CloseWithError(quic.ApplicationErrorCode(0), "closed") +} + +type quicConn struct { + quic.Stream + laddr net.Addr + raddr net.Addr +} + +func (c *quicConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *quicConn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/pkg/dialer/icmp/dialer.go b/pkg/dialer/icmp/dialer.go new file mode 100644 index 0000000..d4b5a5e --- /dev/null +++ b/pkg/dialer/icmp/dialer.go @@ -0,0 +1,130 @@ +package quic + +import ( + "context" + "math" + "math/rand" + "net" + "sync" + "time" + + icmp_pkg "github.com/go-gost/gost/pkg/common/util/icmp" + "github.com/go-gost/gost/pkg/dialer" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/lucas-clemente/quic-go" + "golang.org/x/net/icmp" +) + +func init() { + registry.DialerRegistry().Register("icmp", NewDialer) +} + +type icmpDialer struct { + sessions map[string]*quicSession + sessionMutex sync.Mutex + logger logger.Logger + md metadata + options dialer.Options +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := dialer.Options{} + for _, opt := range opts { + opt(&options) + } + + return &icmpDialer{ + sessions: make(map[string]*quicSession), + logger: options.Logger, + options: options, + } +} + +func (d *icmpDialer) Init(md md.Metadata) (err error) { + if err = d.parseMetadata(md); err != nil { + return + } + + return nil +} + +func (d *icmpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { + if _, _, err := net.SplitHostPort(addr); err != nil { + addr = net.JoinHostPort(addr, "0") + } + + raddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + d.sessionMutex.Lock() + defer d.sessionMutex.Unlock() + + session, ok := d.sessions[addr] + if !ok { + options := &dialer.DialOptions{} + for _, opt := range opts { + opt(options) + } + + var pc net.PacketConn + pc, err = icmp.ListenPacket("ip4:icmp", "") + if err != nil { + return + } + + id := raddr.Port + if id == 0 { + id = rand.New(rand.NewSource(time.Now().UnixNano())).Intn(math.MaxUint16) + 1 + raddr.Port = id + } + pc = icmp_pkg.ClientConn(pc, id) + + session, err = d.initSession(ctx, raddr, pc) + if err != nil { + d.logger.Error(err) + pc.Close() + return nil, err + } + + d.sessions[addr] = session + } + + conn, err = session.GetConn() + if err != nil { + session.Close() + delete(d.sessions, addr) + return nil, err + } + + return +} + +func (d *icmpDialer) initSession(ctx context.Context, addr net.Addr, conn net.PacketConn) (*quicSession, error) { + quicConfig := &quic.Config{ + KeepAlive: d.md.keepAlive, + HandshakeIdleTimeout: d.md.handshakeTimeout, + MaxIdleTimeout: d.md.maxIdleTimeout, + Versions: []quic.VersionNumber{ + quic.Version1, + quic.VersionDraft29, + }, + } + + tlsCfg := d.options.TLSConfig + tlsCfg.NextProtos = []string{"http/3", "quic/v1"} + + session, err := quic.DialContext(ctx, conn, addr, addr.String(), tlsCfg, quicConfig) + if err != nil { + return nil, err + } + return &quicSession{session: session}, nil +} + +// Multiplex implements dialer.Multiplexer interface. +func (d *icmpDialer) Multiplex() bool { + return true +} diff --git a/pkg/dialer/icmp/metadata.go b/pkg/dialer/icmp/metadata.go new file mode 100644 index 0000000..9e4face --- /dev/null +++ b/pkg/dialer/icmp/metadata.go @@ -0,0 +1,29 @@ +package quic + +import ( + "time" + + mdata "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + keepAlive bool + maxIdleTimeout time.Duration + handshakeTimeout time.Duration +} + +func (d *icmpDialer) parseMetadata(md mdata.Metadata) (err error) { + const ( + keepAlive = "keepAlive" + handshakeTimeout = "handshakeTimeout" + maxIdleTimeout = "maxIdleTimeout" + ) + + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + + d.md.keepAlive = mdata.GetBool(md, keepAlive) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) + + return +} diff --git a/pkg/dialer/kcp/conn.go b/pkg/dialer/kcp/conn.go index c46a9f4..41cf3d5 100644 --- a/pkg/dialer/kcp/conn.go +++ b/pkg/dialer/kcp/conn.go @@ -7,7 +7,6 @@ import ( ) type muxSession struct { - conn net.Conn session *smux.Session } diff --git a/pkg/dialer/kcp/dialer.go b/pkg/dialer/kcp/dialer.go index d53025e..71a93ff 100644 --- a/pkg/dialer/kcp/dialer.go +++ b/pkg/dialer/kcp/dialer.go @@ -26,17 +26,19 @@ type kcpDialer struct { sessionMutex sync.Mutex logger logger.Logger md metadata + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &kcpDialer{ sessions: make(map[string]*muxSession), logger: options.Logger, + options: options, } } @@ -50,12 +52,12 @@ func (d *kcpDialer) Init(md md.Metadata) (err error) { return nil } -// Multiplex implements dialer.Multiplexer interface. -func (d *kcpDialer) Multiplex() bool { - return true -} - func (d *kcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { + raddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + d.sessionMutex.Lock() defer d.sessionMutex.Unlock() @@ -70,86 +72,55 @@ func (d *kcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp opt(&options) } + var pc net.PacketConn if d.md.config.TCP { - raddr, err := net.ResolveUDPAddr("udp", addr) + pc, err = tcpraw.Dial("tcp", addr) if err != nil { return nil, err } - - pc, err := tcpraw.Dial("tcp", addr) - if err != nil { - return nil, err - } - conn = &fakeTCPConn{ + pc = &fakeTCPConn{ raddr: raddr, PacketConn: pc, } } else { - conn, err = options.NetDialer.Dial(ctx, "udp", addr) + c, err := options.NetDialer.Dial(ctx, "udp", addr) if err != nil { return nil, err } + + var ok bool + pc, ok = c.(net.PacketConn) + if !ok { + c.Close() + return nil, errors.New("quic: wrong connection type") + } + } + + session, err = d.initSession(ctx, raddr, pc) + if err != nil { + d.logger.Error(err) + pc.Close() + return nil, err } - session = &muxSession{conn: conn} d.sessions[addr] = session } - return session.conn, err -} - -// Handshake implements dialer.Handshaker -func (d *kcpDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { - opts := &dialer.HandshakeOptions{} - for _, option := range options { - option(opts) - } - - d.sessionMutex.Lock() - defer d.sessionMutex.Unlock() - - if d.md.handshakeTimeout > 0 { - conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) - defer conn.SetDeadline(time.Time{}) - } - - session, ok := d.sessions[opts.Addr] - if session != nil && session.conn != conn { - conn.Close() - return nil, errors.New("kcp: unrecognized connection") - } - - if !ok || session.session == nil { - s, err := d.initSession(ctx, opts.Addr, conn) - if err != nil { - d.logger.Error(err) - conn.Close() - delete(d.sessions, opts.Addr) - return nil, err - } - session = s - d.sessions[opts.Addr] = session - } - cc, err := session.GetConn() + conn, err = session.GetConn() if err != nil { session.Close() - delete(d.sessions, opts.Addr) + delete(d.sessions, addr) return nil, err } - return cc, nil + return } -func (d *kcpDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*muxSession, error) { - pc, ok := conn.(net.PacketConn) - if !ok { - return nil, errors.New("kcp: wrong connection type") - } - +func (d *kcpDialer) initSession(ctx context.Context, addr net.Addr, conn net.PacketConn) (*muxSession, error) { config := d.md.config - kcpconn, err := kcp.NewConn(addr, + kcpconn, err := kcp.NewConn(addr.String(), kcp_util.BlockCrypt(config.Key, config.Crypt, kcp_util.DefaultSalt), - config.DataShard, config.ParityShard, pc) + config.DataShard, config.ParityShard, conn) if err != nil { return nil, err } @@ -162,15 +133,15 @@ func (d *kcpDialer) initSession(ctx context.Context, addr string, conn net.Conn) kcpconn.SetACKNoDelay(config.AckNodelay) if config.DSCP > 0 { - if err := kcpconn.SetDSCP(config.DSCP); err != nil { - d.logger.Warn("SetDSCP: ", err) + if er := kcpconn.SetDSCP(config.DSCP); er != nil { + d.logger.Warn("SetDSCP: ", er) } } - if err := kcpconn.SetReadBuffer(config.SockBuf); err != nil { - d.logger.Warn("SetReadBuffer: ", err) + if er := kcpconn.SetReadBuffer(config.SockBuf); er != nil { + d.logger.Warn("SetReadBuffer: ", er) } - if err := kcpconn.SetWriteBuffer(config.SockBuf); err != nil { - d.logger.Warn("SetWriteBuffer: ", err) + if er := kcpconn.SetWriteBuffer(config.SockBuf); er != nil { + d.logger.Warn("SetWriteBuffer: ", er) } // stream multiplex @@ -185,5 +156,10 @@ func (d *kcpDialer) initSession(ctx context.Context, addr string, conn net.Conn) if err != nil { return nil, err } - return &muxSession{conn: conn, session: session}, nil + return &muxSession{session: session}, nil +} + +// Multiplex implements dialer.Multiplexer interface. +func (d *kcpDialer) Multiplex() bool { + return true } diff --git a/pkg/dialer/quic/conn.go b/pkg/dialer/quic/conn.go index e1af828..13c05c5 100644 --- a/pkg/dialer/quic/conn.go +++ b/pkg/dialer/quic/conn.go @@ -8,7 +8,6 @@ import ( ) type quicSession struct { - conn net.Conn session quic.Session } diff --git a/pkg/dialer/quic/dialer.go b/pkg/dialer/quic/dialer.go index 3c55dfd..57f787b 100644 --- a/pkg/dialer/quic/dialer.go +++ b/pkg/dialer/quic/dialer.go @@ -5,7 +5,6 @@ import ( "errors" "net" "sync" - "time" "github.com/go-gost/gost/pkg/dialer" quic_util "github.com/go-gost/gost/pkg/internal/util/quic" @@ -48,12 +47,16 @@ func (d *quicDialer) Init(md md.Metadata) (err error) { return nil } -// Multiplex implements dialer.Multiplexer interface. -func (d *quicDialer) Multiplex() bool { - return true -} - func (d *quicDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { + if _, _, err := net.SplitHostPort(addr); err != nil { + addr = net.JoinHostPort(addr, "0") + } + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + d.sessionMutex.Lock() defer d.sessionMutex.Unlock() @@ -64,80 +67,41 @@ func (d *quicDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO opt(options) } - host := d.md.host - if host == "" { - host = options.Host - } - if h, _, _ := net.SplitHostPort(host); h != "" { - host = h - } - conn, err = options.NetDialer.Dial(ctx, "udp", "") + c, err := options.NetDialer.Dial(ctx, "udp", "") if err != nil { return nil, err } + pc, ok := c.(net.PacketConn) + if !ok { + c.Close() + return nil, errors.New("quic: wrong connection type") + } if d.md.cipherKey != nil { - conn = quic_util.CipherConn(conn.(*net.UDPConn), d.md.cipherKey) + pc = quic_util.CipherPacketConn(pc, d.md.cipherKey) + } + + session, err = d.initSession(ctx, udpAddr, pc) + if err != nil { + d.logger.Error(err) + pc.Close() + return nil, err } - session = &quicSession{conn: conn} d.sessions[addr] = session } - return session.conn, err -} - -// Handshake implements dialer.Handshaker -func (d *quicDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { - opts := &dialer.HandshakeOptions{} - for _, option := range options { - option(opts) - } - - d.sessionMutex.Lock() - defer d.sessionMutex.Unlock() - - if d.md.handshakeTimeout > 0 { - conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) - defer conn.SetDeadline(time.Time{}) - } - - session, ok := d.sessions[opts.Addr] - if session != nil && session.conn != conn { - conn.Close() - return nil, errors.New("quic: unrecognized connection") - } - if !ok || session.session == nil { - s, err := d.initSession(ctx, opts.Addr, conn) - if err != nil { - d.logger.Error(err) - conn.Close() - delete(d.sessions, opts.Addr) - return nil, err - } - session = s - d.sessions[opts.Addr] = session - } - cc, err := session.GetConn() + conn, err = session.GetConn() if err != nil { session.Close() - delete(d.sessions, opts.Addr) + delete(d.sessions, addr) return nil, err } - return cc, nil + return } -func (d *quicDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*quicSession, error) { - pc, ok := conn.(net.PacketConn) - if !ok { - return nil, errors.New("quic: wrong connection type") - } - - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } +func (d *quicDialer) initSession(ctx context.Context, addr net.Addr, conn net.PacketConn) (*quicSession, error) { quicConfig := &quic.Config{ KeepAlive: d.md.keepAlive, HandshakeIdleTimeout: d.md.handshakeTimeout, @@ -151,10 +115,14 @@ func (d *quicDialer) initSession(ctx context.Context, addr string, conn net.Conn tlsCfg := d.options.TLSConfig tlsCfg.NextProtos = []string{"http/3", "quic/v1"} - session, err := quic.DialContext(ctx, pc, udpAddr, addr, tlsCfg, quicConfig) + session, err := quic.DialContext(ctx, conn, addr, addr.String(), tlsCfg, quicConfig) if err != nil { - d.logger.Error(err) return nil, err } - return &quicSession{conn: conn, session: session}, nil + return &quicSession{session: session}, nil +} + +// Multiplex implements dialer.Multiplexer interface. +func (d *quicDialer) Multiplex() bool { + return true } diff --git a/pkg/dialer/ws/dialer.go b/pkg/dialer/ws/dialer.go index f6e8633..82573e6 100644 --- a/pkg/dialer/ws/dialer.go +++ b/pkg/dialer/ws/dialer.go @@ -72,8 +72,8 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial } if d.md.handshakeTimeout > 0 { - conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) - defer conn.SetDeadline(time.Time{}) + conn.SetReadDeadline(time.Now().Add(d.md.handshakeTimeout)) + defer conn.SetReadDeadline(time.Time{}) } host := d.md.host @@ -103,6 +103,8 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial } resp.Body.Close() + cc := ws_util.Conn(c) + if d.md.keepAlive > 0 { c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2)) c.SetPongHandler(func(string) error { @@ -110,13 +112,13 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial d.options.Logger.Infof("pong: set read deadline: %v", d.md.keepAlive*2) return nil }) - go d.keepAlive(c) + go d.keepAlive(cc) } - return ws_util.Conn(c), nil + return cc, nil } -func (d *wsDialer) keepAlive(conn *websocket.Conn) { +func (d *wsDialer) keepAlive(conn ws_util.WebsocketConn) { ticker := time.NewTicker(d.md.keepAlive) defer ticker.Stop() diff --git a/pkg/dialer/ws/mux/dialer.go b/pkg/dialer/ws/mux/dialer.go index 8b0320d..518fedd 100644 --- a/pkg/dialer/ws/mux/dialer.go +++ b/pkg/dialer/ws/mux/dialer.go @@ -103,11 +103,6 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia d.sessionMutex.Lock() defer d.sessionMutex.Unlock() - if d.md.handshakeTimeout > 0 { - conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) - defer conn.SetDeadline(time.Time{}) - } - session, ok := d.sessions[opts.Addr] if session != nil && session.conn != conn { conn.Close() @@ -156,23 +151,31 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) dialer.TLSClientConfig = d.options.TLSConfig } + if d.md.handshakeTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(d.md.handshakeTimeout)) + } + c, resp, err := dialer.DialContext(ctx, url.String(), d.md.header) if err != nil { return nil, err } resp.Body.Close() + if d.md.handshakeTimeout > 0 { + conn.SetReadDeadline(time.Time{}) + } + + cc := ws_util.Conn(c) + if d.md.keepAlive > 0 { c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2)) c.SetPongHandler(func(string) error { c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2)) return nil }) - go d.keepAlive(c) + go d.keepAlive(cc) } - conn = ws_util.Conn(c) - // stream multiplex smuxConfig := smux.DefaultConfig() smuxConfig.KeepAliveDisabled = d.md.muxKeepAliveDisabled @@ -192,14 +195,14 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) smuxConfig.MaxStreamBuffer = d.md.muxMaxStreamBuffer } - session, err := smux.Client(conn, smuxConfig) + session, err := smux.Client(cc, smuxConfig) if err != nil { return nil, err } - return &muxSession{conn: conn, session: session}, nil + return &muxSession{conn: cc, session: session}, nil } -func (d *mwsDialer) keepAlive(conn *websocket.Conn) { +func (d *mwsDialer) keepAlive(conn ws_util.WebsocketConn) { ticker := time.NewTicker(d.md.keepAlive) defer ticker.Stop() diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go index 88e5374..ffbc91c 100644 --- a/pkg/handler/tun/handler.go +++ b/pkg/handler/tun/handler.go @@ -132,7 +132,7 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add var err error var pc net.PacketConn if addr != nil { - cc, err := h.router.Dial(ctx, addr.Network(), addr.String()) + cc, err := h.router.Dial(ctx, addr.Network(), "") if err != nil { return err } @@ -140,7 +140,8 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add var ok bool pc, ok = cc.(net.PacketConn) if !ok { - return errors.New("invalid connnection") + cc.Close() + return errors.New("wrong connection type") } } else { laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) @@ -153,8 +154,9 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add if h.cipher != nil { pc = h.cipher.PacketConn(pc) } + defer pc.Close() - return h.transport(conn, pc, addr, log) + return h.transport(conn, pc, addr, config, log) }() if err != nil { log.Error(err) @@ -183,7 +185,7 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add } -func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr, log logger.Logger) error { +func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr, config *tun_util.Config, log logger.Logger) error { errc := make(chan error, 1) go func() { @@ -236,7 +238,7 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr return err } - addr := h.findRouteFor(dst) + addr := h.findRouteFor(dst, config.Routes...) if addr == nil { log.Warnf("no route for %s -> %s", src, dst) return nil @@ -317,7 +319,7 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr log.Warnf("no route for %s -> %s", src, addr) } - if addr := h.findRouteFor(dst); addr != nil { + if addr := h.findRouteFor(dst, config.Routes...); addr != nil { log.Debugf("find route: %s -> %s", dst, addr) _, err := conn.WriteTo((*b)[:n], addr) diff --git a/pkg/handler/tun/metadata.go b/pkg/handler/tun/metadata.go index 52bd853..8c62cb7 100644 --- a/pkg/handler/tun/metadata.go +++ b/pkg/handler/tun/metadata.go @@ -18,7 +18,7 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.key = mdata.GetString(md, key) h.md.bufferSize = mdata.GetInt(md, bufferSize) if h.md.bufferSize <= 0 { - h.md.bufferSize = 1024 + h.md.bufferSize = 1500 } return } diff --git a/pkg/internal/util/pht/server.go b/pkg/internal/util/pht/server.go index 0efa227..327c2a5 100644 --- a/pkg/internal/util/pht/server.go +++ b/pkg/internal/util/pht/server.go @@ -69,6 +69,7 @@ func LoggerServerOption(logger logger.Logger) ServerOption { } } +// TODO: remove stale clients from conns type Server struct { addr net.Addr httpServer *http.Server diff --git a/pkg/internal/util/quic/conn.go b/pkg/internal/util/quic/conn.go index 7d49a5c..0fe3ed3 100644 --- a/pkg/internal/util/quic/conn.go +++ b/pkg/internal/util/quic/conn.go @@ -10,26 +10,19 @@ import ( ) type cipherConn struct { - *net.UDPConn + net.PacketConn key []byte } -func CipherConn(conn *net.UDPConn, key []byte) net.Conn { +func CipherPacketConn(conn net.PacketConn, key []byte) net.PacketConn { return &cipherConn{ - UDPConn: conn, - key: key, - } -} - -func CipherPacketConn(conn *net.UDPConn, key []byte) net.PacketConn { - return &cipherConn{ - UDPConn: conn, - key: key, + PacketConn: conn, + key: key, } } func (conn *cipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) { - n, addr, err = conn.UDPConn.ReadFrom(data) + n, addr, err = conn.PacketConn.ReadFrom(data) if err != nil { return } @@ -49,7 +42,7 @@ func (conn *cipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) { return } - _, err = conn.UDPConn.WriteTo(b, addr) + _, err = conn.PacketConn.WriteTo(b, addr) if err != nil { return } diff --git a/pkg/internal/util/ws/ws.go b/pkg/internal/util/ws/ws.go index 53290f0..da4beaa 100644 --- a/pkg/internal/util/ws/ws.go +++ b/pkg/internal/util/ws/ws.go @@ -2,17 +2,25 @@ package ws import ( "net" + "sync" "time" "github.com/gorilla/websocket" ) -type websocketConn struct { - *websocket.Conn - rb []byte +type WebsocketConn interface { + net.Conn + WriteMessage(int, []byte) error + ReadMessage() (int, []byte, error) } -func Conn(conn *websocket.Conn) net.Conn { +type websocketConn struct { + *websocket.Conn + rb []byte + mux sync.Mutex +} + +func Conn(conn *websocket.Conn) WebsocketConn { return &websocketConn{ Conn: conn, } @@ -20,7 +28,7 @@ func Conn(conn *websocket.Conn) net.Conn { func (c *websocketConn) Read(b []byte) (n int, err error) { if len(c.rb) == 0 { - _, c.rb, err = c.ReadMessage() + _, c.rb, err = c.Conn.ReadMessage() } n = copy(b, c.rb) c.rb = c.rb[n:] @@ -33,6 +41,13 @@ func (c *websocketConn) Write(b []byte) (n int, err error) { return } +func (c *websocketConn) WriteMessage(messageType int, data []byte) error { + c.mux.Lock() + defer c.mux.Unlock() + + return c.Conn.WriteMessage(messageType, data) +} + func (c *websocketConn) SetDeadline(t time.Time) error { if err := c.SetReadDeadline(t); err != nil { return err diff --git a/pkg/listener/icmp/conn.go b/pkg/listener/icmp/conn.go new file mode 100644 index 0000000..ee1a26c --- /dev/null +++ b/pkg/listener/icmp/conn.go @@ -0,0 +1,21 @@ +package quic + +import ( + "net" + + "github.com/lucas-clemente/quic-go" +) + +type quicConn struct { + quic.Stream + laddr net.Addr + raddr net.Addr +} + +func (c *quicConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *quicConn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/pkg/listener/icmp/listener.go b/pkg/listener/icmp/listener.go new file mode 100644 index 0000000..7a1a832 --- /dev/null +++ b/pkg/listener/icmp/listener.go @@ -0,0 +1,145 @@ +package quic + +import ( + "context" + "net" + + "github.com/go-gost/gost/pkg/common/metrics" + icmp_pkg "github.com/go-gost/gost/pkg/common/util/icmp" + "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/lucas-clemente/quic-go" + "golang.org/x/net/icmp" +) + +func init() { + registry.ListenerRegistry().Register("icmp", NewListener) +} + +type icmpListener struct { + ln quic.Listener + cqueue chan net.Conn + errChan chan error + logger logger.Logger + md metadata + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &icmpListener{ + logger: options.Logger, + options: options, + } +} + +func (l *icmpListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + addr := l.options.Addr + if host, _, err := net.SplitHostPort(addr); err == nil { + addr = host + } + + var conn net.PacketConn + conn, err = icmp.ListenPacket("ip4:icmp", addr) + if err != nil { + return + } + conn = icmp_pkg.ServerConn(conn) + conn = metrics.WrapPacketConn(l.options.Service, conn) + + config := &quic.Config{ + KeepAlive: l.md.keepAlive, + HandshakeIdleTimeout: l.md.handshakeTimeout, + MaxIdleTimeout: l.md.maxIdleTimeout, + Versions: []quic.VersionNumber{ + quic.Version1, + quic.VersionDraft29, + }, + } + + tlsCfg := l.options.TLSConfig + tlsCfg.NextProtos = []string{"http/3", "quic/v1"} + + ln, err := quic.Listen(conn, tlsCfg, config) + if err != nil { + return + } + + l.ln = ln + l.cqueue = make(chan net.Conn, l.md.backlog) + l.errChan = make(chan error, 1) + + go l.listenLoop() + + return +} + +func (l *icmpListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.cqueue: + case err, ok = <-l.errChan: + if !ok { + err = listener.ErrClosed + } + } + return +} + +func (l *icmpListener) Close() error { + return l.ln.Close() +} + +func (l *icmpListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *icmpListener) listenLoop() { + for { + ctx := context.Background() + session, err := l.ln.Accept(ctx) + if err != nil { + l.logger.Error("accept: ", err) + l.errChan <- err + close(l.errChan) + return + } + l.logger.Infof("new client session: %v", session.RemoteAddr()) + go l.mux(ctx, session) + } +} + +func (l *icmpListener) mux(ctx context.Context, session quic.Session) { + defer session.CloseWithError(0, "closed") + + for { + stream, err := session.AcceptStream(ctx) + if err != nil { + l.logger.Error("accept stream: ", err) + return + } + + conn := &quicConn{ + Stream: stream, + laddr: session.LocalAddr(), + raddr: session.RemoteAddr(), + } + select { + case l.cqueue <- conn: + case <-stream.Context().Done(): + stream.Close() + default: + stream.Close() + l.logger.Warnf("connection queue is full, client %s discarded", session.RemoteAddr()) + } + } +} diff --git a/pkg/listener/icmp/metadata.go b/pkg/listener/icmp/metadata.go new file mode 100644 index 0000000..bb7b1d8 --- /dev/null +++ b/pkg/listener/icmp/metadata.go @@ -0,0 +1,41 @@ +package quic + +import ( + "time" + + mdata "github.com/go-gost/gost/pkg/metadata" +) + +const ( + defaultBacklog = 128 +) + +type metadata struct { + keepAlive bool + handshakeTimeout time.Duration + maxIdleTimeout time.Duration + + cipherKey []byte + backlog int +} + +func (l *icmpListener) parseMetadata(md mdata.Metadata) (err error) { + const ( + keepAlive = "keepAlive" + handshakeTimeout = "handshakeTimeout" + maxIdleTimeout = "maxIdleTimeout" + + backlog = "backlog" + ) + + l.md.backlog = mdata.GetInt(md, backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + + l.md.keepAlive = mdata.GetBool(md, keepAlive) + l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + l.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) + + return +} diff --git a/pkg/listener/kcp/listener.go b/pkg/listener/kcp/listener.go index 6b8703f..1ea3bd0 100644 --- a/pkg/listener/kcp/listener.go +++ b/pkg/listener/kcp/listener.go @@ -49,7 +49,6 @@ func (l *kcpListener) Init(md md.Metadata) (err error) { config.Init() var conn net.PacketConn - if config.TCP { conn, err = tcpraw.Listen("tcp", l.options.Addr) } else { diff --git a/pkg/listener/quic/listener.go b/pkg/listener/quic/listener.go index 16b5326..4e17c7d 100644 --- a/pkg/listener/quic/listener.go +++ b/pkg/listener/quic/listener.go @@ -42,21 +42,24 @@ func (l *quicListener) Init(md md.Metadata) (err error) { return } - laddr, err := net.ResolveUDPAddr("udp", l.options.Addr) + addr := l.options.Addr + if _, _, err := net.SplitHostPort(addr); err != nil { + addr = net.JoinHostPort(addr, "0") + } + + var laddr *net.UDPAddr + laddr, err = net.ResolveUDPAddr("udp", addr) if err != nil { return } - uc, err := net.ListenUDP("udp", laddr) + var conn net.PacketConn + conn, err = net.ListenUDP("udp", laddr) if err != nil { return } - - var conn net.PacketConn = uc - // conn = metrics.WrapPacketConn(l.options.Service, conn) - if l.md.cipherKey != nil { - conn = quic_util.CipherPacketConn(uc, l.md.cipherKey) + conn = quic_util.CipherPacketConn(conn, l.md.cipherKey) } config := &quic.Config{