package cache
import (
        "fmt"
        "time"
        "github.com/contentsquare/chproxy/clients"
        "github.com/contentsquare/chproxy/config"
        "github.com/contentsquare/chproxy/log"
        "github.com/redis/go-redis/v9"
)
// AsyncCache is a transactional cache enabled to serve the results from concurrent queries.
// When query A and B are equal, and query B arrives after query A with no more than defined deadline interval [[graceTime]],
// query B will await for the results of query B for the max time equal to:
// max_awaiting_time = graceTime - (arrivalB - arrivalA)
type AsyncCache struct {
        Cache
        TransactionRegistry
        graceTime time.Duration
        MaxPayloadSize     config.ByteSize
        SharedWithAllUsers bool
}
func (c *AsyncCache) Close() error {
        if c.TransactionRegistry != nil {
                c.TransactionRegistry.Close()
        }
        if c.Cache != nil {
                c.Cache.Close()
        }
        return nil
}
func (c *AsyncCache) AwaitForConcurrentTransaction(key *Key) (TransactionStatus, error) {
        startTime := time.Now()
        seenState := transactionAbsent
        for {
                elapsedTime := time.Since(startTime)
                if elapsedTime > c.graceTime {
                        // The entry didn't appear during deadline.
                        // Let the caller creating it.
                        return TransactionStatus{State: seenState}, nil
                }
                status, err := c.TransactionRegistry.Status(key)
                if err != nil {
                        return TransactionStatus{State: seenState}, err
                }
                if !status.State.IsPending() {
                        return status, nil
                }
                // Wait for deadline in the hope the entry will appear
                // in the cache.
                //
                // This should protect from thundering herd problem when
                // a single slow query is executed from concurrent requests.
                d := 100 * time.Millisecond
                if d > c.graceTime {
                        d = c.graceTime
                }
                time.Sleep(d)
        }
}
func NewAsyncCache(cfg config.Cache, maxExecutionTime time.Duration) (*AsyncCache, error) {
        graceTime := time.Duration(cfg.GraceTime)
        if graceTime > 0 {
                log.Errorf("[DEPRECATED] detected grace time configuration %s. It will be removed in the new version",
                        graceTime)
        }
        if graceTime == 0 {
                // Default grace time.
                graceTime = maxExecutionTime
        }
        if graceTime < 0 {
                // Disable protection from `dogpile effect`.
                graceTime = 0
        }
        var cache Cache
        var transaction TransactionRegistry
        var err error
        // transaction will be kept until we're sure there's no possible concurrent query running
        transactionDeadline := 2 * graceTime
        switch cfg.Mode {
        case "file_system":
                cache, err = newFilesSystemCache(cfg, graceTime)
                transaction = newInMemoryTransactionRegistry(transactionDeadline, transactionEndedTTL)
        case "redis":
                var redisClient redis.UniversalClient
                redisClient, err = clients.NewRedisClient(cfg.Redis)
                cache = newRedisCache(redisClient, cfg)
                transaction = newRedisTransactionRegistry(redisClient, transactionDeadline, transactionEndedTTL)
        default:
                return nil, fmt.Errorf("unknown config mode")
        }
        if err != nil {
                return nil, err
        }
        maxPayloadSize := cfg.MaxPayloadSize
        return &AsyncCache{
                Cache:               cache,
                TransactionRegistry: transaction,
                graceTime:           graceTime,
                MaxPayloadSize:      maxPayloadSize,
                SharedWithAllUsers:  cfg.SharedWithAllUsers,
        }, nil
}
		
		package cache
