diff --git a/gost.yml b/gost.yml index 87cc83c..189952a 100644 --- a/gost.yml +++ b/gost.yml @@ -14,16 +14,17 @@ profiling: resolvers: - name: resolver-0 - ttl: 60s - prefer: ipv4 - clientIP: 1.2.3.4 - nameServers: + nameservers: - addr: udp://8.8.8.8:53 - timeout: 5s + chain: chain-0 + ttl: 60s + prefer: ipv4 + clientIP: 1.2.3.4 + timeout: 3s - addr: tcp://1.1.1.1:53 - addr: tls://1.1.1.1:853 - addr: https://1.0.0.1/dns-query - domain: cloudflare-dns.com + hostname: cloudflare-dns.com services: - name: http+tcp diff --git a/pkg/handler/dns/handler.go b/pkg/handler/dns/handler.go index 1ac65b1..632dd16 100644 --- a/pkg/handler/dns/handler.go +++ b/pkg/handler/dns/handler.go @@ -6,12 +6,14 @@ import ( "errors" "net" "strconv" + "strings" "time" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/handler" + resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" @@ -27,6 +29,7 @@ type dnsHandler struct { chain *chain.Chain bypass bypass.Bypass exchangers []exchanger.Exchanger + cache *resolver_util.Cache logger logger.Logger md metadata } @@ -37,8 +40,11 @@ func NewHandler(opts ...handler.Option) handler.Handler { opt(options) } + cache := resolver_util.NewCache().WithLogger(options.Logger) + return &dnsHandler{ bypass: options.Bypass, + cache: cache, logger: options.Logger, } } @@ -49,9 +55,14 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { } for _, server := range h.md.servers { + server = strings.TrimSpace(server) + if server == "" { + continue + } ex, err := exchanger.NewExchanger( server, exchanger.ChainOption(h.chain), + exchanger.TimeoutOption(h.md.timeout), exchanger.LoggerOption(h.logger), ) if err != nil { @@ -61,14 +72,18 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { h.exchangers = append(h.exchangers, ex) } if len(h.exchangers) == 0 { - ex, _ := exchanger.NewExchanger( - "udp://127.0.0.53:53", + addr := "udp://127.0.0.1:53" + ex, err := exchanger.NewExchanger( + addr, exchanger.ChainOption(h.chain), + exchanger.TimeoutOption(h.md.timeout), exchanger.LoggerOption(h.logger), ) - if ex != nil { - h.exchangers = append(h.exchangers, ex) + h.logger.Warnf("resolver not found, default to %s", addr) + if err != nil { + return err } + h.exchangers = append(h.exchangers, ex) } return } @@ -106,7 +121,6 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { reply, err := h.exchange(ctx, b[:n]) if err != nil { - h.logger.Error(err) return } @@ -118,6 +132,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { mq := dns.Msg{} if err := mq.Unpack(msg); err != nil { + h.logger.Error(err) return nil, err } @@ -125,6 +140,8 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { return nil, errors.New("msg: empty question") } + resolver_util.AddSubnetOpt(&mq, h.md.clientIP) + if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(mq.String()) } else { @@ -132,26 +149,22 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { } var mr *dns.Msg - // Only cache for single question. - /* - if len(mq.Question) == 1 { - key := newResolverCacheKey(&mq.Question[0]) - mr = r.cache.loadCache(key) - if mr != nil { - log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) - mr.Id = mq.Id - return mr.Pack() - } - - defer func() { - if mr != nil { - r.cache.storeCache(key, mr, r.TTL()) - } - }() + // cache only for single question message. + if len(mq.Question) == 1 { + key := resolver_util.NewCacheKey(&mq.Question[0]) + mr = h.cache.Load(key) + if mr != nil { + h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) + mr.Id = mq.Id + return mr.Pack() } - */ - // r.addSubnetOpt(mq) + defer func() { + if mr != nil { + h.cache.Store(key, mr, h.md.ttl) + } + }() + } query, err := mq.Pack() if err != nil { @@ -169,7 +182,6 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { h.logger.Error(err) } if err != nil { - h.logger.Error(err) return nil, err } diff --git a/pkg/handler/dns/metadata.go b/pkg/handler/dns/metadata.go index c077d7d..26ef5e2 100644 --- a/pkg/handler/dns/metadata.go +++ b/pkg/handler/dns/metadata.go @@ -1,6 +1,8 @@ package dns import ( + "net" + "strings" "time" mdata "github.com/go-gost/gost/pkg/metadata" @@ -11,10 +13,10 @@ type metadata struct { retryCount int ttl time.Duration timeout time.Duration - prefer string - clientIP string + clientIP net.IP // nameservers servers []string + dns []string // compatible with v2 } func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { @@ -23,9 +25,9 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { retryCount = "retry" ttl = "ttl" timeout = "timeout" - prefer = "prefer" clientIP = "clientIP" servers = "servers" + dns = "dns" ) h.md.readTimeout = mdata.GetDuration(md, readTimeout) @@ -35,9 +37,15 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { if h.md.timeout <= 0 { h.md.timeout = 5 * time.Second } - h.md.prefer = mdata.GetString(md, prefer) - h.md.clientIP = mdata.GetString(md, clientIP) + sip := mdata.GetString(md, clientIP) + if sip != "" { + h.md.clientIP = net.ParseIP(sip) + } h.md.servers = mdata.GetStrings(md, servers) + h.md.dns = strings.Split(mdata.GetString(md, dns), ",") + if len(h.md.dns) > 0 { + h.md.servers = append(h.md.servers, h.md.dns...) + } return } diff --git a/pkg/internal/util/resolver/cache.go b/pkg/internal/util/resolver/cache.go new file mode 100644 index 0000000..eef779e --- /dev/null +++ b/pkg/internal/util/resolver/cache.go @@ -0,0 +1,86 @@ +package resolver + +import ( + "fmt" + "sync" + "time" + + "github.com/go-gost/gost/pkg/logger" + "github.com/miekg/dns" +) + +type CacheKey string + +// NewCacheKey generates resolver cache key from question of dns query. +func NewCacheKey(q *dns.Question) CacheKey { + if q == nil { + return "" + } + key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String()) + return CacheKey(key) +} + +type cacheItem struct { + msg *dns.Msg + ts time.Time + ttl time.Duration +} + +type Cache struct { + m sync.Map + logger logger.Logger +} + +func NewCache() *Cache { + return &Cache{} +} + +func (c *Cache) WithLogger(logger logger.Logger) *Cache { + c.logger = logger + return c +} + +func (c *Cache) Load(key CacheKey) *dns.Msg { + v, ok := c.m.Load(key) + if !ok { + return nil + } + + item, ok := v.(*cacheItem) + if !ok { + return nil + } + + elapsed := time.Since(item.ts) + if item.ttl > 0 { + if elapsed > item.ttl { + c.m.Delete(key) + return nil + } + } else { + for _, rr := range item.msg.Answer { + if elapsed > time.Duration(rr.Header().Ttl)*time.Second { + c.m.Delete(key) + return nil + } + } + } + + c.logger.Debugf("resolver cache hit %s", key) + + return item.msg.Copy() +} + +func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) { + if key == "" || mr == nil || ttl < 0 { + return + } + + c.m.Store(key, &cacheItem{ + msg: mr.Copy(), + ts: time.Now(), + ttl: ttl, + }) + + c.logger.Debugf("resolver cache store %s", key) +} diff --git a/pkg/internal/util/resolver/resolver.go b/pkg/internal/util/resolver/resolver.go new file mode 100644 index 0000000..74ec536 --- /dev/null +++ b/pkg/internal/util/resolver/resolver.go @@ -0,0 +1,30 @@ +package resolver + +import ( + "net" + + "github.com/miekg/dns" +) + +func AddSubnetOpt(m *dns.Msg, ip net.IP) { + if m == nil || ip == nil { + return + } + + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + e := new(dns.EDNS0_SUBNET) + e.Code = dns.EDNS0SUBNET + if ip := ip.To4(); ip != nil { + e.Family = 1 + e.SourceNetmask = 24 + e.Address = ip + } else { + e.Family = 2 + e.SourceNetmask = 128 + e.Address = ip.To16() + } + opt.Option = append(opt.Option, e) + m.Extra = append(m.Extra, opt) +} diff --git a/pkg/resolver/exchanger/exchanger.go b/pkg/resolver/exchanger/exchanger.go index feba5e0..c0c6d94 100644 --- a/pkg/resolver/exchanger/exchanger.go +++ b/pkg/resolver/exchanger/exchanger.go @@ -107,7 +107,7 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) { } } ex.network = "tcp" - case "doh": + case "https": ex.addr = addr if ex.options.tlsConfig == nil { ex.options.tlsConfig = &tls.Config{ @@ -134,7 +134,7 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) { } func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) { - if ex.network == "doh" { + if ex.network == "https" { return ex.dohExchange(ctx, msg) } return ex.exchange(ctx, msg) diff --git a/pkg/resolver/ns.go b/pkg/resolver/ns.go deleted file mode 100644 index 4e1c8a8..0000000 --- a/pkg/resolver/ns.go +++ /dev/null @@ -1,13 +0,0 @@ -package resolver - -import ( - "time" -) - -type NameServer struct { - Addr string - Protocol string - Hostname string // for TLS handshake verification - Exchanger Exchanger - Timeout time.Duration -} diff --git a/pkg/resolver/resolver.go b/pkg/resolver/resolver.go index 6b6108f..548868a 100644 --- a/pkg/resolver/resolver.go +++ b/pkg/resolver/resolver.go @@ -3,9 +3,178 @@ package resolver import ( "context" "net" + "strings" + "time" + + "github.com/go-gost/gost/pkg/chain" + resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver" + "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/resolver/exchanger" + "github.com/miekg/dns" ) type Resolver interface { // Resolve returns a slice of the host's IPv4 and IPv6 addresses. Resolve(ctx context.Context, host string) ([]net.IP, error) } + +type NameServer struct { + Addr string + Chain *chain.Chain + TTL time.Duration + Timeout time.Duration + ClientIP net.IP + Prefer string + Hostname string // for TLS handshake verification + exchanger exchanger.Exchanger +} + +type resolverOptions struct { + domain string + logger logger.Logger +} + +type ResolverOption func(opts *resolverOptions) + +func DomainResolverOption(domain string) ResolverOption { + return func(opts *resolverOptions) { + opts.domain = domain + } +} + +func LoggerResolverOption(logger logger.Logger) ResolverOption { + return func(opts *resolverOptions) { + opts.logger = logger + } +} + +type resolver struct { + servers []NameServer + cache *resolver_util.Cache + options resolverOptions + logger logger.Logger +} + +func NewResolver(nameservers []NameServer, opts ...ResolverOption) (Resolver, error) { + options := resolverOptions{} + for _, opt := range opts { + opt(&options) + } + + var servers []NameServer + for _, server := range nameservers { + addr := strings.TrimSpace(server.Addr) + if addr == "" { + continue + } + ex, err := exchanger.NewExchanger( + addr, + exchanger.ChainOption(server.Chain), + exchanger.TimeoutOption(server.Timeout), + exchanger.LoggerOption(options.logger), + ) + if err != nil { + options.logger.Warnf("parse %s: %v", server, err) + continue + } + + server.exchanger = ex + servers = append(servers, server) + } + cache := resolver_util.NewCache(). + WithLogger(options.logger) + + return &resolver{ + servers: servers, + cache: cache, + options: options, + logger: options.logger, + }, nil +} + +func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) { + if ip := net.ParseIP(host); ip != nil { + return []net.IP{ip}, nil + } + + if r.options.domain != "" && + !strings.Contains(host, ".") { + host = host + "." + r.options.domain + } + + for _, server := range r.servers { + ips, err = r.resolve(ctx, &server, host) + if err != nil { + r.logger.Error(err) + continue + } + + r.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips) + + if len(ips) > 0 { + break + } + } + + return +} + +func (r *resolver) resolve(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { + if server == nil { + return + } + + if server.Prefer == "ipv6" { // prefer ipv6 + mq := dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) + ips, err = r.resolveIPs(ctx, server, &mq) + if err != nil || len(ips) > 0 { + return + } + } + + // fallback to ipv4 + mq := dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeA) + return r.resolveIPs(ctx, server, &mq) +} + +func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) { + key := resolver_util.NewCacheKey(&mq.Question[0]) + mr := r.cache.Load(key) + if mr == nil { + resolver_util.AddSubnetOpt(mq, server.ClientIP) + mr, err = r.exchange(ctx, server.exchanger, mq) + if err != nil { + return + } + r.cache.Store(key, mr, server.TTL) + } + + for _, ans := range mr.Answer { + if ar, _ := ans.(*dns.AAAA); ar != nil { + ips = append(ips, ar.AAAA) + } + if ar, _ := ans.(*dns.A); ar != nil { + ips = append(ips, ar.A) + } + } + + return +} + +func (r *resolver) exchange(ctx context.Context, ex exchanger.Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { + query, err := mq.Pack() + if err != nil { + return + } + reply, err := ex.Exchange(ctx, query) + if err != nil { + return + } + + mr = &dns.Msg{} + err = mr.Unpack(reply) + + return +}