package dns
import (
"context"
"math"
"net"
"sync"
"time"
)
// NewCachingResolver creates a caching [net.Resolver] that uses parent to resolve names.
func NewCachingResolver(parent *net.Resolver, options ...CacheOption) *net.Resolver {
if parent == nil {
parent = &net.Resolver{}
}
return &net.Resolver{
PreferGo: true,
StrictErrors: parent.StrictErrors,
Dial: NewCachingDialer(parent.Dial, options...),
}
}
// NewCachingDialer adds caching to a [net.Resolver.Dial] function.
func NewCachingDialer(parent DialFunc, options ...CacheOption) DialFunc {
var cache = cache{dial: parent, negative: true}
for _, o := range options {
o.apply(&cache)
}
if cache.maxEntries == 0 {
cache.maxEntries = DefaultMaxCacheEntries
}
return func(ctx context.Context, network, address string) (net.Conn, error) {
conn := &dnsConn{}
conn.roundTrip = cachingRoundTrip(&cache, network, address)
return conn, nil
}
}
const DefaultMaxCacheEntries = 150
// A CacheOption customizes the resolver cache.
type CacheOption interface {
apply(*cache)
}
type maxEntriesOption int
type maxTTLOption time.Duration
type minTTLOption time.Duration
type negativeCacheOption bool
func (o maxEntriesOption) apply(c *cache) { c.maxEntries = int(o) }
func (o maxTTLOption) apply(c *cache) { c.maxTTL = time.Duration(o) }
func (o minTTLOption) apply(c *cache) { c.minTTL = time.Duration(o) }
func (o negativeCacheOption) apply(c *cache) { c.negative = bool(o) }
// MaxCacheEntries sets the maximum number of entries to cache.
// If zero, [DefaultMaxCacheEntries] is used; negative means no limit.
func MaxCacheEntries(n int) CacheOption { return maxEntriesOption(n) }
// MaxCacheTTL sets the maximum time-to-live for entries in the cache.
func MaxCacheTTL(d time.Duration) CacheOption { return maxTTLOption(d) }
// MinCacheTTL sets the minimum time-to-live for entries in the cache.
func MinCacheTTL(d time.Duration) CacheOption { return minTTLOption(d) }
// NegativeCache sets whether to cache negative responses.
func NegativeCache(b bool) CacheOption { return negativeCacheOption(b) }
type cache struct {
sync.RWMutex
dial DialFunc
entries map[string]cacheEntry
maxEntries int
maxTTL time.Duration
minTTL time.Duration
negative bool
}
type cacheEntry struct {
deadline time.Time
value string
}
func (c *cache) put(req string, res string) {
// ignore uncacheable/unparseable answers
if invalid(req, res) {
return
}
// ignore errors (if requested)
if nameError(res) && !c.negative {
return
}
// ignore uncacheable/unparseable answers
ttl := getTTL(res)
if ttl <= 0 {
return
}
// adjust TTL
if ttl < c.minTTL {
ttl = c.minTTL
}
// maxTTL overrides minTTL
if ttl > c.maxTTL && c.maxTTL != 0 {
ttl = c.maxTTL
}
c.Lock()
defer c.Unlock()
if c.entries == nil {
c.entries = make(map[string]cacheEntry)
}
// do some cache evition
var tested, evicted int
for k, e := range c.entries {
if time.Until(e.deadline) <= 0 {
// delete expired entry
delete(c.entries, k)
evicted++
}
tested++
if tested < 8 {
continue
}
if evicted == 0 && c.maxEntries > 0 && len(c.entries) >= c.maxEntries {
// delete at least one entry
delete(c.entries, k)
}
break
}
// remove message IDs
c.entries[req[2:]] = cacheEntry{
deadline: time.Now().Add(ttl),
value: res[2:],
}
}
func (c *cache) get(req string) (res string) {
// ignore invalid messages
if len(req) < 12 {
return ""
}
if req[2] >= 0x7f {
return ""
}
c.RLock()
defer c.RUnlock()
if c.entries == nil {
return ""
}
// remove message ID
entry, ok := c.entries[req[2:]]
if ok && time.Until(entry.deadline) > 0 {
// prepend correct ID
return req[:2] + entry.value
}
return ""
}
func invalid(req string, res string) bool {
if len(req) < 12 || len(res) < 12 { // header size
return true
}
if req[0] != res[0] || req[1] != res[1] { // IDs match
return true
}
if req[2] >= 0x7f || res[2] < 0x7f { // query, response
return true
}
if req[2]&0x7a != 0 || res[2]&0x7a != 0 { // standard query, not truncated
return true
}
if res[3]&0xf != 0 && res[3]&0xf != 3 { // no error, or name error
return true
}
return false
}
func nameError(res string) bool {
return res[3]&0xf == 3
}
func getTTL(msg string) time.Duration {
ttl := math.MaxInt32
qdcount := getUint16(msg[4:])
ancount := getUint16(msg[6:])
nscount := getUint16(msg[8:])
arcount := getUint16(msg[10:])
rdcount := ancount + nscount + arcount
msg = msg[12:] // skip header
// skip questions
for i := 0; i < qdcount; i++ {
name := getNameLen(msg)
if name < 0 || name+4 > len(msg) {
return -1
}
msg = msg[name+4:]
}
// parse records
for i := 0; i < rdcount; i++ {
name := getNameLen(msg)
if name < 0 || name+10 > len(msg) {
return -1
}
rtyp := getUint16(msg[name+0:])
rttl := getUint32(msg[name+4:])
rlen := getUint16(msg[name+8:])
if name+10+rlen > len(msg) {
return -1
}
// skip EDNS OPT since it doesn't have a TTL
if rtyp != 41 && rttl < ttl {
ttl = rttl
}
msg = msg[name+10+rlen:]
}
return time.Duration(ttl) * time.Second
}
func getNameLen(msg string) int {
i := 0
for i < len(msg) {
if msg[i] == 0 {
// end of name
i += 1
break
}
if msg[i] >= 0xc0 {
// compressed name
i += 2
break
}
if msg[i] >= 0x40 {
// reserved
return -1
}
i += int(msg[i] + 1)
}
return i
}
func getUint16(s string) int {
return int(s[1]) | int(s[0])<<8
}
func getUint32(s string) int {
return int(s[3]) | int(s[2])<<8 | int(s[1])<<16 | int(s[0])<<24
}
func cachingRoundTrip(cache *cache, network, address string) roundTripper {
return func(ctx context.Context, req string) (res string, err error) {
// check cache
if res := cache.get(req); res != "" {
return res, nil
}
// dial connection
var conn net.Conn
if cache.dial != nil {
conn, err = cache.dial(ctx, network, address)
} else {
var d net.Dialer
conn, err = d.DialContext(ctx, network, address)
}
if err != nil {
return "", err
}
ctx, cancel := context.WithCancel(ctx)
go func() {
<-ctx.Done()
conn.Close()
}()
defer cancel()
if t, ok := ctx.Deadline(); ok {
err = conn.SetDeadline(t)
if err != nil {
return "", err
}
}
// send request
err = writeMessage(conn, req)
if err != nil {
return "", err
}
// read response
res, err = readMessage(conn)
if err != nil {
return "", err
}
// cache response
cache.put(req, res)
return res, nil
}
}
package dns
import (
"bytes"
"context"
"io"
"net"
"strings"
"sync"
"time"
)
type dnsConn struct {
sync.Mutex
ibuf bytes.Buffer
obuf bytes.Buffer
ctx context.Context
cancel context.CancelFunc
deadline time.Time
roundTrip roundTripper
}
type roundTripper func(ctx context.Context, req string) (res string, err error)
func (c *dnsConn) Read(b []byte) (n int, err error) {
imsg, n, err := c.drainBuffers(b)
if n != 0 || err != nil {
return n, err
}
ctx, cancel := c.childContext()
omsg, err := c.roundTrip(ctx, imsg)
cancel()
if err != nil {
return 0, err
}
return c.fillBuffer(b, omsg)
}
func (c *dnsConn) Write(b []byte) (n int, err error) {
c.Lock()
defer c.Unlock()
return c.ibuf.Write(b)
}
func (c *dnsConn) Close() error {
c.Lock()
cancel := c.cancel
c.Unlock()
if cancel != nil {
cancel()
}
return nil
}
func (c *dnsConn) LocalAddr() net.Addr {
return nil
}
func (c *dnsConn) RemoteAddr() net.Addr {
return nil
}
func (c *dnsConn) SetDeadline(t time.Time) error {
c.SetReadDeadline(t)
c.SetWriteDeadline(t)
return nil
}
func (c *dnsConn) SetReadDeadline(t time.Time) error {
c.Lock()
defer c.Unlock()
c.deadline = t
return nil
}
func (c *dnsConn) SetWriteDeadline(t time.Time) error {
// writes do not timeout
return nil
}
func (c *dnsConn) drainBuffers(b []byte) (string, int, error) {
c.Lock()
defer c.Unlock()
// drain the output buffer
if c.obuf.Len() > 0 {
n, err := c.obuf.Read(b)
return "", n, err
}
// otherwise, get the next message from the input buffer
sz := c.ibuf.Next(2)
if len(sz) < 2 {
return "", 0, io.ErrUnexpectedEOF
}
size := int64(sz[0])<<8 | int64(sz[1])
var str strings.Builder
_, err := io.CopyN(&str, &c.ibuf, size)
if err == io.EOF {
return "", 0, io.ErrUnexpectedEOF
}
if err != nil {
return "", 0, err
}
return str.String(), 0, nil
}
func (c *dnsConn) fillBuffer(b []byte, str string) (int, error) {
c.Lock()
defer c.Unlock()
c.obuf.WriteByte(byte(len(str) >> 8))
c.obuf.WriteByte(byte(len(str)))
c.obuf.WriteString(str)
return c.obuf.Read(b)
}
func (c *dnsConn) childContext() (context.Context, context.CancelFunc) {
c.Lock()
defer c.Unlock()
if c.ctx == nil {
c.ctx, c.cancel = context.WithCancel(context.Background())
}
return context.WithDeadline(c.ctx, c.deadline)
}
func writeMessage(conn net.Conn, msg string) error {
var buf []byte
if _, ok := conn.(net.PacketConn); ok {
buf = []byte(msg)
} else {
buf = make([]byte, len(msg)+2)
buf[0] = byte(len(msg) >> 8)
buf[1] = byte(len(msg))
copy(buf[2:], msg)
}
// SHOULD do a single write on TCP (RFC 7766, section 8).
// MUST do a single write on UDP.
_, err := conn.Write(buf)
return err
}
func readMessage(c net.Conn) (string, error) {
if _, ok := c.(net.PacketConn); ok {
// RFC 1035 specifies 512 as the maximum message size for DNS over UDP.
// RFC 6891 OTOH suggests 4096 as the maximum payload size for EDNS.
b := make([]byte, 4096)
n, err := c.Read(b)
if err != nil {
return "", err
}
return string(b[:n]), nil
} else {
var sz [2]byte
_, err := io.ReadFull(c, sz[:])
if err != nil {
return "", err
}
size := int64(sz[0])<<8 | int64(sz[1])
var str strings.Builder
_, err = io.CopyN(&str, c, size)
if err == io.EOF {
return "", io.ErrUnexpectedEOF
}
if err != nil {
return "", err
}
return str.String(), nil
}
}
// Package dns provides [net.Resolver] instances implementing caching,
// opportunistic encryption, and DNS over TLS/HTTPS.
//
// To replace the [net.DefaultResolver] with a caching DNS over HTTPS instance
// using the Google Public DNS resolver:
//
// net.DefaultResolver = dns.NewDoHResolver(
// "https://dns.google/dns-query",
// dns.DoHCache())
package dns
import (
"context"
"crypto/tls"
"net"
"sync"
"time"
)
// OpportunisticResolver opportunistically tries encrypted DNS over TLS
// using the local resolver.
var OpportunisticResolver = &net.Resolver{
Dial: opportunisticDial,
PreferGo: true,
}
func opportunisticDial(ctx context.Context, network, address string) (net.Conn, error) {
host, port, _ := net.SplitHostPort(address)
if (port == "53" || port == "domain") && notBadServer(address) {
deadline, ok := ctx.Deadline()
if ok && deadline.After(time.Now().Add(2*time.Second)) {
var d net.Dialer
d.Timeout = time.Second
tlsAddr := net.JoinHostPort(host, "853")
tlsConf := tls.Config{InsecureSkipVerify: true}
conn, _ := tls.DialWithDialer(&d, "tcp", tlsAddr, &tlsConf)
if conn != nil {
return conn, nil
}
addBadServer(address)
}
}
var d net.Dialer
return d.DialContext(ctx, network, address)
}
var badServers struct {
sync.Mutex
next int
list [4]string
}
func notBadServer(address string) bool {
badServers.Lock()
defer badServers.Unlock()
for _, a := range badServers.list {
if a == address {
return false
}
}
return true
}
func addBadServer(address string) {
badServers.Lock()
defer badServers.Unlock()
for _, a := range badServers.list {
if a == address {
return
}
}
badServers.list[badServers.next] = address
badServers.next = (badServers.next + 1) % len(badServers.list)
}
// DialFunc is a [net.Resolver.Dial] function.
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
package dns
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync/atomic"
"time"
)
// NewDoHResolver creates a DNS over HTTPS resolver.
// The uri may be an URI Template.
func NewDoHResolver(uri string, options ...DoHOption) (*net.Resolver, error) {
// parse the uri template into a url
uri, err := parseURITemplate(uri)
if err != nil {
return nil, err
}
url, err := url.Parse(uri)
if err != nil {
return nil, err
}
port := url.Port()
if port == "" {
port = url.Scheme
}
// apply options
var opts dohOpts
for _, o := range options {
o.apply(&opts)
}
// resolve server network addresses
if len(opts.addrs) == 0 {
ips, err := OpportunisticResolver.LookupIPAddr(context.Background(), url.Hostname())
if err != nil {
return nil, err
}
opts.addrs = make([]string, len(ips))
for i, ip := range ips {
opts.addrs[i] = net.JoinHostPort(ip.String(), port)
}
} else {
for i, a := range opts.addrs {
if net.ParseIP(a) != nil {
opts.addrs[i] = net.JoinHostPort(a, port)
}
}
}
// setup the http transport
if opts.transport == nil {
opts.transport = &http.Transport{
MaxIdleConns: http.DefaultMaxIdleConnsPerHost,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ForceAttemptHTTP2: true,
}
} else {
opts.transport = opts.transport.Clone()
}
// setup the http client
client := http.Client{
Transport: opts.transport,
}
// create the resolver
var resolver = net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
conn := &dnsConn{}
conn.roundTrip = dohRoundTrip(uri, &client)
return conn, nil
},
}
// setup dialer
var index atomic.Uint32
opts.transport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
var d net.Dialer
i := index.Load()
conn, err := d.DialContext(ctx, network, opts.addrs[i])
if err != nil {
index.CompareAndSwap(i, (i+1)%uint32(len(opts.addrs)))
}
return conn, err
}
// setup caching
if opts.cache {
resolver.Dial = NewCachingDialer(resolver.Dial, opts.cacheOpts...)
}
return &resolver, nil
}
// A DoHOption customizes the DNS over HTTPS resolver.
type DoHOption interface {
apply(*dohOpts)
}
type dohOpts struct {
transport *http.Transport
addrs []string
cache bool
cacheOpts []CacheOption
}
type (
dohTransport http.Transport
dohAddresses []string
dohCache []CacheOption
)
func (o *dohTransport) apply(t *dohOpts) { t.transport = (*http.Transport)(o) }
func (o dohAddresses) apply(t *dohOpts) { t.addrs = ([]string)(o) }
func (o dohCache) apply(t *dohOpts) { t.cache = true; t.cacheOpts = ([]CacheOption)(o) }
// DoHTransport sets the http.Transport used by the resolver.
func DoHTransport(transport *http.Transport) DoHOption { return (*dohTransport)(transport) }
// DoHAddresses sets the network addresses of the resolver.
// These should be IP addresses, or network addresses of the form "IP:port".
// This avoids having to resolve the resolver's addresses, improving performance and privacy.
func DoHAddresses(addresses ...string) DoHOption { return dohAddresses(addresses) }
// DoHCache adds caching to the resolver, with the given options.
func DoHCache(options ...CacheOption) DoHOption { return dohCache(options) }
func dohRoundTrip(uri string, client *http.Client) roundTripper {
return func(ctx context.Context, msg string) (string, error) {
// prepare request
req, err := http.NewRequestWithContext(ctx,
http.MethodPost, uri, strings.NewReader(msg))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/dns-message")
// send request
res, err := client.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return "", errors.New(http.StatusText(res.StatusCode))
}
// read response
var str strings.Builder
_, err = io.Copy(&str, res.Body)
if err != nil {
return "", err
}
return str.String(), nil
}
}
func parseURITemplate(uri string) (string, error) {
var str strings.Builder
var exp bool
for i := 0; i < len(uri); i++ {
switch c := uri[i]; c {
case '{':
if exp {
return "", errors.New("uri: invalid syntax")
}
exp = true
case '}':
if !exp {
return "", errors.New("uri: invalid syntax")
}
exp = false
default:
if !exp {
str.WriteByte(c)
}
}
}
return str.String(), nil
}
package dns
import (
"context"
"crypto/tls"
"net"
"sync/atomic"
)
// NewDoTResolver creates a DNS over TLS resolver.
// The server can be an IP address, a host name, or a network address of the form "host:port".
func NewDoTResolver(server string, options ...DoTOption) (*net.Resolver, error) {
// look for a custom port
host, port, err := net.SplitHostPort(server)
if err != nil {
port = "853"
} else {
server = host
}
// apply options
var opts dotOpts
for _, o := range options {
o.apply(&opts)
}
// resolve server network addresses
if len(opts.addrs) == 0 {
ips, err := OpportunisticResolver.LookupIPAddr(context.Background(), server)
if err != nil {
return nil, err
}
opts.addrs = make([]string, len(ips))
for i, ip := range ips {
opts.addrs[i] = net.JoinHostPort(ip.String(), port)
}
} else {
for i, a := range opts.addrs {
if net.ParseIP(a) != nil {
opts.addrs[i] = net.JoinHostPort(a, port)
}
}
}
// setup TLS config
if opts.config == nil {
opts.config = &tls.Config{
ClientSessionCache: tls.NewLRUClientSessionCache(len(opts.addrs)),
}
} else {
opts.config = opts.config.Clone()
}
if opts.config.ServerName == "" {
opts.config.ServerName = server
}
// setup the dialFunc
if opts.dialFunc == nil {
var d net.Dialer
opts.dialFunc = d.DialContext
}
// create the resolver
var resolver = net.Resolver{PreferGo: true}
// setup dialer
var index atomic.Uint32
resolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
i := index.Load()
conn, err := opts.dialFunc(ctx, "tcp", opts.addrs[i])
if err != nil {
index.CompareAndSwap(i, (i+1)%uint32(len(opts.addrs)))
return nil, err
}
return tls.Client(conn, opts.config), nil
}
// setup caching
if opts.cache {
resolver.Dial = NewCachingDialer(resolver.Dial, opts.cacheOpts...)
}
return &resolver, nil
}
// A DoTOption customizes the DNS over TLS resolver.
type DoTOption interface {
apply(*dotOpts)
}
type dotOpts struct {
config *tls.Config
addrs []string
cache bool
cacheOpts []CacheOption
dialFunc DialFunc
}
type (
dotConfig tls.Config
dotAddresses []string
dotCache []CacheOption
dotDialFunc DialFunc
)
func (o *dotConfig) apply(t *dotOpts) { t.config = (*tls.Config)(o) }
func (o dotAddresses) apply(t *dotOpts) { t.addrs = ([]string)(o) }
func (o dotCache) apply(t *dotOpts) { t.cache = true; t.cacheOpts = ([]CacheOption)(o) }
func (o dotDialFunc) apply(t *dotOpts) { t.dialFunc = (DialFunc)(o) }
// DoTConfig sets the tls.Config used by the resolver.
func DoTConfig(config *tls.Config) DoTOption { return (*dotConfig)(config) }
// DoTAddresses sets the network addresses of the resolver.
// These should be IP addresses, or network addresses of the form "IP:port".
// This avoids having to resolve the resolver's addresses, improving performance and privacy.
func DoTAddresses(addresses ...string) DoTOption { return dotAddresses(addresses) }
// DoTCache adds caching to the resolver, with the given options.
func DoTCache(options ...CacheOption) DoTOption { return dotCache(options) }
// DoTDialFunc sets the DialFunc used by the resolver.
// By default [net.Dialer.DialContext] is used.
func DoTDialFunc(f DialFunc) DoTOption { return dotDialFunc(f) }