import (
        "fmt"
        "io"
        "math/rand"
        "os"
        "path/filepath"
        "regexp"
        "strconv"
        "sync"
        "sync/atomic"
        "time"
        "github.com/contentsquare/chproxy/config"
        "github.com/contentsquare/chproxy/log"
)
var cachefileRegexp = regexp.MustCompile(`^[0-9a-f]{32}$`)
// fileSystemCache represents a file cache.
type fileSystemCache struct {
        // name is cache name.
        name string
        dir     string
        maxSize uint64
        expire  time.Duration
        grace   time.Duration
        stats   Stats
        wg     sync.WaitGroup
        stopCh chan struct{}
}
// newFilesSystemCache returns new cache for the given cfg.
func newFilesSystemCache(cfg config.Cache, graceTime time.Duration) (*fileSystemCache, error) {
        if len(cfg.FileSystem.Dir) == 0 {
                return nil, fmt.Errorf("`dir` cannot be empty")
        }
        if cfg.FileSystem.MaxSize <= 0 {
                return nil, fmt.Errorf("`max_size` must be positive")
        }
        if cfg.Expire <= 0 {
                return nil, fmt.Errorf("`expire` must be positive")
        }
        c := &fileSystemCache{
                name: cfg.Name,
                dir:     cfg.FileSystem.Dir,
                maxSize: uint64(cfg.FileSystem.MaxSize),
                expire:  time.Duration(cfg.Expire),
                grace:   graceTime,
                stopCh:  make(chan struct{}),
        }
        if err := os.MkdirAll(c.dir, 0700); err != nil {
                return nil, fmt.Errorf("cannot create %q: %w", c.dir, err)
        }
        c.wg.Add(1)
        go func() {
                log.Debugf("cache %q: cleaner start", c.Name())
                c.cleaner()
                log.Debugf("cache %q: cleaner stop", c.Name())
                c.wg.Done()
        }()
        return c, nil
}
func (f *fileSystemCache) Name() string {
        return f.name
}
func (f *fileSystemCache) Close() error {
        log.Debugf("cache %q: stopping", f.Name())
        close(f.stopCh)
        f.wg.Wait()
        log.Debugf("cache %q: stopped", f.Name())
        return nil
}
func (f *fileSystemCache) Stats() Stats {
        var s Stats
        s.Size = atomic.LoadUint64(&f.stats.Size)
        s.Items = atomic.LoadUint64(&f.stats.Items)
        return s
}
func (f *fileSystemCache) Get(key *Key) (*CachedData, error) {
        fp := key.filePath(f.dir)
        file, err := os.Open(fp)
        if err != nil {
                return nil, ErrMissing
        }
        // the file will be closed once it's read as an io.ReaderCloser
        // This  ReaderCloser is stored in the returned CachedData
        fi, err := file.Stat()
        if err != nil {
                return nil, fmt.Errorf("cache %q: cannot stat %q: %w", f.Name(), fp, err)
        }
        mt := fi.ModTime()
        age := time.Since(mt)
        if age > f.expire {
                // check if file exceeded expiration time + grace time
                if age > f.expire+f.grace {
                        file.Close()
                        return nil, ErrMissing
                }
                // Serve expired file in the hope it will be substituted
                // with the fresh file during deadline.
        }
        if err != nil {
                file.Close()
                return nil, fmt.Errorf("failed to read file content from %q: %w", f.Name(), err)
        }
        metadata, err := decodeHeader(file)
        if err != nil {
                return nil, err
        }
        value := &CachedData{
                ContentMetadata: *metadata,
                Data:            file,
                Ttl:             f.expire - age,
        }
        return value, nil
}
// decodeHeader decodes header from raw byte stream. Data is encoded as follows:
// length(contentType)|contentType|length(contentEncoding)|contentEncoding|length(contentLength)|contentLength|cachedData
func decodeHeader(reader io.Reader) (*ContentMetadata, error) {
        contentType, err := readHeader(reader)
        if err != nil {
                return nil, fmt.Errorf("cannot read Content-Type from provided reader: %w", err)
        }
        contentEncoding, err := readHeader(reader)
        if err != nil {
                return nil, fmt.Errorf("cannot read Content-Encoding from provided reader: %w", err)
        }
        contentLengthStr, err := readHeader(reader)
        if err != nil {
                return nil, fmt.Errorf("cannot read Content-Encoding from provided reader: %w", err)
        }
        contentLength, err := strconv.Atoi(contentLengthStr)
        if err != nil {
                log.Errorf("found corrupted content length %s", err)
                contentLength = 0
        }
        return &ContentMetadata{
                Length:   int64(contentLength),
                Type:     contentType,
                Encoding: contentEncoding,
        }, nil
}
func (f *fileSystemCache) Put(r io.Reader, contentMetadata ContentMetadata, key *Key) (time.Duration, error) {
        fp := key.filePath(f.dir)
        file, err := os.Create(fp)
        if err != nil {
                return 0, fmt.Errorf("cache %q: cannot create file: %s : %w", f.Name(), key, err)
        }
        if err := writeHeader(file, contentMetadata.Type); err != nil {
                fn := file.Name()
                return 0, fmt.Errorf("cannot write Content-Type to %q: %w", fn, err)
        }
        if err := writeHeader(file, contentMetadata.Encoding); err != nil {
                fn := file.Name()
                return 0, fmt.Errorf("cannot write Content-Encoding to %q: %w", fn, err)
        }
        if err := writeHeader(file, fmt.Sprintf("%d", contentMetadata.Length)); err != nil {
                fn := file.Name()
                return 0, fmt.Errorf("cannot write Content-Encoding to %q: %w", fn, err)
        }
        cnt, err := io.Copy(file, r)
        if err != nil {
                return 0, fmt.Errorf("cache %q: cannot write results to file: %s : %w", f.Name(), key, err)
        }
        atomic.AddUint64(&f.stats.Size, uint64(cnt))
        atomic.AddUint64(&f.stats.Items, 1)
        return f.expire, nil
}
func (f *fileSystemCache) cleaner() {
        d := f.expire / 2
        if d < time.Minute {
                d = time.Minute
        }
        if d > time.Hour {
                d = time.Hour
        }
        forceCleanCh := time.After(d)
        f.clean()
        for {
                select {
                case <-time.After(time.Second):
                        // Clean cache only on cache size overflow.
                        stats := f.Stats()
                        if stats.Size > f.maxSize {
                                f.clean()
                        }
                case <-forceCleanCh:
                        // Forcibly clean cache from expired items.
                        f.clean()
                        forceCleanCh = time.After(d)
                case <-f.stopCh:
                        return
                }
        }
}
func (f *fileSystemCache) fileInfoPath(fi os.FileInfo) string {
        return filepath.Join(f.dir, fi.Name())
}
func (f *fileSystemCache) clean() {
        currentTime := time.Now()
        log.Debugf("cache %q: start cleaning dir %q", f.Name(), f.dir)
        // Remove cached files after a deadline from their expiration,
        // so they may be served until they are substituted with fresh files.
        expire := f.expire + f.grace
        // Calculate total cache size and remove expired files.
        var totalSize uint64
        var totalItems uint64
        var removedSize uint64
        var removedItems uint64
        err := walkDir(f.dir, func(fi os.FileInfo) {
                mt := fi.ModTime()
                fs := uint64(fi.Size())
                if currentTime.Sub(mt) > expire {
                        fn := f.fileInfoPath(fi)
                        err := os.Remove(fn)
                        if err == nil {
                                removedSize += fs
                                removedItems++
                                return
                        }
                        log.Errorf("cache %q: cannot remove file %q: %s", f.Name(), fn, err)
                        // Return skipped intentionally.
                }
                totalSize += fs
                totalItems++
        })
        if err != nil {
                log.Errorf("cache %q: %s", f.Name(), err)
                return
        }
        loopsCount := 0
        // Use dedicated random generator instead of global one from math/rand,
        // since the global generator is slow due to locking.
        //
        // Seed the generator with the current time in order to randomize
        // set of files to be removed below.
        // nolint:gosec // not security sensitve, only used internally.
        rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
        for totalSize > f.maxSize && loopsCount < 3 {
                // Remove some files in order to reduce cache size.
                excessSize := totalSize - f.maxSize
                p := int32(float64(excessSize) / float64(totalSize) * 100)
                // Remove +10% over totalSize.
                p += 10
                err := walkDir(f.dir, func(fi os.FileInfo) {
                        if rnd.Int31n(100) > p {
                                return
                        }
                        fs := uint64(fi.Size())
                        fn := f.fileInfoPath(fi)
                        if err := os.Remove(fn); err != nil {
                                log.Errorf("cache %q: cannot remove file %q: %s", f.Name(), fn, err)
                                return
                        }
                        removedSize += fs
                        removedItems++
                        totalSize -= fs
                        totalItems--
                })
                if err != nil {
                        log.Errorf("cache %q: %s", f.Name(), err)
                        return
                }
                // This should protect from infinite loop.
                loopsCount++
        }
        atomic.StoreUint64(&f.stats.Size, totalSize)
        atomic.StoreUint64(&f.stats.Items, totalItems)
        log.Debugf("cache %q: final size %d; final items %d; removed size %d; removed items %d",
                f.Name(), totalSize, totalItems, removedSize, removedItems)
        log.Debugf("cache %q: finish cleaning dir %q", f.Name(), f.dir)
}
// writeHeader encodes headers in little endian
func writeHeader(w io.Writer, s string) error {
        n := uint32(len(s))
        b := make([]byte, 0, n+4)
        b = append(b, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
        b = append(b, s...)
        _, err := w.Write(b)
        return err
}
// readHeader decodes headers to big endian
func readHeader(r io.Reader) (string, error) {
        b := make([]byte, 4)
        if _, err := io.ReadFull(r, b); err != nil {
                return "", fmt.Errorf("cannot read header length: %w", err)
        }
        n := uint32(b[3]) | (uint32(b[2]) << 8) | (uint32(b[1]) << 16) | (uint32(b[0]) << 24)
        s := make([]byte, n)
        if _, err := io.ReadFull(r, s); err != nil {
                return "", fmt.Errorf("cannot read header value with length %d: %w", n, err)
        }
        return string(s), nil
}
		
		package cache
import (
        "crypto/sha256"
        "encoding/hex"
        "fmt"
        "net/url"
        "path/filepath"
)
// Version must be increased with each backward-incompatible change
// in the cache storage.
const Version = 5
// Key is the key for use in the cache.
type Key struct {
        // Query must contain full request query.
        Query []byte
        // AcceptEncoding must contain 'Accept-Encoding' request header value.
        AcceptEncoding string
        // DefaultFormat must contain `default_format` query arg.
        DefaultFormat string
        // Database must contain `database` query arg.
        Database string
        // Compress must contain `compress` query arg.
        Compress string
        // EnableHTTPCompression must contain `enable_http_compression` query arg.
        EnableHTTPCompression string
        // Namespace is an optional cache namespace.
        Namespace string
        // MaxResultRows must contain `max_result_rows` query arg
        MaxResultRows string
        // Extremes must contain `extremes` query arg
        Extremes string
        // ResultOverflowMode must contain `result_overflow_mode` query arg
        ResultOverflowMode string
        // UserParamsHash must contain hashed value of users params
        UserParamsHash uint32
        // Version represents data encoding version number
        Version int
        // QueryParamsHash must contain hashed value of query params
        QueryParamsHash uint32
        // UserCredentialHash must contain hashed value of username & password
        UserCredentialHash uint32
}
// NewKey construct cache key from provided parameters with default version number
func NewKey(query []byte, originParams url.Values, acceptEncoding string, userParamsHash uint32, queryParamsHash uint32, userCredentialHash uint32) *Key {
        return &Key{
                Query:                 query,
                AcceptEncoding:        acceptEncoding,
                DefaultFormat:         originParams.Get("default_format"),
                Database:              originParams.Get("database"),
                Compress:              originParams.Get("compress"),
                EnableHTTPCompression: originParams.Get("enable_http_compression"),
                Namespace:             originParams.Get("cache_namespace"),
                Extremes:              originParams.Get("extremes"),
                MaxResultRows:         originParams.Get("max_result_rows"),
                ResultOverflowMode:    originParams.Get("result_overflow_mode"),
                UserParamsHash:        userParamsHash,
                Version:               Version,
                QueryParamsHash:       queryParamsHash,
                UserCredentialHash:    userCredentialHash,
        }
}
func (k *Key) filePath(dir string) string {
        return filepath.Join(dir, k.String())
}
// String returns string representation of the key.
func (k *Key) String() string {
        s := fmt.Sprintf("V%d; Query=%q; AcceptEncoding=%q; DefaultFormat=%q; Database=%q; Compress=%q; EnableHTTPCompression=%q; Namespace=%q; MaxResultRows=%q; Extremes=%q; ResultOverflowMode=%q; UserParams=%d; QueryParams=%d; UserCredentialHash=%d",
                k.Version, k.Query, k.AcceptEncoding, k.DefaultFormat, k.Database, k.Compress, k.EnableHTTPCompression, k.Namespace,
                k.MaxResultRows, k.Extremes, k.ResultOverflowMode, k.UserParamsHash, k.QueryParamsHash, k.UserCredentialHash)
        h := sha256.Sum256([]byte(s))
        // The first 16 bytes of the hash should be enough
        // for collision prevention :)
        return hex.EncodeToString(h[:16])
}
		
		package cache
import (
        "bytes"
        "context"
        "errors"
        "fmt"
        "io"
        "math/rand"
        "os"
        "regexp"
        "strconv"
        "time"
        "github.com/contentsquare/chproxy/config"
        "github.com/contentsquare/chproxy/log"
        "github.com/redis/go-redis/v9"
)
type redisCache struct {
        name   string
        client redis.UniversalClient
        expire time.Duration
}
const getTimeout = 2 * time.Second
const removeTimeout = 1 * time.Second
const renameTimeout = 1 * time.Second
const putTimeout = 2 * time.Second
const statsTimeout = 500 * time.Millisecond
// this variable is key to select whether the result should be streamed
// from redis to the http response or if chproxy should first put the
// result from redis in a temporary files before sending it to the http response
const minTTLForRedisStreamingReader = 15 * time.Second
// tmpDir temporary path to store ongoing queries results
const tmpDir = "/tmp"
const redisTmpFilePrefix = "chproxyRedisTmp"
func newRedisCache(client redis.UniversalClient, cfg config.Cache) *redisCache {
        redisCache := &redisCache{
                name:   cfg.Name,
                expire: time.Duration(cfg.Expire),
                client: client,
        }
        return redisCache
}
func (r *redisCache) Close() error {
        return r.client.Close()
}
var usedMemoryRegexp = regexp.MustCompile(`used_memory:([0-9]+)\r\n`)
// Stats will make two calls to redis.
// First one fetches the number of keys stored in redis (DBSize)
// Second one will fetch memory info that will be parsed to fetch the used_memory
// NOTE : we can only fetch database size, not cache size
func (r *redisCache) Stats() Stats {
        return Stats{
                Items: r.nbOfKeys(),
                Size:  r.nbOfBytes(),
        }
}
func (r *redisCache) nbOfKeys() uint64 {
        ctx, cancelFunc := context.WithTimeout(context.Background(), statsTimeout)
        defer cancelFunc()
        nbOfKeys, err := r.client.DBSize(ctx).Result()
        if err != nil {
                log.Errorf("failed to fetch nb of keys in redis: %s", err)
        }
        return uint64(nbOfKeys)
}
func (r *redisCache) nbOfBytes() uint64 {
        ctx, cancelFunc := context.WithTimeout(context.Background(), statsTimeout)
        defer cancelFunc()
        memoryInfo, err := r.client.Info(ctx, "memory").Result()
        if err != nil {
                log.Errorf("failed to fetch nb of bytes in redis: %s", err)
        }
        matches := usedMemoryRegexp.FindStringSubmatch(memoryInfo)
        var cacheSize int
        if len(matches) > 1 {
                cacheSize, err = strconv.Atoi(matches[1])
                if err != nil {
                        log.Errorf("failed to parse memory usage with error %s", err)
                }
        }
        return uint64(cacheSize)
}
func (r *redisCache) Get(key *Key) (*CachedData, error) {
        ctx, cancelFunc := context.WithTimeout(context.Background(), getTimeout)
        defer cancelFunc()
        nbBytesToFetch := int64(100 * 1024)
        stringKey := key.String()
        // fetching 100kBytes from redis to be sure to have the full metadata and,
        //  for most of the queries that fetch a few data, the cached results
        val, err := r.client.GetRange(ctx, stringKey, 0, nbBytesToFetch).Result()
        // errors, such as timeouts
        if err != nil {
                log.Errorf("failed to get key %s with error: %s", stringKey, err)
                return nil, ErrMissing
        }
        // if key not found in cache
        if len(val) == 0 {
                return nil, ErrMissing
        }
        ttl, err := r.client.TTL(ctx, stringKey).Result()
        if err != nil {
                return nil, fmt.Errorf("failed to ttl of key %s with error: %w", stringKey, err)
        }
        b := []byte(val)
        metadata, offset, err := r.decodeMetadata(b)
        if err != nil {
                if errors.Is(err, &RedisCacheCorruptionError{}) {
                        log.Errorf("an error happened while handling redis key =%s, err=%s", stringKey, err)
                }
                return nil, err
        }
        if (int64(offset) + metadata.Length) < nbBytesToFetch {
                // the condition is true ony if the bytes fetched contain the metadata + the cached results
                // so we extract from the remaining bytes the cached results
                payload := b[offset:]
                reader := &ioReaderDecorator{Reader: bytes.NewReader(payload)}
                value := &CachedData{
                        ContentMetadata: *metadata,
                        Data:            reader,
                        Ttl:             ttl,
                }
                return value, nil
        }
        return r.readResultsAboveLimit(offset, stringKey, metadata, ttl)
}
func (r *redisCache) readResultsAboveLimit(offset int, stringKey string, metadata *ContentMetadata, ttl time.Duration) (*CachedData, error) {
        // since the cached results in redis are too big, we can't fetch all of them because of the memory overhead.
        // We will create an io.reader that will fetch redis bulk by bulk to reduce the memory usage.
        redisStreamreader := newRedisStreamReader(uint64(offset), r.client, stringKey, metadata.Length)
        // But before that, since the usage of the reader could take time and the object in redis could disappear btw 2 fetches
        // we need to make sure the TTL will be long enough to avoid nasty side effects
        // if the TTL is too short we will put all the data into a file and use it as a streamer
        // nb: it would be better to retry the flow if such a failure happened but this requires a huge refactoring of proxy.go
        if ttl <= minTTLForRedisStreamingReader {
                fileStream, err := newFileWriterReader(tmpDir)
                if err != nil {
                        return nil, err
                }
                _, err = io.Copy(fileStream, redisStreamreader)
                if err != nil {
                        return nil, err
                }
                err = fileStream.resetOffset()
                if err != nil {
                        return nil, err
                }
                value := &CachedData{
                        ContentMetadata: *metadata,
                        Data:            fileStream,
                        Ttl:             ttl,
                }
                return value, nil
        }
        value := &CachedData{
                ContentMetadata: *metadata,
                Data:            &ioReaderDecorator{Reader: redisStreamreader},
                Ttl:             ttl,
        }
        return value, nil
}
// this struct is here because CachedData requires an io.ReadCloser
// but logic in the Get function generates only an io.Reader
type ioReaderDecorator struct {
        io.Reader
}
func (m ioReaderDecorator) Close() error {
        return nil
}
func (r *redisCache) encodeString(s string) []byte {
        n := uint32(len(s))
        b := make([]byte, 0, n+4)
        b = append(b, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
        b = append(b, s...)
        return b
}
func (r *redisCache) decodeString(bytes []byte) (string, int, error) {
        if len(bytes) < 4 {
                return "", 0, &RedisCacheCorruptionError{}
        }
        b := bytes[:4]
        n := uint32(b[3]) | (uint32(b[2]) << 8) | (uint32(b[1]) << 16) | (uint32(b[0]) << 24)
        if len(bytes) < int(4+n) {
                return "", 0, &RedisCacheCorruptionError{}
        }
        s := bytes[4 : 4+n]
        return string(s), int(4 + n), nil
}
func (r *redisCache) encodeMetadata(contentMetadata *ContentMetadata) []byte {
        cLength := contentMetadata.Length
        cType := r.encodeString(contentMetadata.Type)
        cEncoding := r.encodeString(contentMetadata.Encoding)
        b := make([]byte, 0, len(cEncoding)+len(cType)+8)
        b = append(b, byte(cLength>>56), byte(cLength>>48), byte(cLength>>40), byte(cLength>>32), byte(cLength>>24), byte(cLength>>16), byte(cLength>>8), byte(cLength))
        b = append(b, cType...)
        b = append(b, cEncoding...)
        return b
}
func (r *redisCache) decodeMetadata(b []byte) (*ContentMetadata, int, error) {
        if len(b) < 8 {
                return nil, 0, &RedisCacheCorruptionError{}
        }
        cLength := uint64(b[7]) | (uint64(b[6]) << 8) | (uint64(b[5]) << 16) | (uint64(b[4]) << 24) | uint64(b[3])<<32 | (uint64(b[2]) << 40) | (uint64(b[1]) << 48) | (uint64(b[0]) << 56)
        offset := 8
        cType, sizeCType, err := r.decodeString(b[offset:])
        if err != nil {
                return nil, 0, err
        }
        offset += sizeCType
        cEncoding, sizeCEncoding, err := r.decodeString(b[offset:])
        if err != nil {
                return nil, 0, err
        }
        offset += sizeCEncoding
        metadata := &ContentMetadata{
                Length:   int64(cLength),
                Type:     cType,
                Encoding: cEncoding,
        }
        return metadata, offset, nil
}
func (r *redisCache) Put(reader io.Reader, contentMetadata ContentMetadata, key *Key) (time.Duration, error) {
        medatadata := r.encodeMetadata(&contentMetadata)
        stringKey := key.String()
        // in order to make the streaming operation atomic, chproxy streams into a temporary key (only known by the current goroutine)
        // then it switches the full result to the "real" stringKey available for other goroutines
        // nolint:gosec // not security sensitve, only used internally.
        random := strconv.Itoa(rand.Int())
        // Redis RENAME is considered to be a multikey operation. In Cluster mode, both oldkey and renamedkey must be in the same hash slot,
        // Refer Redis Documentation here: https://redis.io/commands/rename/
        // To solve this,we need to force the temporary key to be in the same hash slot. We can do this by adding hashtag to the
        // actual part of the temporary key. When the key contains a "{...}" pattern, only the substring between the braces, "{" and "},"
        // is hashed to obtain the hash slot.
        // Refer the hash tags section of Redis documentation here: https://redis.io/docs/reference/cluster-spec/#hash-tags
        stringKeyTmp := "{" + stringKey + "}" + random + "_tmp"
        ctxSet, cancelFuncSet := context.WithTimeout(context.Background(), putTimeout)
        defer cancelFuncSet()
        err := r.client.Set(ctxSet, stringKeyTmp, medatadata, r.expire).Err()
        if err != nil {
                return 0, err
        }
        // we don't fetch all the reader content bulks by bulks to from redis to avoid memory issue
        // if the content is big (which is the case when chproxy users are fetching a lot of data)
        buffer := make([]byte, 2*1024*1024)
        totalByteWrittenExpected := len(medatadata)
        for {
                n, err := reader.Read(buffer)
                // the reader should return an err = io.EOF once it has nothing to read or at the last read call with content.
                // But this is not the case with this reader so we check the condition n == 0 to exit the read loop.
                // We kept the err == io.EOF in the loop in case the behavior of the reader changes
                if n == 0 {
                        break
                }
                if err != nil && !errors.Is(err, io.EOF) {
                        return 0, err
                }
                ctxAppend, cancelFuncAppend := context.WithTimeout(context.Background(), putTimeout)
                defer cancelFuncAppend()
                totalByteWritten, err := r.client.Append(ctxAppend, stringKeyTmp, string(buffer[:n])).Result()
                if err != nil {
                        // trying to clean redis from this partially inserted item
                        r.clean(stringKeyTmp)
                        return 0, err
                }
                totalByteWrittenExpected += n
                if int(totalByteWritten) != totalByteWrittenExpected {
                        // trying to clean redis from this partially inserted item
                        r.clean(stringKeyTmp)
                        return 0, fmt.Errorf("could not stream the value into redis, only %d bytes were written instead of %d", totalByteWritten, totalByteWrittenExpected)
                }
                if errors.Is(err, io.EOF) {
                        break
                }
        }
        // at this step we know that the item stored in stringKeyTmp is fully written
        // so we can put it to its final stringKey
        ctxRename, cancelFuncRename := context.WithTimeout(context.Background(), renameTimeout)
        defer cancelFuncRename()
        r.client.Rename(ctxRename, stringKeyTmp, stringKey)
        return r.expire, nil
}
func (r *redisCache) clean(stringKey string) {
        delCtx, cancelFunc := context.WithTimeout(context.Background(), removeTimeout)
        defer cancelFunc()
        delErr := r.client.Del(delCtx, stringKey).Err()
        if delErr != nil {
                log.Debugf("redis item was only partially inserted and chproxy couldn't remove the partial result because of %s", delErr)
        } else {
                log.Debugf("redis item was only partially inserted, chproxy was able to remove it")
        }
}
func (r *redisCache) Name() string {
        return r.name
}
type redisStreamReader struct {
        isRedisEOF          bool
        redisOffset         uint64                // the redisOffset that gives the beginning of the next bulk to fetch
        key                 string                // the key of the value we want to stream from redis
        buffer              []byte                // internal buffer to store the bulks fetched from redis
        bufferOffset        int                   // the offset of the buffer that keep were the read() need to start copying data
        client              redis.UniversalClient // the redis client
        expectedPayloadSize int                   // the size of the object the streamer is supposed to read.
        readPayloadSize     int                   // the size of the object currently written by the reader
}
func newRedisStreamReader(offset uint64, client redis.UniversalClient, key string, payloadSize int64) *redisStreamReader {
        bufferSize := uint64(2 * 1024 * 1024)
        return &redisStreamReader{
                isRedisEOF:          false,
                redisOffset:         offset,
                key:                 key,
                bufferOffset:        int(bufferSize),
                buffer:              make([]byte, bufferSize),
                client:              client,
                expectedPayloadSize: int(payloadSize),
        }
}
func (r *redisStreamReader) Read(destBuf []byte) (n int, err error) {
        // the logic is simple:
        // 1) if the buffer still has data to write, it writes it into destBuf without overflowing destBuf
        // 2) if the buffer only has already written data, the StreamRedis refresh the buffer with new data from redis
        // 3) if the buffer only has already written data & redis has no more data to read then StreamRedis sends an EOF err
        bufSize := len(r.buffer)
        bytesWritten := 0
        // case 3) both the buffer & redis were fully consumed, we can tell the reader to stop reading
        if r.bufferOffset >= bufSize && r.isRedisEOF {
                // Because of the way we fetch from redis, we need to do an extra check because we have no way
                // to know if redis is really EOF or if the value was expired from cache while reading it
                if r.readPayloadSize != r.expectedPayloadSize {
                        log.Debugf("error while fetching data from redis payload size doesn't match")
                        return 0, &RedisCacheError{key: r.key, readPayloadSize: r.readPayloadSize, expectedPayloadSize: r.expectedPayloadSize}
                }
                return 0, io.EOF
        }
        // case 2) the buffer only has already written data, we need to refresh it with redis datas
        if r.bufferOffset >= bufSize {
                if err := r.readRangeFromRedis(bufSize); err != nil {
                        return bytesWritten, err
                }
        }
        // case 1) the buffer contains data to write into destBuf
        if r.bufferOffset < bufSize {
                bytesWritten = copy(destBuf, r.buffer[r.bufferOffset:])
                r.bufferOffset += bytesWritten
                r.readPayloadSize += bytesWritten
        }
        return bytesWritten, nil
}
func (r *redisStreamReader) readRangeFromRedis(bufSize int) error {
        ctx, cancelFunc := context.WithTimeout(context.Background(), getTimeout)
        defer cancelFunc()
        newBuf, err := r.client.GetRange(ctx, r.key, int64(r.redisOffset), int64(r.redisOffset+uint64(bufSize))).Result()
        r.redisOffset += uint64(len(newBuf))
        if errors.Is(err, redis.Nil) || len(newBuf) == 0 {
                r.isRedisEOF = true
        }
        // if redis gave less data than asked it means that it reached the end of the value
        if len(newBuf) < bufSize {
                r.isRedisEOF = true
        }
        // others errors, such as timeouts
        if err != nil && !errors.Is(err, redis.Nil) {
                log.Debugf("failed to get key %s with error: %s", r.key, err)
                err2 := &RedisCacheError{key: r.key, readPayloadSize: r.readPayloadSize,
                        expectedPayloadSize: r.expectedPayloadSize, rootcause: err}
                return err2
        }
        r.bufferOffset = 0
        r.buffer = []byte(newBuf)
        return nil
}
type fileWriterReader struct {
        f *os.File
}
func newFileWriterReader(dir string) (*fileWriterReader, error) {
        f, err := os.CreateTemp(dir, redisTmpFilePrefix)
        if err != nil {
                return nil, fmt.Errorf("cannot create temporary file in %q: %w", dir, err)
        }
        return &fileWriterReader{
                f: f,
        }, nil
}
func (r *fileWriterReader) Close() error {
        err := r.f.Close()
        if err != nil {
                return err
        }
        return os.Remove(r.f.Name())
}
func (r *fileWriterReader) Read(destBuf []byte) (n int, err error) {
        return r.f.Read(destBuf)
}
func (r *fileWriterReader) Write(p []byte) (n int, err error) {
        return r.f.Write(p)
}
func (r *fileWriterReader) resetOffset() error {
        if _, err := r.f.Seek(0, io.SeekStart); err != nil {
                return fmt.Errorf("cannot reset offset in: %w", err)
        }
        return nil
}
type RedisCacheError struct {
        key                 string
        readPayloadSize     int
        expectedPayloadSize int
        rootcause           error
}
func (e *RedisCacheError) Error() string {
        errorMsg := fmt.Sprintf("error while reading cached result in redis for key %s, only %d bytes of %d were fetched",
                e.key, e.readPayloadSize, e.expectedPayloadSize)
        if e.rootcause != nil {
                errorMsg = fmt.Sprintf("%s, root cause:%s", errorMsg, e.rootcause)
        }
        return errorMsg
}
type RedisCacheCorruptionError struct {
}
func (e *RedisCacheCorruptionError) Error() string {
        return "chproxy can't decode the cached result from redis, it seems to have been corrupted"
}
		
		package cache
import (
        "bufio"
        "fmt"
        "io"
        "net/http"
        "os"
        "github.com/contentsquare/chproxy/log"
)
// TmpFileResponseWriter caches Clickhouse response.
// the http header are kept in memory
type TmpFileResponseWriter struct {
        http.ResponseWriter // the original response writer
        contentLength   int64
        contentType     string
        contentEncoding string
        headersCaptured bool
        statusCode      int
        tmpFile *os.File      // temporary file for response streaming
        bw      *bufio.Writer // buffered writer for the temporary file
}
func NewTmpFileResponseWriter(rw http.ResponseWriter, dir string) (*TmpFileResponseWriter, error) {
        _, ok := rw.(http.CloseNotifier)
        if !ok {
                return nil, fmt.Errorf("the response writer does not implement http.CloseNotifier")
        }
        f, err := os.CreateTemp(dir, "tmp")
        if err != nil {
                return nil, fmt.Errorf("cannot create temporary file in %q: %w", dir, err)
        }
        return &TmpFileResponseWriter{
                ResponseWriter: rw,
                tmpFile: f,
                bw:      bufio.NewWriter(f),
        }, nil
}
func (rw *TmpFileResponseWriter) Close() error {
        rw.tmpFile.Close()
        return os.Remove(rw.tmpFile.Name())
}
func (rw *TmpFileResponseWriter) GetFile() (*os.File, error) {
        if err := rw.bw.Flush(); err != nil {
                fn := rw.tmpFile.Name()
                errTmp := rw.tmpFile.Close()
                if errTmp != nil {
                        log.Errorf("cannot close tmpFile: %s, error: %s", fn, errTmp)
                }
                errTmp = os.Remove(fn)
                if errTmp != nil {
                        log.Errorf("cannot remove tmpFile: %s, error: %s", fn, errTmp)
                }
                return nil, fmt.Errorf("cannot flush data into %q: %w", fn, err)
        }
        return rw.tmpFile, nil
}
func (rw *TmpFileResponseWriter) Reader() (io.Reader, error) {
        f, err := rw.GetFile()
        if err != nil {
                return nil, fmt.Errorf("cannot open tmp file: %w", err)
        }
        return f, nil
}
func (rw *TmpFileResponseWriter) ResetFileOffset() error {
        data, err := rw.GetFile()
        if err != nil {
                return err
        }
        if _, err := data.Seek(0, io.SeekStart); err != nil {
                return fmt.Errorf("cannot reset offset in: %w", err)
        }
        return nil
}
func (rw *TmpFileResponseWriter) captureHeaders() error {
        if rw.headersCaptured {
                return nil
        }
        rw.headersCaptured = true
        h := rw.Header()
        ct := h.Get("Content-Type")
        ce := h.Get("Content-Encoding")
        rw.contentEncoding = ce
        rw.contentType = ct
        // nb: the Content-Length http header is not set by CH so we can't get it
        return nil
}
func (rw *TmpFileResponseWriter) GetCapturedContentType() string {
        return rw.contentType
}
func (rw *TmpFileResponseWriter) GetCapturedContentLength() (int64, error) {
        if rw.contentLength == 0 {
                // Determine Content-Length looking at the file
                data, err := rw.GetFile()
                if err != nil {
                        return 0, fmt.Errorf("GetCapturedContentLength: cannot open tmp file: %w", err)
                }
                end, err := data.Seek(0, io.SeekEnd)
                if err != nil {
                        return 0, fmt.Errorf("GetCapturedContentLength: cannot determine the last position in: %w", err)
                }
                if err := rw.ResetFileOffset(); err != nil {
                        return 0, err
                }
                return end - 0, nil
        }
        return rw.contentLength, nil
}
func (rw *TmpFileResponseWriter) GetCapturedContentEncoding() string {
        return rw.contentEncoding
}
// CloseNotify implements http.CloseNotifier
func (rw *TmpFileResponseWriter) CloseNotify() <-chan bool {
        // nolint:forcetypeassert // it is guaranteed by NewTmpFileResponseWriter
        // The rw.FSResponseWriter must implement http.CloseNotifier.
        return rw.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// WriteHeader captures response status code.
func (rw *TmpFileResponseWriter) WriteHeader(statusCode int) {
        rw.statusCode = statusCode
        // Do not call rw.ClickhouseResponseWriter.WriteHeader here
        // It will be called explicitly in Finalize / Unregister.
}
// StatusCode returns captured status code from WriteHeader.
func (rw *TmpFileResponseWriter) StatusCode() int {
        if rw.statusCode == 0 {
                return http.StatusOK
        }
        return rw.statusCode
}
// Write writes b into rw.
func (rw *TmpFileResponseWriter) Write(b []byte) (int, error) {
        if err := rw.captureHeaders(); err != nil {
                return 0, err
        }
        return rw.bw.Write(b)
}
		
		package cache
import (
        "io"
        "time"
)
// TransactionRegistry is a registry of ongoing queries identified by Key.
type TransactionRegistry interface {
        io.Closer
        // Create creates a new transaction record
        Create(key *Key) error
        // Complete completes a transaction for given key
        Complete(key *Key) error
        // Fail fails a transaction for given key
        Fail(key *Key, reason string) error
        // Status checks the status of the transaction
        Status(key *Key) (TransactionStatus, error)
}
// transactionEndedTTL amount of time transaction record is kept after being updated
const transactionEndedTTL = 500 * time.Millisecond
type TransactionStatus struct {
        State      TransactionState
        FailReason string // filled in only if state of transaction is transactionFailed
}
type TransactionState uint8
const (
        transactionCreated   TransactionState = 0
        transactionCompleted TransactionState = 1
        transactionFailed    TransactionState = 2
        transactionAbsent    TransactionState = 3
)
func (t *TransactionState) IsAbsent() bool {
        if t != nil {
                return *t == transactionAbsent
        }
        return false
}
func (t *TransactionState) IsFailed() bool {
        if t != nil {
                return *t == transactionFailed
        }
        return false
}
func (t *TransactionState) IsCompleted() bool {
        if t != nil {
                return *t == transactionCompleted
        }
        return false
}
func (t *TransactionState) IsPending() bool {
        if t != nil {
                return *t == transactionCreated
        }
        return false
}
		
		package cache
import (
        "sync"
        "time"
        "github.com/contentsquare/chproxy/log"
)
type pendingEntry struct {
        deadline     time.Time
        state        TransactionState
        failedReason string
}
type inMemoryTransactionRegistry struct {
        pendingEntriesLock sync.Mutex
        pendingEntries     map[string]pendingEntry
        deadline                 time.Duration
        transactionEndedDeadline time.Duration
        stopCh                   chan struct{}
        wg                       sync.WaitGroup
}
func newInMemoryTransactionRegistry(deadline, transactionEndedDeadline time.Duration) *inMemoryTransactionRegistry {
        transaction := &inMemoryTransactionRegistry{
                pendingEntriesLock:       sync.Mutex{},
                pendingEntries:           make(map[string]pendingEntry),
                deadline:                 deadline,
                transactionEndedDeadline: transactionEndedDeadline,
                stopCh:                   make(chan struct{}),
        }
        transaction.wg.Add(1)
        go func() {
                log.Debugf("inmem transaction: cleaner start")
                transaction.pendingEntriesCleaner()
                transaction.wg.Done()
                log.Debugf("inmem transaction: cleaner stop")
        }()
        return transaction
}
func (i *inMemoryTransactionRegistry) Create(key *Key) error {
        i.pendingEntriesLock.Lock()
        defer i.pendingEntriesLock.Unlock()
        k := key.String()
        _, exists := i.pendingEntries[k]
        if !exists {
                i.pendingEntries[k] = pendingEntry{
                        deadline: time.Now().Add(i.deadline),
                        state:    transactionCreated,
                }
        }
        return nil
}
func (i *inMemoryTransactionRegistry) Complete(key *Key) error {
        i.updateTransactionState(key, transactionCompleted, "")
        return nil
}
func (i *inMemoryTransactionRegistry) Fail(key *Key, reason string) error {
        i.updateTransactionState(key, transactionFailed, reason)
        return nil
}
func (i *inMemoryTransactionRegistry) updateTransactionState(key *Key, state TransactionState, failReason string) {
        i.pendingEntriesLock.Lock()
        defer i.pendingEntriesLock.Unlock()
        k := key.String()
        if entry, ok := i.pendingEntries[k]; ok {
                entry.state = state
                entry.failedReason = failReason
                entry.deadline = time.Now().Add(i.transactionEndedDeadline)
                i.pendingEntries[k] = entry
        } else {
                log.Errorf("[attempt to complete transaction] entry not found for key: %s, registering new entry with %v status", key.String(), state)
                i.pendingEntries[k] = pendingEntry{
                        deadline:     time.Now().Add(i.transactionEndedDeadline),
                        state:        state,
                        failedReason: failReason,
                }
        }
}
func (i *inMemoryTransactionRegistry) Status(key *Key) (TransactionStatus, error) {
        i.pendingEntriesLock.Lock()
        defer i.pendingEntriesLock.Unlock()
        k := key.String()
        if entry, ok := i.pendingEntries[k]; ok {
                return TransactionStatus{State: entry.state, FailReason: entry.failedReason}, nil
        }
        return TransactionStatus{State: transactionAbsent}, nil
}
func (i *inMemoryTransactionRegistry) Close() error {
        close(i.stopCh)
        i.wg.Wait()
        return nil
}
func (i *inMemoryTransactionRegistry) pendingEntriesCleaner() {
        d := i.deadline
        if d < 100*time.Millisecond {
                d = 100 * time.Millisecond
        }
        if d > time.Second {
                d = time.Second
        }
        for {
                currentTime := time.Now()
                // Clear outdated pending entries, since they may remain here
                // forever if unregisterPendingEntry call is missing.
                i.pendingEntriesLock.Lock()
                for path, pe := range i.pendingEntries {
                        if currentTime.After(pe.deadline) {
                                delete(i.pendingEntries, path)
                        }
                }
                i.pendingEntriesLock.Unlock()
                select {
                case <-time.After(d):
                case <-i.stopCh:
                        return
                }
        }
}
		
		package cache
import (
        "context"
        "errors"
        "time"
        "fmt"
        "github.com/contentsquare/chproxy/log"
        "github.com/redis/go-redis/v9"
)
type redisTransactionRegistry struct {
        redisClient redis.UniversalClient
        // deadline specifies TTL of the record to be kept
        deadline time.Duration
        // transactionEndedDeadline specifies TTL of the record to be kept that has been ended (either completed or failed)
        transactionEndedDeadline time.Duration
}
func newRedisTransactionRegistry(redisClient redis.UniversalClient, deadline time.Duration,
        endedDeadline time.Duration) *redisTransactionRegistry {
        return &redisTransactionRegistry{
                redisClient:              redisClient,
                deadline:                 deadline,
                transactionEndedDeadline: endedDeadline,
        }
}
func (r *redisTransactionRegistry) Create(key *Key) error {
        return r.redisClient.Set(context.Background(), toTransactionKey(key),
                []byte{uint8(transactionCreated)}, r.deadline).Err()
}
func (r *redisTransactionRegistry) Complete(key *Key) error {
        return r.updateTransactionState(key, []byte{uint8(transactionCompleted)})
}
func (r *redisTransactionRegistry) Fail(key *Key, reason string) error {
        b := make([]byte, 0, uint32(len(reason))+1)
        b = append(b, byte(transactionFailed))
        b = append(b, []byte(reason)...)
        return r.updateTransactionState(key, b)
}
func (r *redisTransactionRegistry) updateTransactionState(key *Key, value []byte) error {
        return r.redisClient.Set(context.Background(), toTransactionKey(key), value, r.transactionEndedDeadline).Err()
}
func (r *redisTransactionRegistry) Status(key *Key) (TransactionStatus, error) {
        raw, err := r.redisClient.Get(context.Background(), toTransactionKey(key)).Bytes()
        if errors.Is(err, redis.Nil) {
                return TransactionStatus{State: transactionAbsent}, nil
        }
        if err != nil {
                log.Errorf("Failed to fetch transaction status from redis for key: %s", key.String())
                return TransactionStatus{State: transactionAbsent}, err
        }
        if len(raw) == 0 {
                log.Errorf("Failed to fetch transaction status from redis raw value: %s", key.String())
                return TransactionStatus{State: transactionAbsent}, err
        }
        state := TransactionState(uint8(raw[0]))
        var reason string
        if state.IsFailed() && len(raw) > 1 {
                reason = string(raw[1:])
        }
        return TransactionStatus{State: state, FailReason: reason}, nil
}
func (r *redisTransactionRegistry) Close() error {
        return r.redisClient.Close()
}
func toTransactionKey(key *Key) string {
        return fmt.Sprintf("%s-transaction", key.String())
}
		
		package cache
import (
        "fmt"
        "io"
        "os"
)
// walkDir calls f on all the cache files in the given dir.
func walkDir(dir string, f func(fi os.FileInfo)) error {
        // Do not use filepath.Walk, since it is inefficient
        // for large number of files.
        // See https://golang.org/pkg/path/filepath/#Walk .
        fd, err := os.Open(dir)
        if err != nil {
                return fmt.Errorf("cannot open %q: %w", dir, err)
        }
        defer fd.Close()
        for {
                fis, err := fd.Readdir(1024)
                if err != nil {
                        if err == io.EOF {
                                return nil
                        }
                        return fmt.Errorf("cannot read files in %q: %w", dir, err)
                }
                for _, fi := range fis {
                        if fi.IsDir() {
                                // Skip subdirectories
                                continue
                        }
                        fn := fi.Name()
                        if !cachefileRegexp.MatchString(fn) {
                                // Skip invalid filenames
                                continue
                        }
                        f(fi)
                }
        }
}
		
		package chdecompressor
import (
        "errors"
        "fmt"
        "io"
        "github.com/klauspost/compress/zstd"
        "github.com/pierrec/lz4"
)
// Reader reads clickhouse compressed stream.
// See https://github.com/yandex/ClickHouse/blob/ae8783aee3ef982b6eb7e1721dac4cb3ce73f0fe/dbms/src/IO/CompressedStream.h
type Reader struct {
        src     io.Reader
        data    []byte
        scratch []byte
}
// NewReader returns new clickhouse compressed stream reader reading from src.
func NewReader(src io.Reader) *Reader {
        return &Reader{
                src:     src,
                scratch: make([]byte, 16),
        }
}
// Read reads up to len(buf) bytes from clickhouse compressed stream.
func (r *Reader) Read(buf []byte) (int, error) {
        // exhaust remaining data from previous Read()
        if len(r.data) == 0 {
                if err := r.readNextBlock(); err != nil {
                        return 0, err
                }
        }
        n := copy(buf, r.data)
        r.data = r.data[n:]
        return n, nil
}
func (r *Reader) readNextBlock() error {
        // Skip checksum
        if _, err := io.ReadFull(r.src, r.scratch[:16]); err != nil {
                if errors.Is(err, io.EOF) {
                        return io.EOF
                }
                return fmt.Errorf("cannot read checksum: %w", err)
        }
        // Read compression type
        if _, err := io.ReadFull(r.src, r.scratch[:1]); err != nil {
                return fmt.Errorf("cannot read compression type: %w", err)
        }
        compressionType := r.scratch[0]
        // Read compressed size
        compressedSize, err := r.readUint32()
        if err != nil {
                return fmt.Errorf("cannot read compressed size: %w", err)
        }
        compressedSize -= 9 // minus header length
        // Read decompressed size
        decompressedSize, err := r.readUint32()
        if err != nil {
                return fmt.Errorf("cannot read decompressed size: %w", err)
        }
        // Read compressed block
        block := make([]byte, compressedSize)
        if _, err = io.ReadFull(r.src, block); err != nil {
                return fmt.Errorf("cannot read compressed block: %w", err)
        }
        // Decompress block
        if err := r.decompressBlock(block, compressionType, decompressedSize); err != nil {
                return err
        }
        return nil
}
func (r *Reader) decompressBlock(block []byte, compressionType byte, decompressedSize uint32) error {
        var err error
        r.data = make([]byte, decompressedSize)
        var decoder, _ = zstd.NewReader(nil)
        switch compressionType {
        case noneType:
                r.data = block
        case lz4Type:
                if _, err := lz4.UncompressBlock(block, r.data); err != nil {
                        return fmt.Errorf("cannot decompress lz4 block: %w", err)
                }
        case zstdType:
                r.data = r.data[:0] // Wipe the slice but keep allocated memory
                r.data, err = decoder.DecodeAll(block, r.data)
                if err != nil {
                        return fmt.Errorf("cannot decompress zstd block: %w", err)
                }
        default:
                return fmt.Errorf("unknown compressionType: %X", compressionType)
        }
        return nil
}
const (
        noneType = 0x02
        lz4Type  = 0x82
        zstdType = 0x90
)
func (r *Reader) readUint32() (uint32, error) {
        b := r.scratch[:4]
        _, err := io.ReadFull(r.src, b)
        if err != nil {
                return 0, err
        }
        n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
        return n, nil
}
		
		package clients
import (
        "context"
        "fmt"
        "github.com/contentsquare/chproxy/config"
        "github.com/redis/go-redis/v9"
)
func NewRedisClient(cfg config.RedisCacheConfig) (redis.UniversalClient, error) {
        options := &redis.UniversalOptions{
                Addrs:      cfg.Addresses,
                Username:   cfg.Username,
                Password:   cfg.Password,
                MaxRetries: 7, // default value = 3, since MinRetryBackoff = 8 msec & MinRetryBackoff = 512 msec
                // the redis client will wait up to 1016 msec btw the 7 tries
        }
        if len(cfg.Addresses) == 1 {
                options.DB = cfg.DBIndex
        }
        if len(cfg.CertFile) != 0 || len(cfg.KeyFile) != 0 {
                tlsConfig, err := cfg.TLS.BuildTLSConfig(nil)
                if err != nil {
                        return nil, err
                }
                options.TLSConfig = tlsConfig
        }
        r := redis.NewUniversalClient(options)
        err := r.Ping(context.Background()).Err()
        if err != nil {
                return nil, fmt.Errorf("failed to reach redis: %w", err)
        }
        return r, nil
}
		
		package config
import (
        "bytes"
        "crypto/tls"
        "fmt"
        "os"
        "regexp"
        "strings"
        "time"
        "github.com/mohae/deepcopy"
        "golang.org/x/crypto/acme/autocert"
        "gopkg.in/yaml.v2"
)
var (
        defaultConfig = Config{
                Clusters: []Cluster{defaultCluster},
        }
        defaultCluster = Cluster{
                Scheme:       "http",
                ClusterUsers: []ClusterUser{defaultClusterUser},
                HeartBeat:    defaultHeartBeat,
                RetryNumber:  defaultRetryNumber,
        }
        defaultClusterUser = ClusterUser{
                Name: "default",
        }
        defaultHeartBeat = HeartBeat{
                Interval: Duration(time.Second * 5),
                Timeout:  Duration(time.Second * 3),
                Request:  "/ping",
                Response: "Ok.\n",
                User:     "",
                Password: "",
        }
        defaultConnectionPool = ConnectionPool{
                MaxIdleConns:        100,
                MaxIdleConnsPerHost: 2,
        }
        defaultExecutionTime = Duration(120 * time.Second)
        defaultMaxPayloadSize = ByteSize(1 << 50)
        defaultRetryNumber = 0
)
// Config describes server configuration, access and proxy rules
type Config struct {
        Server Server `yaml:"server,omitempty"`
        Clusters []Cluster `yaml:"clusters"`
        Users []User `yaml:"users"`
        // Whether to print debug logs
        LogDebug bool `yaml:"log_debug,omitempty"`
        // Whether to ignore security warnings
        HackMePlease bool `yaml:"hack_me_please,omitempty"`
        NetworkGroups []NetworkGroups `yaml:"network_groups,omitempty"`
        Caches []Cache `yaml:"caches,omitempty"`
        ParamGroups []ParamGroup `yaml:"param_groups,omitempty"`
        ConnectionPool ConnectionPool `yaml:"connection_pool,omitempty"`
        networkReg map[string]Networks
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
}
// String implements the Stringer interface
func (c *Config) String() string {
        b, err := yaml.Marshal(withoutSensitiveInfo(c))
        if err != nil {
                panic(err)
        }
        return string(b)
}
func withoutSensitiveInfo(config *Config) *Config {
        const pswPlaceHolder = "XXX"
        // nolint: forcetypeassert // no need to check type, it is specified by function.
        c := deepcopy.Copy(config).(*Config)
        for i := range c.Users {
                c.Users[i].Password = pswPlaceHolder
        }
        for i := range c.Clusters {
                if len(c.Clusters[i].KillQueryUser.Name) > 0 {
                        c.Clusters[i].KillQueryUser.Password = pswPlaceHolder
                }
                for j := range c.Clusters[i].ClusterUsers {
                        c.Clusters[i].ClusterUsers[j].Password = pswPlaceHolder
                }
        }
        for i := range c.Caches {
                if len(c.Caches[i].Redis.Username) > 0 {
                        c.Caches[i].Redis.Password = pswPlaceHolder
                }
        }
        return c
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
        // set c to the defaults and then overwrite it with the input.
        *c = defaultConfig
        type plain Config
        if err := unmarshal((*plain)(c)); err != nil {
                return err
        }
        if err := c.validate(); err != nil {
                return err
        }
        return checkOverflow(c.XXX, "config")
}
func (c *Config) validate() error {
        if len(c.Users) == 0 {
                return fmt.Errorf("`users` must contain at least 1 user")
        }
        if len(c.Clusters) == 0 {
                return fmt.Errorf("`clusters` must contain at least 1 cluster")
        }
        if len(c.Server.HTTP.ListenAddr) == 0 && len(c.Server.HTTPS.ListenAddr) == 0 {
                return fmt.Errorf("neither HTTP nor HTTPS not configured")
        }
        if len(c.Server.HTTPS.ListenAddr) > 0 {
                if len(c.Server.HTTPS.Autocert.CacheDir) == 0 && len(c.Server.HTTPS.CertFile) == 0 && len(c.Server.HTTPS.KeyFile) == 0 {
                        return fmt.Errorf("configuration `https` is missing. " +
                                "Must be specified `https.cache_dir` for autocert " +
                                "OR `https.key_file` and `https.cert_file` for already existing certs")
                }
                if len(c.Server.HTTPS.Autocert.CacheDir) > 0 {
                        c.Server.HTTP.ForceAutocertHandler = true
                }
        }
        return nil
}
func (cfg *Config) setDefaults() error {
        var maxResponseTime time.Duration
        var err error
        for i := range cfg.Clusters {
                c := &cfg.Clusters[i]
                for j := range c.ClusterUsers {
                        u := &c.ClusterUsers[j]
                        cud := time.Duration(u.MaxExecutionTime + u.MaxQueueTime)
                        if cud > maxResponseTime {
                                maxResponseTime = cud
                        }
                        if u.AllowedNetworks, err = cfg.groupToNetwork(u.NetworksOrGroups); err != nil {
                                return err
                        }
                }
        }
        for i := range cfg.Users {
                u := &cfg.Users[i]
                u.setDefaults()
                ud := time.Duration(u.MaxExecutionTime + u.MaxQueueTime)
                if ud > maxResponseTime {
                        maxResponseTime = ud
                }
                if u.AllowedNetworks, err = cfg.groupToNetwork(u.NetworksOrGroups); err != nil {
                        return err
                }
        }
        for i := range cfg.Caches {
                c := &cfg.Caches[i]
                c.setDefaults()
        }
        cfg.setServerMaxResponseTime(maxResponseTime)
        return nil
}
func (cfg *Config) setServerMaxResponseTime(maxResponseTime time.Duration) {
        if maxResponseTime < 0 {
                maxResponseTime = 0
        }
        // Give an additional minute for the maximum response time,
        // so the response body may be sent to the requester.
        maxResponseTime += time.Minute
        if len(cfg.Server.HTTP.ListenAddr) > 0 && cfg.Server.HTTP.WriteTimeout == 0 {
                cfg.Server.HTTP.WriteTimeout = Duration(maxResponseTime)
        }
        if len(cfg.Server.HTTPS.ListenAddr) > 0 && cfg.Server.HTTPS.WriteTimeout == 0 {
                cfg.Server.HTTPS.WriteTimeout = Duration(maxResponseTime)
        }
}
func (c *Config) groupToNetwork(src NetworksOrGroups) (Networks, error) {
        if len(src) == 0 {
                return nil, nil
        }
        dst := make(Networks, 0)
        for _, v := range src {
                group, ok := c.networkReg[v]
                if ok {
                        dst = append(dst, group...)
                } else {
                        ipnet, err := stringToIPnet(v)
                        if err != nil {
                                return nil, err
                        }
                        dst = append(dst, ipnet)
                }
        }
        return dst, nil
}
// Server describes configuration of proxy server
// These settings are immutable and can't be reloaded without restart
type Server struct {
        // Optional HTTP configuration
        HTTP HTTP `yaml:"http,omitempty"`
        // Optional TLS configuration
        HTTPS HTTPS `yaml:"https,omitempty"`
        // Optional metrics handler configuration
        Metrics Metrics `yaml:"metrics,omitempty"`
        // Optional Proxy configuration
        Proxy Proxy `yaml:"proxy,omitempty"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (s *Server) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain Server
        if err := unmarshal((*plain)(s)); err != nil {
                return err
        }
        return checkOverflow(s.XXX, "server")
}
// TimeoutCfg contains configurable http.Server timeouts
type TimeoutCfg struct {
        // ReadTimeout is the maximum duration for reading the entire
        // request, including the body.
        // Default value is 1m
        ReadTimeout Duration `yaml:"read_timeout,omitempty"`
        // WriteTimeout is the maximum duration before timing out writes of the response.
        // Default is largest MaxExecutionTime + MaxQueueTime value from Users or Clusters
        WriteTimeout Duration `yaml:"write_timeout,omitempty"`
        // IdleTimeout is the maximum amount of time to wait for the next request.
        // Default is 10m
        IdleTimeout Duration `yaml:"idle_timeout,omitempty"`
}
// HTTP describes configuration for server to listen HTTP connections
type HTTP struct {
        // TCP address to listen to for http
        ListenAddr string `yaml:"listen_addr"`
        NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
        // List of networks that access is allowed from
        // Each list item could be IP address or subnet mask
        // if omitted or zero - no limits would be applied
        AllowedNetworks Networks `yaml:"-"`
        // Whether to support Autocert handler for http-01 challenge
        ForceAutocertHandler bool
        TimeoutCfg `yaml:",inline"`
        // Catches all undefined fields and must be empty after parsing.
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *HTTP) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain HTTP
        if err := unmarshal((*plain)(c)); err != nil {
                return err
        }
        if err := c.validate(); err != nil {
                return err
        }
        return checkOverflow(c.XXX, "http")
}
func (c *HTTP) validate() error {
        if c.ReadTimeout == 0 {
                c.ReadTimeout = Duration(time.Minute)
        }
        if c.IdleTimeout == 0 {
                c.IdleTimeout = Duration(time.Minute * 10)
        }
        return nil
}
// TLS describes generic configuration for TLS connections,
// it can be used for both HTTPS and Redis TLS.
type TLS struct {
        // Certificate and key files for client cert authentication to the server
        CertFile           string   `yaml:"cert_file,omitempty"`
        KeyFile            string   `yaml:"key_file,omitempty"`
        Autocert           Autocert `yaml:"autocert,omitempty"`
        InsecureSkipVerify bool     `yaml:"insecure_skip_verify,omitempty"`
}
// BuildTLSConfig builds tls.Config from TLS configuration.
func (c *TLS) BuildTLSConfig(acm *autocert.Manager) (*tls.Config, error) {
        tlsCfg := tls.Config{
                PreferServerCipherSuites: true,
                MinVersion:               tls.VersionTLS12,
                CurvePreferences: []tls.CurveID{
                        tls.CurveP256,
                        tls.X25519,
                },
                InsecureSkipVerify: c.InsecureSkipVerify, // nolint: gosec
        }
        if len(c.KeyFile) > 0 && len(c.CertFile) > 0 {
                cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
                if err != nil {
                        return nil, fmt.Errorf("cannot load cert for `cert_file`=%q, `key_file`=%q: %w",
                                c.CertFile, c.KeyFile, err)
                }
                tlsCfg.Certificates = []tls.Certificate{cert}
        } else {
                if acm == nil {
                        return nil, fmt.Errorf("autocert manager is not configured")
                }
                tlsCfg.GetCertificate = acm.GetCertificate
        }
        return &tlsCfg, nil
}
// HTTPS describes configuration for server to listen HTTPS connections
// It can be autocert with letsencrypt
// or custom certificate
type HTTPS struct {
        // TCP address to listen to for https
        // Default is `:443`
        ListenAddr string `yaml:"listen_addr,omitempty"`
        // TLS configuration
        TLS `yaml:",inline"`
        NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
        // List of networks that access is allowed from
        // Each list item could be IP address or subnet mask
        // if omitted or zero - no limits would be applied
        AllowedNetworks Networks `yaml:"-"`
        TimeoutCfg `yaml:",inline"`
        // Catches all undefined fields and must be empty after parsing.
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *HTTPS) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain HTTPS
        if err := unmarshal((*plain)(c)); err != nil {
                return err
        }
        if err := c.validate(); err != nil {
                return err
        }
        return checkOverflow(c.XXX, "https")
}
func (c *HTTPS) validate() error {
        if c.ReadTimeout == 0 {
                c.ReadTimeout = Duration(time.Minute)
        }
        if c.IdleTimeout == 0 {
                c.IdleTimeout = Duration(time.Minute * 10)
        }
        if len(c.ListenAddr) == 0 {
                c.ListenAddr = ":443"
        }
        if err := c.validateCertConfig(); err != nil {
                return err
        }
        return nil
}
func (c *HTTPS) validateCertConfig() error {
        if len(c.Autocert.CacheDir) > 0 {
                if len(c.CertFile) > 0 || len(c.KeyFile) > 0 {
                        return fmt.Errorf("it is forbidden to specify certificate and `https.autocert` at the same time. Choose one way")
                }
                if len(c.NetworksOrGroups) > 0 {
                        return fmt.Errorf("`letsencrypt` specification requires https server to be without `allowed_networks` limits. " +
                                "Otherwise, certificates will be impossible to generate")
                }
        }
        if len(c.CertFile) > 0 && len(c.KeyFile) == 0 {
                return fmt.Errorf("`https.key_file` must be specified")
        }
        if len(c.KeyFile) > 0 && len(c.CertFile) == 0 {
                return fmt.Errorf("`https.cert_file` must be specified")
        }
        return nil
}
// Autocert configuration via letsencrypt
// It requires port :80 to be open
// see https://community.letsencrypt.org/t/2018-01-11-update-regarding-acme-tls-sni-and-shared-hosting-infrastructure/50188
type Autocert struct {
        // Path to the directory where autocert certs are cached
        CacheDir string `yaml:"cache_dir,omitempty"`
        // List of host names to which proxy is allowed to respond to
        // see https://godoc.org/golang.org/x/crypto/acme/autocert#HostPolicy
        AllowedHosts []string `yaml:"allowed_hosts,omitempty"`
        // Catches all undefined fields and must be empty after parsing.
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Autocert) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain Autocert
        if err := unmarshal((*plain)(c)); err != nil {
                return err
        }
        return checkOverflow(c.XXX, "autocert")
}
// Metrics describes configuration to access metrics endpoint
type Metrics struct {
        NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
        // List of networks that access is allowed from
        // Each list item could be IP address or subnet mask
        // if omitted or zero - no limits would be applied
        AllowedNetworks Networks `yaml:"-"`
        // Prometheus metric namespace
        Namespace string `yaml:"namespace,omitempty"`
        // Catches all undefined fields and must be empty after parsing.
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Metrics) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain Metrics
        if err := unmarshal((*plain)(c)); err != nil {
                return err
        }
        return checkOverflow(c.XXX, "metrics")
}
type Proxy struct {
        // Enable enables parsing proxy headers. In proxy mode, CHProxy will try to
        // parse the X-Forwarded-For, X-Real-IP or Forwarded header to extract the IP. If an other header is configured
        // in the proxy settings, CHProxy will  use that header instead.
        Enable bool `yaml:"enable,omitempty"`
        // Header allows for configuring an alternative header to parse the remote IP from, e.g.
        // CF-Connecting-IP. If this is set, Enable must be set to true otherwise this setting
        // will be ignored.
        Header string `yaml:"header,omitempty"`
        // Catches all undefined fields and must be empty after parsing.
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Proxy) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain Proxy
        if err := unmarshal((*plain)(c)); err != nil {
                return err
        }
        if !c.Enable && c.Header != "" {
                return fmt.Errorf("`proxy_header` cannot be set without enabling proxy settings")
        }
        return checkOverflow(c.XXX, "proxy")
}
// Cluster describes CH cluster configuration
// The simplest configuration consists of:
//
//        cluster description - see <remote_servers> section in CH config.xml
//        and users - see <users> section in CH users.xml
type Cluster struct {
        // Name of ClickHouse cluster
        Name string `yaml:"name"`
        // Scheme: `http` or `https`; would be applied to all nodes
        // default value is `http`
        Scheme string `yaml:"scheme,omitempty"`
        // Nodes contains cluster nodes.
        //
        // Either Nodes or Replicas must be set, but not both.
        Nodes []string `yaml:"nodes,omitempty"`
        // Replicas contains replicas.
        //
        // Either Replicas or Nodes must be set, but not both.
        Replicas []Replica `yaml:"replicas,omitempty"`
        // ClusterUsers - list of ClickHouse users
        ClusterUsers []ClusterUser `yaml:"users"`
        // KillQueryUser - user configuration for killing timed out queries.
        // By default timed out queries are killed under `default` user.
        KillQueryUser KillQueryUser `yaml:"kill_query_user,omitempty"`
        // HeartBeat - user configuration for heart beat requests
        HeartBeat HeartBeat `yaml:"heartbeat,omitempty"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
        // Retry number for query - how many times a query can retry after receiving a recoverable but failed response from Clickhouse node
        RetryNumber int `yaml:"retry_number,omitempty"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Cluster) UnmarshalYAML(unmarshal func(interface{}) error) error {
        *c = defaultCluster
        type plain Cluster
        if err := unmarshal((*plain)(c)); err != nil {
                return err
        }
        if err := c.validate(); err != nil {
                return err
        }
        return checkOverflow(c.XXX, fmt.Sprintf("cluster %q", c.Name))
}
func (c *Cluster) validate() error {
        if len(c.Name) == 0 {
                return fmt.Errorf("`cluster.name` cannot be empty")
        }
        if err := c.validateMinimumRequirements(); err != nil {
                return err
        }
        if c.Scheme != "http" && c.Scheme != "https" {
                return fmt.Errorf("`cluster.scheme` must be `http` or `https`, got %q instead for %q", c.Scheme, c.Name)
        }
        if c.HeartBeat.Interval == 0 && c.HeartBeat.Timeout == 0 && c.HeartBeat.Response == "" {
                return fmt.Errorf("`cluster.heartbeat` cannot be unset for %q", c.Name)
        }
        return nil
}
func (c *Cluster) validateMinimumRequirements() error {
        if len(c.Nodes) == 0 && len(c.Replicas) == 0 {
                return fmt.Errorf("either `cluster.nodes` or `cluster.replicas` must be set for %q", c.Name)
        }
        if len(c.Nodes) > 0 && len(c.Replicas) > 0 {
                return fmt.Errorf("`cluster.nodes` cannot be simultaneously set with `cluster.replicas` for %q", c.Name)
        }
        if len(c.ClusterUsers) == 0 {
                return fmt.Errorf("`cluster.users` must contain at least 1 user for %q", c.Name)
        }
        return nil
}
// Replica contains ClickHouse replica configuration.
type Replica struct {
        // Name is replica name.
        Name string `yaml:"name"`
        // Nodes contains replica nodes.
        Nodes []string `yaml:"nodes"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (r *Replica) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain Replica
        if err := unmarshal((*plain)(r)); err != nil {
                return err
        }
        if len(r.Name) == 0 {
                return fmt.Errorf("`replica.name` cannot be empty")
        }
        if len(r.Nodes) == 0 {
                return fmt.Errorf("`replica.nodes` cannot be empty for %q", r.Name)
        }
        return checkOverflow(r.XXX, fmt.Sprintf("replica %q", r.Name))
}
// KillQueryUser - user configuration for killing timed out queries.
type KillQueryUser struct {
        // User name
        Name string `yaml:"name"`
        // User password to access CH with basic auth
        Password string `yaml:"password,omitempty"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (u *KillQueryUser) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain KillQueryUser
        if err := unmarshal((*plain)(u)); err != nil {
                return err
        }
        if len(u.Name) == 0 {
                return fmt.Errorf("`cluster.kill_query_user.name` must be specified")
        }
        return checkOverflow(u.XXX, "kill_query_user")
}
// HeartBeat - configuration for heartbeat.
type HeartBeat struct {
        // Interval is an interval of checking
        // all cluster nodes for availability
        // if omitted or zero - interval will be set to 5s
        Interval Duration `yaml:"interval,omitempty"`
        // Timeout is a timeout of wait response from cluster nodes
        // if omitted or zero - interval will be set to 3s
        Timeout Duration `yaml:"timeout,omitempty"`
        // Request is a query
        // default value is `/ping`
        Request string `yaml:"request,omitempty"`
        // Reference response from clickhouse on health check request
        // default value is `Ok.\n`
        Response string `yaml:"response,omitempty"`
        // Credentials to send heartbeat requests
        // for anything except '/ping'.
        // If not specified, the first cluster user' creadentials are used
        User     string `yaml:"user,omitempty"`
        Password string `yaml:"password,omitempty"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (h *HeartBeat) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain HeartBeat
        if err := unmarshal((*plain)(h)); err != nil {
                return err
        }
        return checkOverflow(h.XXX, "heartbeat")
}
// User describes list of allowed users
// which requests will be proxied to ClickHouse
type User struct {
        // User name
        Name string `yaml:"name"`
        // User password to access proxy with basic auth
        Password string `yaml:"password,omitempty"`
        // ToCluster is the name of cluster where requests
        // will be proxied
        ToCluster string `yaml:"to_cluster"`
        // ToUser is the name of cluster_user from cluster's ToCluster
        // whom credentials will be used for proxying request to CH
        ToUser string `yaml:"to_user"`
        // Maximum number of concurrently running queries for user
        // if omitted or zero - no limits would be applied
        MaxConcurrentQueries uint32 `yaml:"max_concurrent_queries,omitempty"`
        // Maximum duration of query execution for user
        // if omitted or zero - limit is set to 120 seconds
        MaxExecutionTime Duration `yaml:"max_execution_time,omitempty"`
        // Maximum number of requests per minute for user
        // if omitted or zero - no limits would be applied
        ReqPerMin uint32 `yaml:"requests_per_minute,omitempty"`
        // The burst of request packet size token bucket for user
        // if omitted or zero - no limits would be applied
        ReqPacketSizeTokensBurst ByteSize `yaml:"request_packet_size_tokens_burst,omitempty"`
        // The request packet size tokens produced rate per second for user
        // if omitted or zero - no limits would be applied
        ReqPacketSizeTokensRate ByteSize `yaml:"request_packet_size_tokens_rate,omitempty"`
        // Maximum number of queries waiting for execution in the queue
        // if omitted or zero - queries are executed without waiting
        // in the queue
        MaxQueueSize uint32 `yaml:"max_queue_size,omitempty"`
        // Maximum duration the query may wait in the queue
        // if omitted or zero - 10s duration is used
        MaxQueueTime Duration `yaml:"max_queue_time,omitempty"`
        NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
        // List of networks that access is allowed from
        // Each list item could be IP address or subnet mask
        // if omitted or zero - no limits would be applied
        AllowedNetworks Networks `yaml:"-"`
        // Whether to deny http connections for this user
        DenyHTTP bool `yaml:"deny_http,omitempty"`
        // Whether to deny https connections for this user
        DenyHTTPS bool `yaml:"deny_https,omitempty"`
        // Whether to allow CORS requests for this user
        AllowCORS bool `yaml:"allow_cors,omitempty"`
        // Name of Cache configuration to use for responses of this user
        Cache string `yaml:"cache,omitempty"`
        // Name of ParamGroup to use
        Params string `yaml:"params,omitempty"`
        // prefix_*
        IsWildcarded bool `yaml:"is_wildcarded,omitempty"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (u *User) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain User
        if err := unmarshal((*plain)(u)); err != nil {
                return err
        }
        if err := u.validate(); err != nil {
                return err
        }
        return checkOverflow(u.XXX, fmt.Sprintf("user %q", u.Name))
}
func (u *User) validate() error {
        if len(u.Name) == 0 {
                return fmt.Errorf("`user.name` cannot be empty")
        }
        if len(u.ToUser) == 0 {
                return fmt.Errorf("`user.to_user` cannot be empty for %q", u.Name)
        }
        if len(u.ToCluster) == 0 {
                return fmt.Errorf("`user.to_cluster` cannot be empty for %q", u.Name)
        }
        if u.DenyHTTP && u.DenyHTTPS {
                return fmt.Errorf("`deny_http` and `deny_https` cannot be simultaneously set to `true` for %q", u.Name)
        }
        if err := u.validateWildcarded(); err != nil {
                return err
        }
        if err := u.validateRateLimitConfig(); err != nil {
                return err
        }
        return nil
}
func (u *User) validateWildcarded() error {
        if u.IsWildcarded {
                if s := strings.Split(u.Name, "*"); !(len(s) == 2 && (s[0] == "" || s[1] == "")) {
                        return fmt.Errorf("user name %q marked 'is_wildcared' does not match 'prefix*' or '*suffix' or '*'", u.Name)
                }
        }
        return nil
}
func (u *User) validateRateLimitConfig() error {
        if u.MaxQueueTime > 0 && u.MaxQueueSize == 0 {
                return fmt.Errorf("`max_queue_size` must be set if `max_queue_time` is set for %q", u.Name)
        }
        if u.ReqPacketSizeTokensBurst > 0 && u.ReqPacketSizeTokensRate == 0 {
                return fmt.Errorf("`request_packet_size_tokens_rate` must be set if `request_packet_size_tokens_burst` is set for %q", u.Name)
        }
        return nil
}
func (u *User) validateSecurity(hasHTTP, hasHTTPS bool) error {
        if len(u.Password) == 0 {
                if !u.DenyHTTPS && hasHTTPS {
                        return fmt.Errorf("https: user %q has neither password nor `allowed_networks` on `user` or `server.http` level",
                                u.Name)
                }
                if !u.DenyHTTP && hasHTTP {
                        return fmt.Errorf("http: user %q has neither password nor `allowed_networks` on `user` or `server.http` level",
                                u.Name)
                }
        }
        if len(u.Password) > 0 && hasHTTP {
                return fmt.Errorf("http: user %q is allowed to connect via http, but not limited by `allowed_networks` "+
                        "on `user` or `server.http` level - password could be stolen", u.Name)
        }
        return nil
}
func (u *User) setDefaults() {
        if u.MaxExecutionTime == 0 {
                u.MaxExecutionTime = defaultExecutionTime
        }
}
// NetworkGroups describes a named Networks lists
type NetworkGroups struct {
        // Name of the group
        Name string `yaml:"name"`
        // List of networks
        // Each list item could be IP address or subnet mask
        Networks Networks `yaml:"networks"`
        // Catches all undefined fields and must be empty after parsing.
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (ng *NetworkGroups) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain NetworkGroups
        if err := unmarshal((*plain)(ng)); err != nil {
                return err
        }
        if len(ng.Name) == 0 {
                return fmt.Errorf("`network_group.name` must be specified")
        }
        if len(ng.Networks) == 0 {
                return fmt.Errorf("`network_group.networks` must contain at least one network")
        }
        return checkOverflow(ng.XXX, fmt.Sprintf("network_group %q", ng.Name))
}
// NetworksOrGroups is a list of strings with names of NetworkGroups
// or just Networks
type NetworksOrGroups []string
// Cache describes configuration options for caching
// responses from CH clusters
type Cache struct {
        // Mode of cache (file_system, redis)
        // todo make it an enum
        Mode string `yaml:"mode"`
        // Name of configuration for further assign
        Name string `yaml:"name"`
        // Expiration period for cached response
        // Files which are older than expiration period will be deleted
        // on new request and re-cached
        Expire Duration `yaml:"expire,omitempty"`
        // Deprecated: GraceTime duration before the expired entry is deleted from the cache.
        // It's deprecated and in future versions it'll be replaced by user's MaxExecutionTime.
        // It's already the case today if value of GraceTime is omitted.
        GraceTime Duration `yaml:"grace_time,omitempty"`
        FileSystem FileSystemCacheConfig `yaml:"file_system,omitempty"`
        Redis RedisCacheConfig `yaml:"redis,omitempty"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
        // Maximum total size of request payload for caching
        MaxPayloadSize ByteSize `yaml:"max_payload_size,omitempty"`
        // Whether a query cached by a user could be used by another user
        SharedWithAllUsers bool `yaml:"shared_with_all_users,omitempty"`
}
func (c *Cache) setDefaults() {
        if c.MaxPayloadSize <= 0 {
                c.MaxPayloadSize = defaultMaxPayloadSize
        }
}
type FileSystemCacheConfig struct {
        //// Path to directory where cached files will be saved
        Dir string `yaml:"dir"`
        // Maximum total size of all cached to Dir files
        // If size is exceeded - the oldest files in Dir will be deleted
        // until total size becomes normal
        MaxSize ByteSize `yaml:"max_size"`
}
type RedisCacheConfig struct {
        TLS `yaml:",inline"`
        Username  string                 `yaml:"username,omitempty"`
        Password  string                 `yaml:"password,omitempty"`
        Addresses []string               `yaml:"addresses"`
        DBIndex   int                    `yaml:"db_index,omitempty"`
        XXX       map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Cache) UnmarshalYAML(unmarshal func(interface{}) error) (err error) {
        type plain Cache
        if err = unmarshal((*plain)(c)); err != nil {
                return err
        }
        if len(c.Name) == 0 {
                return fmt.Errorf("`cache.name` must be specified")
        }
        switch c.Mode {
        case "file_system":
                err = c.checkFileSystemConfig()
        case "redis":
                err = c.checkRedisConfig()
        default:
                err = fmt.Errorf("not supported cache type %v. Supported types: [file_system]", c.Mode)
        }
        if err != nil {
                return fmt.Errorf("failed to configure cache for %q", c.Name)
        }
        return checkOverflow(c.XXX, fmt.Sprintf("cache %q", c.Name))
}
func (c *Cache) checkFileSystemConfig() error {
        if len(c.FileSystem.Dir) == 0 {
                return fmt.Errorf("`cache.filesystem.dir` must be specified for %q", c.Name)
        }
        if c.FileSystem.MaxSize <= 0 {
                return fmt.Errorf("`cache.filesystem.max_size` must be specified for %q", c.Name)
        }
        return nil
}
func (c *Cache) checkRedisConfig() error {
        if len(c.Redis.Addresses) == 0 {
                return fmt.Errorf("`cache.redis.addresses` must be specified for %q", c.Name)
        }
        return nil
}
// ParamGroup describes named group of GET params
// for sending with each query
type ParamGroup struct {
        // Name of configuration for further assign
        Name string `yaml:"name"`
        // Params contains a list of GET params
        Params []Param `yaml:"params"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (pg *ParamGroup) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain ParamGroup
        if err := unmarshal((*plain)(pg)); err != nil {
                return err
        }
        if len(pg.Name) == 0 {
                return fmt.Errorf("`param_group.name` must be specified")
        }
        if len(pg.Params) == 0 {
                return fmt.Errorf("`param_group.params` must contain at least one param")
        }
        return checkOverflow(pg.XXX, fmt.Sprintf("param_group %q", pg.Name))
}
// Param describes URL param value
type Param struct {
        // Key is a name of params
        Key string `yaml:"key"`
        // Value is a value of param
        Value string `yaml:"value"`
}
// ConnectionPool describes pool of connection with ClickHouse
// settings
type ConnectionPool struct {
        // Maximum total number of idle connections between chproxy and all ClickHouse instances
        MaxIdleConns int `yaml:"max_idle_conns,omitempty"`
        // Maximum number of idle connections between chproxy and particuler ClickHouse instance
        MaxIdleConnsPerHost int `yaml:"max_idle_conns_per_host,omitempty"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (cp *ConnectionPool) UnmarshalYAML(unmarshal func(interface{}) error) error {
        *cp = defaultConnectionPool
        type plain ConnectionPool
        if err := unmarshal((*plain)(cp)); err != nil {
                return err
        }
        if cp.MaxIdleConnsPerHost > cp.MaxIdleConns || cp.MaxIdleConns < 0 {
                return fmt.Errorf("inconsistent ConnectionPool settings")
        }
        return checkOverflow(cp.XXX, "connection_pool")
}
// ClusterUser describes simplest <users> configuration
type ClusterUser struct {
        // User name in ClickHouse users.xml config
        Name string `yaml:"name"`
        // User password in ClickHouse users.xml config
        Password string `yaml:"password,omitempty"`
        // Maximum number of concurrently running queries for user
        // if omitted or zero - no limits would be applied
        MaxConcurrentQueries uint32 `yaml:"max_concurrent_queries,omitempty"`
        // Maximum duration of query execution for user
        // if omitted or zero - limit is set to 120 seconds
        MaxExecutionTime Duration `yaml:"max_execution_time,omitempty"`
        // Maximum number of requests per minute for user
        // if omitted or zero - no limits would be applied
        ReqPerMin uint32 `yaml:"requests_per_minute,omitempty"`
        // The burst of request packet size token bucket for user
        // if omitted or zero - no limits would be applied
        ReqPacketSizeTokensBurst ByteSize `yaml:"request_packet_size_tokens_burst,omitempty"`
        // The request packet size tokens produced rate for user
        // if omitted or zero - no limits would be applied
        ReqPacketSizeTokensRate ByteSize `yaml:"request_packet_size_tokens_rate,omitempty"`
        // Maximum number of queries waiting for execution in the queue
        // if omitted or zero - queries are executed without waiting
        // in the queue
        MaxQueueSize uint32 `yaml:"max_queue_size,omitempty"`
        // Maximum duration the query may wait in the queue
        // if omitted or zero - 10s duration is used
        MaxQueueTime Duration `yaml:"max_queue_time,omitempty"`
        NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
        // List of networks that access is allowed from
        // Each list item could be IP address or subnet mask
        // if omitted or zero - no limits would be applied
        AllowedNetworks Networks `yaml:"-"`
        // Catches all undefined fields
        XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (cu *ClusterUser) UnmarshalYAML(unmarshal func(interface{}) error) error {
        type plain ClusterUser
        if err := unmarshal((*plain)(cu)); err != nil {
                return err
        }
        if len(cu.Name) == 0 {
                return fmt.Errorf("`cluster.user.name` cannot be empty")
        }
        if cu.MaxQueueTime > 0 && cu.MaxQueueSize == 0 {
                return fmt.Errorf("`max_queue_size` must be set if `max_queue_time` is set for %q", cu.Name)
        }
        if cu.ReqPacketSizeTokensBurst > 0 && cu.ReqPacketSizeTokensRate == 0 {
                return fmt.Errorf("`request_packet_size_tokens_rate` must be set if `request_packet_size_tokens_burst` is set for %q", cu.Name)
        }
        return checkOverflow(cu.XXX, fmt.Sprintf("cluster.user %q", cu.Name))
}
// LoadFile loads and validates configuration from provided .yml file
func LoadFile(filename string) (*Config, error) {
        content, err := os.ReadFile(filename)
        if err != nil {
                return nil, err
        }
        content = findAndReplacePlaceholders(content)
        cfg := &Config{}
        if err := yaml.Unmarshal(content, cfg); err != nil {
                return nil, err
        }
        cfg.networkReg = make(map[string]Networks, len(cfg.NetworkGroups))
        for _, ng := range cfg.NetworkGroups {
                if _, ok := cfg.networkReg[ng.Name]; ok {
                        return nil, fmt.Errorf("duplicate `network_groups.name` %q", ng.Name)
                }
                cfg.networkReg[ng.Name] = ng.Networks
        }
        if cfg.Server.HTTP.AllowedNetworks, err = cfg.groupToNetwork(cfg.Server.HTTP.NetworksOrGroups); err != nil {
                return nil, err
        }
        if cfg.Server.HTTPS.AllowedNetworks, err = cfg.groupToNetwork(cfg.Server.HTTPS.NetworksOrGroups); err != nil {
                return nil, err
        }
        if cfg.Server.Metrics.AllowedNetworks, err = cfg.groupToNetwork(cfg.Server.Metrics.NetworksOrGroups); err != nil {
                return nil, err
        }
        if err := cfg.setDefaults(); err != nil {
                return nil, err
        }
        if err := cfg.checkVulnerabilities(); err != nil {
                return nil, fmt.Errorf("security breach: %w\nSet option `hack_me_please=true` to disable security errors", err)
        }
        return cfg, nil
}
var envVarRegex = regexp.MustCompile(`\${([a-zA-Z_][a-zA-Z0-9_]*)}`)
// findAndReplacePlaceholders finds all environment variables placeholders in the config.
// Each placeholder is a string like ${VAR_NAME}. They will be replaced with the value of the
// corresponding environment variable. It returns the new content with replaced placeholders.
func findAndReplacePlaceholders(content []byte) []byte {
        for _, match := range envVarRegex.FindAllSubmatch(content, -1) {
                envVar := os.Getenv(string(match[1]))
                if envVar != "" {
                        content = bytes.ReplaceAll(content, match[0], []byte(envVar))
                }
        }
        return content
}
func (c Config) checkVulnerabilities() error {
        if c.HackMePlease {
                return nil
        }
        hasHTTPS := len(c.Server.HTTPS.ListenAddr) > 0 && len(c.Server.HTTPS.NetworksOrGroups) == 0
        hasHTTP := len(c.Server.HTTP.ListenAddr) > 0 && len(c.Server.HTTP.NetworksOrGroups) == 0
        for _, u := range c.Users {
                if len(u.NetworksOrGroups) != 0 {
                        continue
                }
                if err := u.validateSecurity(hasHTTP, hasHTTPS); err != nil {
                        return err
                }
        }
        return nil
}
		
		package config
import (
        "fmt"
        "math"
        "net"
        "regexp"
        "strconv"
        "strings"
        "time"
)
// ByteSize holds size in bytes.
//
// May be used in yaml for parsing byte size values.
type ByteSize uint64
var byteSizeRegexp = regexp.MustCompile(`^(\d+(?:\.\d+)?)\s*([KMGTP]?)B?$`)
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (bs *ByteSize) UnmarshalYAML(unmarshal func(interface{}) error) error {
        var s string
        if err := unmarshal(&s); err != nil {
                return err
        }
        s = strings.ToUpper(s)
        value, unit, err := parseStringParts(s)
        if err != nil {
                return err
        }
        if value <= 0 {
                return fmt.Errorf("byte size %q must be positive", s)
        }
        k := float64(1)
        switch unit {
        case "P":
                k = 1 << 50
        case "T":
                k = 1 << 40
        case "G":
                k = 1 << 30
        case "M":
                k = 1 << 20
        case "K":
                k = 1 << 10
        }
        value *= k
        *bs = ByteSize(value)
        // check for overflow
        e := math.Abs(float64(*bs)-value) / value
        if e > 1e-6 {
                return fmt.Errorf("byte size %q is too big", s)
        }
        return nil
}
func parseStringParts(s string) (float64, string, error) {
        s = strings.ToUpper(s)
        parts := byteSizeRegexp.FindStringSubmatch(strings.TrimSpace(s))
        if len(parts) < 3 {
                return -1, "", fmt.Errorf("cannot parse byte size %q: it must be positive float followed by optional units. For example, 1.5Gb, 3T", s)
        }
        value, err := strconv.ParseFloat(parts[1], 64)
        if err != nil {
                return -1, "", fmt.Errorf("cannot parse byte size %q: it must be positive float followed by optional units. For example, 1.5Gb, 3T; err: %w", s, err)
        }
        unit := parts[2]
        return value, unit, nil
}
// Networks is a list of IPNet entities
type Networks []*net.IPNet
// MarshalYAML implements yaml.Marshaler interface.
//
// It prettifies yaml output for Networks.
func (n Networks) MarshalYAML() (interface{}, error) {
        var a []string
        for _, x := range n {
                a = append(a, x.String())
        }
        return a, nil
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (n *Networks) UnmarshalYAML(unmarshal func(interface{}) error) error {
        var s []string
        if err := unmarshal(&s); err != nil {
                return err
        }
        networks := make(Networks, len(s))
        for i, s := range s {
                ipnet, err := stringToIPnet(s)
                if err != nil {
                        return err
                }
                networks[i] = ipnet
        }
        *n = networks
        return nil
}
// Contains checks whether passed addr is in the range of networks
func (n Networks) Contains(addr string) bool {
        if len(n) == 0 {
                return true
        }
        h, _, err := net.SplitHostPort(addr)
        if err != nil {
                // If we only have an IP address. This happens when the proxy middleware is enabled.
                h = addr
        }
        ip := net.ParseIP(h)
        if ip == nil {
                panic(fmt.Sprintf("BUG: unexpected error while parsing IP: %s", h))
        }
        for _, ipnet := range n {
                if ipnet.Contains(ip) {
                        return true
                }
        }
        return false
}
// Duration wraps time.Duration. It is used to parse the custom duration format
type Duration time.Duration
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
        var s string
        if err := unmarshal(&s); err != nil {
                return err
        }
        dur, err := StringToDuration(s)
        if err != nil {
                return err
        }
        *d = dur
        return nil
}
// String implements the Stringer interface.
func (d Duration) String() string {
        factors := map[string]time.Duration{
                "w":  time.Hour * 24 * 7,
                "d":  time.Hour * 24,
                "h":  time.Hour,
                "m":  time.Minute,
                "s":  time.Second,
                "ms": time.Millisecond,
                "µs": time.Microsecond,
                "ns": 1,
        }
        var t = time.Duration(d)
        unit := "ns"
        //nolint:exhaustive // Custom duration counter that doesn't switch on the duration.
        switch time.Duration(0) {
        case t % factors["w"]:
                unit = "w"
        case t % factors["d"]:
                unit = "d"
        case t % factors["h"]:
                unit = "h"
        case t % factors["m"]:
                unit = "m"
        case t % factors["s"]:
                unit = "s"
        case t % factors["ms"]:
                unit = "ms"
        case t % factors["µs"]:
                unit = "µs"
        }
        return fmt.Sprintf("%d%v", t/factors[unit], unit)
}
// MarshalYAML implements the yaml.Marshaler interface.
func (d Duration) MarshalYAML() (interface{}, error) {
        return d.String(), nil
}
// borrowed from github.com/prometheus/prometheus
var durationRE = regexp.MustCompile("^([0-9]+)(w|d|h|m|s|ms|µs|ns)$")
// StringToDuration parses a string into a time.Duration,
// assuming that a week always has 7d, and a day always has 24h.
func StringToDuration(durationStr string) (Duration, error) {
        n, unit, err := parseDurationParts(durationStr)
        if err != nil {
                return 0, err
        }
        return calculateDuration(n, unit)
}
func parseDurationParts(s string) (int, string, error) {
        matches := durationRE.FindStringSubmatch(s)
        if len(matches) != 3 {
                return 0, "", fmt.Errorf("not a valid duration string: %q", s)
        }
        n, err := strconv.Atoi(matches[1])
        if err != nil {
                return 0, "", fmt.Errorf("duration too long: %q", matches[1])
        }
        unit := matches[2]
        return n, unit, nil
}
func calculateDuration(n int, unit string) (Duration, error) {
        dur := time.Duration(n)
        switch unit {
        case "w":
                dur *= time.Hour * 24 * 7
        case "d":
                dur *= time.Hour * 24
        case "h":
                dur *= time.Hour
        case "m":
                dur *= time.Minute
        case "s":
                dur *= time.Second
        case "ms":
                dur *= time.Millisecond
        case "µs":
                dur *= time.Microsecond
        case "ns":
        default:
                return 0, fmt.Errorf("invalid time unit in duration string: %q", unit)
        }
        return Duration(dur), nil
}
		
		package config
import (
        "fmt"
        "net"
        "strings"
)
const entireIPv4 = "0.0.0.0/0"
func stringToIPnet(s string) (*net.IPNet, error) {
        if s == entireIPv4 {
                return nil, fmt.Errorf("suspicious mask specified \"0.0.0.0/0\". " +
                        "If you want to allow all then just omit `allowed_networks` field")
        }
        ip := s
        if !strings.Contains(ip, `/`) {
                ip += "/32"
        }
        _, ipnet, err := net.ParseCIDR(ip)
        if err != nil {
                return nil, fmt.Errorf("wrong network group name or address %q: %w", s, err)
        }
        return ipnet, nil
}
func checkOverflow(m map[string]interface{}, ctx string) error {
        if len(m) > 0 {
                var keys []string
                for k := range m {
                        keys = append(keys, k)
                }
                return fmt.Errorf("unknown fields in %s: %s", ctx, strings.Join(keys, ", "))
        }
        return nil
}
		
		package counter
import "sync/atomic"
type Counter struct {
        value atomic.Uint32
}
func (c *Counter) Store(n uint32) { c.value.Store(n) }
func (c *Counter) Load() uint32 { return c.value.Load() }
func (c *Counter) Dec() { c.value.Add(^uint32(0)) }
func (c *Counter) Inc() uint32 { return c.value.Add(1) }
		
		package heartbeat
import (
        "context"
        "fmt"
        "io"
        "net/http"
        "time"
        "github.com/contentsquare/chproxy/config"
)
var errUnexpectedResponse = fmt.Errorf("unexpected response")
type HeartBeat interface {
        IsHealthy(ctx context.Context, addr string) error
        Interval() time.Duration
}
type heartBeatOpts struct {
        defaultUser     string
        defaultPassword string
}
type Option interface {
        apply(*heartBeatOpts)
}
type defaultUser struct {
        defaultUser     string
        defaultPassword string
}
func (o defaultUser) apply(opts *heartBeatOpts) {
        opts.defaultUser = o.defaultUser
        opts.defaultPassword = o.defaultPassword
}
func WithDefaultUser(user, password string) Option {
        return defaultUser{
                defaultUser:     user,
                defaultPassword: password,
        }
}
type heartBeat struct {
        interval time.Duration
        timeout  time.Duration
        request  string
        response string
        user     string
        password string
}
// User credentials are not needed
const defaultEndpoint string = "/ping"
func NewHeartbeat(c config.HeartBeat, options ...Option) HeartBeat {
        opts := &heartBeatOpts{}
        for _, o := range options {
                o.apply(opts)
        }
        newHB := &heartBeat{
                interval: time.Duration(c.Interval),
                timeout:  time.Duration(c.Timeout),
                request:  c.Request,
                response: c.Response,
        }
        if c.Request != defaultEndpoint {
                if c.User != "" {
                        newHB.user = c.User
                        newHB.password = c.Password
                } else {
                        newHB.user = opts.defaultUser
                        newHB.password = opts.defaultPassword
                }
        }
        if newHB.request != defaultEndpoint && newHB.user == "" {
                panic("BUG: user is empty, no default user provided")
        }
        return newHB
}
func (hb *heartBeat) IsHealthy(ctx context.Context, addr string) error {
        req, err := http.NewRequest("GET", addr+hb.request, nil)
        if err != nil {
                return err
        }
        if hb.request != defaultEndpoint {
                req.SetBasicAuth(hb.user, hb.password)
        }
        ctx, cancel := context.WithTimeout(ctx, hb.timeout)
        defer cancel()
        req = req.WithContext(ctx)
        startTime := time.Now()
        resp, err := http.DefaultClient.Do(req)
        if err != nil {
                return fmt.Errorf("cannot send request in %s: %w", time.Since(startTime), err)
        }
        defer resp.Body.Close()
        if resp.StatusCode != http.StatusOK {
                return fmt.Errorf("non-200 status code: %s", resp.Status)
        }
        body, err := io.ReadAll(resp.Body)
        if err != nil {
                return fmt.Errorf("cannot read response in %s: %w", time.Since(startTime), err)
        }
        r := string(body)
        if r != hb.response {
                return fmt.Errorf("%w: %s", errUnexpectedResponse, r)
        }
        return nil
}
func (hb *heartBeat) Interval() time.Duration {
        return hb.interval
}
		
		package topology
// TODO this is only here to avoid recursive imports. We should have a separate package for metrics.
import (
        "github.com/contentsquare/chproxy/config"
        "github.com/prometheus/client_golang/prometheus"
)
var (
        HostHealth    *prometheus.GaugeVec
        HostPenalties *prometheus.CounterVec
)
func initMetrics(cfg *config.Config) {
        namespace := cfg.Server.Metrics.Namespace
        HostHealth = prometheus.NewGaugeVec(
                prometheus.GaugeOpts{
                        Namespace: namespace,
                        Name:      "host_health",
                        Help:      "Health state of hosts by clusters",
                },
                []string{"cluster", "replica", "cluster_node"},
        )
        HostPenalties = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "host_penalties_total",
                        Help:      "Total number of given penalties by host",
                },
                []string{"cluster", "replica", "cluster_node"},
        )
}
func RegisterMetrics(cfg *config.Config) {
        initMetrics(cfg)
        prometheus.MustRegister(HostHealth, HostPenalties)
}
func reportNodeHealthMetric(clusterName, replicaName, nodeName string, active bool) {
        label := prometheus.Labels{
                "cluster":      clusterName,
                "replica":      replicaName,
                "cluster_node": nodeName,
        }
        if active {
                HostHealth.With(label).Set(1)
        } else {
                HostHealth.With(label).Set(0)
        }
}
func incrementPenaltiesMetric(clusterName, replicaName, nodeName string) {
        label := prometheus.Labels{
                "cluster":      clusterName,
                "replica":      replicaName,
                "cluster_node": nodeName,
        }
        HostPenalties.With(label).Inc()
}
		
		package topology
import (
        "context"
        "net/url"
        "sync/atomic"
        "time"
        "github.com/contentsquare/chproxy/internal/counter"
        "github.com/contentsquare/chproxy/internal/heartbeat"
        "github.com/contentsquare/chproxy/log"
)
const (
        // prevents excess goroutine creating while penalizing overloaded host
        DefaultPenaltySize     = 5
        DefaultMaxSize         = 300
        DefaultPenaltyDuration = time.Second * 10
)
type nodeOpts struct {
        defaultActive   bool
        penaltySize     uint32
        penaltyMaxSize  uint32
        penaltyDuration time.Duration
}
func defaultNodeOpts() nodeOpts {
        return nodeOpts{
                penaltySize:     DefaultPenaltySize,
                penaltyMaxSize:  DefaultMaxSize,
                penaltyDuration: DefaultPenaltyDuration,
        }
}
type NodeOption interface {
        apply(*nodeOpts)
}
type defaultActive struct {
        active bool
}
func (o defaultActive) apply(opts *nodeOpts) {
        opts.defaultActive = o.active
}
func WithDefaultActiveState(active bool) NodeOption {
        return defaultActive{
                active: active,
        }
}
type Node struct {
        // Node Address.
        addr *url.URL
        // Whether this node is alive.
        active atomic.Bool
        // Counter of currently running connections.
        connections counter.Counter
        // Counter of unsuccesfull request to decrease host priority.
        penalty atomic.Uint32
        // Heartbeat function
        hb heartbeat.HeartBeat
        // TODO These fields are only used for labels in prometheus. We should have a different way to pass the labels.
        // For metrics only
        clusterName string
        replicaName string
        // Additional configuration options
        opts nodeOpts
}
func NewNode(addr *url.URL, hb heartbeat.HeartBeat, clusterName, replicaName string, opts ...NodeOption) *Node {
        nodeOpts := defaultNodeOpts()
        for _, opt := range opts {
                opt.apply(&nodeOpts)
        }
        n := &Node{
                addr:        addr,
                hb:          hb,
                clusterName: clusterName,
                replicaName: replicaName,
                opts:        nodeOpts,
        }
        if n.opts.defaultActive {
                n.SetIsActive(true)
        }
        return n
}
func (n *Node) IsActive() bool {
        return n.active.Load()
}
func (n *Node) SetIsActive(active bool) {
        n.active.Store(active)
}
// StartHeartbeat runs the heartbeat healthcheck against the node
// until the done channel is closed.
// If the heartbeat fails, the active status of the node is changed.
func (n *Node) StartHeartbeat(done <-chan struct{}) {
        ctx, cancel := context.WithCancel(context.Background())
        for {
                n.heartbeat(ctx)
                select {
                case <-done:
                        cancel()
                        return
                case <-time.After(n.hb.Interval()):
                }
        }
}
func (n *Node) heartbeat(ctx context.Context) {
        if err := n.hb.IsHealthy(ctx, n.addr.String()); err == nil {
                n.active.Store(true)
                reportNodeHealthMetric(n.clusterName, n.replicaName, n.Host(), true)
        } else {
                log.Errorf("error while health-checking %q host: %s", n.Host(), err)
                n.active.Store(false)
                reportNodeHealthMetric(n.clusterName, n.replicaName, n.Host(), false)
        }
}
// Penalize a node if a request failed to decrease it's priority.
// If the penalty is already at the maximum allowed size this function
// will not penalize the node further.
// A function will be registered to run after the penalty duration to
// increase the priority again.
func (n *Node) Penalize() {
        penalty := n.penalty.Load()
        if penalty >= n.opts.penaltyMaxSize {
                return
        }
        incrementPenaltiesMetric(n.clusterName, n.replicaName, n.Host())
        n.penalty.Add(n.opts.penaltySize)
        time.AfterFunc(n.opts.penaltyDuration, func() {
                n.penalty.Add(^uint32(n.opts.penaltySize - 1))
        })
}
// CurrentLoad returns the current node returns the number of open connections
// plus the penalty.
func (n *Node) CurrentLoad() uint32 {
        c := n.connections.Load()
        p := n.penalty.Load()
        return c + p
}
func (n *Node) CurrentConnections() uint32 {
        return n.connections.Load()
}
func (n *Node) CurrentPenalty() uint32 {
        return n.penalty.Load()
}
func (n *Node) IncrementConnections() {
        n.connections.Inc()
}
func (n *Node) DecrementConnections() {
        n.connections.Dec()
}
func (n *Node) Scheme() string {
        return n.addr.Scheme
}
func (n *Node) Host() string {
        return n.addr.Host
}
func (n *Node) ReplicaName() string {
        return n.replicaName
}
func (n *Node) String() string {
        return n.addr.String()
}
		
		package main
import (
        "errors"
        "fmt"
        "io"
        "net/http"
        "sync"
        "time"
        "github.com/contentsquare/chproxy/cache"
        "github.com/contentsquare/chproxy/log"
        "github.com/prometheus/client_golang/prometheus"
)
type ResponseWriterWithCode interface {
        http.ResponseWriter
        StatusCode() int
}
type StatResponseWriter interface {
        http.ResponseWriter
        http.CloseNotifier
        StatusCode() int
        SetStatusCode(code int)
}
var _ StatResponseWriter = &statResponseWriter{}
// statResponseWriter collects the amount of bytes written.
//
// The wrapped ResponseWriter must implement http.CloseNotifier.
//
// Additionally it caches response status code.
type statResponseWriter struct {
        http.ResponseWriter
        statusCode int
        // wroteHeader tells whether the header's been written to
        // the original ResponseWriter
        wroteHeader bool
        bytesWritten prometheus.Counter
}
const (
        XCacheHit  = "HIT"
        XCacheMiss = "MISS"
        XCacheNA   = "N/A"
)
func RespondWithData(rw http.ResponseWriter, data io.Reader, metadata cache.ContentMetadata, ttl time.Duration, cacheHit string, statusCode int, labels prometheus.Labels) error {
        h := rw.Header()
        if len(metadata.Type) > 0 {
                h.Set("Content-Type", metadata.Type)
        }
        if len(metadata.Encoding) > 0 {
                h.Set("Content-Encoding", metadata.Encoding)
        }
        h.Set("Content-Length", fmt.Sprintf("%d", metadata.Length))
        if ttl > 0 {
                expireSeconds := uint(ttl / time.Second)
                h.Set("Cache-Control", fmt.Sprintf("max-age=%d", expireSeconds))
        }
        h.Set("X-Cache", cacheHit)
        rw.WriteHeader(statusCode)
        if _, err := io.Copy(rw, data); err != nil {
                var perr *cache.RedisCacheError
                if errors.As(err, &perr) {
                        cacheCorruptedFetch.With(labels).Inc()
                        log.Debugf("redis cache error")
                }
                log.Errorf("cannot send response to client: %s", err)
                return fmt.Errorf("cannot send response to client: %w", err)
        }
        return nil
}
func (rw *statResponseWriter) SetStatusCode(code int) {
        rw.statusCode = code
}
func (rw *statResponseWriter) StatusCode() int {
        if rw.statusCode == 0 {
                return http.StatusOK
        }
        return rw.statusCode
}
func (rw *statResponseWriter) Write(b []byte) (int, error) {
        if rw.statusCode == 0 {
                rw.statusCode = http.StatusOK
        }
        if !rw.wroteHeader {
                rw.ResponseWriter.WriteHeader(rw.statusCode)
                rw.wroteHeader = true
        }
        n, err := rw.ResponseWriter.Write(b)
        rw.bytesWritten.Add(float64(n))
        return n, err
}
func (rw *statResponseWriter) WriteHeader(statusCode int) {
        // cache statusCode to keep the opportunity to change it in further
        rw.statusCode = statusCode
}
// CloseNotify implements http.CloseNotifier
func (rw *statResponseWriter) CloseNotify() <-chan bool {
        // The rw.ResponseWriter must implement http.CloseNotifier
        rwc, ok := rw.ResponseWriter.(http.CloseNotifier)
        if !ok {
                panic("BUG: the wrapped ResponseWriter must implement http.CloseNotifier")
        }
        return rwc.CloseNotify()
}
var _ io.ReadCloser = &statReadCloser{}
// statReadCloser collects the amount of bytes read.
type statReadCloser struct {
        io.ReadCloser
        bytesRead prometheus.Counter
}
func (src *statReadCloser) Read(p []byte) (int, error) {
        n, err := src.ReadCloser.Read(p)
        src.bytesRead.Add(float64(n))
        return n, err
}
var _ io.ReadCloser = &cachedReadCloser{}
// cachedReadCloser caches the first 1Kb form the wrapped ReadCloser.
type cachedReadCloser struct {
        io.ReadCloser
        // bLock protects b from concurrent access when Read and String
        // are called from concurrent goroutines.
        bLock sync.Mutex
        // b holds up to 1Kb of the initial data read from ReadCloser.
        b []byte
}
func (crc *cachedReadCloser) Read(p []byte) (int, error) {
        n, err := crc.ReadCloser.Read(p)
        crc.bLock.Lock()
        if len(crc.b) < 1024 {
                crc.b = append(crc.b, p[:n]...)
                if len(crc.b) >= 1024 {
                        crc.b = append(crc.b[:1024], "..."...)
                }
        }
        crc.bLock.Unlock()
        // Do not cache the last read operation, since it slows down
        // reading large amounts of data such as large INSERT queries.
        return n, err
}
func (crc *cachedReadCloser) String() string {
        crc.bLock.Lock()
        s := string(crc.b)
        crc.bLock.Unlock()
        return s
}
		
		package log
import (
        "fmt"
        "io"
        "log"
        "os"
        "sync/atomic"
)
var (
        stdLogFlags     = log.LstdFlags | log.Lshortfile | log.LUTC
        outputCallDepth = 2
        debugLogger = log.New(os.Stderr, "DEBUG: ", stdLogFlags)
        infoLogger  = log.New(os.Stderr, "INFO: ", stdLogFlags)
        errorLogger = log.New(os.Stderr, "ERROR: ", stdLogFlags)
        fatalLogger = log.New(os.Stderr, "FATAL: ", stdLogFlags)
        // NilLogger suppresses all the log messages.
        NilLogger = log.New(io.Discard, "", stdLogFlags)
)
// SuppressOutput suppresses all output from logs if `suppress` is true
// used while testing
func SuppressOutput(suppress bool) {
        if suppress {
                debugLogger.SetOutput(io.Discard)
                infoLogger.SetOutput(io.Discard)
                errorLogger.SetOutput(io.Discard)
        } else {
                debugLogger.SetOutput(os.Stderr)
                infoLogger.SetOutput(os.Stderr)
                errorLogger.SetOutput(os.Stderr)
        }
}
var debug uint32
// SetDebug sets output into debug mode if true passed
func SetDebug(val bool) {
        if val {
                atomic.StoreUint32(&debug, 1)
        } else {
                atomic.StoreUint32(&debug, 0)
        }
}
// Debugf prints debug message according to a format
func Debugf(format string, args ...interface{}) {
        if atomic.LoadUint32(&debug) == 0 {
                return
        }
        s := fmt.Sprintf(format, args...)
        debugLogger.Output(outputCallDepth, s) // nolint
}
// Infof prints info message according to a format
func Infof(format string, args ...interface{}) {
        s := fmt.Sprintf(format, args...)
        infoLogger.Output(outputCallDepth, s) // nolint
}
// Errorf prints warning message according to a format
func Errorf(format string, args ...interface{}) {
        s := fmt.Sprintf(format, args...)
        errorLogger.Output(outputCallDepth, s) // nolint
}
// ErrorWithCallDepth prints err into error log using the given callDepth.
func ErrorWithCallDepth(err error, callDepth int) {
        s := err.Error()
        errorLogger.Output(outputCallDepth+callDepth, s) //nolint
}
// Fatalf prints fatal message according to a format and exits program
func Fatalf(format string, args ...interface{}) {
        s := fmt.Sprintf(format, args...)
        fatalLogger.Output(outputCallDepth, s) // nolint
        os.Exit(1)
}
		
		package main
import (
        "context"
        "crypto/tls"
        "flag"
        "fmt"
        "net"
        "net/http"
        "os"
        "os/signal"
        "strings"
        "sync/atomic"
        "syscall"
        "time"
        "github.com/contentsquare/chproxy/config"
        "github.com/contentsquare/chproxy/log"
        "github.com/prometheus/client_golang/prometheus/promhttp"
        "golang.org/x/crypto/acme/autocert"
)
var (
        configFile = flag.String("config", "", "Proxy configuration filename")
        version    = flag.Bool("version", false, "Prints current version and exits")
        enableTCP6 = flag.Bool("enableTCP6", false, "Whether to enable listening for IPv6 TCP ports. "+
                "By default only IPv4 TCP ports are listened")
)
var (
        proxy *reverseProxy
        // networks allow lists
        allowedNetworksHTTP    atomic.Value
        allowedNetworksHTTPS   atomic.Value
        allowedNetworksMetrics atomic.Value
        proxyHandler           atomic.Value
)
func main() {
        flag.Parse()
        if *version {
                fmt.Printf("%s\n", versionString())
                os.Exit(0)
        }
        log.Infof("%s", versionString())
        log.Infof("Loading config: %s", *configFile)
        cfg, err := loadConfig()
        if err != nil {
                log.Fatalf("error while loading config: %s", err)
        }
        registerMetrics(cfg)
        if err = applyConfig(cfg); err != nil {
                log.Fatalf("error while applying config: %s", err)
        }
        configSuccess.Set(1)
        configSuccessTime.Set(float64(time.Now().Unix()))
        log.Infof("Loading config %q: successful", *configFile)
        setupReloadConfigWatch()
        server := cfg.Server
        if len(server.HTTP.ListenAddr) == 0 && len(server.HTTPS.ListenAddr) == 0 {
                panic("BUG: broken config validation - `listen_addr` is not configured")
        }
        if server.HTTP.ForceAutocertHandler {
                autocertManager = newAutocertManager(server.HTTPS.Autocert)
        }
        if len(server.HTTPS.ListenAddr) != 0 {
                go serveTLS(server.HTTPS)
        }
        if len(server.HTTP.ListenAddr) != 0 {
                go serve(server.HTTP)
        }
        select {}
}
func setupReloadConfigWatch() {
        c := make(chan os.Signal, 1)
        signal.Notify(c, syscall.SIGHUP)
        go func() {
                for {
                        if <-c == syscall.SIGHUP {
                                log.Infof("SIGHUP received. Going to reload config %s ...", *configFile)
                                if err := reloadConfig(); err != nil {
                                        log.Errorf("error while reloading config: %s", err)
                                        continue
                                }
                                log.Infof("Reloading config %s: successful", *configFile)
                        }
                }
        }()
}
var autocertManager *autocert.Manager
func newAutocertManager(cfg config.Autocert) *autocert.Manager {
        if len(cfg.CacheDir) > 0 {
                if err := os.MkdirAll(cfg.CacheDir, 0o700); err != nil {
                        log.Fatalf("error while creating folder %q: %s", cfg.CacheDir, err)
                }
        }
        var hp autocert.HostPolicy
        if len(cfg.AllowedHosts) != 0 {
                allowedHosts := make(map[string]struct{}, len(cfg.AllowedHosts))
                for _, v := range cfg.AllowedHosts {
                        allowedHosts[v] = struct{}{}
                }
                hp = func(_ context.Context, host string) error {
                        if _, ok := allowedHosts[host]; ok {
                                return nil
                        }
                        return fmt.Errorf("host %q doesn't match `host_policy` configuration", host)
                }
        }
        return &autocert.Manager{
                Prompt:     autocert.AcceptTOS,
                Cache:      autocert.DirCache(cfg.CacheDir),
                HostPolicy: hp,
        }
}
func newListener(listenAddr string) net.Listener {
        network := "tcp4"
        if *enableTCP6 {
                // Enable listening on both tcp4 and tcp6
                network = "tcp"
        }
        ln, err := net.Listen(network, listenAddr)
        if err != nil {
                log.Fatalf("cannot listen for %q: %s", listenAddr, err)
        }
        return ln
}
func serveTLS(cfg config.HTTPS) {
        ln := newListener(cfg.ListenAddr)
        h := http.HandlerFunc(serveHTTP)
        tlsCfg, err := cfg.TLS.BuildTLSConfig(autocertManager)
        if err != nil {
                log.Fatalf("cannot build TLS config: %s", err)
        }
        tln := tls.NewListener(ln, tlsCfg)
        log.Infof("Serving https on %q", cfg.ListenAddr)
        if err := listenAndServe(tln, h, cfg.TimeoutCfg); err != nil {
                log.Fatalf("TLS server error on %q: %s", cfg.ListenAddr, err)
        }
}
func serve(cfg config.HTTP) {
        var h http.Handler
        ln := newListener(cfg.ListenAddr)
        h = http.HandlerFunc(serveHTTP)
        if cfg.ForceAutocertHandler {
                if autocertManager == nil {
                        panic("BUG: autocertManager is not inited")
                }
                addr := ln.Addr().String()
                parts := strings.Split(addr, ":")
                if parts[len(parts)-1] != "80" {
                        log.Fatalf("`letsencrypt` specification requires http server to listen on :80 port to satisfy http-01 challenge. " +
                                "Otherwise, certificates will be impossible to generate")
                }
                h = autocertManager.HTTPHandler(h)
        }
        log.Infof("Serving http on %q", cfg.ListenAddr)
        if err := listenAndServe(ln, h, cfg.TimeoutCfg); err != nil {
                log.Fatalf("HTTP server error on %q: %s", cfg.ListenAddr, err)
        }
}
func newServer(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) *http.Server {
        // nolint:gosec // We already configured ReadTimeout, so no need to set ReadHeaderTimeout as well.
        return &http.Server{
                TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
                Handler:      h,
                ReadTimeout:  time.Duration(cfg.ReadTimeout),
                WriteTimeout: time.Duration(cfg.WriteTimeout),
                IdleTimeout:  time.Duration(cfg.IdleTimeout),
                // Suppress error logging from the server, since chproxy
                // must handle all these errors in the code.
                ErrorLog: log.NilLogger,
        }
}
func listenAndServe(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) error {
        s := newServer(ln, h, cfg)
        return s.Serve(ln)
}
var promHandler = promhttp.Handler()
//nolint:cyclop //TODO reduce complexity here.
func serveHTTP(rw http.ResponseWriter, r *http.Request) {
        switch r.Method {
        case http.MethodGet, http.MethodPost:
                // Only GET and POST methods are supported.
        case http.MethodOptions:
                // This is required for CORS shit :)
                rw.Header().Set("Allow", "GET,POST")
                return
        default:
                err := fmt.Errorf("%q: unsupported method %q", r.RemoteAddr, r.Method)
                rw.Header().Set("Connection", "close")
                respondWith(rw, err, http.StatusMethodNotAllowed)
                return
        }
        switch r.URL.Path {
        case "/favicon.ico":
        case "/metrics":
                // nolint:forcetypeassert // We will cover this by tests as we control what is stored.
                an := allowedNetworksMetrics.Load().(*config.Networks)
                if !an.Contains(r.RemoteAddr) {
                        err := fmt.Errorf("connections to /metrics are not allowed from %s", r.RemoteAddr)
                        rw.Header().Set("Connection", "close")
                        respondWith(rw, err, http.StatusForbidden)
                        return
                }
                proxy.refreshCacheMetrics()
                promHandler.ServeHTTP(rw, r)
        case "/", "/query":
                var err error
                // nolint:forcetypeassert // We will cover this by tests as we control what is stored.
                proxyHandler := proxyHandler.Load().(*ProxyHandler)
                r.RemoteAddr = proxyHandler.GetRemoteAddr(r)
                var an *config.Networks
                if r.TLS != nil {
                        // nolint:forcetypeassert // We will cover this by tests as we control what is stored.
                        an = allowedNetworksHTTPS.Load().(*config.Networks)
                        err = fmt.Errorf("https connections are not allowed from %s", r.RemoteAddr)
                } else {
                        // nolint:forcetypeassert // We will cover this by tests as we control what is stored.
                        an = allowedNetworksHTTP.Load().(*config.Networks)
                        err = fmt.Errorf("http connections are not allowed from %s", r.RemoteAddr)
                }
                if !an.Contains(r.RemoteAddr) {
                        rw.Header().Set("Connection", "close")
                        respondWith(rw, err, http.StatusForbidden)
                        return
                }
                proxy.ServeHTTP(rw, r)
        default:
                badRequest.Inc()
                err := fmt.Errorf("%q: unsupported path: %q", r.RemoteAddr, r.URL.Path)
                rw.Header().Set("Connection", "close")
                respondWith(rw, err, http.StatusBadRequest)
        }
}
func loadConfig() (*config.Config, error) {
        if *configFile == "" {
                log.Fatalf("Missing -config flag")
        }
        cfg, err := config.LoadFile(*configFile)
        if err != nil {
                return nil, fmt.Errorf("can't load config %q: %w", *configFile, err)
        }
        return cfg, nil
}
// a configuration parameter value that is used in proxy initialization
// changed
func proxyConfigChanged(cfgCp *config.ConnectionPool, rp *reverseProxy) bool {
        return cfgCp.MaxIdleConns != proxy.maxIdleConns ||
                cfgCp.MaxIdleConnsPerHost != proxy.maxIdleConnsPerHost
}
func applyConfig(cfg *config.Config) error {
        if proxy == nil || proxyConfigChanged(&cfg.ConnectionPool, proxy) {
                proxy = newReverseProxy(&cfg.ConnectionPool)
        }
        if err := proxy.applyConfig(cfg); err != nil {
                return err
        }
        allowedNetworksHTTP.Store(&cfg.Server.HTTP.AllowedNetworks)
        allowedNetworksHTTPS.Store(&cfg.Server.HTTPS.AllowedNetworks)
        allowedNetworksMetrics.Store(&cfg.Server.Metrics.AllowedNetworks)
        proxyHandler.Store(NewProxyHandler(&cfg.Server.Proxy))
        log.SetDebug(cfg.LogDebug)
        log.Infof("Loaded config:\n%s", cfg)
        return nil
}
func reloadConfig() error {
        cfg, err := loadConfig()
        if err != nil {
                return err
        }
        return applyConfig(cfg)
}
var (
        buildTag      = "unknown"
        buildRevision = "unknown"
        buildTime     = "unknown"
)
func versionString() string {
        ver := buildTag
        if len(ver) == 0 {
                ver = "unknown"
        }
        return fmt.Sprintf("chproxy ver. %s, rev. %s, built at %s", ver, buildRevision, buildTime)
}
		
		package main
import (
        "github.com/contentsquare/chproxy/config"
        "github.com/contentsquare/chproxy/internal/topology"
        "github.com/prometheus/client_golang/prometheus"
)
var (
        statusCodes                    *prometheus.CounterVec
        requestSum                     *prometheus.CounterVec
        requestSuccess                 *prometheus.CounterVec
        limitExcess                    *prometheus.CounterVec
        concurrentQueries              *prometheus.GaugeVec
        requestQueueSize               *prometheus.GaugeVec
        userQueueOverflow              *prometheus.CounterVec
        clusterUserQueueOverflow       *prometheus.CounterVec
        requestBodyBytes               *prometheus.CounterVec
        responseBodyBytes              *prometheus.CounterVec
        cacheFailedInsert              *prometheus.CounterVec
        cacheCorruptedFetch            *prometheus.CounterVec
        cacheHit                       *prometheus.CounterVec
        cacheMiss                      *prometheus.CounterVec
        cacheSize                      *prometheus.GaugeVec
        cacheItems                     *prometheus.GaugeVec
        cacheSkipped                   *prometheus.CounterVec
        requestDuration                *prometheus.SummaryVec
        proxiedResponseDuration        *prometheus.SummaryVec
        cachedResponseDuration         *prometheus.SummaryVec
        canceledRequest                *prometheus.CounterVec
        cacheHitFromConcurrentQueries  *prometheus.CounterVec
        cacheMissFromConcurrentQueries *prometheus.CounterVec
        killedRequests                 *prometheus.CounterVec
        timeoutRequest                 *prometheus.CounterVec
        configSuccess                  prometheus.Gauge
        configSuccessTime              prometheus.Gauge
        badRequest                     prometheus.Counter
        retryRequest                   *prometheus.CounterVec
)
func initMetrics(cfg *config.Config) {
        namespace := cfg.Server.Metrics.Namespace
        statusCodes = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "status_codes_total",
                        Help:      "Distribution by status codes",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node", "code"},
        )
        requestSum = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "request_sum_total",
                        Help:      "Total number of sent requests",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        requestSuccess = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "request_success_total",
                        Help:      "Total number of sent success requests",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        limitExcess = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "concurrent_limit_excess_total",
                        Help:      "Total number of max_concurrent_queries excess",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        concurrentQueries = prometheus.NewGaugeVec(
                prometheus.GaugeOpts{
                        Namespace: namespace,
                        Name:      "concurrent_queries",
                        Help:      "The number of concurrent queries at current time",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        requestQueueSize = prometheus.NewGaugeVec(
                prometheus.GaugeOpts{
                        Namespace: namespace,
                        Name:      "request_queue_size",
                        Help:      "Request queue sizes at the current time",
                },
                []string{"user", "cluster", "cluster_user"},
        )
        userQueueOverflow = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "user_queue_overflow_total",
                        Help:      "The number of overflows for per-user request queues",
                },
                []string{"user", "cluster", "cluster_user"},
        )
        clusterUserQueueOverflow = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "cluster_user_queue_overflow_total",
                        Help:      "The number of overflows for per-cluster_user request queues",
                },
                []string{"user", "cluster", "cluster_user"},
        )
        requestBodyBytes = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "request_body_bytes_total",
                        Help:      "The amount of bytes read from request bodies",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        responseBodyBytes = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "response_body_bytes_total",
                        Help:      "The amount of bytes written to response bodies",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        cacheFailedInsert = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "cache_insertion_failures_total",
                        Help:      "The number of insertion in the cache that didn't work out",
                },
                []string{"cache", "user", "cluster", "cluster_user"},
        )
        cacheCorruptedFetch = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "cache_get_corrutpion_total",
                        Help:      "The number of time a data fetching from redis was corrupted",
                },
                []string{"cache", "user", "cluster", "cluster_user"},
        )
        cacheHit = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "cache_hits_total",
                        Help:      "The amount of cache hits",
                },
                []string{"cache", "user", "cluster", "cluster_user"},
        )
        cacheMiss = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "cache_miss_total",
                        Help:      "The amount of cache misses",
                },
                []string{"cache", "user", "cluster", "cluster_user"},
        )
        cacheSize = prometheus.NewGaugeVec(
                prometheus.GaugeOpts{
                        Namespace: namespace,
                        Name:      "cache_size",
                        Help:      "Cache size at the current time",
                },
                []string{"cache"},
        )
        cacheItems = prometheus.NewGaugeVec(
                prometheus.GaugeOpts{
                        Namespace: namespace,
                        Name:      "cache_items",
                        Help:      "Cache items at the current time",
                },
                []string{"cache"},
        )
        cacheSkipped = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "cache_payloadsize_too_big_total",
                        Help:      "The amount of too big payloads to be cached",
                },
                []string{"cache", "user", "cluster", "cluster_user"},
        )
        requestDuration = prometheus.NewSummaryVec(
                prometheus.SummaryOpts{
                        Namespace:  namespace,
                        Name:       "request_duration_seconds",
                        Help:       "Request duration. Includes possible wait time in the queue",
                        Objectives: map[float64]float64{0.5: 1e-1, 0.9: 1e-2, 0.99: 1e-3, 0.999: 1e-4, 1: 1e-5},
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        proxiedResponseDuration = prometheus.NewSummaryVec(
                prometheus.SummaryOpts{
                        Namespace:  namespace,
                        Name:       "proxied_response_duration_seconds",
                        Help:       "Response duration proxied from clickhouse",
                        Objectives: map[float64]float64{0.5: 1e-1, 0.9: 1e-2, 0.99: 1e-3, 0.999: 1e-4, 1: 1e-5},
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        cachedResponseDuration = prometheus.NewSummaryVec(
                prometheus.SummaryOpts{
                        Namespace:  namespace,
                        Name:       "cached_response_duration_seconds",
                        Help:       "Response duration served from the cache",
                        Objectives: map[float64]float64{0.5: 1e-1, 0.9: 1e-2, 0.99: 1e-3, 0.999: 1e-4, 1: 1e-5},
                },
                []string{"cache", "user", "cluster", "cluster_user"},
        )
        canceledRequest = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "canceled_request_total",
                        Help:      "The number of requests canceled by remote client",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        cacheHitFromConcurrentQueries = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "cache_hit_concurrent_query_total",
                        Help:      "The amount of cache hits after having awaited concurrently executed queries",
                },
                []string{"cache", "user", "cluster", "cluster_user"},
        )
        cacheMissFromConcurrentQueries = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "cache_miss_concurrent_query_total",
                        Help:      "The amount of cache misses, even if previously reported as queries available in the cache, after having awaited concurrently executed queries",
                },
                []string{"cache", "user", "cluster", "cluster_user"},
        )
        killedRequests = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "killed_request_total",
                        Help:      "The number of requests killed by proxy",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        timeoutRequest = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "timeout_request_total",
                        Help:      "The number of timed out requests",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
        configSuccess = prometheus.NewGauge(prometheus.GaugeOpts{
                Namespace: namespace,
                Name:      "config_last_reload_successful",
                Help:      "Whether the last configuration reload attempt was successful.",
        })
        configSuccessTime = prometheus.NewGauge(prometheus.GaugeOpts{
                Namespace: namespace,
                Name:      "config_last_reload_success_timestamp_seconds",
                Help:      "Timestamp of the last successful configuration reload.",
        })
        badRequest = prometheus.NewCounter(prometheus.CounterOpts{
                Namespace: namespace,
                Name:      "bad_requests_total",
                Help:      "Total number of unsupported requests",
        })
        retryRequest = prometheus.NewCounterVec(
                prometheus.CounterOpts{
                        Namespace: namespace,
                        Name:      "retry_request_total",
                        Help:      "The number of retry requests",
                },
                []string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
        )
}
func registerMetrics(cfg *config.Config) {
        topology.RegisterMetrics(cfg)
        initMetrics(cfg)
        prometheus.MustRegister(statusCodes, requestSum, requestSuccess,
                limitExcess, concurrentQueries,
                requestQueueSize, userQueueOverflow, clusterUserQueueOverflow,
                requestBodyBytes, responseBodyBytes, cacheFailedInsert, cacheCorruptedFetch,
                cacheHit, cacheMiss, cacheSize, cacheItems, cacheSkipped,
                requestDuration, proxiedResponseDuration, cachedResponseDuration,
                canceledRequest, timeoutRequest,
                configSuccess, configSuccessTime, badRequest, retryRequest)
}
		
		package main
