package dns import ( "context" "math" "net" "sync" "time" ) // NewCachingResolver creates a caching [net.Resolver] that uses parent to resolve names. func NewCachingResolver(parent *net.Resolver, options ...CacheOption) *net.Resolver { if parent == nil { parent = &net.Resolver{} } return &net.Resolver{ PreferGo: true, StrictErrors: parent.StrictErrors, Dial: NewCachingDialer(parent.Dial, options...), } } // NewCachingDialer adds caching to a [net.Resolver.Dial] function. func NewCachingDialer(parent DialFunc, options ...CacheOption) DialFunc { var cache = cache{dial: parent, negative: true} for _, o := range options { o.apply(&cache) } if cache.maxEntries == 0 { cache.maxEntries = DefaultMaxCacheEntries } return func(ctx context.Context, network, address string) (net.Conn, error) { conn := &dnsConn{} conn.roundTrip = cachingRoundTrip(&cache, network, address) return conn, nil } } const DefaultMaxCacheEntries = 150 // A CacheOption customizes the resolver cache. type CacheOption interface { apply(*cache) } type maxEntriesOption int type maxTTLOption time.Duration type minTTLOption time.Duration type negativeCacheOption bool func (o maxEntriesOption) apply(c *cache) { c.maxEntries = int(o) } func (o maxTTLOption) apply(c *cache) { c.maxTTL = time.Duration(o) } func (o minTTLOption) apply(c *cache) { c.minTTL = time.Duration(o) } func (o negativeCacheOption) apply(c *cache) { c.negative = bool(o) } // MaxCacheEntries sets the maximum number of entries to cache. // If zero, [DefaultMaxCacheEntries] is used; negative means no limit. func MaxCacheEntries(n int) CacheOption { return maxEntriesOption(n) } // MaxCacheTTL sets the maximum time-to-live for entries in the cache. func MaxCacheTTL(d time.Duration) CacheOption { return maxTTLOption(d) } // MinCacheTTL sets the minimum time-to-live for entries in the cache. func MinCacheTTL(d time.Duration) CacheOption { return minTTLOption(d) } // NegativeCache sets whether to cache negative responses. func NegativeCache(b bool) CacheOption { return negativeCacheOption(b) } type cache struct { sync.RWMutex dial DialFunc entries map[string]cacheEntry maxEntries int maxTTL time.Duration minTTL time.Duration negative bool } type cacheEntry struct { deadline time.Time value string } func (c *cache) put(req string, res string) { // ignore uncacheable/unparseable answers if invalid(req, res) { return } // ignore errors (if requested) if nameError(res) && !c.negative { return } // ignore uncacheable/unparseable answers ttl := getTTL(res) if ttl <= 0 { return } // adjust TTL if ttl < c.minTTL { ttl = c.minTTL } // maxTTL overrides minTTL if ttl > c.maxTTL && c.maxTTL != 0 { ttl = c.maxTTL } c.Lock() defer c.Unlock() if c.entries == nil { c.entries = make(map[string]cacheEntry) } // do some cache evition var tested, evicted int for k, e := range c.entries { if time.Until(e.deadline) <= 0 { // delete expired entry delete(c.entries, k) evicted++ } tested++ if tested < 8 { continue } if evicted == 0 && c.maxEntries > 0 && len(c.entries) >= c.maxEntries { // delete at least one entry delete(c.entries, k) } break } // remove message IDs c.entries[req[2:]] = cacheEntry{ deadline: time.Now().Add(ttl), value: res[2:], } } func (c *cache) get(req string) (res string) { // ignore invalid messages if len(req) < 12 { return "" } if req[2] >= 0x7f { return "" } c.RLock() defer c.RUnlock() if c.entries == nil { return "" } // remove message ID entry, ok := c.entries[req[2:]] if ok && time.Until(entry.deadline) > 0 { // prepend correct ID return req[:2] + entry.value } return "" } func invalid(req string, res string) bool { if len(req) < 12 || len(res) < 12 { // header size return true } if req[0] != res[0] || req[1] != res[1] { // IDs match return true } if req[2] >= 0x7f || res[2] < 0x7f { // query, response return true } if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated return true } if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error return true } return false } func nameError(res string) bool { return res[3]&0xf == 3 } func getTTL(msg string) time.Duration { ttl := math.MaxInt32 qdcount := getUint16(msg[4:]) ancount := getUint16(msg[6:]) nscount := getUint16(msg[8:]) arcount := getUint16(msg[10:]) rdcount := ancount + nscount + arcount msg = msg[12:] // skip header // skip questions for i := 0; i < qdcount; i++ { name := getNameLen(msg) if name < 0 || name+4 > len(msg) { return -1 } msg = msg[name+4:] } // parse records for i := 0; i < rdcount; i++ { name := getNameLen(msg) if name < 0 || name+10 > len(msg) { return -1 } rtyp := getUint16(msg[name+0:]) rttl := getUint32(msg[name+4:]) rlen := getUint16(msg[name+8:]) if name+10+rlen > len(msg) { return -1 } // skip EDNS OPT since it doesn't have a TTL if rtyp != 41 && rttl < ttl { ttl = rttl } msg = msg[name+10+rlen:] } return time.Duration(ttl) * time.Second } func getNameLen(msg string) int { i := 0 for i < len(msg) { if msg[i] == 0 { // end of name i += 1 break } if msg[i] >= 0xc0 { // compressed name i += 2 break } if msg[i] >= 0x40 { // reserved return -1 } i += int(msg[i] + 1) } return i } func getUint16(s string) int { return int(s[1]) | int(s[0])<<8 } func getUint32(s string) int { return int(s[3]) | int(s[2])<<8 | int(s[1])<<16 | int(s[0])<<24 } func cachingRoundTrip(cache *cache, network, address string) roundTripper { return func(ctx context.Context, req string) (res string, err error) { // check cache if res := cache.get(req); res != "" { return res, nil } // dial connection var conn net.Conn if cache.dial != nil { conn, err = cache.dial(ctx, network, address) } else { var d net.Dialer conn, err = d.DialContext(ctx, network, address) } if err != nil { return "", err } ctx, cancel := context.WithCancel(ctx) go func() { <-ctx.Done() conn.Close() }() defer cancel() if t, ok := ctx.Deadline(); ok { err = conn.SetDeadline(t) if err != nil { return "", err } } // send request err = writeMessage(conn, req) if err != nil { return "", err } // read response res, err = readMessage(conn) if err != nil { return "", err } // cache response cache.put(req, res) return res, nil } }
package dns import ( "bytes" "context" "io" "net" "strings" "sync" "time" ) type dnsConn struct { sync.Mutex ibuf bytes.Buffer obuf bytes.Buffer ctx context.Context cancel context.CancelFunc deadline time.Time roundTrip roundTripper } type roundTripper func(ctx context.Context, req string) (res string, err error) func (c *dnsConn) Read(b []byte) (n int, err error) { imsg, n, err := c.drainBuffers(b) if n != 0 || err != nil { return n, err } ctx, cancel := c.childContext() omsg, err := c.roundTrip(ctx, imsg) cancel() if err != nil { return 0, err } return c.fillBuffer(b, omsg) } func (c *dnsConn) Write(b []byte) (n int, err error) { c.Lock() defer c.Unlock() return c.ibuf.Write(b) } func (c *dnsConn) Close() error { c.Lock() cancel := c.cancel c.Unlock() if cancel != nil { cancel() } return nil } func (c *dnsConn) LocalAddr() net.Addr { return nil } func (c *dnsConn) RemoteAddr() net.Addr { return nil } func (c *dnsConn) SetDeadline(t time.Time) error { c.SetReadDeadline(t) c.SetWriteDeadline(t) return nil } func (c *dnsConn) SetReadDeadline(t time.Time) error { c.Lock() defer c.Unlock() c.deadline = t return nil } func (c *dnsConn) SetWriteDeadline(t time.Time) error { // writes do not timeout return nil } func (c *dnsConn) drainBuffers(b []byte) (string, int, error) { c.Lock() defer c.Unlock() // drain the output buffer if c.obuf.Len() > 0 { n, err := c.obuf.Read(b) return "", n, err } // otherwise, get the next message from the input buffer sz := c.ibuf.Next(2) if len(sz) < 2 { return "", 0, io.ErrUnexpectedEOF } size := int64(sz[0])<<8 | int64(sz[1]) var str strings.Builder _, err := io.CopyN(&str, &c.ibuf, size) if err == io.EOF { return "", 0, io.ErrUnexpectedEOF } if err != nil { return "", 0, err } return str.String(), 0, nil } func (c *dnsConn) fillBuffer(b []byte, str string) (int, error) { c.Lock() defer c.Unlock() c.obuf.WriteByte(byte(len(str) >> 8)) c.obuf.WriteByte(byte(len(str))) c.obuf.WriteString(str) return c.obuf.Read(b) } func (c *dnsConn) childContext() (context.Context, context.CancelFunc) { c.Lock() defer c.Unlock() if c.ctx == nil { c.ctx, c.cancel = context.WithCancel(context.Background()) } return context.WithDeadline(c.ctx, c.deadline) } func writeMessage(conn net.Conn, msg string) error { var buf []byte if _, ok := conn.(net.PacketConn); ok { buf = []byte(msg) } else { buf = make([]byte, len(msg)+2) buf[0] = byte(len(msg) >> 8) buf[1] = byte(len(msg)) copy(buf[2:], msg) } // SHOULD do a single write on TCP (RFC 7766, section 8). // MUST do a single write on UDP. _, err := conn.Write(buf) return err } func readMessage(c net.Conn) (string, error) { if _, ok := c.(net.PacketConn); ok { // RFC 1035 specifies 512 as the maximum message size for DNS over UDP. // RFC 6891 OTOH suggests 4096 as the maximum payload size for EDNS. b := make([]byte, 4096) n, err := c.Read(b) if err != nil { return "", err } return string(b[:n]), nil } else { var sz [2]byte _, err := io.ReadFull(c, sz[:]) if err != nil { return "", err } size := int64(sz[0])<<8 | int64(sz[1]) var str strings.Builder _, err = io.CopyN(&str, c, size) if err == io.EOF { return "", io.ErrUnexpectedEOF } if err != nil { return "", err } return str.String(), nil } }
// Package dns provides [net.Resolver] instances implementing caching, // opportunistic encryption, and DNS over TLS/HTTPS. // // To replace the [net.DefaultResolver] with a caching DNS over HTTPS instance // using the Google Public DNS resolver: // // net.DefaultResolver = dns.NewDoHResolver( // "https://dns.google/dns-query", // dns.DoHCache()) package dns import ( "context" "crypto/tls" "net" "sync" "time" ) // OpportunisticResolver opportunistically tries encrypted DNS over TLS // using the local resolver. var OpportunisticResolver = &net.Resolver{ Dial: opportunisticDial, PreferGo: true, } func opportunisticDial(ctx context.Context, network, address string) (net.Conn, error) { host, port, _ := net.SplitHostPort(address) if (port == "53" || port == "domain") && notBadServer(address) { deadline, ok := ctx.Deadline() if ok && deadline.After(time.Now().Add(2*time.Second)) { var d net.Dialer d.Timeout = time.Second tlsAddr := net.JoinHostPort(host, "853") tlsConf := tls.Config{InsecureSkipVerify: true} conn, _ := tls.DialWithDialer(&d, "tcp", tlsAddr, &tlsConf) if conn != nil { return conn, nil } addBadServer(address) } } var d net.Dialer return d.DialContext(ctx, network, address) } var badServers struct { sync.Mutex next int list [4]string } func notBadServer(address string) bool { badServers.Lock() defer badServers.Unlock() for _, a := range badServers.list { if a == address { return false } } return true } func addBadServer(address string) { badServers.Lock() defer badServers.Unlock() for _, a := range badServers.list { if a == address { return } } badServers.list[badServers.next] = address badServers.next = (badServers.next + 1) % len(badServers.list) } // DialFunc is a [net.Resolver.Dial] function. type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
package dns import ( "context" "errors" "io" "net" "net/http" "net/url" "strings" "sync/atomic" "time" ) // NewDoHResolver creates a DNS over HTTPS resolver. // The uri may be an URI Template. func NewDoHResolver(uri string, options ...DoHOption) (*net.Resolver, error) { // parse the uri template into a url uri, err := parseURITemplate(uri) if err != nil { return nil, err } url, err := url.Parse(uri) if err != nil { return nil, err } port := url.Port() if port == "" { port = url.Scheme } // apply options var opts dohOpts for _, o := range options { o.apply(&opts) } // resolve server network addresses if len(opts.addrs) == 0 { ips, err := OpportunisticResolver.LookupIPAddr(context.Background(), url.Hostname()) if err != nil { return nil, err } opts.addrs = make([]string, len(ips)) for i, ip := range ips { opts.addrs[i] = net.JoinHostPort(ip.String(), port) } } else { for i, a := range opts.addrs { if net.ParseIP(a) != nil { opts.addrs[i] = net.JoinHostPort(a, port) } } } // setup the http transport if opts.transport == nil { opts.transport = &http.Transport{ MaxIdleConns: http.DefaultMaxIdleConnsPerHost, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ForceAttemptHTTP2: true, } } else { opts.transport = opts.transport.Clone() } // setup the http client client := http.Client{ Transport: opts.transport, } // create the resolver var resolver = net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { conn := &dnsConn{} conn.roundTrip = dohRoundTrip(uri, &client) return conn, nil }, } // setup dialer var index atomic.Uint32 opts.transport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { var d net.Dialer i := index.Load() conn, err := d.DialContext(ctx, network, opts.addrs[i]) if err != nil { index.CompareAndSwap(i, (i+1)%uint32(len(opts.addrs))) } return conn, err } // setup caching if opts.cache { resolver.Dial = NewCachingDialer(resolver.Dial, opts.cacheOpts...) } return &resolver, nil } // A DoHOption customizes the DNS over HTTPS resolver. type DoHOption interface { apply(*dohOpts) } type dohOpts struct { transport *http.Transport addrs []string cache bool cacheOpts []CacheOption } type ( dohTransport http.Transport dohAddresses []string dohCache []CacheOption ) func (o *dohTransport) apply(t *dohOpts) { t.transport = (*http.Transport)(o) } func (o dohAddresses) apply(t *dohOpts) { t.addrs = ([]string)(o) } func (o dohCache) apply(t *dohOpts) { t.cache = true; t.cacheOpts = ([]CacheOption)(o) } // DoHTransport sets the http.Transport used by the resolver. func DoHTransport(transport *http.Transport) DoHOption { return (*dohTransport)(transport) } // DoHAddresses sets the network addresses of the resolver. // These should be IP addresses, or network addresses of the form "IP:port". // This avoids having to resolve the resolver's addresses, improving performance and privacy. func DoHAddresses(addresses ...string) DoHOption { return dohAddresses(addresses) } // DoHCache adds caching to the resolver, with the given options. func DoHCache(options ...CacheOption) DoHOption { return dohCache(options) } func dohRoundTrip(uri string, client *http.Client) roundTripper { return func(ctx context.Context, msg string) (string, error) { // prepare request req, err := http.NewRequestWithContext(ctx, http.MethodPost, uri, strings.NewReader(msg)) if err != nil { return "", err } req.Header.Set("Content-Type", "application/dns-message") // send request res, err := client.Do(req) if err != nil { return "", err } defer res.Body.Close() if res.StatusCode != http.StatusOK { return "", errors.New(http.StatusText(res.StatusCode)) } // read response var str strings.Builder _, err = io.Copy(&str, res.Body) if err != nil { return "", err } return str.String(), nil } } func parseURITemplate(uri string) (string, error) { var str strings.Builder var exp bool for i := 0; i < len(uri); i++ { switch c := uri[i]; c { case '{': if exp { return "", errors.New("uri: invalid syntax") } exp = true case '}': if !exp { return "", errors.New("uri: invalid syntax") } exp = false default: if !exp { str.WriteByte(c) } } } return str.String(), nil }
package dns import ( "context" "crypto/tls" "net" "sync/atomic" ) // NewDoTResolver creates a DNS over TLS resolver. // The server can be an IP address, a host name, or a network address of the form "host:port". func NewDoTResolver(server string, options ...DoTOption) (*net.Resolver, error) { // look for a custom port host, port, err := net.SplitHostPort(server) if err != nil { port = "853" } else { server = host } // apply options var opts dotOpts for _, o := range options { o.apply(&opts) } // resolve server network addresses if len(opts.addrs) == 0 { ips, err := OpportunisticResolver.LookupIPAddr(context.Background(), server) if err != nil { return nil, err } opts.addrs = make([]string, len(ips)) for i, ip := range ips { opts.addrs[i] = net.JoinHostPort(ip.String(), port) } } else { for i, a := range opts.addrs { if net.ParseIP(a) != nil { opts.addrs[i] = net.JoinHostPort(a, port) } } } // setup TLS config if opts.config == nil { opts.config = &tls.Config{ ClientSessionCache: tls.NewLRUClientSessionCache(len(opts.addrs)), } } else { opts.config = opts.config.Clone() } if opts.config.ServerName == "" { opts.config.ServerName = server } // setup the dialFunc if opts.dialFunc == nil { var d net.Dialer opts.dialFunc = d.DialContext } // create the resolver var resolver = net.Resolver{PreferGo: true} // setup dialer var index atomic.Uint32 resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { i := index.Load() conn, err := opts.dialFunc(ctx, "tcp", opts.addrs[i]) if err != nil { index.CompareAndSwap(i, (i+1)%uint32(len(opts.addrs))) return nil, err } return tls.Client(conn, opts.config), nil } // setup caching if opts.cache { resolver.Dial = NewCachingDialer(resolver.Dial, opts.cacheOpts...) } return &resolver, nil } // A DoTOption customizes the DNS over TLS resolver. type DoTOption interface { apply(*dotOpts) } type dotOpts struct { config *tls.Config addrs []string cache bool cacheOpts []CacheOption dialFunc DialFunc } type ( dotConfig tls.Config dotAddresses []string dotCache []CacheOption dotDialFunc DialFunc ) func (o *dotConfig) apply(t *dotOpts) { t.config = (*tls.Config)(o) } func (o dotAddresses) apply(t *dotOpts) { t.addrs = ([]string)(o) } func (o dotCache) apply(t *dotOpts) { t.cache = true; t.cacheOpts = ([]CacheOption)(o) } func (o dotDialFunc) apply(t *dotOpts) { t.dialFunc = (DialFunc)(o) } // DoTConfig sets the tls.Config used by the resolver. func DoTConfig(config *tls.Config) DoTOption { return (*dotConfig)(config) } // DoTAddresses sets the network addresses of the resolver. // These should be IP addresses, or network addresses of the form "IP:port". // This avoids having to resolve the resolver's addresses, improving performance and privacy. func DoTAddresses(addresses ...string) DoTOption { return dotAddresses(addresses) } // DoTCache adds caching to the resolver, with the given options. func DoTCache(options ...CacheOption) DoTOption { return dotCache(options) } // DoTDialFunc sets the DialFunc used by the resolver. // By default [net.Dialer.DialContext] is used. func DoTDialFunc(f DialFunc) DoTOption { return dotDialFunc(f) }