diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go new file mode 100644 index 0000000..5b7f54b --- /dev/null +++ b/cmd/gost/cmd.go @@ -0,0 +1,88 @@ +package main + +import ( + "errors" + "fmt" + "net/url" + "strings" + + "github.com/go-gost/gost/pkg/config" +) + +var ( + ErrInvalidService = errors.New("invalid service") + ErrInvalidNode = errors.New("invalid node") +) + +type stringList []string + +func (l *stringList) String() string { + return fmt.Sprintf("%s", *l) +} +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} + +func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { + cfg := &config.Config{} + + var chain *config.ChainConfig + if len(nodes) > 0 { + chain = &config.ChainConfig{ + Name: "chain-0", + } + cfg.Chains = append(cfg.Chains, chain) + } + + for i, node := range nodes { + url, err := checkCmd(node) + if err != nil { + return nil, err + } + chain.Hops = append(chain.Hops, &config.HopConfig{ + Name: fmt.Sprintf("hop-%d", i), + Nodes: []*config.NodeConfig{ + { + Name: "node-0", + URL: url, + }, + }, + }) + } + + for i, svc := range services { + url, err := checkCmd(svc) + if err != nil { + return nil, err + } + service := &config.ServiceConfig{ + Name: fmt.Sprintf("service-%d", i), + URL: url, + } + if chain != nil { + service.Chain = chain.Name + } + cfg.Services = append(cfg.Services, service) + } + + return cfg, nil +} + +func checkCmd(s string) (string, error) { + s = strings.TrimSpace(s) + if s == "" { + return "", ErrInvalidService + } + + if !strings.Contains(s, "://") { + s = "auto://" + s + } + + u, err := url.Parse(s) + if err != nil { + return "", err + } + + return u.String(), nil +} diff --git a/cmd/gost/config.go b/cmd/gost/config.go index f587376..446db13 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -29,11 +29,11 @@ func buildService(cfg *config.Config) (services []*service.Service) { } for _, bypassCfg := range cfg.Bypasses { - bypasses[bypassCfg.Name] = bypassFromConfig(&bypassCfg) + bypasses[bypassCfg.Name] = bypassFromConfig(bypassCfg) } for _, chainCfg := range cfg.Chains { - chains[chainCfg.Name] = chainFromConfig(&chainCfg) + chains[chainCfg.Name] = chainFromConfig(chainCfg) } for _, svc := range cfg.Services { @@ -47,9 +47,14 @@ func buildService(cfg *config.Config) (services []*service.Service) { }) ln := registry.GetListener(svc.Listener.Type)( listener.AddrOption(svc.Addr), - listener.ChainOption(chains[svc.Listener.Chain]), listener.LoggerOption(listenerLogger), ) + + cln, chainable := ln.(listener.Chainable) + if chainable { + cln.Chain(chains[svc.Chain]) + } + if err := ln.Init(metadata.MapMetadata(svc.Listener.Metadata)); err != nil { listenerLogger.Fatal("init: ", err) } @@ -60,13 +65,17 @@ func buildService(cfg *config.Config) (services []*service.Service) { }) h := registry.GetHandler(svc.Handler.Type)( - handler.ChainOption(chains[svc.Handler.Chain]), - handler.BypassOption(bypasses[svc.Handler.Bypass]), + handler.ChainOption(chains[svc.Chain]), + handler.BypassOption(bypasses[svc.Bypass]), handler.LoggerOption(handlerLogger), ) if forwarder, ok := h.(handler.Forwarder); ok { - forwarder.Forward(forwarderFromConfig(svc.Forwarder)) + chain := chains[svc.Chain] + if chainable { + chain = nil + } + forwarder.Forward(forwarderFromConfig(svc.Forwarder), chain) } if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil { @@ -145,6 +154,9 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { } func logFromConfig(cfg *config.LogConfig) logger.Logger { + if cfg == nil { + cfg = &config.LogConfig{} + } opts := []logger.LoggerOption{ logger.FormatLoggerOption(logger.LogFormat(cfg.Format)), logger.LevelLoggerOption(logger.LogLevel(cfg.Level)), @@ -152,9 +164,9 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger { var out io.Writer = os.Stderr switch cfg.Output { - case "stdout": + case "stdout", "": out = os.Stdout - case "stderr", "": + case "stderr": out = os.Stderr default: f, err := os.OpenFile(cfg.Output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) @@ -201,7 +213,7 @@ func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass { } func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup { - if cfg == nil { + if cfg == nil || len(cfg.Targets) == 0 { return nil } diff --git a/cmd/gost/gost.yml b/cmd/gost/gost.yml index 451bf48..c93d21e 100644 --- a/cmd/gost/gost.yml +++ b/cmd/gost/gost.yml @@ -11,10 +11,10 @@ services: - name: http+tcp url: "http://gost:gost@:8000" addr: ":28000" + chain: chain01 + # bypass: bypass01 handler: type: http - chain: chain01 - # bypass: bypass01 metadata: proxyAgent: "gost/3.0" retry: 3 @@ -30,10 +30,10 @@ services: - name: ss url: "ss://chacha20:gost@:8000" addr: ":28338" + # chain: chain01 + # bypass: bypass01 handler: type: ss - # chain: chain01 - # bypass: bypass01 metadata: method: chacha20-ietf password: gost @@ -48,10 +48,10 @@ services: - name: socks5 url: "socks5://gost:gost@:1080" addr: ":21080" + # chain: chain-ss + # bypass: bypass01 handler: type: socks5 - # chain: chain-ss - # bypass: bypass01 metadata: auths: - gost:gost @@ -121,6 +121,7 @@ services: - name: rtcp addr: ":28100" + # chain: chain-socks5 forwarder: targets: - 192.168.8.8:80 @@ -131,7 +132,6 @@ services: retry: 3 listener: type: rtcp - # chain: chain-socks5 metadata: keepAlive: 15s mux: true diff --git a/cmd/gost/main.go b/cmd/gost/main.go index fa43391..ea9a84f 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -1,9 +1,12 @@ package main import ( - stdlog "log" + "flag" + "fmt" "net/http" _ "net/http/pprof" + "os" + "runtime" "github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/logger" @@ -11,14 +14,63 @@ import ( var ( log = logger.NewLogger() + + cfgFile string + outputCfgFile string + services stringList + nodes stringList + debug bool ) +func init() { + var printVersion bool + + flag.Var(&services, "L", "service list") + flag.Var(&nodes, "F", "chain node list") + flag.StringVar(&cfgFile, "C", "", "configure file") + flag.BoolVar(&printVersion, "V", false, "print version") + flag.BoolVar(&debug, "D", false, "debug mode") + flag.StringVar(&outputCfgFile, "O", "", "write config to FILE") + flag.Parse() + + if printVersion { + fmt.Fprintf(os.Stdout, "gost %s (%s %s/%s)\n", + version, runtime.Version(), runtime.GOOS, runtime.GOARCH) + os.Exit(0) + } +} + func main() { - stdlog.SetFlags(stdlog.LstdFlags | stdlog.Lshortfile) cfg := &config.Config{} - if err := cfg.Load(); err != nil { + var err error + if len(services) > 0 { + cfg, err = buildConfigFromCmd(services, nodes) + if debug && cfg != nil { + if cfg.Log == nil { + cfg.Log = &config.LogConfig{} + } + cfg.Log.Level = string(logger.DebugLevel) + } + } else { + if cfgFile != "" { + err = cfg.ReadFile(cfgFile) + } else { + err = cfg.Load() + } + } + if err != nil { log.Fatal(err) } + + normConfig(cfg) + + if outputCfgFile != "" { + if err := cfg.WriteFile(outputCfgFile); err != nil { + log.Fatal(err) + } + os.Exit(0) + } + log = logFromConfig(cfg.Log) if cfg.Profiling != nil && cfg.Profiling.Enabled { diff --git a/cmd/gost/norm.go b/cmd/gost/norm.go new file mode 100644 index 0000000..29af6ec --- /dev/null +++ b/cmd/gost/norm.go @@ -0,0 +1,100 @@ +package main + +import ( + "net/url" + "strings" + + "github.com/go-gost/gost/pkg/config" +) + +// normConfig normalizes the config. +func normConfig(cfg *config.Config) { + for _, svc := range cfg.Services { + normService(svc) + } + for _, chain := range cfg.Chains { + normChain(chain) + } +} + +func normService(svc *config.ServiceConfig) { + if svc.URL == "" { + return + } + + u, _ := url.Parse(svc.URL) + + var handler, listener string + schemes := strings.Split(u.Scheme, "+") + if len(schemes) == 1 { + handler = schemes[0] + listener = schemes[0] + } + if len(schemes) == 2 { + handler = schemes[0] + listener = schemes[1] + } + + md := make(map[string]interface{}) + for k, v := range u.Query() { + if len(v) > 0 { + md[k] = v[0] + } + } + + svc.Addr = u.Host + svc.Handler = &config.HandlerConfig{ + Type: handler, + Metadata: md, + } + svc.Listener = &config.ListenerConfig{ + Type: listener, + Metadata: md, + } + + if remotes := strings.Trim(u.EscapedPath(), "/"); remotes != "" { + svc.Forwarder = &config.ForwarderConfig{ + Targets: strings.Split(remotes, ","), + } + } +} + +func normChain(chain *config.ChainConfig) { + for _, hop := range chain.Hops { + for _, node := range hop.Nodes { + if node.URL == "" { + continue + } + + u, _ := url.Parse(node.URL) + + var connector, dialer string + schemes := strings.Split(u.Scheme, "+") + if len(schemes) == 1 { + connector = schemes[0] + dialer = schemes[0] + } + if len(schemes) == 2 { + connector = schemes[0] + dialer = schemes[1] + } + + md := make(map[string]interface{}) + for k, v := range u.Query() { + if len(v) > 0 { + md[k] = v[0] + } + } + + node.Addr = u.Host + node.Connector = &config.ConnectorConfig{ + Type: connector, + Metadata: md, + } + node.Dialer = &config.DialerConfig{ + Type: dialer, + Metadata: md, + } + } + } +} diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 35a32f0..94baa70 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -14,6 +14,7 @@ import ( // Register handlers _ "github.com/go-gost/gost/pkg/handler/forward" _ "github.com/go-gost/gost/pkg/handler/http" + _ "github.com/go-gost/gost/pkg/handler/relay" _ "github.com/go-gost/gost/pkg/handler/socks/v4" _ "github.com/go-gost/gost/pkg/handler/socks/v5" _ "github.com/go-gost/gost/pkg/handler/ss" diff --git a/cmd/gost/version.go b/cmd/gost/version.go new file mode 100644 index 0000000..3b7d97e --- /dev/null +++ b/cmd/gost/version.go @@ -0,0 +1,5 @@ +package main + +const ( + version = "3.0.0-alpha" +) diff --git a/go.mod b/go.mod index 4ebcbbb..7d16411 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 + github.com/go-gost/relay v0.1.1-0.20211117133957-4b109438dc89 github.com/gobwas/glob v0.2.3 github.com/golang/snappy v0.0.3 github.com/google/gopacket v1.1.19 // indirect diff --git a/go.sum b/go.sum index a92866f..6e39a29 100644 --- a/go.sum +++ b/go.sum @@ -111,22 +111,14 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= -github.com/go-gost/gosocks5 v0.3.0 h1:Hkmp9YDRBSCJd7xywW6dBPT6B9aQTkuWd+3WCheJiJA= -github.com/go-gost/gosocks5 v0.3.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/gosocks5 v0.3.1-0.20211107101454-adcd9c8808ae h1:IPcN2DQQxPiKE/hYudFYYeeN3036tQOjnWpsQBLo4Bw= -github.com/go-gost/gosocks5 v0.3.1-0.20211107101454-adcd9c8808ae/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/gosocks5 v0.3.1-0.20211107103216-55ee58d8201a h1:H35F3INQHliXXipi3aU6F0PJ4OzE5EZVeOZohUd/sKc= -github.com/go-gost/gosocks5 v0.3.1-0.20211107103216-55ee58d8201a/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/gosocks5 v0.3.1-0.20211107150557-ff084b955b6a h1:LQ189f7tRprIJwx3znV+V2KMw1Yjo+up38Tujo5vGFo= -github.com/go-gost/gosocks5 v0.3.1-0.20211107150557-ff084b955b6a/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/gosocks5 v0.3.1-0.20211107153135-23b5baedc2aa h1:4yBKO6CPj5LokDeVJy3jbvQTcclG6lMk7zQMQ1/MAYo= -github.com/go-gost/gosocks5 v0.3.1-0.20211107153135-23b5baedc2aa/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/gosocks5 v0.3.1-0.20211108032632-bbfd2de9a32d h1:mjoFToMUWNN06IwOyXOk9bEsev3T5RUoC9n4Xt7ZDkg= -github.com/go-gost/gosocks5 v0.3.1-0.20211108032632-bbfd2de9a32d/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea h1:mrm6bMpdxBvInvBuDbUaAQWV60r/PaByLIG9fQJEEIc= -github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= +github.com/go-gost/relay v0.1.0 h1:UOf2YwAzzaUjY5mdpMuLfSw0vz62iIFYk7oJQkuhlGw= +github.com/go-gost/relay v0.1.0/go.mod h1:YFCpddLOFE3NlIkeDWRdEs8gL/GFsqXdtaf8SV5v4YQ= +github.com/go-gost/relay v0.1.1-0.20211028021513-03c783f893bc h1:F8wBeQYP8JvzIG/6rwRsC+R+D97lstozbsqwRknd4XU= +github.com/go-gost/relay v0.1.1-0.20211028021513-03c783f893bc/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= +github.com/go-gost/relay v0.1.1-0.20211117133957-4b109438dc89 h1:1EtXLpAYeGVcptB0Jt8AeRe+GQnbTjbqeYA3L02pCIY= +github.com/go-gost/relay v0.1.1-0.20211117133957-4b109438dc89/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= diff --git a/pkg/chain/node.go b/pkg/chain/node.go index 67d8469..ff42843 100644 --- a/pkg/chain/node.go +++ b/pkg/chain/node.go @@ -76,12 +76,12 @@ func (g *NodeGroup) Next() *Node { return nil } - selector := g.selector - if selector == nil { - return g.nodes[0] + s := g.selector + if s == nil { + s = DefaultSelector } - return selector.Select(g.nodes...) + return s.Select(g.nodes...) } type FailMarker struct { diff --git a/pkg/chain/selector.go b/pkg/chain/selector.go index 2588407..5c96338 100644 --- a/pkg/chain/selector.go +++ b/pkg/chain/selector.go @@ -14,6 +14,10 @@ const ( DefaultFailTimeout = 30 * time.Second ) +var ( + DefaultSelector = NewSelector(RoundRobinStrategy()) +) + type Selector interface { Select(nodes ...*Node) *Node } diff --git a/pkg/common/util/relay/conn.go b/pkg/common/util/relay/conn.go new file mode 100644 index 0000000..51477b4 --- /dev/null +++ b/pkg/common/util/relay/conn.go @@ -0,0 +1,58 @@ +package relay + +import ( + "encoding/binary" + "errors" + "io" + "math" + "net" +) + +type packetConn struct { + net.Conn +} + +func UDPTunConn(conn net.Conn) net.Conn { + return &packetConn{ + Conn: conn, + } +} + +func (c *packetConn) Read(b []byte) (n int, err error) { + var bb [2]byte + _, err = io.ReadFull(c.Conn, bb[:]) + if err != nil { + return + } + + dlen := int(binary.BigEndian.Uint16(bb[:])) + if len(b) >= dlen { + return io.ReadFull(c.Conn, b[:dlen]) + } + buf := make([]byte, dlen) + _, err = io.ReadFull(c.Conn, buf) + n = copy(b, buf) + + return +} + +func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(b) + addr = c.Conn.RemoteAddr() + return +} + +func (c *packetConn) Write(b []byte) (n int, err error) { + if len(b) > math.MaxUint16 { + err = errors.New("write: data maximum exceeded") + return + } + + var bb [2]byte + binary.BigEndian.PutUint16(bb[:2], uint16(len(b))) + _, err = c.Conn.Write(bb[:]) + if err != nil { + return + } + return c.Conn.Write(b) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 96776d6..cd9a54d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -2,9 +2,11 @@ package config import ( "io" + "os" "time" "github.com/spf13/viper" + "gopkg.in/yaml.v2" ) var ( @@ -19,9 +21,9 @@ func init() { } type LogConfig struct { - Output string - Level string - Format string + Output string `yaml:",omitempty"` + Level string `yaml:",omitempty"` + Format string `yaml:",omitempty"` } type ProfilingConfig struct { @@ -37,25 +39,22 @@ type SelectorConfig struct { type BypassConfig struct { Name string - Reverse bool + Reverse bool `yaml:",omitempty"` Matchers []string } type ListenerConfig struct { Type string - Chain string - Metadata map[string]interface{} + Metadata map[string]interface{} `yaml:",omitempty"` } type HandlerConfig struct { Type string - Chain string - Bypass string - Metadata map[string]interface{} + Metadata map[string]interface{} `yaml:",omitempty"` } type ForwarderConfig struct { Targets []string - Selector *SelectorConfig + Selector *SelectorConfig `yaml:",omitempty"` } type DialerConfig struct { @@ -70,40 +69,42 @@ type ConnectorConfig struct { type ServiceConfig struct { Name string - URL string - Addr string - Listener *ListenerConfig - Handler *HandlerConfig - Forwarder *ForwarderConfig + URL string `yaml:",omitempty"` + Addr string `yaml:",omitempty"` + Chain string `yaml:",omitempty"` + Bypass string `yaml:",omitempty"` + Listener *ListenerConfig `yaml:",omitempty"` + Handler *HandlerConfig `yaml:",omitempty"` + Forwarder *ForwarderConfig `yaml:",omitempty"` } type ChainConfig struct { Name string - Selector *SelectorConfig - Hops []HopConfig + Selector *SelectorConfig `yaml:",omitempty"` + Hops []*HopConfig } type HopConfig struct { Name string - Selector *SelectorConfig - Nodes []NodeConfig + Selector *SelectorConfig `yaml:",omitempty"` + Nodes []*NodeConfig } type NodeConfig struct { Name string - URL string - Addr string - Dialer *DialerConfig - Connector *ConnectorConfig - Bypass string + URL string `yaml:",omitempty"` + Addr string `yaml:",omitempty"` + Dialer *DialerConfig `yaml:",omitempty"` + Connector *ConnectorConfig `yaml:",omitempty"` + Bypass string `yaml:",omitempty"` } type Config struct { - Log *LogConfig - Profiling *ProfilingConfig - Services []ServiceConfig - Chains []ChainConfig - Bypasses []BypassConfig + Log *LogConfig `yaml:",omitempty"` + Profiling *ProfilingConfig `yaml:",omitempty"` + Services []*ServiceConfig + Chains []*ChainConfig `yaml:",omitempty"` + Bypasses []*BypassConfig `yaml:",omitempty"` } func (c *Config) Load() error { @@ -129,3 +130,16 @@ func (c *Config) ReadFile(file string) error { } return v.Unmarshal(c) } + +func (c *Config) WriteFile(file string) error { + f, err := os.Create(file) + if err != nil { + return err + } + defer f.Close() + + enc := yaml.NewEncoder(f) + defer enc.Close() + + return enc.Encode(c) +} diff --git a/pkg/connector/http/connector.go b/pkg/connector/http/connector.go index 8df23b0..38e9d8a 100644 --- a/pkg/connector/http/connector.go +++ b/pkg/connector/http/connector.go @@ -52,8 +52,13 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add switch network { case "tcp", "tcp4", "tcp6": + if _, ok := conn.(net.PacketConn); ok { + err := fmt.Errorf("tcp over udp is unsupported") + c.logger.Error(err) + return nil, err + } default: - err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network) + err := fmt.Errorf("network %s is unsupported", network) c.logger.Error(err) return nil, err } diff --git a/pkg/connector/socks/v4/connector.go b/pkg/connector/socks/v4/connector.go index 124fb79..f079167 100644 --- a/pkg/connector/socks/v4/connector.go +++ b/pkg/connector/socks/v4/connector.go @@ -51,8 +51,13 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a switch network { case "tcp", "tcp4", "tcp6": + if _, ok := conn.(net.PacketConn); ok { + err := fmt.Errorf("tcp over udp is unsupported") + c.logger.Error(err) + return nil, err + } default: - err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network) + err := fmt.Errorf("network %s is unsupported", network) c.logger.Error(err) return nil, err } diff --git a/pkg/connector/socks/v5/connector.go b/pkg/connector/socks/v5/connector.go index 6e8abf5..2b31999 100644 --- a/pkg/connector/socks/v5/connector.go +++ b/pkg/connector/socks/v5/connector.go @@ -97,8 +97,13 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a case "udp", "udp4", "udp6": return c.connectUDP(ctx, conn, network, address) case "tcp", "tcp4", "tcp6": + if _, ok := conn.(net.PacketConn); ok { + err := fmt.Errorf("tcp over udp is unsupported") + c.logger.Error(err) + return nil, err + } default: - err := fmt.Errorf("network %s unsupported", network) + err := fmt.Errorf("network %s is unsupported", network) c.logger.Error(err) return nil, err } diff --git a/pkg/connector/ss/connector.go b/pkg/connector/ss/connector.go index 28c2128..e1407ac 100644 --- a/pkg/connector/ss/connector.go +++ b/pkg/connector/ss/connector.go @@ -52,10 +52,15 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre switch network { case "tcp", "tcp4", "tcp6": + if _, ok := conn.(net.PacketConn); ok { + err := fmt.Errorf("tcp over udp is unsupported") + c.logger.Error(err) + return nil, err + } case "udp", "udp4", "udp6": return c.connectUDP(ctx, conn, network, address) default: - err := fmt.Errorf("network %s unsupported", network) + err := fmt.Errorf("network %s is unsupported", network) c.logger.Error(err) return nil, err } @@ -99,7 +104,7 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre } func (c *ssConnector) connectUDP(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { - if c.md.enableUDP { + if !c.md.enableUDP { err := errors.New("UDP relay is disabled") c.logger.Error(err) return nil, err diff --git a/pkg/connector/ss/metadata.go b/pkg/connector/ss/metadata.go index 4157d06..3825f23 100644 --- a/pkg/connector/ss/metadata.go +++ b/pkg/connector/ss/metadata.go @@ -22,7 +22,7 @@ func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { password = "password" key = "key" connectTimeout = "timeout" - noDelay = "noDelay" + noDelay = "nodelay" enableUDP = "udp" // enable UDP relay udpBufferSize = "udpBufferSize" // udp buffer size ) diff --git a/pkg/handler/forward/handler.go b/pkg/handler/forward/handler.go index e4cd309..4f4bc52 100644 --- a/pkg/handler/forward/handler.go +++ b/pkg/handler/forward/handler.go @@ -43,8 +43,9 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *forwardHandler) Forward(group *chain.NodeGroup) { +func (h *forwardHandler) Forward(group *chain.NodeGroup, chain *chain.Chain) { h.group = group + h.chain = chain } func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 1016d91..80c4d10 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -14,5 +14,5 @@ type Handler interface { } type Forwarder interface { - Forward(*chain.NodeGroup) + Forward(*chain.NodeGroup, *chain.Chain) } diff --git a/pkg/handler/relay/forward.go b/pkg/handler/relay/forward.go new file mode 100644 index 0000000..34686ee --- /dev/null +++ b/pkg/handler/relay/forward.go @@ -0,0 +1,49 @@ +package relay + +import ( + "context" + "net" + "time" + + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" +) + +func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string) { + target := h.group.Next() + if target == nil { + h.logger.Error("no target available") + return + } + + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": target.Addr(), + }) + + h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + + cc, err := r.Dial(ctx, network, target.Addr()) + if err != nil { + h.logger.Error(err) + // TODO: the router itself may be failed due to the failed node in the router, + // the dead marker may be a wrong operation. + target.Marker().Mark() + return + } + defer cc.Close() + target.Marker().Reset() + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) + handler.Transport(conn, cc) + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(t), + }). + Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) +} diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go new file mode 100644 index 0000000..7f44508 --- /dev/null +++ b/pkg/handler/relay/handler.go @@ -0,0 +1,143 @@ +package relay + +import ( + "context" + "net" + "strconv" + "time" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/go-gost/relay" +) + +func init() { + registry.RegisterHandler("relay", NewHandler) +} + +type relayHandler struct { + group *chain.NodeGroup + chain *chain.Chain + bypass bypass.Bypass + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &relayHandler{ + chain: options.Chain, + bypass: options.Bypass, + logger: options.Logger, + } +} + +func (h *relayHandler) Init(md md.Metadata) (err error) { + return h.parseMetadata(md) +} + +// Forward implements handler.Forwarder. +func (h *relayHandler) Forward(group *chain.NodeGroup, chain *chain.Chain) { + h.group = group + h.chain = chain +} + +func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + if h.md.readTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) + } + + req := relay.Request{} + if _, err := req.ReadFrom(conn); err != nil { + h.logger.Error(err) + return + } + + conn.SetReadDeadline(time.Time{}) + + if req.Version != relay.Version1 { + h.logger.Error("bad version") + return + } + + var user, pass string + var target string + for _, f := range req.Features { + if f.Type() == relay.FeatureUserAuth { + feature := f.(*relay.UserAuthFeature) + user, pass = feature.Username, feature.Password + } + if f.Type() == relay.FeatureTargetAddr { + feature := f.(*relay.TargetAddrFeature) + target = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) + } + } + + if user != "" { + h.logger = h.logger.WithFields(map[string]interface{}{"user": user}) + } + if target != "" { + h.logger = h.logger.WithFields(map[string]interface{}{"dst": target}) + } + + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + if h.md.authenticator != nil && !h.md.authenticator.Authenticate(user, pass) { + resp.Status = relay.StatusUnauthorized + resp.WriteTo(conn) + h.logger.Error("unauthorized") + return + } + + network := "tcp" + if (req.Flags & relay.FUDP) == relay.FUDP { + network = "udp" + } + + if h.group != nil { + if target != "" { + resp.Status = relay.StatusForbidden + resp.WriteTo(conn) + h.logger.Error("forbidden") + return + } + // forward mode + h.handleForward(ctx, conn, network) + return + } + + if target == "" { + resp.Status = relay.StatusBadRequest + resp.WriteTo(conn) + h.logger.Error("bad request") + return + } + + // proxy mode + h.handleProxy(ctx, conn, network, target) +} diff --git a/pkg/handler/relay/metadata.go b/pkg/handler/relay/metadata.go new file mode 100644 index 0000000..f2070db --- /dev/null +++ b/pkg/handler/relay/metadata.go @@ -0,0 +1,41 @@ +package relay + +import ( + "strings" + "time" + + "github.com/go-gost/gost/pkg/auth" + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + authenticator auth.Authenticator + readTimeout time.Duration + retryCount int +} + +func (h *relayHandler) parseMetadata(md md.Metadata) (err error) { + const ( + authsKey = "auths" + readTimeout = "readTimeout" + retryCount = "retry" + ) + + if v, _ := md.Get(authsKey).([]interface{}); len(v) > 0 { + authenticator := auth.NewLocalAuthenticator(nil) + for _, auth := range v { + if s, _ := auth.(string); s != "" { + ss := strings.SplitN(s, ":", 2) + if len(ss) == 1 { + authenticator.Add(ss[0], "") + } else { + authenticator.Add(ss[0], ss[1]) + } + } + } + h.md.authenticator = authenticator + } + h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) + return +} diff --git a/pkg/handler/relay/proxy.go b/pkg/handler/relay/proxy.go new file mode 100644 index 0000000..1887da7 --- /dev/null +++ b/pkg/handler/relay/proxy.go @@ -0,0 +1,57 @@ +package relay + +import ( + "context" + "net" + "time" + + "github.com/go-gost/gost/pkg/chain" + util_relay "github.com/go-gost/gost/pkg/common/util/relay" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/relay" +) + +func (h *relayHandler) handleProxy(ctx context.Context, conn net.Conn, network, address string) { + h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) + + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + if h.bypass != nil && h.bypass.Contains(address) { + h.logger.Info("bypass: ", address) + resp.Status = relay.StatusForbidden + resp.WriteTo(conn) + return + } + + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, network, address) + if err != nil { + resp.Status = relay.StatusNetworkUnreachable + resp.WriteTo(conn) + return + } + defer cc.Close() + + if _, err := resp.WriteTo(conn); err != nil { + h.logger.Error(err) + } + + if network == "udp" { + conn = util_relay.UDPTunConn(conn) + } + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), address) + handler.Transport(conn, cc) + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(t), + }). + Infof("%s >-< %s", conn.RemoteAddr(), address) +} diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 24762c4..8759094 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -64,13 +64,7 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { // standard UDP relay. if pc, ok := conn.(net.PacketConn); ok { - if h.md.enableUDP { - h.handleUDP(ctx, conn.RemoteAddr(), pc) - return - } else { - h.logger.Error("UDP relay is diabled") - } - + h.handleUDP(ctx, pc, conn.RemoteAddr()) return } @@ -84,21 +78,17 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { br := bufio.NewReader(conn) data, err := br.Peek(3) + conn.SetReadDeadline(time.Time{}) if err != nil { h.logger.Error(err) h.discard(conn) return } - conn.SetReadDeadline(time.Time{}) conn = handler.NewBufferReaderConn(conn, br) if data[2] == 0xff { - if h.md.enableUDP { - // UDP-over-TCP relay - h.handleUDPTun(ctx, conn) - } else { - h.logger.Error("UDP relay is diabled") - } + // UDP-over-TCP relay + h.handleUDPTun(ctx, conn) return } @@ -110,8 +100,6 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { return } - conn.SetReadDeadline(time.Time{}) - h.logger = h.logger.WithFields(map[string]interface{}{ "dst": addr.String(), }) diff --git a/pkg/handler/ss/udp.go b/pkg/handler/ss/udp.go index 6cdbbc8..a75d65a 100644 --- a/pkg/handler/ss/udp.go +++ b/pkg/handler/ss/udp.go @@ -11,7 +11,12 @@ import ( "github.com/go-gost/gost/pkg/common/util/ss" ) -func (h *ssHandler) handleUDP(ctx context.Context, raddr net.Addr, conn net.PacketConn) { +func (h *ssHandler) handleUDP(ctx context.Context, conn net.PacketConn, raddr net.Addr) { + if !h.md.enableUDP { + h.logger.Error("UDP relay is diabled") + return + } + if h.md.cipher != nil { conn = h.md.cipher.PacketConn(conn) } @@ -50,6 +55,11 @@ func (h *ssHandler) handleUDP(ctx context.Context, raddr net.Addr, conn net.Pack } func (h *ssHandler) handleUDPTun(ctx context.Context, conn net.Conn) { + if !h.md.enableUDP { + h.logger.Error("UDP relay is diabled") + return + } + // obtain a udp connection r := (&chain.Router{}). WithChain(h.chain). diff --git a/pkg/listener/listener.go b/pkg/listener/listener.go index 078cd5f..297d52a 100644 --- a/pkg/listener/listener.go +++ b/pkg/listener/listener.go @@ -4,6 +4,7 @@ import ( "errors" "net" + "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/metadata" ) @@ -21,3 +22,7 @@ type Listener interface { type Accepter interface { Accept() (net.Conn, error) } + +type Chainable interface { + Chain(chain *chain.Chain) +} diff --git a/pkg/listener/option.go b/pkg/listener/option.go index 570dea5..ceb6b88 100644 --- a/pkg/listener/option.go +++ b/pkg/listener/option.go @@ -1,13 +1,11 @@ package listener import ( - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/logger" ) type Options struct { Addr string - Chain *chain.Chain Logger logger.Logger } @@ -19,12 +17,6 @@ func AddrOption(addr string) Option { } } -func ChainOption(chain *chain.Chain) Option { - return func(opts *Options) { - opts.Chain = chain - } -} - func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/listener/rtcp/listener.go b/pkg/listener/rtcp/listener.go index 61852cb..2e030f3 100644 --- a/pkg/listener/rtcp/listener.go +++ b/pkg/listener/rtcp/listener.go @@ -40,12 +40,16 @@ func NewListener(opts ...listener.Option) listener.Listener { } return &rtcpListener{ addr: options.Addr, - chain: options.Chain, closed: make(chan struct{}), logger: options.Logger, } } +// implements listener.Chainable interface +func (l *rtcpListener) Chain(chain *chain.Chain) { + l.chain = chain +} + func (l *rtcpListener) Init(md md.Metadata) (err error) { if err = l.parseMetadata(md); err != nil { return diff --git a/pkg/listener/rudp/listener.go b/pkg/listener/rudp/listener.go index 6995a03..56032e1 100644 --- a/pkg/listener/rudp/listener.go +++ b/pkg/listener/rudp/listener.go @@ -39,12 +39,16 @@ func NewListener(opts ...listener.Option) listener.Listener { } return &rudpListener{ addr: options.Addr, - chain: options.Chain, closed: make(chan struct{}), logger: options.Logger, } } +// implements listener.Chainable interface +func (l *rudpListener) Chain(chain *chain.Chain) { + l.chain = chain +} + func (l *rudpListener) Init(md md.Metadata) (err error) { if err = l.parseMetadata(md); err != nil { return diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index d63294a..2c7dd10 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -84,15 +84,15 @@ func NewLogger(opts ...LoggerOption) Logger { } switch options.Format { - case JSONFormat: + case TextFormat: + log.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + }) + default: log.SetFormatter(&logrus.JSONFormatter{ DisableHTMLEscape: true, // PrettyPrint: true, }) - default: - log.SetFormatter(&logrus.TextFormatter{ - FullTimestamp: true, - }) } switch options.Level {