import (
        "bytes"
        "context"
        "errors"
        "fmt"
        "io"
        "net"
        "net/http"
        "net/http/httputil"
        "net/url"
        "strconv"
        "strings"
        "sync"
        "time"
        "github.com/contentsquare/chproxy/cache"
        "github.com/contentsquare/chproxy/config"
        "github.com/contentsquare/chproxy/internal/topology"
        "github.com/contentsquare/chproxy/log"
        "github.com/prometheus/client_golang/prometheus"
)
// tmpDir temporary path to store ongoing queries results
const tmpDir = "/tmp"
// failedTransactionPrefix prefix added to the failed reason for concurrent queries registry
const failedTransactionPrefix = "[concurrent query failed]"
type reverseProxy struct {
        rp *httputil.ReverseProxy
        // configLock serializes access to applyConfig.
        // It protects reload* fields.
        configLock sync.Mutex
        reloadSignal chan struct{}
        reloadWG     sync.WaitGroup
        // lock protects users, clusters and caches.
        // RWMutex enables concurrent access to getScope.
        lock sync.RWMutex
        users               map[string]*user
        clusters            map[string]*cluster
        caches              map[string]*cache.AsyncCache
        hasWildcarded       bool
        maxIdleConns        int
        maxIdleConnsPerHost int
}
func newReverseProxy(cfgCp *config.ConnectionPool) *reverseProxy {
        transport := &http.Transport{
                Proxy: http.ProxyFromEnvironment,
                DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
                        dialer := &net.Dialer{
                                Timeout:   30 * time.Second,
                                KeepAlive: 30 * time.Second,
                        }
                        return dialer.DialContext(ctx, network, addr)
                },
                ForceAttemptHTTP2:     true,
                MaxIdleConns:          cfgCp.MaxIdleConns,
                MaxIdleConnsPerHost:   cfgCp.MaxIdleConnsPerHost,
                IdleConnTimeout:       90 * time.Second,
                TLSHandshakeTimeout:   10 * time.Second,
                ExpectContinueTimeout: 1 * time.Second,
        }
        return &reverseProxy{
                rp: &httputil.ReverseProxy{
                        Director:  func(*http.Request) {},
                        Transport: transport,
                        // Suppress error logging in ReverseProxy, since all the errors
                        // are handled and logged in the code below.
                        ErrorLog: log.NilLogger,
                },
                reloadSignal:        make(chan struct{}),
                reloadWG:            sync.WaitGroup{},
                maxIdleConns:        cfgCp.MaxIdleConnsPerHost,
                maxIdleConnsPerHost: cfgCp.MaxIdleConnsPerHost,
        }
}
func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
        startTime := time.Now()
        s, status, err := rp.getScope(req)
        if err != nil {
                q := getQuerySnippet(req)
                err = fmt.Errorf("%q: %w; query: %q", req.RemoteAddr, err, q)
                respondWith(rw, err, status)
                return
        }
        // WARNING: don't use s.labels before s.incQueued,
        // since `replica` and `cluster_node` may change inside incQueued.
        if err := s.incQueued(); err != nil {
                limitExcess.With(s.labels).Inc()
                q := getQuerySnippet(req)
                err = fmt.Errorf("%s: %w; query: %q", s, err, q)
                respondWith(rw, err, http.StatusTooManyRequests)
                return
        }
        defer s.dec()
        log.Debugf("%s: request start", s)
        requestSum.With(s.labels).Inc()
        if s.user.allowCORS {
                origin := req.Header.Get("Origin")
                if len(origin) == 0 {
                        origin = "*"
                }
                rw.Header().Set("Access-Control-Allow-Origin", origin)
        }
        req.Body = &statReadCloser{
                ReadCloser: req.Body,
                bytesRead:  requestBodyBytes.With(s.labels),
        }
        srw := &statResponseWriter{
                ResponseWriter: rw,
                bytesWritten:   responseBodyBytes.With(s.labels),
        }
        req, origParams := s.decorateRequest(req)
        // wrap body into cachedReadCloser, so we could obtain the original
        // request on error.
        req.Body = &cachedReadCloser{
                ReadCloser: req.Body,
        }
        // publish session_id if needed
        if s.sessionId != "" {
                rw.Header().Set("X-ClickHouse-Server-Session-Id", s.sessionId)
        }
        q, shouldReturnFromCache, err := shouldRespondFromCache(s, origParams, req)
        if err != nil {
                respondWith(srw, err, http.StatusBadRequest)
                return
        }
        if shouldReturnFromCache {
                rp.serveFromCache(s, srw, req, origParams, q)
        } else {
                rp.proxyRequest(s, srw, srw, req)
        }
        // It is safe calling getQuerySnippet here, since the request
        // has been already read in proxyRequest or serveFromCache.
        query := getQuerySnippet(req)
        if srw.statusCode == http.StatusOK {
                requestSuccess.With(s.labels).Inc()
                log.Debugf("%s: request success; query: %q; Method: %s; URL: %q", s, query, req.Method, req.URL.String())
        } else {
                log.Debugf("%s: request failure: non-200 status code %d; query: %q; Method: %s; URL: %q", s, srw.statusCode, query, req.Method, req.URL.String())
        }
        statusCodes.With(
                prometheus.Labels{
                        "user":         s.user.name,
                        "cluster":      s.cluster.name,
                        "cluster_user": s.clusterUser.name,
                        "replica":      s.host.ReplicaName(),
                        "cluster_node": s.host.Host(),
                        "code":         strconv.Itoa(srw.statusCode),
                },
        ).Inc()
        since := time.Since(startTime).Seconds()
        requestDuration.With(s.labels).Observe(since)
}
func shouldRespondFromCache(s *scope, origParams url.Values, req *http.Request) ([]byte, bool, error) {
        if s.user.cache == nil || s.user.cache.Cache == nil {
                return nil, false, nil
        }
        noCache := origParams.Get("no_cache")
        if noCache == "1" || noCache == "true" {
                return nil, false, nil
        }
        q, err := getFullQuery(req)
        if err != nil {
                return nil, false, fmt.Errorf("%s: cannot read query: %w", s, err)
        }
        return q, canCacheQuery(q), nil
}
func executeWithRetry(
        ctx context.Context,
        s *scope,
        maxRetry int,
        rp func(http.ResponseWriter, *http.Request),
        rw ResponseWriterWithCode,
        srw StatResponseWriter,
        req *http.Request,
        monitorDuration func(float64),
        monitorRetryRequestInc func(prometheus.Labels),
) (float64, error) {
        startTime := time.Now()
        var since float64
        // keep the request body
        body, err := io.ReadAll(req.Body)
        req.Body.Close()
        if err != nil {
                since = time.Since(startTime).Seconds()
                return since, err
        }
        numRetry := 0
        for {
                // update body
                req.Body = io.NopCloser(bytes.NewBuffer(body))
                req.Body.Close()
                rp(rw, req)
                err := ctx.Err()
                if err != nil {
                        since = time.Since(startTime).Seconds()
                        return since, err
                }
                // The request has been successfully proxied.
                srw.SetStatusCode(rw.StatusCode())
                // StatusBadGateway response is returned by http.ReverseProxy when
                // it cannot establish connection to remote host.
                if rw.StatusCode() == http.StatusBadGateway {
                        log.Debugf("the invalid host is: %s", s.host)
                        s.host.Penalize()
                        // comment s.host.dec() line to avoid double increment; issue #322
                        // s.host.dec()
                        s.host.SetIsActive(false)
                        nextHost := s.cluster.getHost()
                        // The query could be retried if it has no stickiness to a certain server
                        if numRetry < maxRetry && nextHost.IsActive() && s.sessionId == "" {
                                // the query execution has been failed
                                monitorRetryRequestInc(s.labels)
                                currentHost := s.host
                                // decrement the current failed host counter and increment the new host
                                // as for the end of the requests we will close the scope and in that closed scope
                                // decrement the new host PR - https://github.com/ContentSquare/chproxy/pull/357
                                if currentHost != nextHost {
                                        currentHost.DecrementConnections()
                                        nextHost.IncrementConnections()
                                }
                                // update host
                                s.host = nextHost
                                req.URL.Host = s.host.Host()
                                req.URL.Scheme = s.host.Scheme()
                                log.Debugf("the valid host is: %s", s.host)
                        } else {
                                since = time.Since(startTime).Seconds()
                                monitorDuration(since)
                                q := getQuerySnippet(req)
                                err1 := fmt.Errorf("%s: cannot reach %s; query: %q", s, s.host.Host(), q)
                                respondWith(srw, err1, srw.StatusCode())
                                break
                        }
                } else {
                        since = time.Since(startTime).Seconds()
                        break
                }
                numRetry++
        }
        return since, nil
}
// proxyRequest proxies the given request to clickhouse and sends response
// to rw.
//
// srw is required only for setting non-200 status codes on timeouts
// or on client connection disconnects.
func (rp *reverseProxy) proxyRequest(s *scope, rw ResponseWriterWithCode, srw *statResponseWriter, req *http.Request) {
        // wrap body into cachedReadCloser, so we could obtain the original
        // request on error.
        if _, ok := req.Body.(*cachedReadCloser); !ok {
                req.Body = &cachedReadCloser{
                        ReadCloser: req.Body,
                }
        }
        timeout, timeoutErrMsg := s.getTimeoutWithErrMsg()
        ctx := context.Background()
        if timeout > 0 {
                var cancel context.CancelFunc
                ctx, cancel = context.WithTimeout(ctx, timeout)
                defer cancel()
        }
        // Cancel the ctx if client closes the remote connection,
        // so the proxied query may be killed instantly.
        ctx, ctxCancel := listenToCloseNotify(ctx, rw)
        defer ctxCancel()
        req = req.WithContext(ctx)
        startTime := time.Now()
        executeDuration, err := executeWithRetry(ctx, s, s.cluster.retryNumber, rp.rp.ServeHTTP, rw, srw, req, func(duration float64) {
                proxiedResponseDuration.With(s.labels).Observe(duration)
        }, func(labels prometheus.Labels) { retryRequest.With(labels).Inc() })
        switch {
        case err == nil:
                return
        case errors.Is(err, context.Canceled):
                canceledRequest.With(s.labels).Inc()
                q := getQuerySnippet(req)
                log.Debugf("%s: remote client closed the connection in %s; query: %q", s, time.Since(startTime), q)
                if err := s.killQuery(); err != nil {
                        log.Errorf("%s: cannot kill query: %s; query: %q", s, err, q)
                }
                srw.statusCode = 499 // See https://httpstatuses.com/499 .
        case errors.Is(err, context.DeadlineExceeded):
                timeoutRequest.With(s.labels).Inc()
                // Penalize host with the timed out query, because it may be overloaded.
                s.host.Penalize()
                q := getQuerySnippet(req)
                log.Debugf("%s: query timeout in %f; query: %q", s, executeDuration, q)
                if err := s.killQuery(); err != nil {
                        log.Errorf("%s: cannot kill query: %s; query: %q", s, err, q)
                }
                err = fmt.Errorf("%s: %w; query: %q", s, timeoutErrMsg, q)
                respondWith(rw, err, http.StatusGatewayTimeout)
                srw.statusCode = http.StatusGatewayTimeout
        default:
                panic(fmt.Sprintf("BUG: context.Context.Err() returned unexpected error: %s", err))
        }
}
func listenToCloseNotify(ctx context.Context, rw ResponseWriterWithCode) (context.Context, context.CancelFunc) {
        // Cancel the ctx if client closes the remote connection,
        // so the proxied query may be killed instantly.
        ctx, ctxCancel := context.WithCancel(ctx)
        // rw must implement http.CloseNotifier.
        rwc, ok := rw.(http.CloseNotifier)
        if !ok {
                panic("BUG: the wrapped ResponseWriter must implement http.CloseNotifier")
        }
        ch := rwc.CloseNotify()
        go func() {
                select {
                case <-ch:
                        ctxCancel()
                case <-ctx.Done():
                }
        }()
        return ctx, ctxCancel
}
//nolint:cyclop //TODO refactor this method, most likely requires some work.
func (rp *reverseProxy) serveFromCache(s *scope, srw *statResponseWriter, req *http.Request, origParams url.Values, q []byte) {
        labels := makeCacheLabels(s)
        key := newCacheKey(s, origParams, q, req)
        startTime := time.Now()
        userCache := s.user.cache
        // Try to serve from cache
        cachedData, err := userCache.Get(key)
        if err == nil {
                // The response has been successfully served from cache.
                defer cachedData.Data.Close()
                cacheHit.With(labels).Inc()
                cachedResponseDuration.With(labels).Observe(time.Since(startTime).Seconds())
                log.Debugf("%s: cache hit", s)
                _ = RespondWithData(srw, cachedData.Data, cachedData.ContentMetadata, cachedData.Ttl, XCacheHit, http.StatusOK, labels)
                return
        }
        // Await for potential result from concurrent query
        transactionStatus, err := userCache.AwaitForConcurrentTransaction(key)
        if err != nil {
                // log and continue processing
                log.Errorf("failed to await for concurrent transaction due to: %v", err)
        } else {
                if transactionStatus.State.IsCompleted() {
                        cachedData, err := userCache.Get(key)
                        if err == nil {
                                defer cachedData.Data.Close()
                                _ = RespondWithData(srw, cachedData.Data, cachedData.ContentMetadata, cachedData.Ttl, XCacheHit, http.StatusOK, labels)
                                cacheHitFromConcurrentQueries.With(labels).Inc()
                                log.Debugf("%s: cache hit after awaiting concurrent query", s)
                                return
                        } else {
                                cacheMissFromConcurrentQueries.With(labels).Inc()
                                log.Debugf("%s: cache miss after awaiting concurrent query", s)
                        }
                } else if transactionStatus.State.IsFailed() {
                        respondWith(srw, fmt.Errorf(transactionStatus.FailReason), http.StatusInternalServerError)
                        return
                }
        }
        // The response wasn't found in the cache.
        // Request it from clickhouse.
        tmpFileRespWriter, err := cache.NewTmpFileResponseWriter(srw, tmpDir)
        if err != nil {
                err = fmt.Errorf("%s: %w; query: %q", s, err, q)
                respondWith(srw, err, http.StatusInternalServerError)
                return
        }
        defer tmpFileRespWriter.Close()
        // Initialise transaction
        err = userCache.Create(key)
        if err != nil {
                log.Errorf("%s: %s; query: %q - failed to register transaction", s, err, q)
        }
        // proxy request and capture response along with headers to [[TmpFileResponseWriter]]
        rp.proxyRequest(s, tmpFileRespWriter, srw, req)
        contentEncoding := tmpFileRespWriter.GetCapturedContentEncoding()
        contentType := tmpFileRespWriter.GetCapturedContentType()
        contentLength, err := tmpFileRespWriter.GetCapturedContentLength()
        if err != nil {
                log.Errorf("%s: %s; query: %q - failed to get contentLength of query", s, err, q)
                respondWith(srw, err, http.StatusInternalServerError)
                return
        }
        reader, err := tmpFileRespWriter.Reader()
        if err != nil {
                log.Errorf("%s: %s; query: %q - failed to get Reader from tmp file", s, err, q)
                respondWith(srw, err, http.StatusInternalServerError)
                return
        }
        contentMetadata := cache.ContentMetadata{Length: contentLength, Encoding: contentEncoding, Type: contentType}
        statusCode := tmpFileRespWriter.StatusCode()
        if statusCode != http.StatusOK || s.canceled {
                // Do not cache non-200 or cancelled responses.
                // Restore the original status code by proxyRequest if it was set.
                if srw.statusCode != 0 {
                        tmpFileRespWriter.WriteHeader(srw.statusCode)
                }
                errString, err := toString(reader)
                if err != nil {
                        log.Errorf("%s failed to get error reason: %s", s, err.Error())
                }
                errReason := fmt.Sprintf("%s %s", failedTransactionPrefix, errString)
                rp.completeTransaction(s, statusCode, userCache, key, q, errReason)
                // we need to reset the offset since the reader of tmpFileRespWriter was already
                // consumed in RespondWithData(...)
                err = tmpFileRespWriter.ResetFileOffset()
                if err != nil {
                        err = fmt.Errorf("%s: %w; query: %q", s, err, q)
                        respondWith(srw, err, http.StatusInternalServerError)
                        return
                }
                err = RespondWithData(srw, reader, contentMetadata, 0*time.Second, XCacheMiss, statusCode, labels)
                if err != nil {
                        err = fmt.Errorf("%s: %w; query: %q", s, err, q)
                        respondWith(srw, err, http.StatusInternalServerError)
                }
        } else {
                // Do not cache responses greater than max payload size.
                if contentLength > int64(s.user.cache.MaxPayloadSize) {
                        cacheSkipped.With(labels).Inc()
                        log.Infof("%s: Request will not be cached. Content length (%d) is greater than max payload size (%d)", s, contentLength, s.user.cache.MaxPayloadSize)
                        rp.completeTransaction(s, statusCode, userCache, key, q, "")
                        err = RespondWithData(srw, reader, contentMetadata, 0*time.Second, XCacheNA, tmpFileRespWriter.StatusCode(), labels)
                        if err != nil {
                                err = fmt.Errorf("%s: %w; query: %q", s, err, q)
                                respondWith(srw, err, http.StatusInternalServerError)
                        }
                        return
                }
                cacheMiss.With(labels).Inc()
                log.Debugf("%s: cache miss", s)
                expiration, err := userCache.Put(reader, contentMetadata, key)
                if err != nil {
                        cacheFailedInsert.With(labels).Inc()
                        log.Errorf("%s: %s; query: %q - failed to put response in the cache", s, err, q)
                }
                rp.completeTransaction(s, statusCode, userCache, key, q, "")
                // we need to reset the offset since the reader of tmpFileRespWriter was already
                // consumed in RespondWithData(...)
                err = tmpFileRespWriter.ResetFileOffset()
                if err != nil {
                        err = fmt.Errorf("%s: %w; query: %q", s, err, q)
                        respondWith(srw, err, http.StatusInternalServerError)
                        return
                }
                err = RespondWithData(srw, reader, contentMetadata, expiration, XCacheMiss, statusCode, labels)
                if err != nil {
                        err = fmt.Errorf("%s: %w; query: %q", s, err, q)
                        respondWith(srw, err, http.StatusInternalServerError)
                        return
                }
        }
}
func makeCacheLabels(s *scope) prometheus.Labels {
        // Do not store `replica` and `cluster_node` in labels, since they have
        // no sense for cache metrics.
        return prometheus.Labels{
                "cache":        s.user.cache.Name(),
                "user":         s.labels["user"],
                "cluster":      s.labels["cluster"],
                "cluster_user": s.labels["cluster_user"],
        }
}
func newCacheKey(s *scope, origParams url.Values, q []byte, req *http.Request) *cache.Key {
        var userParamsHash uint32
        if s.user.params != nil {
                userParamsHash = s.user.params.key
        }
        queryParamsHash := calcQueryParamsHash(origParams)
        credHash, err := uint32(0), error(nil)
        if !s.user.cache.SharedWithAllUsers {
                credHash, err = calcCredentialHash(s.clusterUser.name, s.clusterUser.password)
        }
        if err != nil {
                log.Errorf("fail to calc hash on credentials for user %s", s.user.name)
                credHash = 0
        }
        return cache.NewKey(
                skipLeadingComments(q),
                origParams,
                sortHeader(req.Header.Get("Accept-Encoding")),
                userParamsHash,
                queryParamsHash,
                credHash,
        )
}
func toString(stream io.Reader) (string, error) {
        buf := new(bytes.Buffer)
        _, err := buf.ReadFrom(stream)
        if err != nil {
                return "", err
        }
        return bytes.NewBuffer(buf.Bytes()).String(), nil
}
// clickhouseRecoverableStatusCodes set of recoverable http responses' status codes from Clickhouse.
// When such happens we mark transaction as completed and let concurrent query to hit another Clickhouse shard.
// possible http error codes in clickhouse (i.e: https://github.com/ClickHouse/ClickHouse/blob/master/src/Server/HTTPHandler.cpp)
var clickhouseRecoverableStatusCodes = map[int]struct{}{http.StatusServiceUnavailable: {}, http.StatusRequestTimeout: {}}
func (rp *reverseProxy) completeTransaction(s *scope, statusCode int, userCache *cache.AsyncCache, key *cache.Key,
        q []byte,
        failReason string,
) {
        // complete successful transactions or those with empty fail reason
        if statusCode < 300 || failReason == "" {
                if err := userCache.Complete(key); err != nil {
                        log.Errorf("%s: %s; query: %q", s, err, q)
                }
                return
        }
        if _, ok := clickhouseRecoverableStatusCodes[statusCode]; ok {
                if err := userCache.Complete(key); err != nil {
                        log.Errorf("%s: %s; query: %q", s, err, q)
                }
        } else {
                if err := userCache.Fail(key, failReason); err != nil {
                        log.Errorf("%s: %s; query: %q", s, err, q)
                }
        }
}
func calcQueryParamsHash(origParams url.Values) uint32 {
        queryParams := make(map[string]string)
        for param := range origParams {
                if strings.HasPrefix(param, "param_") {
                        queryParams[param] = origParams.Get(param)
                }
        }
        queryParamsHash, err := calcMapHash(queryParams)
        if err != nil {
                log.Errorf("fail to calc hash for params %s; %s", origParams, err)
                return 0
        }
        return queryParamsHash
}
// applyConfig applies the given cfg to reverseProxy.
//
// New config is applied only if non-nil error returned.
// Otherwise old config version is kept.
func (rp *reverseProxy) applyConfig(cfg *config.Config) error {
        // configLock protects from concurrent calls to applyConfig
        // by serializing such calls.
        // configLock shouldn't be used in other places.
        rp.configLock.Lock()
        defer rp.configLock.Unlock()
        clusters, err := newClusters(cfg.Clusters)
        if err != nil {
                return err
        }
        caches := make(map[string]*cache.AsyncCache, len(cfg.Caches))
        defer func() {
                // caches is swapped with old caches from rp.caches
                // on successful config reload - see the end of reloadConfig.
                for _, tmpCache := range caches {
                        // Speed up applyConfig by closing caches in background,
                        // since the process of cache closing may be lengthy
                        // due to cleaning.
                        go tmpCache.Close()
                }
        }()
        // transactionsTimeout used for creation of transactions registry inside async cache.
        // It is set to the highest configured execution time of all users to avoid setups were users use the same cache and have configured different maxExecutionTime.
        // This would provoke undesired behaviour of `dogpile effect`
        transactionsTimeout := config.Duration(0)
        for _, user := range cfg.Users {
                if user.MaxExecutionTime > transactionsTimeout {
                        transactionsTimeout = user.MaxExecutionTime
                }
                if user.IsWildcarded {
                        rp.hasWildcarded = true
                }
        }
        if err := initTempCaches(caches, transactionsTimeout, cfg.Caches); err != nil {
                return err
        }
        params, err := paramsFromConfig(cfg.ParamGroups)
        if err != nil {
                return err
        }
        profile := &usersProfile{
                cfg:      cfg.Users,
                clusters: clusters,
                caches:   caches,
                params:   params,
        }
        users, err := profile.newUsers()
        if err != nil {
                return err
        }
        if err := validateNoWildcardedUserForHeartbeat(clusters, cfg.Clusters); err != nil {
                return err
        }
        // New configs have been successfully prepared.
        // Restart service goroutines with new configs.
        // Stop the previous service goroutines.
        close(rp.reloadSignal)
        rp.reloadWG.Wait()
        rp.reloadSignal = make(chan struct{})
        rp.restartWithNewConfig(caches, clusters, users)
        // Substitute old configs with the new configs in rp.
        // All the currently running requests will continue with old configs,
        // while all the new requests will use new configs.
        rp.lock.Lock()
        rp.clusters = clusters
        rp.users = users
        // Swap is needed for deferred closing of old caches.
        // See the code above where new caches are created.
        caches, rp.caches = rp.caches, caches
        rp.lock.Unlock()
        return nil
}
func initTempCaches(caches map[string]*cache.AsyncCache, transactionsTimeout config.Duration, cfg []config.Cache) error {
        for _, cc := range cfg {
                if _, ok := caches[cc.Name]; ok {
                        return fmt.Errorf("duplicate config for cache %q", cc.Name)
                }
                tmpCache, err := cache.NewAsyncCache(cc, time.Duration(transactionsTimeout))
                if err != nil {
                        return err
                }
                caches[cc.Name] = tmpCache
        }
        return nil
}
func paramsFromConfig(cfg []config.ParamGroup) (map[string]*paramsRegistry, error) {
        params := make(map[string]*paramsRegistry, len(cfg))
        for _, p := range cfg {
                if _, ok := params[p.Name]; ok {
                        return nil, fmt.Errorf("duplicate config for ParamGroups %q", p.Name)
                }
                registry, err := newParamsRegistry(p.Params)
                if err != nil {
                        return nil, fmt.Errorf("cannot initialize params %q: %w", p.Name, err)
                }
                params[p.Name] = registry
        }
        return params, nil
}
func validateNoWildcardedUserForHeartbeat(clusters map[string]*cluster, cfg []config.Cluster) error {
        for c := range cfg {
                cfgcl := cfg[c]
                clname := cfgcl.Name
                cuname := cfgcl.ClusterUsers[0].Name
                heartbeat := cfg[c].HeartBeat
                cl := clusters[clname]
                cu := cl.users[cuname]
                if cu.isWildcarded {
                        if heartbeat.Request != "/ping" && len(heartbeat.User) == 0 {
                                return fmt.Errorf(
                                        "`cluster.heartbeat.user ` cannot be unset for %q because a wildcarded user cannot send heartbeat",
                                        clname,
                                )
                        }
                }
        }
        return nil
}
func (rp *reverseProxy) restartWithNewConfig(caches map[string]*cache.AsyncCache, clusters map[string]*cluster, users map[string]*user) {
        // Reset metrics from the previous configs, which may become irrelevant
        // with new configs.
        // Counters and Summary metrics are always relevant.
        // Gauge metrics may become irrelevant if they may freeze at non-zero
        // value after config reload.
        topology.HostHealth.Reset()
        cacheSize.Reset()
        cacheItems.Reset()
        // Start service goroutines with new configs.
        for _, c := range clusters {
                for _, r := range c.replicas {
                        for _, h := range r.hosts {
                                rp.reloadWG.Add(1)
                                go func(h *topology.Node) {
                                        h.StartHeartbeat(rp.reloadSignal)
                                        rp.reloadWG.Done()
                                }(h)
                        }
                }
                for _, cu := range c.users {
                        rp.reloadWG.Add(1)
                        go func(cu *clusterUser) {
                                cu.rateLimiter.run(rp.reloadSignal)
                                rp.reloadWG.Done()
                        }(cu)
                }
        }
        for _, u := range users {
                rp.reloadWG.Add(1)
                go func(u *user) {
                        u.rateLimiter.run(rp.reloadSignal)
                        rp.reloadWG.Done()
                }(u)
        }
}
// refreshCacheMetrics refreshes cacheSize and cacheItems metrics.
func (rp *reverseProxy) refreshCacheMetrics() {
        rp.lock.RLock()
        defer rp.lock.RUnlock()
        for _, c := range rp.caches {
                stats := c.Stats()
                labels := prometheus.Labels{
                        "cache": c.Name(),
                }
                cacheSize.With(labels).Set(float64(stats.Size))
                cacheItems.With(labels).Set(float64(stats.Items))
        }
}
// find user, cluster and clusterUser
// in case of wildcarded user, cluster user is crafted to use original credentials
func (rp *reverseProxy) getUser(name string, password string) (found bool, u *user, c *cluster, cu *clusterUser) {
        rp.lock.RLock()
        defer rp.lock.RUnlock()
        found = false
        u = rp.users[name]
        switch {
        case u != nil:
                found = (u.password == password)
                // existence of c and cu for toCluster is guaranteed by applyConfig
                c = rp.clusters[u.toCluster]
                cu = c.users[u.toUser]
        case name == "" || name == defaultUser:
                // default user can't work with the wildcarded feature for security reasons
                found = false
        case rp.hasWildcarded:
                // checking if we have wildcarded users and if username matches one 3 possibles patterns
                found, u, c, cu = rp.findWildcardedUserInformation(name, password)
        }
        return found, u, c, cu
}
func (rp *reverseProxy) findWildcardedUserInformation(name string, password string) (found bool, u *user, c *cluster, cu *clusterUser) {
        // cf a validation in config.go, the names must contains either a prefix, a suffix or a wildcard
        // the wildcarded user is "*"
        // the wildcarded user is "*[suffix]"
        // the wildcarded user is "[prefix]*"
        for _, user := range rp.users {
                if user.isWildcarded {
                        s := strings.Split(user.name, "*")
                        switch {
                        case s[0] == "" && s[1] == "":
                                return rp.generateWildcardedUserInformation(user, name, password)
                        case s[0] == "":
                                suffix := s[1]
                                if strings.HasSuffix(name, suffix) {
                                        return rp.generateWildcardedUserInformation(user, name, password)
                                }
                        case s[1] == "":
                                prefix := s[0]
                                if strings.HasPrefix(name, prefix) {
                                        return rp.generateWildcardedUserInformation(user, name, password)
                                }
                        }
                }
        }
        return false, nil, nil, nil
}
func (rp *reverseProxy) generateWildcardedUserInformation(user *user, name string, password string) (found bool, u *user, c *cluster, cu *clusterUser) {
        found = false
        c = rp.clusters[user.toCluster]
        wildcardedCu := c.users[user.toUser]
        if wildcardedCu != nil {
                newCU := deepCopy(wildcardedCu)
                found = true
                u = user
                cu = newCU
                cu.name = name
                cu.password = password
                // TODO : improve the following behavior
                // the wildcarded user feature creates some side-effects on clusterUser limitations (like the max_concurrent_queries)
                // because of the use of a deep copy of the clusterUser. The side effect should not be too impactful since the limitation still works on user.
                // But we need this deep copy since we're changing the name & password of clusterUser and if we used the same instance for every call to chproxy,
                // it could lead to security issues since a specific query run by user A on chproxy side could trigger a query in clickhouse from user B.
                // Doing a clean fix would require a huge refactoring.
        }
        return
}
func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) {
        name, password := getAuth(req)
        sessionId := getSessionId(req)
        sessionTimeout := getSessionTimeout(req)
        var (
                u  *user
                c  *cluster
                cu *clusterUser
        )
        found, u, c, cu := rp.getUser(name, password)
        if !found {
                return nil, http.StatusUnauthorized, fmt.Errorf("invalid username or password for user %q", name)
        }
        if u.denyHTTP && req.TLS == nil {
                return nil, http.StatusForbidden, fmt.Errorf("user %q is not allowed to access via http", u.name)
        }
        if u.denyHTTPS && req.TLS != nil {
                return nil, http.StatusForbidden, fmt.Errorf("user %q is not allowed to access via https", u.name)
        }
        if !u.allowedNetworks.Contains(req.RemoteAddr) {
                return nil, http.StatusForbidden, fmt.Errorf("user %q is not allowed to access", u.name)
        }
        if !cu.allowedNetworks.Contains(req.RemoteAddr) {
                return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name)
        }
        s := newScope(req, u, c, cu, sessionId, sessionTimeout)
        q, err := getFullQuery(req)
        if err != nil {
                return nil, http.StatusBadRequest, fmt.Errorf("%s: cannot read query: %w", s, err)
        }
        s.requestPacketSize = len(q)
        return s, 0, nil
}
		
		package main
