From a0ee8bc45cdb0ce0529d4ab5efc0b1d4678836c7 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 18 Jan 2022 23:54:59 +0800 Subject: [PATCH] add resolver for cmd --- cmd/gost/cmd.go | 28 +++++- cmd/gost/config.go | 10 +- cmd/gost/register.go | 1 + pkg/config/config.go | 12 +-- pkg/dialer/http3/conn.go | 148 ++++++++++++++++++++++++++++ pkg/dialer/http3/dialer.go | 112 +++++++++++++++++++++ pkg/dialer/http3/metadata.go | 48 +++++++++ pkg/dialer/pht/dialer.go | 21 ++-- pkg/handler/dns/handler.go | 2 +- pkg/handler/dns/metadata.go | 11 +-- pkg/resolver/exchanger/exchanger.go | 6 ++ 11 files changed, 373 insertions(+), 26 deletions(-) create mode 100644 pkg/dialer/http3/conn.go create mode 100644 pkg/dialer/http3/dialer.go create mode 100644 pkg/dialer/http3/metadata.go diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index a84d4b1..2bab3d9 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -75,6 +75,28 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } } cfg.Services = append(cfg.Services, service) + + md := metadata.MapMetadata(service.Handler.Metadata) + if v := metadata.GetString(md, "resolver"); v != "" { + resolverCfg := &config.ResolverConfig{ + Name: fmt.Sprintf("resolver-%d", len(cfg.Resolvers)), + } + for _, rs := range strings.Split(v, ",") { + if rs == "" { + continue + } + resolverCfg.Nameservers = append( + resolverCfg.Nameservers, + config.NameserverConfig{ + Addr: rs, + }, + ) + } + service.Handler.Resolver = resolverCfg.Name + cfg.Resolvers = append(cfg.Resolvers, resolverCfg) + md.Del("resolver") + } + } return cfg, nil @@ -159,6 +181,10 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { tlsConfig = nil } + if v := metadata.GetString(md, "dns"); v != "" { + md.Set("dns", strings.Split(v, ",")) + } + svc.Handler = &config.HandlerConfig{ Type: handler, Auths: auths, @@ -259,7 +285,7 @@ func normCmd(s string) (*url.URL, error) { return nil, ErrInvalidCmd } - if !strings.Contains(s, "://") { + if s[0] == ':' { s = "auto://" + s } diff --git a/cmd/gost/config.go b/cmd/gost/config.go index bcc6b58..2e0e78c 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -359,7 +359,15 @@ func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) { Hostname: server.Hostname, }) } - return resolver_impl.NewResolver(nameservers) + + logger := log.WithFields(map[string]interface{}{ + "kind": "resolver", + "resolver": cfg.Name, + }) + return resolver_impl.NewResolver( + nameservers, + resolver_impl.LoggerResolverOption(logger), + ) } func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper { diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 70c1f9a..d92366c 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -18,6 +18,7 @@ import ( _ "github.com/go-gost/gost/pkg/dialer/ftcp" _ "github.com/go-gost/gost/pkg/dialer/http2" _ "github.com/go-gost/gost/pkg/dialer/http2/h2" + _ "github.com/go-gost/gost/pkg/dialer/http3" _ "github.com/go-gost/gost/pkg/dialer/kcp" _ "github.com/go-gost/gost/pkg/dialer/obfs/http" _ "github.com/go-gost/gost/pkg/dialer/obfs/tls" diff --git a/pkg/config/config.go b/pkg/config/config.go index 541a3fb..40f1ee4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -57,12 +57,12 @@ type BypassConfig struct { type NameserverConfig struct { Addr string - Chain string - Prefer string - ClientIP string `yaml:"clientIP"` - Hostname string - TTL time.Duration - Timeout time.Duration + Chain string `yaml:",omitempty"` + Prefer string `yaml:",omitempty"` + ClientIP string `yaml:"clientIP,omitempty"` + Hostname string `yaml:",omitempty"` + TTL time.Duration `yaml:",omitempty"` + Timeout time.Duration `yaml:",omitempty"` } type ResolverConfig struct { diff --git a/pkg/dialer/http3/conn.go b/pkg/dialer/http3/conn.go new file mode 100644 index 0000000..6343e64 --- /dev/null +++ b/pkg/dialer/http3/conn.go @@ -0,0 +1,148 @@ +package http3 + +import ( + "bufio" + "bytes" + "encoding/base64" + "errors" + "fmt" + "net" + "net/http" + "time" + + "github.com/go-gost/gost/pkg/logger" +) + +type conn struct { + cid string + addr string + client *http.Client + buf []byte + rxc chan []byte + closed chan struct{} + md metadata + logger logger.Logger +} + +func (c *conn) Read(b []byte) (n int, err error) { + if len(c.buf) == 0 { + select { + case c.buf = <-c.rxc: + case <-c.closed: + err = net.ErrClosed + return + } + } + + n = copy(b, c.buf) + c.buf = c.buf[n:] + + return +} + +func (c *conn) Write(b []byte) (n int, err error) { + if len(b) == 0 { + return + } + + buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b)) + buf.WriteByte('\n') + + url := fmt.Sprintf("https://%s%s?token=%s", c.addr, c.md.pushPath, c.cid) + r, err := http.NewRequest(http.MethodPost, url, buf) + if err != nil { + return + } + + resp, err := c.client.Do(r) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + err = errors.New(resp.Status) + return + } + + n = len(b) + return +} + +func (c *conn) readLoop() { + defer c.Close() + + url := fmt.Sprintf("https://%s%s?token=%s", c.addr, c.md.pullPath, c.cid) + for { + err := func() error { + r, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return err + } + + resp, err := c.client.Do(r) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + b, err := base64.StdEncoding.DecodeString(scanner.Text()) + if err != nil { + return err + } + select { + case c.rxc <- b: + case <-c.closed: + return net.ErrClosed + } + } + + return scanner.Err() + }() + + if err != nil { + c.logger.Error(err) + return + } + } +} + +func (c *conn) LocalAddr() net.Addr { + return &net.TCPAddr{} +} + +func (c *conn) RemoteAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", c.addr) + if addr == nil { + addr = &net.TCPAddr{} + } + + return addr +} + +func (c *conn) Close() error { + select { + case <-c.closed: + default: + close(c.closed) + } + return nil +} + +func (c *conn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *conn) SetWriteDeadline(t time.Time) error { + return nil +} + +func (c *conn) SetDeadline(t time.Time) error { + return nil +} diff --git a/pkg/dialer/http3/dialer.go b/pkg/dialer/http3/dialer.go new file mode 100644 index 0000000..0303638 --- /dev/null +++ b/pkg/dialer/http3/dialer.go @@ -0,0 +1,112 @@ +package http3 + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "strings" + "time" + + "github.com/go-gost/gost/pkg/dialer" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/lucas-clemente/quic-go/http3" +) + +func init() { + registry.RegisterDialer("http3", NewDialer) +} + +type http3Dialer struct { + client *http.Client + md metadata + logger logger.Logger + options dialer.Options +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := dialer.Options{} + for _, opt := range opts { + opt(&options) + } + + tr := &http3.RoundTripper{ + TLSClientConfig: options.TLSConfig, + } + client := &http.Client{ + Timeout: 60 * time.Second, + Transport: tr, + } + return &http3Dialer{ + client: client, + logger: options.Logger, + options: options, + } +} + +func (d *http3Dialer) Init(md md.Metadata) (err error) { + return d.parseMetadata(md) +} + +func (d *http3Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { + token, err := d.authorize(ctx, addr) + if err != nil { + d.logger.Error(err) + return nil, err + } + + c := &conn{ + cid: token, + addr: addr, + client: d.client, + rxc: make(chan []byte, 128), + closed: make(chan struct{}), + md: d.md, + logger: d.logger, + } + go c.readLoop() + + return c, nil +} + +func (d *http3Dialer) authorize(ctx context.Context, addr string) (token string, err error) { + url := fmt.Sprintf("https://%s%s", addr, d.md.authorizePath) + r, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return + } + + if d.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + d.logger.Debug(string(dump)) + } + + resp, err := d.client.Do(r) + if err != nil { + return + } + defer resp.Body.Close() + + if d.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + d.logger.Debug(string(dump)) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return + } + + if strings.HasPrefix(string(data), "token=") { + token = strings.TrimPrefix(string(data), "token=") + } + if token == "" { + err = errors.New("authorize failed") + } + return +} diff --git a/pkg/dialer/http3/metadata.go b/pkg/dialer/http3/metadata.go new file mode 100644 index 0000000..615ade8 --- /dev/null +++ b/pkg/dialer/http3/metadata.go @@ -0,0 +1,48 @@ +package http3 + +import ( + "strings" + "time" + + mdata "github.com/go-gost/gost/pkg/metadata" +) + +const ( + dialTimeout = "dialTimeout" + defaultAuthorizePath = "/authorize" + defaultPushPath = "/push" + defaultPullPath = "/pull" +) + +const ( + defaultDialTimeout = 5 * time.Second +) + +type metadata struct { + dialTimeout time.Duration + authorizePath string + pushPath string + pullPath string +} + +func (d *http3Dialer) parseMetadata(md mdata.Metadata) (err error) { + const ( + authorizePath = "authorizePath" + pushPath = "pushPath" + pullPath = "pullPath" + ) + + d.md.authorizePath = mdata.GetString(md, authorizePath) + if !strings.HasPrefix(d.md.authorizePath, "/") { + d.md.authorizePath = defaultAuthorizePath + } + d.md.pushPath = mdata.GetString(md, pushPath) + if !strings.HasPrefix(d.md.pushPath, "/") { + d.md.pushPath = defaultPushPath + } + d.md.pullPath = mdata.GetString(md, pullPath) + if !strings.HasPrefix(d.md.pullPath, "/") { + d.md.pullPath = defaultPullPath + } + return +} diff --git a/pkg/dialer/pht/dialer.go b/pkg/dialer/pht/dialer.go index 03a0669..af9bf36 100644 --- a/pkg/dialer/pht/dialer.go +++ b/pkg/dialer/pht/dialer.go @@ -24,6 +24,7 @@ func init() { type phtDialer struct { tlsEnabled bool + client *http.Client md metadata logger logger.Logger options dialer.Options @@ -55,10 +56,10 @@ func NewTLSDialer(opts ...dialer.Option) dialer.Dialer { } func (d *phtDialer) Init(md md.Metadata) (err error) { - return d.parseMetadata(md) -} + if err = d.parseMetadata(md); err != nil { + return + } -func (d *phtDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { tr := &http.Transport{ // Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ @@ -75,11 +76,15 @@ func (d *phtDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp tr.TLSClientConfig = d.options.TLSConfig } - client := &http.Client{ + d.client = &http.Client{ Timeout: 60 * time.Second, Transport: tr, } - token, err := d.authorize(ctx, client, addr) + return nil +} + +func (d *phtDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { + token, err := d.authorize(ctx, addr) if err != nil { d.logger.Error(err) return nil, err @@ -88,7 +93,7 @@ func (d *phtDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp c := &conn{ cid: token, addr: addr, - client: client, + client: d.client, tlsEnabled: d.tlsEnabled, rxc: make(chan []byte, 128), closed: make(chan struct{}), @@ -100,7 +105,7 @@ func (d *phtDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp return c, nil } -func (d *phtDialer) authorize(ctx context.Context, client *http.Client, addr string) (token string, err error) { +func (d *phtDialer) authorize(ctx context.Context, addr string) (token string, err error) { var url string if d.tlsEnabled { url = fmt.Sprintf("https://%s%s", addr, d.md.authorizePath) @@ -117,7 +122,7 @@ func (d *phtDialer) authorize(ctx context.Context, client *http.Client, addr str d.logger.Debug(string(dump)) } - resp, err := client.Do(r) + resp, err := d.client.Do(r) if err != nil { return } diff --git a/pkg/handler/dns/handler.go b/pkg/handler/dns/handler.go index d49d19d..002df53 100644 --- a/pkg/handler/dns/handler.go +++ b/pkg/handler/dns/handler.go @@ -63,7 +63,7 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { } h.logger = h.options.Logger - for _, server := range h.md.servers { + for _, server := range h.md.dns { server = strings.TrimSpace(server) if server == "" { continue diff --git a/pkg/handler/dns/metadata.go b/pkg/handler/dns/metadata.go index 4433edd..c9b3628 100644 --- a/pkg/handler/dns/metadata.go +++ b/pkg/handler/dns/metadata.go @@ -2,7 +2,6 @@ package dns import ( "net" - "strings" "time" mdata "github.com/go-gost/gost/pkg/metadata" @@ -14,8 +13,7 @@ type metadata struct { timeout time.Duration clientIP net.IP // nameservers - servers []string - dns []string // compatible with v2 + dns []string } func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { @@ -24,7 +22,6 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { ttl = "ttl" timeout = "timeout" clientIP = "clientIP" - servers = "servers" dns = "dns" ) @@ -38,11 +35,7 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { if sip != "" { h.md.clientIP = net.ParseIP(sip) } - h.md.servers = mdata.GetStrings(md, servers) - h.md.dns = strings.Split(mdata.GetString(md, dns), ",") - if len(h.md.dns) > 0 { - h.md.servers = append(h.md.servers, h.md.dns...) - } + h.md.dns = mdata.GetStrings(md, dns) return } diff --git a/pkg/resolver/exchanger/exchanger.go b/pkg/resolver/exchanger/exchanger.go index b592b77..0a5933b 100644 --- a/pkg/resolver/exchanger/exchanger.go +++ b/pkg/resolver/exchanger/exchanger.go @@ -71,6 +71,8 @@ type exchanger struct { } // NewExchanger create an Exchanger. +// The addr should be URL-like format, +// e.g. udp://1.1.1.1:53, tls://1.1.1.1:853, https://1.0.0.1/dns-query func NewExchanger(addr string, opts ...Option) (Exchanger, error) { var options Options for _, opt := range opts { @@ -85,6 +87,10 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) { return nil, err } + if options.timeout <= 0 { + options.timeout = 5 * time.Second + } + ex := &exchanger{ network: u.Scheme, addr: u.Host,