import (
        "net"
        "net/http"
        "strings"
        "github.com/contentsquare/chproxy/config"
)
const (
        xForwardedForHeader = "X-Forwarded-For"
        xRealIPHeader       = "X-Real-Ip"
        forwardedHeader     = "Forwarded"
)
type ProxyHandler struct {
        proxy *config.Proxy
}
func NewProxyHandler(proxy *config.Proxy) *ProxyHandler {
        return &ProxyHandler{
                proxy: proxy,
        }
}
func (m *ProxyHandler) GetRemoteAddr(r *http.Request) string {
        if m.proxy.Enable {
                var addr string
                if m.proxy.Header != "" {
                        addr = r.Header.Get(m.proxy.Header)
                } else {
                        addr = parseDefaultProxyHeaders(r)
                }
                if isValidAddr(addr) {
                        return addr
                }
        }
        return r.RemoteAddr
}
// isValidAddr checks if the Addr is a valid IP or IP:port.
func isValidAddr(addr string) bool {
        if addr == "" {
                return false
        }
        ip, _, err := net.SplitHostPort(addr)
        if err != nil {
                return net.ParseIP(addr) != nil
        }
        return net.ParseIP(ip) != nil
}
func parseDefaultProxyHeaders(r *http.Request) string {
        var addr string
        if fwd := r.Header.Get(xForwardedForHeader); fwd != "" {
                addr = extractFirstMatchFromIPList(fwd)
        } else if fwd := r.Header.Get(xRealIPHeader); fwd != "" {
                addr = extractFirstMatchFromIPList(fwd)
        } else if fwd := r.Header.Get(forwardedHeader); fwd != "" {
                // See: https://tools.ietf.org/html/rfc7239.
                addr = parseForwardedHeader(fwd)
        }
        return addr
}
func extractFirstMatchFromIPList(ipList string) string {
        if ipList == "" {
                return ""
        }
        s := strings.Index(ipList, ",")
        if s == -1 {
                s = len(ipList)
        }
        return ipList[:s]
}
func parseForwardedHeader(fwd string) string {
        splits := strings.Split(fwd, ";")
        if len(splits) == 0 {
                return ""
        }
        for _, split := range splits {
                trimmed := strings.TrimSpace(split)
                if strings.HasPrefix(strings.ToLower(trimmed), "for=") {
                        forSplits := strings.Split(trimmed, ",")
                        if len(forSplits) == 0 {
                                return ""
                        }
                        addr := forSplits[0][4:]
                        trimmedAddr := strings.
                                NewReplacer("\"", "", "[", "", "]", "").
                                Replace(addr) // If IpV6, remove brackets and quotes.
                        return trimmedAddr
                }
        }
        return ""
}
		
		package main
import (
        "context"
        "fmt"
        "io"
        "net"
        "net/http"
        "net/url"
        "regexp"
        "strconv"
        "strings"
        "sync/atomic"
        "time"
        "github.com/contentsquare/chproxy/cache"
        "github.com/contentsquare/chproxy/config"
        "github.com/contentsquare/chproxy/internal/heartbeat"
        "github.com/contentsquare/chproxy/internal/topology"
        "github.com/contentsquare/chproxy/log"
        "github.com/prometheus/client_golang/prometheus"
        "golang.org/x/time/rate"
)
type scopeID uint64
func (sid scopeID) String() string {
        return fmt.Sprintf("%08X", uint64(sid))
}
func newScopeID() scopeID {
        sid := atomic.AddUint64(&nextScopeID, 1)
        return scopeID(sid)
}
var nextScopeID = uint64(time.Now().UnixNano())
type scope struct {
        startTime   time.Time
        id          scopeID
        host        *topology.Node
        cluster     *cluster
        user        *user
        clusterUser *clusterUser
        sessionId      string
        sessionTimeout int
        remoteAddr string
        localAddr  string
        // is true when KillQuery has been called
        canceled bool
        labels prometheus.Labels
        requestPacketSize int
}
func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int) *scope {
        h := c.getHost()
        if sessionId != "" {
                h = c.getHostSticky(sessionId)
        }
        var localAddr string
        if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
                localAddr = addr.String()
        }
        s := &scope{
                startTime:      time.Now(),
                id:             newScopeID(),
                host:           h,
                cluster:        c,
                user:           u,
                clusterUser:    cu,
                sessionId:      sessionId,
                sessionTimeout: sessionTimeout,
                remoteAddr: req.RemoteAddr,
                localAddr:  localAddr,
                labels: prometheus.Labels{
                        "user":         u.name,
                        "cluster":      c.name,
                        "cluster_user": cu.name,
                        "replica":      h.ReplicaName(),
                        "cluster_node": h.Host(),
                },
        }
        return s
}
func (s *scope) String() string {
        return fmt.Sprintf("[ Id: %s; User %q(%d) proxying as %q(%d) to %q(%d); RemoteAddr: %q; LocalAddr: %q; Duration: %d μs]",
                s.id,
                s.user.name, s.user.queryCounter.load(),
                s.clusterUser.name, s.clusterUser.queryCounter.load(),
                s.host.Host(), s.host.CurrentLoad(),
                s.remoteAddr, s.localAddr, time.Since(s.startTime).Nanoseconds()/1000.0)
}
//nolint:cyclop // TODO abstract user queues to reduce complexity here.
func (s *scope) incQueued() error {
        if s.user.queueCh == nil && s.clusterUser.queueCh == nil {
                // Request queues in the current scope are disabled.
                return s.inc()
        }
        // Do not store `replica` and `cluster_node` in labels, since they have
        // no sense for queue metrics.
        labels := prometheus.Labels{
                "user":         s.labels["user"],
                "cluster":      s.labels["cluster"],
                "cluster_user": s.labels["cluster_user"],
        }
        if s.user.queueCh != nil {
                select {
                case s.user.queueCh <- struct{}{}:
                        defer func() {
                                <-s.user.queueCh
                        }()
                default:
                        // Per-user request queue is full.
                        // Give the request the last chance to run.
                        err := s.inc()
                        if err != nil {
                                userQueueOverflow.With(labels).Inc()
                        }
                        return err
                }
        }
        if s.clusterUser.queueCh != nil {
                select {
                case s.clusterUser.queueCh <- struct{}{}:
                        defer func() {
                                <-s.clusterUser.queueCh
                        }()
                default:
                        // Per-clusterUser request queue is full.
                        // Give the request the last chance to run.
                        err := s.inc()
                        if err != nil {
                                clusterUserQueueOverflow.With(labels).Inc()
                        }
                        return err
                }
        }
        // The request has been successfully queued.
        queueSize := requestQueueSize.With(labels)
        queueSize.Inc()
        defer queueSize.Dec()
        // Try starting the request during the given duration.
        sleep, deadline := s.calculateQueueDeadlineAndSleep()
        return s.waitUntilAllowStart(sleep, deadline, labels)
}
func (s *scope) waitUntilAllowStart(sleep time.Duration, deadline time.Time, labels prometheus.Labels) error {
        for {
                err := s.inc()
                if err == nil {
                        // The request is allowed to start.
                        return nil
                }
                dLeft := time.Until(deadline)
                if dLeft <= 0 {
                        // Give up: the request exceeded its wait time
                        // in the queue :(
                        return err
                }
                // The request has dLeft remaining time to wait in the queue.
                // Sleep for a bit and try starting it again.
                if sleep > dLeft {
                        time.Sleep(dLeft)
                } else {
                        time.Sleep(sleep)
                }
                var h *topology.Node
                // Choose new host, since the previous one may become obsolete
                // after sleeping.
                if s.sessionId == "" {
                        h = s.cluster.getHost()
                } else {
                        // if request has session_id, set same host
                        h = s.cluster.getHostSticky(s.sessionId)
                }
                s.host = h
                s.labels["replica"] = h.ReplicaName()
                s.labels["cluster_node"] = h.Host()
        }
}
func (s *scope) calculateQueueDeadlineAndSleep() (time.Duration, time.Time) {
        d := s.maxQueueTime()
        dSleep := d / 10
        if dSleep > time.Second {
                dSleep = time.Second
        }
        if dSleep < time.Millisecond {
                dSleep = time.Millisecond
        }
        deadline := time.Now().Add(d)
        return dSleep, deadline
}
func (s *scope) inc() error {
        uQueries := s.user.queryCounter.inc()
        cQueries := s.clusterUser.queryCounter.inc()
        var err error
        if s.user.maxConcurrentQueries > 0 && uQueries > s.user.maxConcurrentQueries {
                err = fmt.Errorf("limits for user %q are exceeded: max_concurrent_queries limit: %d",
                        s.user.name, s.user.maxConcurrentQueries)
        }
        if s.clusterUser.maxConcurrentQueries > 0 && cQueries > s.clusterUser.maxConcurrentQueries {
                err = fmt.Errorf("limits for cluster user %q are exceeded: max_concurrent_queries limit: %d",
                        s.clusterUser.name, s.clusterUser.maxConcurrentQueries)
        }
        err2 := s.checkTokenFreeRateLimiters()
        if err2 != nil {
                err = err2
        }
        if err != nil {
                s.user.queryCounter.dec()
                s.clusterUser.queryCounter.dec()
                // Decrement rate limiter here, so it doesn't count requests
                // that didn't start due to limits overflow.
                s.user.rateLimiter.dec()
                s.clusterUser.rateLimiter.dec()
                return err
        }
        s.host.IncrementConnections()
        concurrentQueries.With(s.labels).Inc()
        return nil
}
func (s *scope) checkTokenFreeRateLimiters() error {
        var err error
        uRPM := s.user.rateLimiter.inc()
        cRPM := s.clusterUser.rateLimiter.inc()
        // int32(xRPM) > 0 check is required to detect races when RPM
        // is decremented on error below after per-minute zeroing
        // in rateLimiter.run.
        // These races become innocent with the given check.
        if s.user.reqPerMin > 0 && int32(uRPM) > 0 && uRPM > s.user.reqPerMin {
                err = fmt.Errorf("rate limit for user %q is exceeded: requests_per_minute limit: %d",
                        s.user.name, s.user.reqPerMin)
        }
        if s.clusterUser.reqPerMin > 0 && int32(cRPM) > 0 && cRPM > s.clusterUser.reqPerMin {
                err = fmt.Errorf("rate limit for cluster user %q is exceeded: requests_per_minute limit: %d",
                        s.clusterUser.name, s.clusterUser.reqPerMin)
        }
        err2 := s.checkTokenFreePacketSizeRateLimiters()
        if err2 != nil {
                err = err2
        }
        return err
}
func (s *scope) checkTokenFreePacketSizeRateLimiters() error {
        var err error
        // reserving tokens num s.requestPacketSize
        if s.user.reqPacketSizeTokensBurst > 0 {
                tl := s.user.reqPacketSizeTokenLimiter
                ok := tl.AllowN(time.Now(), s.requestPacketSize)
                if !ok {
                        err = fmt.Errorf("limits for user %q is exceeded: request_packet_size_tokens_burst limit: %d",
                                s.user.name, s.user.reqPacketSizeTokensBurst)
                }
        }
        if s.clusterUser.reqPacketSizeTokensBurst > 0 {
                tl := s.clusterUser.reqPacketSizeTokenLimiter
                ok := tl.AllowN(time.Now(), s.requestPacketSize)
                if !ok {
                        err = fmt.Errorf("limits for cluster user %q is exceeded: request_packet_size_tokens_burst limit: %d",
                                s.clusterUser.name, s.clusterUser.reqPacketSizeTokensBurst)
                }
        }
        return err
}
func (s *scope) dec() {
        // There is no need in ratelimiter.dec here, since the rate limiter
        // is automatically zeroed every minute in rateLimiter.run.
        s.user.queryCounter.dec()
        s.clusterUser.queryCounter.dec()
        s.host.DecrementConnections()
        concurrentQueries.With(s.labels).Dec()
}
const killQueryTimeout = time.Second * 30
func (s *scope) killQuery() error {
        log.Debugf("killing the query with query_id=%s", s.id)
        killedRequests.With(s.labels).Inc()
        s.canceled = true
        query := fmt.Sprintf("KILL QUERY WHERE query_id = '%s'", s.id)
        r := strings.NewReader(query)
        addr := s.host.String()
        req, err := http.NewRequest("POST", addr, r)
        if err != nil {
                return fmt.Errorf("error while creating kill query request to %s: %w", addr, err)
        }
        ctx, cancel := context.WithTimeout(context.Background(), killQueryTimeout)
        defer cancel()
        req = req.WithContext(ctx)
        // send request as kill_query_user
        userName := s.cluster.killQueryUserName
        if len(userName) == 0 {
                userName = defaultUser
        }
        req.SetBasicAuth(userName, s.cluster.killQueryUserPassword)
        resp, err := http.DefaultClient.Do(req)
        if err != nil {
                return fmt.Errorf("error while executing clickhouse query %q at %q: %w", query, addr, err)
        }
        defer resp.Body.Close()
        if resp.StatusCode != http.StatusOK {
                responseBody, _ := io.ReadAll(resp.Body)
                return fmt.Errorf("unexpected status code returned from query %q at %q: %d. Response body: %q",
                        query, addr, resp.StatusCode, responseBody)
        }
        respBody, err := io.ReadAll(resp.Body)
        if err != nil {
                return fmt.Errorf("cannot read response body for the query %q: %w", query, err)
        }
        log.Debugf("killed the query with query_id=%s; respBody: %q", s.id, respBody)
        return nil
}
// allowedParams contains query args allowed to be proxied.
// See https://clickhouse.com/docs/en/operations/settings/
//
// All the other params passed via query args are stripped before
// proxying the request. This is for the sake of security.
var allowedParams = []string{
        "query",
        "database",
        "default_format",
        // if `compress=1`, CH will compress the data it sends you
        "compress",
        // if `decompress=1` , CH will decompress the same data that you pass in the POST method
        "decompress",
        // compress the result if the client over HTTP said that it understands data compressed by gzip or deflate.
        "enable_http_compression",
        // limit on the number of rows in the result
        "max_result_rows",
        // whether to count extreme values
        "extremes",
        // what to do if the volume of the result exceeds one of the limits
        "result_overflow_mode",
        // session stickiness
        "session_id",
        // session timeout
        "session_timeout",
}
// This regexp must match params needed to describe a way to use external data
// @see https://clickhouse.yandex/docs/en/table_engines/external_data/
var externalDataParams = regexp.MustCompile(`(_types|_structure|_format)$`)
func (s *scope) decorateRequest(req *http.Request) (*http.Request, url.Values) {
        // Make new params to purify URL.
        params := make(url.Values)
        // Set user params
        if s.user.params != nil {
                for _, param := range s.user.params.params {
                        params.Set(param.Key, param.Value)
                }
        }
        // Keep allowed params.
        origParams := req.URL.Query()
        for _, param := range allowedParams {
                val := origParams.Get(param)
                if len(val) > 0 {
                        params.Set(param, val)
                }
        }
        // Keep parametrized queries params
        for param := range origParams {
                if strings.HasPrefix(param, "param_") {
                        params.Set(param, origParams.Get(param))
                }
        }
        // Keep external_data params
        if req.Method == "POST" {
                s.decoratePostRequest(req, origParams, params)
        }
        // Set query_id as scope_id to have possibility to kill query if needed.
        params.Set("query_id", s.id.String())
        // Set session_timeout an idle timeout for session
        params.Set("session_timeout", strconv.Itoa(s.sessionTimeout))
        req.URL.RawQuery = params.Encode()
        // Rewrite possible previous Basic Auth and send request
        // as cluster user.
        req.SetBasicAuth(s.clusterUser.name, s.clusterUser.password)
        // Delete possible X-ClickHouse headers,
        // it is not allowed to use X-ClickHouse HTTP headers and other authentication methods simultaneously
        req.Header.Del("X-ClickHouse-User")
        req.Header.Del("X-ClickHouse-Key")
        // Send request to the chosen host from cluster.
        req.URL.Scheme = s.host.Scheme()
        req.URL.Host = s.host.Host()
        // Extend ua with additional info, so it may be queried
        // via system.query_log.http_user_agent.
        ua := fmt.Sprintf("RemoteAddr: %s; LocalAddr: %s; CHProxy-User: %s; CHProxy-ClusterUser: %s; %s",
                s.remoteAddr, s.localAddr, s.user.name, s.clusterUser.name, req.UserAgent())
        req.Header.Set("User-Agent", ua)
        return req, origParams
}
func (s *scope) decoratePostRequest(req *http.Request, origParams, params url.Values) {
        ct := req.Header.Get("Content-Type")
        if strings.Contains(ct, "multipart/form-data") {
                for key := range origParams {
                        if externalDataParams.MatchString(key) {
                                params.Set(key, origParams.Get(key))
                        }
                }
                // disable cache for external_data queries
                origParams.Set("no_cache", "1")
                log.Debugf("external data params detected - cache will be disabled")
        }
}
func (s *scope) getTimeoutWithErrMsg() (time.Duration, error) {
        var (
                timeout       time.Duration
                timeoutErrMsg error
        )
        if s.user.maxExecutionTime > 0 {
                timeout = s.user.maxExecutionTime
                timeoutErrMsg = fmt.Errorf("timeout for user %q exceeded: %v", s.user.name, timeout)
        }
        if timeout == 0 || (s.clusterUser.maxExecutionTime > 0 && s.clusterUser.maxExecutionTime < timeout) {
                timeout = s.clusterUser.maxExecutionTime
                timeoutErrMsg = fmt.Errorf("timeout for cluster user %q exceeded: %v", s.clusterUser.name, timeout)
        }
        return timeout, timeoutErrMsg
}
func (s *scope) maxQueueTime() time.Duration {
        d := s.user.maxQueueTime
        if d <= 0 || s.clusterUser.maxQueueTime > 0 && s.clusterUser.maxQueueTime < d {
                d = s.clusterUser.maxQueueTime
        }
        if d <= 0 {
                // Default queue time.
                d = 10 * time.Second
        }
        return d
}
type paramsRegistry struct {
        // key is a hashed concatenation of the params list
        key uint32
        params []config.Param
}
func newParamsRegistry(params []config.Param) (*paramsRegistry, error) {
        if len(params) == 0 {
                return nil, fmt.Errorf("params can't be empty")
        }
        paramsMap := make(map[string]string, len(params))
        for _, k := range params {
                paramsMap[k.Key] = k.Value
        }
        key, err := calcMapHash(paramsMap)
        if err != nil {
                return nil, err
        }
        return ¶msRegistry{
                key:    key,
                params: params,
        }, nil
}
type user struct {
        name     string
        password string
        toCluster string
        toUser    string
        maxConcurrentQueries uint32
        queryCounter         counter
        maxExecutionTime time.Duration
        reqPerMin   uint32
        rateLimiter rateLimiter
        reqPacketSizeTokenLimiter *rate.Limiter
        reqPacketSizeTokensBurst  config.ByteSize
        reqPacketSizeTokensRate   config.ByteSize
        queueCh      chan struct{}
        maxQueueTime time.Duration
        allowedNetworks config.Networks
        denyHTTP     bool
        denyHTTPS    bool
        allowCORS    bool
        isWildcarded bool
        cache  *cache.AsyncCache
        params *paramsRegistry
}
type usersProfile struct {
        cfg      []config.User
        clusters map[string]*cluster
        caches   map[string]*cache.AsyncCache
        params   map[string]*paramsRegistry
}
func (up usersProfile) newUsers() (map[string]*user, error) {
        users := make(map[string]*user, len(up.cfg))
        for _, u := range up.cfg {
                if _, ok := users[u.Name]; ok {
                        return nil, fmt.Errorf("duplicate config for user %q", u.Name)
                }
                tmpU, err := up.newUser(u)
                if err != nil {
                        return nil, fmt.Errorf("cannot initialize user %q: %w", u.Name, err)
                }
                users[u.Name] = tmpU
        }
        return users, nil
}
func (up usersProfile) newUser(u config.User) (*user, error) {
        c, ok := up.clusters[u.ToCluster]
        if !ok {
                return nil, fmt.Errorf("unknown `to_cluster` %q", u.ToCluster)
        }
        var cu *clusterUser
        if cu, ok = c.users[u.ToUser]; !ok {
                return nil, fmt.Errorf("unknown `to_user` %q in cluster %q", u.ToUser, u.ToCluster)
        } else if u.IsWildcarded {
                // a wildcarded user is mapped to this cluster user
                // used to check if a proper user to send heartbeat exists
                cu.isWildcarded = true
        }
        var queueCh chan struct{}
        if u.MaxQueueSize > 0 {
                queueCh = make(chan struct{}, u.MaxQueueSize)
        }
        var cc *cache.AsyncCache
        if len(u.Cache) > 0 {
                cc = up.caches[u.Cache]
                if cc == nil {
                        return nil, fmt.Errorf("unknown `cache` %q", u.Cache)
                }
        }
        var params *paramsRegistry
        if len(u.Params) > 0 {
                params = up.params[u.Params]
                if params == nil {
                        return nil, fmt.Errorf("unknown `params` %q", u.Params)
                }
        }
        return &user{
                name:                      u.Name,
                password:                  u.Password,
                toCluster:                 u.ToCluster,
                toUser:                    u.ToUser,
                maxConcurrentQueries:      u.MaxConcurrentQueries,
                maxExecutionTime:          time.Duration(u.MaxExecutionTime),
                reqPerMin:                 u.ReqPerMin,
                queueCh:                   queueCh,
                maxQueueTime:              time.Duration(u.MaxQueueTime),
                reqPacketSizeTokenLimiter: rate.NewLimiter(rate.Limit(u.ReqPacketSizeTokensRate), int(u.ReqPacketSizeTokensBurst)),
                reqPacketSizeTokensBurst:  u.ReqPacketSizeTokensBurst,
                reqPacketSizeTokensRate:   u.ReqPacketSizeTokensRate,
                allowedNetworks:           u.AllowedNetworks,
                denyHTTP:                  u.DenyHTTP,
                denyHTTPS:                 u.DenyHTTPS,
                allowCORS:                 u.AllowCORS,
                isWildcarded:              u.IsWildcarded,
                cache:                     cc,
                params:                    params,
        }, nil
}
type clusterUser struct {
        name     string
        password string
        maxConcurrentQueries uint32
        queryCounter         counter
        maxExecutionTime time.Duration
        reqPerMin   uint32
        rateLimiter rateLimiter
        queueCh      chan struct{}
        maxQueueTime time.Duration
        reqPacketSizeTokenLimiter *rate.Limiter
        reqPacketSizeTokensBurst  config.ByteSize
        reqPacketSizeTokensRate   config.ByteSize
        allowedNetworks config.Networks
        isWildcarded    bool
}
func deepCopy(cu *clusterUser) *clusterUser {
        var queueCh chan struct{}
        if cu.maxQueueTime > 0 {
                queueCh = make(chan struct{}, cu.maxQueueTime)
        }
        return &clusterUser{
                name:                 cu.name,
                password:             cu.password,
                maxConcurrentQueries: cu.maxConcurrentQueries,
                maxExecutionTime:     time.Duration(cu.maxExecutionTime),
                reqPerMin:            cu.reqPerMin,
                queueCh:              queueCh,
                maxQueueTime:         time.Duration(cu.maxQueueTime),
                allowedNetworks:      cu.allowedNetworks,
        }
}
func newClusterUser(cu config.ClusterUser) *clusterUser {
        var queueCh chan struct{}
        if cu.MaxQueueSize > 0 {
                queueCh = make(chan struct{}, cu.MaxQueueSize)
        }
        return &clusterUser{
                name:                      cu.Name,
                password:                  cu.Password,
                maxConcurrentQueries:      cu.MaxConcurrentQueries,
                maxExecutionTime:          time.Duration(cu.MaxExecutionTime),
                reqPerMin:                 cu.ReqPerMin,
                reqPacketSizeTokenLimiter: rate.NewLimiter(rate.Limit(cu.ReqPacketSizeTokensRate), int(cu.ReqPacketSizeTokensBurst)),
                reqPacketSizeTokensBurst:  cu.ReqPacketSizeTokensBurst,
                reqPacketSizeTokensRate:   cu.ReqPacketSizeTokensRate,
                queueCh:                   queueCh,
                maxQueueTime:              time.Duration(cu.MaxQueueTime),
                allowedNetworks:           cu.AllowedNetworks,
        }
}
type replica struct {
        cluster *cluster
        name string
        hosts       []*topology.Node
        nextHostIdx uint32
}
func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c *cluster) ([]*replica, error) {
        if len(nodes) > 0 {
                // No replicas, just flat nodes. Create default replica
                // containing all the nodes.
                r := &replica{
                        cluster: c,
                        name:    "default",
                }
                hosts, err := newNodes(nodes, scheme, r)
                if err != nil {
                        return nil, err
                }
                r.hosts = hosts
                return []*replica{r}, nil
        }
        replicas := make([]*replica, len(replicasCfg))
        for i, rCfg := range replicasCfg {
                r := &replica{
                        cluster: c,
                        name:    rCfg.Name,
                }
                hosts, err := newNodes(rCfg.Nodes, scheme, r)
                if err != nil {
                        return nil, fmt.Errorf("cannot initialize replica %q: %w", rCfg.Name, err)
                }
                r.hosts = hosts
                replicas[i] = r
        }
        return replicas, nil
}
func newNodes(nodes []string, scheme string, r *replica) ([]*topology.Node, error) {
        hosts := make([]*topology.Node, len(nodes))
        for i, node := range nodes {
                addr, err := url.Parse(fmt.Sprintf("%s://%s", scheme, node))
                if err != nil {
                        return nil, fmt.Errorf("cannot parse `node` %q with `scheme` %q: %w", node, scheme, err)
                }
                hosts[i] = topology.NewNode(addr, r.cluster.heartBeat, r.cluster.name, r.name)
        }
        return hosts, nil
}
func (r *replica) isActive() bool {
        // The replica is active if at least a single host is active.
        for _, h := range r.hosts {
                if h.IsActive() {
                        return true
                }
        }
        return false
}
func (r *replica) load() uint32 {
        var reqs uint32
        for _, h := range r.hosts {
                reqs += h.CurrentLoad()
        }
        return reqs
}
type cluster struct {
        name string
        replicas       []*replica
        nextReplicaIdx uint32
        users map[string]*clusterUser
        killQueryUserName     string
        killQueryUserPassword string
        heartBeat heartbeat.HeartBeat
        retryNumber int
}
func newCluster(c config.Cluster) (*cluster, error) {
        clusterUsers := make(map[string]*clusterUser, len(c.ClusterUsers))
        for _, cu := range c.ClusterUsers {
                if _, ok := clusterUsers[cu.Name]; ok {
                        return nil, fmt.Errorf("duplicate config for cluster user %q", cu.Name)
                }
                clusterUsers[cu.Name] = newClusterUser(cu)
        }
        heartBeat := heartbeat.NewHeartbeat(c.HeartBeat, heartbeat.WithDefaultUser(c.ClusterUsers[0].Name, c.ClusterUsers[0].Password))
        newC := &cluster{
                name:                  c.Name,
                users:                 clusterUsers,
                killQueryUserName:     c.KillQueryUser.Name,
                killQueryUserPassword: c.KillQueryUser.Password,
                heartBeat:             heartBeat,
                retryNumber:           c.RetryNumber,
        }
        replicas, err := newReplicas(c.Replicas, c.Nodes, c.Scheme, newC)
        if err != nil {
                return nil, fmt.Errorf("cannot initialize replicas: %w", err)
        }
        newC.replicas = replicas
        return newC, nil
}
func newClusters(cfg []config.Cluster) (map[string]*cluster, error) {
        clusters := make(map[string]*cluster, len(cfg))
        for _, c := range cfg {
                if _, ok := clusters[c.Name]; ok {
                        return nil, fmt.Errorf("duplicate config for cluster %q", c.Name)
                }
                tmpC, err := newCluster(c)
                if err != nil {
                        return nil, fmt.Errorf("cannot initialize cluster %q: %w", c.Name, err)
                }
                clusters[c.Name] = tmpC
        }
        return clusters, nil
}
// getReplica returns least loaded + round-robin replica from the cluster.
//
// Always returns non-nil.
func (c *cluster) getReplica() *replica {
        idx := atomic.AddUint32(&c.nextReplicaIdx, 1)
        n := uint32(len(c.replicas))
        if n == 1 {
                return c.replicas[0]
        }
        idx %= n
        r := c.replicas[idx]
        reqs := r.load()
        // Set least priority to inactive replica.
        if !r.isActive() {
                reqs = ^uint32(0)
        }
        if reqs == 0 {
                return r
        }
        // Scan all the replicas for the least loaded replica.
        for i := uint32(1); i < n; i++ {
                tmpIdx := (idx + i) % n
                tmpR := c.replicas[tmpIdx]
                if !tmpR.isActive() {
                        continue
                }
                tmpReqs := tmpR.load()
                if tmpReqs == 0 {
                        return tmpR
                }
                if tmpReqs < reqs {
                        r = tmpR
                        reqs = tmpReqs
                }
        }
        // The returned replica may be inactive. This is OK,
        // since this means all the replicas are inactive,
        // so let's try proxying the request to any replica.
        return r
}
func (c *cluster) getReplicaSticky(sessionId string) *replica {
        idx := atomic.AddUint32(&c.nextReplicaIdx, 1)
        n := uint32(len(c.replicas))
        if n == 1 {
                return c.replicas[0]
        }
        idx %= n
        r := c.replicas[idx]
        for i := uint32(1); i < n; i++ {
                // handling sticky session
                sessionId := hash(sessionId)
                tmpIdx := (sessionId) % n
                tmpRSticky := c.replicas[tmpIdx]
                log.Debugf("Sticky replica candidate is: %s", tmpRSticky.name)
                if !tmpRSticky.isActive() {
                        log.Debugf("Sticky session replica has been picked up, but it is not available")
                        continue
                }
                log.Debugf("Sticky session replica is: %s, session_id: %d, replica_idx: %d, max replicas in pool: %d", tmpRSticky.name, sessionId, tmpIdx, n)
                return tmpRSticky
        }
        // The returned replica may be inactive. This is OK,
        // since this means all the replicas are inactive,
        // so let's try proxying the request to any replica.
        return r
}
// getHostSticky returns host by stickiness from replica.
//
// Always returns non-nil.
func (r *replica) getHostSticky(sessionId string) *topology.Node {
        idx := atomic.AddUint32(&r.nextHostIdx, 1)
        n := uint32(len(r.hosts))
        if n == 1 {
                return r.hosts[0]
        }
        idx %= n
        h := r.hosts[idx]
        // Scan all the hosts for the least loaded host.
        for i := uint32(1); i < n; i++ {
                // handling sticky session
                sessionId := hash(sessionId)
                tmpIdx := (sessionId) % n
                tmpHSticky := r.hosts[tmpIdx]
                log.Debugf("Sticky server candidate is: %s", tmpHSticky)
                if !tmpHSticky.IsActive() {
                        log.Debugf("Sticky session server has been picked up, but it is not available")
                        continue
                }
                log.Debugf("Sticky session server is: %s, session_id: %d, server_idx: %d, max nodes in pool: %d", tmpHSticky, sessionId, tmpIdx, n)
                return tmpHSticky
        }
        // The returned host may be inactive. This is OK,
        // since this means all the hosts are inactive,
        // so let's try proxying the request to any host.
        return h
}
// getHost returns least loaded + round-robin host from replica.
//
// Always returns non-nil.
func (r *replica) getHost() *topology.Node {
        idx := atomic.AddUint32(&r.nextHostIdx, 1)
        n := uint32(len(r.hosts))
        if n == 1 {
                return r.hosts[0]
        }
        idx %= n
        h := r.hosts[idx]
        reqs := h.CurrentLoad()
        // Set least priority to inactive host.
        if !h.IsActive() {
                reqs = ^uint32(0)
        }
        if reqs == 0 {
                return h
        }
        // Scan all the hosts for the least loaded host.
        for i := uint32(1); i < n; i++ {
                tmpIdx := (idx + i) % n
                tmpH := r.hosts[tmpIdx]
                if !tmpH.IsActive() {
                        continue
                }
                tmpReqs := tmpH.CurrentLoad()
                if tmpReqs == 0 {
                        return tmpH
                }
                if tmpReqs < reqs {
                        h = tmpH
                        reqs = tmpReqs
                }
        }
        // The returned host may be inactive. This is OK,
        // since this means all the hosts are inactive,
        // so let's try proxying the request to any host.
        return h
}
// getHostSticky returns host based on stickiness from cluster.
//
// Always returns non-nil.
func (c *cluster) getHostSticky(sessionId string) *topology.Node {
        r := c.getReplicaSticky(sessionId)
        return r.getHostSticky(sessionId)
}
// getHost returns least loaded + round-robin host from cluster.
//
// Always returns non-nil.
func (c *cluster) getHost() *topology.Node {
        r := c.getReplica()
        return r.getHost()
}
type rateLimiter struct {
        counter
}
func (rl *rateLimiter) run(done <-chan struct{}) {
        for {
                select {
                case <-done:
                        return
                case <-time.After(time.Minute):
                        rl.store(0)
                }
        }
}
type counter struct {
        value uint32
}
func (c *counter) store(n uint32) { atomic.StoreUint32(&c.value, n) }
func (c *counter) load() uint32 { return atomic.LoadUint32(&c.value) }
func (c *counter) dec() { atomic.AddUint32(&c.value, ^uint32(0)) }
func (c *counter) inc() uint32 { return atomic.AddUint32(&c.value, 1) }
		
		package main
import (
        "bytes"
        "compress/gzip"
        "fmt"
        "hash/fnv"
        "io"
        "net/http"
        "sort"
        "strconv"
        "strings"
        "github.com/contentsquare/chproxy/chdecompressor"
        "github.com/contentsquare/chproxy/log"
)
func respondWith(rw http.ResponseWriter, err error, status int) {
        log.ErrorWithCallDepth(err, 1)
        rw.WriteHeader(status)
        fmt.Fprintf(rw, "%s\n", err)
}
var defaultUser = "default"
// getAuth retrieves auth credentials from request
// according to CH documentation @see "https://clickhouse.yandex/docs/en/interfaces/http/"
func getAuth(req *http.Request) (string, string) {
        // check X-ClickHouse- headers
        name := req.Header.Get("X-ClickHouse-User")
        pass := req.Header.Get("X-ClickHouse-Key")
        if name != "" {
                return name, pass
        }
        // if header is empty - check basicAuth
        if name, pass, ok := req.BasicAuth(); ok {
                return name, pass
        }
        // if basicAuth is empty - check URL params `user` and `password`
        params := req.URL.Query()
        if name := params.Get("user"); name != "" {
                pass := params.Get("password")
                return name, pass
        }
        // if still no credentials - treat it as `default` user request
        return defaultUser, ""
}
// getSessionId retrieves session id
func getSessionId(req *http.Request) string {
        params := req.URL.Query()
        sessionId := params.Get("session_id")
        return sessionId
}
// getSessionId retrieves session id
func getSessionTimeout(req *http.Request) int {
        params := req.URL.Query()
        sessionTimeout, err := strconv.Atoi(params.Get("session_timeout"))
        if err == nil && sessionTimeout > 0 {
                return sessionTimeout
        }
        return 60
}
// getQuerySnippet returns query snippet.
//
// getQuerySnippet must be called only for error reporting.
func getQuerySnippet(req *http.Request) string {
        query := req.URL.Query().Get("query")
        body := getQuerySnippetFromBody(req)
        if len(query) != 0 && len(body) != 0 {
                query += "\n"
        }
        return query + body
}
func hash(s string) uint32 {
        h := fnv.New32a()
        h.Write([]byte(s))
        return h.Sum32()
}
func getQuerySnippetFromBody(req *http.Request) string {
        if req.Body == nil {
                return ""
        }
        crc, ok := req.Body.(*cachedReadCloser)
        if !ok {
                crc = &cachedReadCloser{
                        ReadCloser: req.Body,
                }
        }
        // 'read' request body, so it traps into to crc.
        // Ignore any errors, since getQuerySnippet is called only
        // during error reporting.
        io.Copy(io.Discard, crc) // nolint
        data := crc.String()
        u := getDecompressor(req)
        if u == nil {
                return data
        }
        bs := bytes.NewBufferString(data)
        b, err := u.decompress(bs)
        if err == nil {
                return string(b)
        }
        // It is better to return partially decompressed data instead of an empty string.
        if len(b) > 0 {
                return string(b)
        }
        // The data failed to be decompressed. Return compressed data
        // instead of an empty string.
        return data
}
// getFullQuery returns full query from req.
func getFullQuery(req *http.Request) ([]byte, error) {
        var result bytes.Buffer
        if req.URL.Query().Get("query") != "" {
                result.WriteString(req.URL.Query().Get("query"))
        }
        body, err := getFullQueryFromBody(req)
        if err != nil {
                return nil, err
        }
        if result.Len() != 0 && len(body) != 0 {
                result.WriteByte('\n')
        }
        result.Write(body)
        return result.Bytes(), nil
}
func getFullQueryFromBody(req *http.Request) ([]byte, error) {
        if req.Body == nil {
                return nil, nil
        }
        data, err := io.ReadAll(req.Body)
        if err != nil {
                return nil, err
        }
        // restore body for further reading
        req.Body = io.NopCloser(bytes.NewBuffer(data))
        u := getDecompressor(req)
        if u == nil {
                return data, nil
        }
        br := bytes.NewReader(data)
        b, err := u.decompress(br)
        if err != nil {
                return nil, fmt.Errorf("cannot uncompress query: %w", err)
        }
        return b, nil
}
var cachableStatements = []string{"SELECT", "WITH"}
// canCacheQuery returns true if q can be cached.
func canCacheQuery(q []byte) bool {
        q = skipLeadingComments(q)
        for _, statement := range cachableStatements {
                if len(q) < len(statement) {
                        continue
                }
                l := bytes.ToUpper(q[:len(statement)])
                if bytes.HasPrefix(l, []byte(statement)) {
                        return true
                }
        }
        return false
}
//nolint:cyclop // No clean way to split this.
func skipLeadingComments(q []byte) []byte {
        for len(q) > 0 {
                switch q[0] {
                case '\t', '\n', '\v', '\f', '\r', ' ':
                        q = q[1:]
                case '-':
                        if len(q) < 2 || q[1] != '-' {
                                return q
                        }
                        // skip `-- comment`
                        n := bytes.IndexByte(q, '\n')
                        if n < 0 {
                                return nil
                        }
                        q = q[n+1:]
                case '/':
                        if len(q) < 2 || q[1] != '*' {
                                return q
                        }
                        // skip `/* comment */`
                        for {
                                n := bytes.IndexByte(q, '*')
                                if n < 0 {
                                        return nil
                                }
                                q = q[n+1:]
                                if len(q) == 0 {
                                        return nil
                                }
                                if q[0] == '/' {
                                        q = q[1:]
                                        break
                                }
                        }
                default:
                        return q
                }
        }
        return nil
}
// splits header string in sorted slice
func sortHeader(header string) string {
        h := strings.Split(header, ",")
        for i, v := range h {
                h[i] = strings.TrimSpace(v)
        }
        sort.Strings(h)
        return strings.Join(h, ",")
}
func getDecompressor(req *http.Request) decompressor {
        if req.Header.Get("Content-Encoding") == "gzip" {
                return gzipDecompressor{}
        }
        if req.URL.Query().Get("decompress") == "1" {
                return chDecompressor{}
        }
        return nil
}
type decompressor interface {
        decompress(r io.Reader) ([]byte, error)
}
type gzipDecompressor struct{}
func (dc gzipDecompressor) decompress(r io.Reader) ([]byte, error) {
        gr, err := gzip.NewReader(r)
        if err != nil {
                return nil, fmt.Errorf("cannot ungzip query: %w", err)
        }
        return io.ReadAll(gr)
}
type chDecompressor struct{}
func (dc chDecompressor) decompress(r io.Reader) ([]byte, error) {
        lr := chdecompressor.NewReader(r)
        return io.ReadAll(lr)
}
func calcMapHash(m map[string]string) (uint32, error) {
        if len(m) == 0 {
                return 0, nil
        }
        var keys []string
        for key := range m {
                keys = append(keys, key)
        }
        sort.Strings(keys)
        h := fnv.New32a()
        for _, k := range keys {
                str := fmt.Sprintf("%s=%s&", k, m[k])
                _, err := h.Write([]byte(str))
                if err != nil {
                        return 0, err
                }
        }
        return h.Sum32(), nil
}
func calcCredentialHash(user string, pwd string) (uint32, error) {
        h := fnv.New32a()
        _, err := h.Write([]byte(user + pwd))
        return h.Sum32(), err
}