package cache import ( "fmt" "time" "github.com/contentsquare/chproxy/clients" "github.com/contentsquare/chproxy/config" "github.com/contentsquare/chproxy/log" "github.com/redis/go-redis/v9" ) // AsyncCache is a transactional cache enabled to serve the results from concurrent queries. // When query A and B are equal, and query B arrives after query A with no more than defined deadline interval [[graceTime]], // query B will await for the results of query B for the max time equal to: // max_awaiting_time = graceTime - (arrivalB - arrivalA) type AsyncCache struct { Cache TransactionRegistry graceTime time.Duration MaxPayloadSize config.ByteSize SharedWithAllUsers bool } func (c *AsyncCache) Close() error { if c.TransactionRegistry != nil { c.TransactionRegistry.Close() } if c.Cache != nil { c.Cache.Close() } return nil } func (c *AsyncCache) AwaitForConcurrentTransaction(key *Key) (TransactionStatus, error) { startTime := time.Now() seenState := transactionAbsent for { elapsedTime := time.Since(startTime) if elapsedTime > c.graceTime { // The entry didn't appear during deadline. // Let the caller creating it. return TransactionStatus{State: seenState}, nil } status, err := c.TransactionRegistry.Status(key) if err != nil { return TransactionStatus{State: seenState}, err } if !status.State.IsPending() { return status, nil } // Wait for deadline in the hope the entry will appear // in the cache. // // This should protect from thundering herd problem when // a single slow query is executed from concurrent requests. d := 100 * time.Millisecond if d > c.graceTime { d = c.graceTime } time.Sleep(d) } } func NewAsyncCache(cfg config.Cache, maxExecutionTime time.Duration) (*AsyncCache, error) { graceTime := time.Duration(cfg.GraceTime) if graceTime > 0 { log.Errorf("[DEPRECATED] detected grace time configuration %s. It will be removed in the new version", graceTime) } if graceTime == 0 { // Default grace time. graceTime = maxExecutionTime } if graceTime < 0 { // Disable protection from `dogpile effect`. graceTime = 0 } var cache Cache var transaction TransactionRegistry var err error // transaction will be kept until we're sure there's no possible concurrent query running transactionDeadline := 2 * graceTime switch cfg.Mode { case "file_system": cache, err = newFilesSystemCache(cfg, graceTime) transaction = newInMemoryTransactionRegistry(transactionDeadline, transactionEndedTTL) case "redis": var redisClient redis.UniversalClient redisClient, err = clients.NewRedisClient(cfg.Redis) cache = newRedisCache(redisClient, cfg) transaction = newRedisTransactionRegistry(redisClient, transactionDeadline, transactionEndedTTL) default: return nil, fmt.Errorf("unknown config mode") } if err != nil { return nil, err } maxPayloadSize := cfg.MaxPayloadSize return &AsyncCache{ Cache: cache, TransactionRegistry: transaction, graceTime: graceTime, MaxPayloadSize: maxPayloadSize, SharedWithAllUsers: cfg.SharedWithAllUsers, }, nil }
package cache import ( "fmt" "io" "math/rand" "os" "path/filepath" "regexp" "strconv" "sync" "sync/atomic" "time" "github.com/contentsquare/chproxy/config" "github.com/contentsquare/chproxy/log" ) var cachefileRegexp = regexp.MustCompile(`^[0-9a-f]{32}$`) // fileSystemCache represents a file cache. type fileSystemCache struct { // name is cache name. name string dir string maxSize uint64 expire time.Duration grace time.Duration stats Stats wg sync.WaitGroup stopCh chan struct{} } // newFilesSystemCache returns new cache for the given cfg. func newFilesSystemCache(cfg config.Cache, graceTime time.Duration) (*fileSystemCache, error) { if len(cfg.FileSystem.Dir) == 0 { return nil, fmt.Errorf("`dir` cannot be empty") } if cfg.FileSystem.MaxSize <= 0 { return nil, fmt.Errorf("`max_size` must be positive") } if cfg.Expire <= 0 { return nil, fmt.Errorf("`expire` must be positive") } c := &fileSystemCache{ name: cfg.Name, dir: cfg.FileSystem.Dir, maxSize: uint64(cfg.FileSystem.MaxSize), expire: time.Duration(cfg.Expire), grace: graceTime, stopCh: make(chan struct{}), } if err := os.MkdirAll(c.dir, 0700); err != nil { return nil, fmt.Errorf("cannot create %q: %w", c.dir, err) } c.wg.Add(1) go func() { log.Debugf("cache %q: cleaner start", c.Name()) c.cleaner() log.Debugf("cache %q: cleaner stop", c.Name()) c.wg.Done() }() return c, nil } func (f *fileSystemCache) Name() string { return f.name } func (f *fileSystemCache) Close() error { log.Debugf("cache %q: stopping", f.Name()) close(f.stopCh) f.wg.Wait() log.Debugf("cache %q: stopped", f.Name()) return nil } func (f *fileSystemCache) Stats() Stats { var s Stats s.Size = atomic.LoadUint64(&f.stats.Size) s.Items = atomic.LoadUint64(&f.stats.Items) return s } func (f *fileSystemCache) Get(key *Key) (*CachedData, error) { fp := key.filePath(f.dir) file, err := os.Open(fp) if err != nil { return nil, ErrMissing } // the file will be closed once it's read as an io.ReaderCloser // This ReaderCloser is stored in the returned CachedData fi, err := file.Stat() if err != nil { return nil, fmt.Errorf("cache %q: cannot stat %q: %w", f.Name(), fp, err) } mt := fi.ModTime() age := time.Since(mt) if age > f.expire { // check if file exceeded expiration time + grace time if age > f.expire+f.grace { file.Close() return nil, ErrMissing } // Serve expired file in the hope it will be substituted // with the fresh file during deadline. } if err != nil { file.Close() return nil, fmt.Errorf("failed to read file content from %q: %w", f.Name(), err) } metadata, err := decodeHeader(file) if err != nil { return nil, err } value := &CachedData{ ContentMetadata: *metadata, Data: file, Ttl: f.expire - age, } return value, nil } // decodeHeader decodes header from raw byte stream. Data is encoded as follows: // length(contentType)|contentType|length(contentEncoding)|contentEncoding|length(contentLength)|contentLength|cachedData func decodeHeader(reader io.Reader) (*ContentMetadata, error) { contentType, err := readHeader(reader) if err != nil { return nil, fmt.Errorf("cannot read Content-Type from provided reader: %w", err) } contentEncoding, err := readHeader(reader) if err != nil { return nil, fmt.Errorf("cannot read Content-Encoding from provided reader: %w", err) } contentLengthStr, err := readHeader(reader) if err != nil { return nil, fmt.Errorf("cannot read Content-Encoding from provided reader: %w", err) } contentLength, err := strconv.Atoi(contentLengthStr) if err != nil { log.Errorf("found corrupted content length %s", err) contentLength = 0 } return &ContentMetadata{ Length: int64(contentLength), Type: contentType, Encoding: contentEncoding, }, nil } func (f *fileSystemCache) Put(r io.Reader, contentMetadata ContentMetadata, key *Key) (time.Duration, error) { fp := key.filePath(f.dir) file, err := os.Create(fp) if err != nil { return 0, fmt.Errorf("cache %q: cannot create file: %s : %w", f.Name(), key, err) } if err := writeHeader(file, contentMetadata.Type); err != nil { fn := file.Name() return 0, fmt.Errorf("cannot write Content-Type to %q: %w", fn, err) } if err := writeHeader(file, contentMetadata.Encoding); err != nil { fn := file.Name() return 0, fmt.Errorf("cannot write Content-Encoding to %q: %w", fn, err) } if err := writeHeader(file, fmt.Sprintf("%d", contentMetadata.Length)); err != nil { fn := file.Name() return 0, fmt.Errorf("cannot write Content-Encoding to %q: %w", fn, err) } cnt, err := io.Copy(file, r) if err != nil { return 0, fmt.Errorf("cache %q: cannot write results to file: %s : %w", f.Name(), key, err) } atomic.AddUint64(&f.stats.Size, uint64(cnt)) atomic.AddUint64(&f.stats.Items, 1) return f.expire, nil } func (f *fileSystemCache) cleaner() { d := f.expire / 2 if d < time.Minute { d = time.Minute } if d > time.Hour { d = time.Hour } forceCleanCh := time.After(d) f.clean() for { select { case <-time.After(time.Second): // Clean cache only on cache size overflow. stats := f.Stats() if stats.Size > f.maxSize { f.clean() } case <-forceCleanCh: // Forcibly clean cache from expired items. f.clean() forceCleanCh = time.After(d) case <-f.stopCh: return } } } func (f *fileSystemCache) fileInfoPath(fi os.FileInfo) string { return filepath.Join(f.dir, fi.Name()) } func (f *fileSystemCache) clean() { currentTime := time.Now() log.Debugf("cache %q: start cleaning dir %q", f.Name(), f.dir) // Remove cached files after a deadline from their expiration, // so they may be served until they are substituted with fresh files. expire := f.expire + f.grace // Calculate total cache size and remove expired files. var totalSize uint64 var totalItems uint64 var removedSize uint64 var removedItems uint64 err := walkDir(f.dir, func(fi os.FileInfo) { mt := fi.ModTime() fs := uint64(fi.Size()) if currentTime.Sub(mt) > expire { fn := f.fileInfoPath(fi) err := os.Remove(fn) if err == nil { removedSize += fs removedItems++ return } log.Errorf("cache %q: cannot remove file %q: %s", f.Name(), fn, err) // Return skipped intentionally. } totalSize += fs totalItems++ }) if err != nil { log.Errorf("cache %q: %s", f.Name(), err) return } loopsCount := 0 // Use dedicated random generator instead of global one from math/rand, // since the global generator is slow due to locking. // // Seed the generator with the current time in order to randomize // set of files to be removed below. // nolint:gosec // not security sensitve, only used internally. rnd := rand.New(rand.NewSource(time.Now().UnixNano())) for totalSize > f.maxSize && loopsCount < 3 { // Remove some files in order to reduce cache size. excessSize := totalSize - f.maxSize p := int32(float64(excessSize) / float64(totalSize) * 100) // Remove +10% over totalSize. p += 10 err := walkDir(f.dir, func(fi os.FileInfo) { if rnd.Int31n(100) > p { return } fs := uint64(fi.Size()) fn := f.fileInfoPath(fi) if err := os.Remove(fn); err != nil { log.Errorf("cache %q: cannot remove file %q: %s", f.Name(), fn, err) return } removedSize += fs removedItems++ totalSize -= fs totalItems-- }) if err != nil { log.Errorf("cache %q: %s", f.Name(), err) return } // This should protect from infinite loop. loopsCount++ } atomic.StoreUint64(&f.stats.Size, totalSize) atomic.StoreUint64(&f.stats.Items, totalItems) log.Debugf("cache %q: final size %d; final items %d; removed size %d; removed items %d", f.Name(), totalSize, totalItems, removedSize, removedItems) log.Debugf("cache %q: finish cleaning dir %q", f.Name(), f.dir) } // writeHeader encodes headers in little endian func writeHeader(w io.Writer, s string) error { n := uint32(len(s)) b := make([]byte, 0, n+4) b = append(b, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) b = append(b, s...) _, err := w.Write(b) return err } // readHeader decodes headers to big endian func readHeader(r io.Reader) (string, error) { b := make([]byte, 4) if _, err := io.ReadFull(r, b); err != nil { return "", fmt.Errorf("cannot read header length: %w", err) } n := uint32(b[3]) | (uint32(b[2]) << 8) | (uint32(b[1]) << 16) | (uint32(b[0]) << 24) s := make([]byte, n) if _, err := io.ReadFull(r, s); err != nil { return "", fmt.Errorf("cannot read header value with length %d: %w", n, err) } return string(s), nil }
package cache import ( "crypto/sha256" "encoding/hex" "fmt" "net/url" "path/filepath" ) // Version must be increased with each backward-incompatible change // in the cache storage. const Version = 5 // Key is the key for use in the cache. type Key struct { // Query must contain full request query. Query []byte // AcceptEncoding must contain 'Accept-Encoding' request header value. AcceptEncoding string // DefaultFormat must contain `default_format` query arg. DefaultFormat string // Database must contain `database` query arg. Database string // Compress must contain `compress` query arg. Compress string // EnableHTTPCompression must contain `enable_http_compression` query arg. EnableHTTPCompression string // Namespace is an optional cache namespace. Namespace string // MaxResultRows must contain `max_result_rows` query arg MaxResultRows string // Extremes must contain `extremes` query arg Extremes string // ResultOverflowMode must contain `result_overflow_mode` query arg ResultOverflowMode string // UserParamsHash must contain hashed value of users params UserParamsHash uint32 // Version represents data encoding version number Version int // QueryParamsHash must contain hashed value of query params QueryParamsHash uint32 // UserCredentialHash must contain hashed value of username & password UserCredentialHash uint32 } // NewKey construct cache key from provided parameters with default version number func NewKey(query []byte, originParams url.Values, acceptEncoding string, userParamsHash uint32, queryParamsHash uint32, userCredentialHash uint32) *Key { return &Key{ Query: query, AcceptEncoding: acceptEncoding, DefaultFormat: originParams.Get("default_format"), Database: originParams.Get("database"), Compress: originParams.Get("compress"), EnableHTTPCompression: originParams.Get("enable_http_compression"), Namespace: originParams.Get("cache_namespace"), Extremes: originParams.Get("extremes"), MaxResultRows: originParams.Get("max_result_rows"), ResultOverflowMode: originParams.Get("result_overflow_mode"), UserParamsHash: userParamsHash, Version: Version, QueryParamsHash: queryParamsHash, UserCredentialHash: userCredentialHash, } } func (k *Key) filePath(dir string) string { return filepath.Join(dir, k.String()) } // String returns string representation of the key. func (k *Key) String() string { s := fmt.Sprintf("V%d; Query=%q; AcceptEncoding=%q; DefaultFormat=%q; Database=%q; Compress=%q; EnableHTTPCompression=%q; Namespace=%q; MaxResultRows=%q; Extremes=%q; ResultOverflowMode=%q; UserParams=%d; QueryParams=%d; UserCredentialHash=%d", k.Version, k.Query, k.AcceptEncoding, k.DefaultFormat, k.Database, k.Compress, k.EnableHTTPCompression, k.Namespace, k.MaxResultRows, k.Extremes, k.ResultOverflowMode, k.UserParamsHash, k.QueryParamsHash, k.UserCredentialHash) h := sha256.Sum256([]byte(s)) // The first 16 bytes of the hash should be enough // for collision prevention :) return hex.EncodeToString(h[:16]) }
package cache import ( "bytes" "context" "errors" "fmt" "io" "math/rand" "os" "regexp" "strconv" "time" "github.com/contentsquare/chproxy/config" "github.com/contentsquare/chproxy/log" "github.com/redis/go-redis/v9" ) type redisCache struct { name string client redis.UniversalClient expire time.Duration } const getTimeout = 2 * time.Second const removeTimeout = 1 * time.Second const renameTimeout = 1 * time.Second const putTimeout = 2 * time.Second const statsTimeout = 500 * time.Millisecond // this variable is key to select whether the result should be streamed // from redis to the http response or if chproxy should first put the // result from redis in a temporary files before sending it to the http response const minTTLForRedisStreamingReader = 15 * time.Second // tmpDir temporary path to store ongoing queries results const tmpDir = "/tmp" const redisTmpFilePrefix = "chproxyRedisTmp" func newRedisCache(client redis.UniversalClient, cfg config.Cache) *redisCache { redisCache := &redisCache{ name: cfg.Name, expire: time.Duration(cfg.Expire), client: client, } return redisCache } func (r *redisCache) Close() error { return r.client.Close() } var usedMemoryRegexp = regexp.MustCompile(`used_memory:([0-9]+)\r\n`) // Stats will make two calls to redis. // First one fetches the number of keys stored in redis (DBSize) // Second one will fetch memory info that will be parsed to fetch the used_memory // NOTE : we can only fetch database size, not cache size func (r *redisCache) Stats() Stats { return Stats{ Items: r.nbOfKeys(), Size: r.nbOfBytes(), } } func (r *redisCache) nbOfKeys() uint64 { ctx, cancelFunc := context.WithTimeout(context.Background(), statsTimeout) defer cancelFunc() nbOfKeys, err := r.client.DBSize(ctx).Result() if err != nil { log.Errorf("failed to fetch nb of keys in redis: %s", err) } return uint64(nbOfKeys) } func (r *redisCache) nbOfBytes() uint64 { ctx, cancelFunc := context.WithTimeout(context.Background(), statsTimeout) defer cancelFunc() memoryInfo, err := r.client.Info(ctx, "memory").Result() if err != nil { log.Errorf("failed to fetch nb of bytes in redis: %s", err) } matches := usedMemoryRegexp.FindStringSubmatch(memoryInfo) var cacheSize int if len(matches) > 1 { cacheSize, err = strconv.Atoi(matches[1]) if err != nil { log.Errorf("failed to parse memory usage with error %s", err) } } return uint64(cacheSize) } func (r *redisCache) Get(key *Key) (*CachedData, error) { ctx, cancelFunc := context.WithTimeout(context.Background(), getTimeout) defer cancelFunc() nbBytesToFetch := int64(100 * 1024) stringKey := key.String() // fetching 100kBytes from redis to be sure to have the full metadata and, // for most of the queries that fetch a few data, the cached results val, err := r.client.GetRange(ctx, stringKey, 0, nbBytesToFetch).Result() // errors, such as timeouts if err != nil { log.Errorf("failed to get key %s with error: %s", stringKey, err) return nil, ErrMissing } // if key not found in cache if len(val) == 0 { return nil, ErrMissing } ttl, err := r.client.TTL(ctx, stringKey).Result() if err != nil { return nil, fmt.Errorf("failed to ttl of key %s with error: %w", stringKey, err) } b := []byte(val) metadata, offset, err := r.decodeMetadata(b) if err != nil { if errors.Is(err, &RedisCacheCorruptionError{}) { log.Errorf("an error happened while handling redis key =%s, err=%s", stringKey, err) } return nil, err } if (int64(offset) + metadata.Length) < nbBytesToFetch { // the condition is true ony if the bytes fetched contain the metadata + the cached results // so we extract from the remaining bytes the cached results payload := b[offset:] reader := &ioReaderDecorator{Reader: bytes.NewReader(payload)} value := &CachedData{ ContentMetadata: *metadata, Data: reader, Ttl: ttl, } return value, nil } return r.readResultsAboveLimit(offset, stringKey, metadata, ttl) } func (r *redisCache) readResultsAboveLimit(offset int, stringKey string, metadata *ContentMetadata, ttl time.Duration) (*CachedData, error) { // since the cached results in redis are too big, we can't fetch all of them because of the memory overhead. // We will create an io.reader that will fetch redis bulk by bulk to reduce the memory usage. redisStreamreader := newRedisStreamReader(uint64(offset), r.client, stringKey, metadata.Length) // But before that, since the usage of the reader could take time and the object in redis could disappear btw 2 fetches // we need to make sure the TTL will be long enough to avoid nasty side effects // if the TTL is too short we will put all the data into a file and use it as a streamer // nb: it would be better to retry the flow if such a failure happened but this requires a huge refactoring of proxy.go if ttl <= minTTLForRedisStreamingReader { fileStream, err := newFileWriterReader(tmpDir) if err != nil { return nil, err } _, err = io.Copy(fileStream, redisStreamreader) if err != nil { return nil, err } err = fileStream.resetOffset() if err != nil { return nil, err } value := &CachedData{ ContentMetadata: *metadata, Data: fileStream, Ttl: ttl, } return value, nil } value := &CachedData{ ContentMetadata: *metadata, Data: &ioReaderDecorator{Reader: redisStreamreader}, Ttl: ttl, } return value, nil } // this struct is here because CachedData requires an io.ReadCloser // but logic in the Get function generates only an io.Reader type ioReaderDecorator struct { io.Reader } func (m ioReaderDecorator) Close() error { return nil } func (r *redisCache) encodeString(s string) []byte { n := uint32(len(s)) b := make([]byte, 0, n+4) b = append(b, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) b = append(b, s...) return b } func (r *redisCache) decodeString(bytes []byte) (string, int, error) { if len(bytes) < 4 { return "", 0, &RedisCacheCorruptionError{} } b := bytes[:4] n := uint32(b[3]) | (uint32(b[2]) << 8) | (uint32(b[1]) << 16) | (uint32(b[0]) << 24) if len(bytes) < int(4+n) { return "", 0, &RedisCacheCorruptionError{} } s := bytes[4 : 4+n] return string(s), int(4 + n), nil } func (r *redisCache) encodeMetadata(contentMetadata *ContentMetadata) []byte { cLength := contentMetadata.Length cType := r.encodeString(contentMetadata.Type) cEncoding := r.encodeString(contentMetadata.Encoding) b := make([]byte, 0, len(cEncoding)+len(cType)+8) b = append(b, byte(cLength>>56), byte(cLength>>48), byte(cLength>>40), byte(cLength>>32), byte(cLength>>24), byte(cLength>>16), byte(cLength>>8), byte(cLength)) b = append(b, cType...) b = append(b, cEncoding...) return b } func (r *redisCache) decodeMetadata(b []byte) (*ContentMetadata, int, error) { if len(b) < 8 { return nil, 0, &RedisCacheCorruptionError{} } cLength := uint64(b[7]) | (uint64(b[6]) << 8) | (uint64(b[5]) << 16) | (uint64(b[4]) << 24) | uint64(b[3])<<32 | (uint64(b[2]) << 40) | (uint64(b[1]) << 48) | (uint64(b[0]) << 56) offset := 8 cType, sizeCType, err := r.decodeString(b[offset:]) if err != nil { return nil, 0, err } offset += sizeCType cEncoding, sizeCEncoding, err := r.decodeString(b[offset:]) if err != nil { return nil, 0, err } offset += sizeCEncoding metadata := &ContentMetadata{ Length: int64(cLength), Type: cType, Encoding: cEncoding, } return metadata, offset, nil } func (r *redisCache) Put(reader io.Reader, contentMetadata ContentMetadata, key *Key) (time.Duration, error) { medatadata := r.encodeMetadata(&contentMetadata) stringKey := key.String() // in order to make the streaming operation atomic, chproxy streams into a temporary key (only known by the current goroutine) // then it switches the full result to the "real" stringKey available for other goroutines // nolint:gosec // not security sensitve, only used internally. random := strconv.Itoa(rand.Int()) // Redis RENAME is considered to be a multikey operation. In Cluster mode, both oldkey and renamedkey must be in the same hash slot, // Refer Redis Documentation here: https://redis.io/commands/rename/ // To solve this,we need to force the temporary key to be in the same hash slot. We can do this by adding hashtag to the // actual part of the temporary key. When the key contains a "{...}" pattern, only the substring between the braces, "{" and "}," // is hashed to obtain the hash slot. // Refer the hash tags section of Redis documentation here: https://redis.io/docs/reference/cluster-spec/#hash-tags stringKeyTmp := "{" + stringKey + "}" + random + "_tmp" ctxSet, cancelFuncSet := context.WithTimeout(context.Background(), putTimeout) defer cancelFuncSet() err := r.client.Set(ctxSet, stringKeyTmp, medatadata, r.expire).Err() if err != nil { return 0, err } // we don't fetch all the reader content bulks by bulks to from redis to avoid memory issue // if the content is big (which is the case when chproxy users are fetching a lot of data) buffer := make([]byte, 2*1024*1024) totalByteWrittenExpected := len(medatadata) for { n, err := reader.Read(buffer) // the reader should return an err = io.EOF once it has nothing to read or at the last read call with content. // But this is not the case with this reader so we check the condition n == 0 to exit the read loop. // We kept the err == io.EOF in the loop in case the behavior of the reader changes if n == 0 { break } if err != nil && !errors.Is(err, io.EOF) { return 0, err } ctxAppend, cancelFuncAppend := context.WithTimeout(context.Background(), putTimeout) defer cancelFuncAppend() totalByteWritten, err := r.client.Append(ctxAppend, stringKeyTmp, string(buffer[:n])).Result() if err != nil { // trying to clean redis from this partially inserted item r.clean(stringKeyTmp) return 0, err } totalByteWrittenExpected += n if int(totalByteWritten) != totalByteWrittenExpected { // trying to clean redis from this partially inserted item r.clean(stringKeyTmp) return 0, fmt.Errorf("could not stream the value into redis, only %d bytes were written instead of %d", totalByteWritten, totalByteWrittenExpected) } if errors.Is(err, io.EOF) { break } } // at this step we know that the item stored in stringKeyTmp is fully written // so we can put it to its final stringKey ctxRename, cancelFuncRename := context.WithTimeout(context.Background(), renameTimeout) defer cancelFuncRename() r.client.Rename(ctxRename, stringKeyTmp, stringKey) return r.expire, nil } func (r *redisCache) clean(stringKey string) { delCtx, cancelFunc := context.WithTimeout(context.Background(), removeTimeout) defer cancelFunc() delErr := r.client.Del(delCtx, stringKey).Err() if delErr != nil { log.Debugf("redis item was only partially inserted and chproxy couldn't remove the partial result because of %s", delErr) } else { log.Debugf("redis item was only partially inserted, chproxy was able to remove it") } } func (r *redisCache) Name() string { return r.name } type redisStreamReader struct { isRedisEOF bool redisOffset uint64 // the redisOffset that gives the beginning of the next bulk to fetch key string // the key of the value we want to stream from redis buffer []byte // internal buffer to store the bulks fetched from redis bufferOffset int // the offset of the buffer that keep were the read() need to start copying data client redis.UniversalClient // the redis client expectedPayloadSize int // the size of the object the streamer is supposed to read. readPayloadSize int // the size of the object currently written by the reader } func newRedisStreamReader(offset uint64, client redis.UniversalClient, key string, payloadSize int64) *redisStreamReader { bufferSize := uint64(2 * 1024 * 1024) return &redisStreamReader{ isRedisEOF: false, redisOffset: offset, key: key, bufferOffset: int(bufferSize), buffer: make([]byte, bufferSize), client: client, expectedPayloadSize: int(payloadSize), } } func (r *redisStreamReader) Read(destBuf []byte) (n int, err error) { // the logic is simple: // 1) if the buffer still has data to write, it writes it into destBuf without overflowing destBuf // 2) if the buffer only has already written data, the StreamRedis refresh the buffer with new data from redis // 3) if the buffer only has already written data & redis has no more data to read then StreamRedis sends an EOF err bufSize := len(r.buffer) bytesWritten := 0 // case 3) both the buffer & redis were fully consumed, we can tell the reader to stop reading if r.bufferOffset >= bufSize && r.isRedisEOF { // Because of the way we fetch from redis, we need to do an extra check because we have no way // to know if redis is really EOF or if the value was expired from cache while reading it if r.readPayloadSize != r.expectedPayloadSize { log.Debugf("error while fetching data from redis payload size doesn't match") return 0, &RedisCacheError{key: r.key, readPayloadSize: r.readPayloadSize, expectedPayloadSize: r.expectedPayloadSize} } return 0, io.EOF } // case 2) the buffer only has already written data, we need to refresh it with redis datas if r.bufferOffset >= bufSize { if err := r.readRangeFromRedis(bufSize); err != nil { return bytesWritten, err } } // case 1) the buffer contains data to write into destBuf if r.bufferOffset < bufSize { bytesWritten = copy(destBuf, r.buffer[r.bufferOffset:]) r.bufferOffset += bytesWritten r.readPayloadSize += bytesWritten } return bytesWritten, nil } func (r *redisStreamReader) readRangeFromRedis(bufSize int) error { ctx, cancelFunc := context.WithTimeout(context.Background(), getTimeout) defer cancelFunc() newBuf, err := r.client.GetRange(ctx, r.key, int64(r.redisOffset), int64(r.redisOffset+uint64(bufSize))).Result() r.redisOffset += uint64(len(newBuf)) if errors.Is(err, redis.Nil) || len(newBuf) == 0 { r.isRedisEOF = true } // if redis gave less data than asked it means that it reached the end of the value if len(newBuf) < bufSize { r.isRedisEOF = true } // others errors, such as timeouts if err != nil && !errors.Is(err, redis.Nil) { log.Debugf("failed to get key %s with error: %s", r.key, err) err2 := &RedisCacheError{key: r.key, readPayloadSize: r.readPayloadSize, expectedPayloadSize: r.expectedPayloadSize, rootcause: err} return err2 } r.bufferOffset = 0 r.buffer = []byte(newBuf) return nil } type fileWriterReader struct { f *os.File } func newFileWriterReader(dir string) (*fileWriterReader, error) { f, err := os.CreateTemp(dir, redisTmpFilePrefix) if err != nil { return nil, fmt.Errorf("cannot create temporary file in %q: %w", dir, err) } return &fileWriterReader{ f: f, }, nil } func (r *fileWriterReader) Close() error { err := r.f.Close() if err != nil { return err } return os.Remove(r.f.Name()) } func (r *fileWriterReader) Read(destBuf []byte) (n int, err error) { return r.f.Read(destBuf) } func (r *fileWriterReader) Write(p []byte) (n int, err error) { return r.f.Write(p) } func (r *fileWriterReader) resetOffset() error { if _, err := r.f.Seek(0, io.SeekStart); err != nil { return fmt.Errorf("cannot reset offset in: %w", err) } return nil } type RedisCacheError struct { key string readPayloadSize int expectedPayloadSize int rootcause error } func (e *RedisCacheError) Error() string { errorMsg := fmt.Sprintf("error while reading cached result in redis for key %s, only %d bytes of %d were fetched", e.key, e.readPayloadSize, e.expectedPayloadSize) if e.rootcause != nil { errorMsg = fmt.Sprintf("%s, root cause:%s", errorMsg, e.rootcause) } return errorMsg } type RedisCacheCorruptionError struct { } func (e *RedisCacheCorruptionError) Error() string { return "chproxy can't decode the cached result from redis, it seems to have been corrupted" }
package cache import ( "bufio" "fmt" "io" "net/http" "os" "github.com/contentsquare/chproxy/log" ) // TmpFileResponseWriter caches Clickhouse response. // the http header are kept in memory type TmpFileResponseWriter struct { http.ResponseWriter // the original response writer contentLength int64 contentType string contentEncoding string headersCaptured bool statusCode int tmpFile *os.File // temporary file for response streaming bw *bufio.Writer // buffered writer for the temporary file } func NewTmpFileResponseWriter(rw http.ResponseWriter, dir string) (*TmpFileResponseWriter, error) { _, ok := rw.(http.CloseNotifier) if !ok { return nil, fmt.Errorf("the response writer does not implement http.CloseNotifier") } f, err := os.CreateTemp(dir, "tmp") if err != nil { return nil, fmt.Errorf("cannot create temporary file in %q: %w", dir, err) } return &TmpFileResponseWriter{ ResponseWriter: rw, tmpFile: f, bw: bufio.NewWriter(f), }, nil } func (rw *TmpFileResponseWriter) Close() error { rw.tmpFile.Close() return os.Remove(rw.tmpFile.Name()) } func (rw *TmpFileResponseWriter) GetFile() (*os.File, error) { if err := rw.bw.Flush(); err != nil { fn := rw.tmpFile.Name() errTmp := rw.tmpFile.Close() if errTmp != nil { log.Errorf("cannot close tmpFile: %s, error: %s", fn, errTmp) } errTmp = os.Remove(fn) if errTmp != nil { log.Errorf("cannot remove tmpFile: %s, error: %s", fn, errTmp) } return nil, fmt.Errorf("cannot flush data into %q: %w", fn, err) } return rw.tmpFile, nil } func (rw *TmpFileResponseWriter) Reader() (io.Reader, error) { f, err := rw.GetFile() if err != nil { return nil, fmt.Errorf("cannot open tmp file: %w", err) } return f, nil } func (rw *TmpFileResponseWriter) ResetFileOffset() error { data, err := rw.GetFile() if err != nil { return err } if _, err := data.Seek(0, io.SeekStart); err != nil { return fmt.Errorf("cannot reset offset in: %w", err) } return nil } func (rw *TmpFileResponseWriter) captureHeaders() error { if rw.headersCaptured { return nil } rw.headersCaptured = true h := rw.Header() ct := h.Get("Content-Type") ce := h.Get("Content-Encoding") rw.contentEncoding = ce rw.contentType = ct // nb: the Content-Length http header is not set by CH so we can't get it return nil } func (rw *TmpFileResponseWriter) GetCapturedContentType() string { return rw.contentType } func (rw *TmpFileResponseWriter) GetCapturedContentLength() (int64, error) { if rw.contentLength == 0 { // Determine Content-Length looking at the file data, err := rw.GetFile() if err != nil { return 0, fmt.Errorf("GetCapturedContentLength: cannot open tmp file: %w", err) } end, err := data.Seek(0, io.SeekEnd) if err != nil { return 0, fmt.Errorf("GetCapturedContentLength: cannot determine the last position in: %w", err) } if err := rw.ResetFileOffset(); err != nil { return 0, err } return end - 0, nil } return rw.contentLength, nil } func (rw *TmpFileResponseWriter) GetCapturedContentEncoding() string { return rw.contentEncoding } // CloseNotify implements http.CloseNotifier func (rw *TmpFileResponseWriter) CloseNotify() <-chan bool { // nolint:forcetypeassert // it is guaranteed by NewTmpFileResponseWriter // The rw.FSResponseWriter must implement http.CloseNotifier. return rw.ResponseWriter.(http.CloseNotifier).CloseNotify() } // WriteHeader captures response status code. func (rw *TmpFileResponseWriter) WriteHeader(statusCode int) { rw.statusCode = statusCode // Do not call rw.ClickhouseResponseWriter.WriteHeader here // It will be called explicitly in Finalize / Unregister. } // StatusCode returns captured status code from WriteHeader. func (rw *TmpFileResponseWriter) StatusCode() int { if rw.statusCode == 0 { return http.StatusOK } return rw.statusCode } // Write writes b into rw. func (rw *TmpFileResponseWriter) Write(b []byte) (int, error) { if err := rw.captureHeaders(); err != nil { return 0, err } return rw.bw.Write(b) }
package cache import ( "io" "time" ) // TransactionRegistry is a registry of ongoing queries identified by Key. type TransactionRegistry interface { io.Closer // Create creates a new transaction record Create(key *Key) error // Complete completes a transaction for given key Complete(key *Key) error // Fail fails a transaction for given key Fail(key *Key, reason string) error // Status checks the status of the transaction Status(key *Key) (TransactionStatus, error) } // transactionEndedTTL amount of time transaction record is kept after being updated const transactionEndedTTL = 500 * time.Millisecond type TransactionStatus struct { State TransactionState FailReason string // filled in only if state of transaction is transactionFailed } type TransactionState uint8 const ( transactionCreated TransactionState = 0 transactionCompleted TransactionState = 1 transactionFailed TransactionState = 2 transactionAbsent TransactionState = 3 ) func (t *TransactionState) IsAbsent() bool { if t != nil { return *t == transactionAbsent } return false } func (t *TransactionState) IsFailed() bool { if t != nil { return *t == transactionFailed } return false } func (t *TransactionState) IsCompleted() bool { if t != nil { return *t == transactionCompleted } return false } func (t *TransactionState) IsPending() bool { if t != nil { return *t == transactionCreated } return false }
package cache import ( "sync" "time" "github.com/contentsquare/chproxy/log" ) type pendingEntry struct { deadline time.Time state TransactionState failedReason string } type inMemoryTransactionRegistry struct { pendingEntriesLock sync.Mutex pendingEntries map[string]pendingEntry deadline time.Duration transactionEndedDeadline time.Duration stopCh chan struct{} wg sync.WaitGroup } func newInMemoryTransactionRegistry(deadline, transactionEndedDeadline time.Duration) *inMemoryTransactionRegistry { transaction := &inMemoryTransactionRegistry{ pendingEntriesLock: sync.Mutex{}, pendingEntries: make(map[string]pendingEntry), deadline: deadline, transactionEndedDeadline: transactionEndedDeadline, stopCh: make(chan struct{}), } transaction.wg.Add(1) go func() { log.Debugf("inmem transaction: cleaner start") transaction.pendingEntriesCleaner() transaction.wg.Done() log.Debugf("inmem transaction: cleaner stop") }() return transaction } func (i *inMemoryTransactionRegistry) Create(key *Key) error { i.pendingEntriesLock.Lock() defer i.pendingEntriesLock.Unlock() k := key.String() _, exists := i.pendingEntries[k] if !exists { i.pendingEntries[k] = pendingEntry{ deadline: time.Now().Add(i.deadline), state: transactionCreated, } } return nil } func (i *inMemoryTransactionRegistry) Complete(key *Key) error { i.updateTransactionState(key, transactionCompleted, "") return nil } func (i *inMemoryTransactionRegistry) Fail(key *Key, reason string) error { i.updateTransactionState(key, transactionFailed, reason) return nil } func (i *inMemoryTransactionRegistry) updateTransactionState(key *Key, state TransactionState, failReason string) { i.pendingEntriesLock.Lock() defer i.pendingEntriesLock.Unlock() k := key.String() if entry, ok := i.pendingEntries[k]; ok { entry.state = state entry.failedReason = failReason entry.deadline = time.Now().Add(i.transactionEndedDeadline) i.pendingEntries[k] = entry } else { log.Errorf("[attempt to complete transaction] entry not found for key: %s, registering new entry with %v status", key.String(), state) i.pendingEntries[k] = pendingEntry{ deadline: time.Now().Add(i.transactionEndedDeadline), state: state, failedReason: failReason, } } } func (i *inMemoryTransactionRegistry) Status(key *Key) (TransactionStatus, error) { i.pendingEntriesLock.Lock() defer i.pendingEntriesLock.Unlock() k := key.String() if entry, ok := i.pendingEntries[k]; ok { return TransactionStatus{State: entry.state, FailReason: entry.failedReason}, nil } return TransactionStatus{State: transactionAbsent}, nil } func (i *inMemoryTransactionRegistry) Close() error { close(i.stopCh) i.wg.Wait() return nil } func (i *inMemoryTransactionRegistry) pendingEntriesCleaner() { d := i.deadline if d < 100*time.Millisecond { d = 100 * time.Millisecond } if d > time.Second { d = time.Second } for { currentTime := time.Now() // Clear outdated pending entries, since they may remain here // forever if unregisterPendingEntry call is missing. i.pendingEntriesLock.Lock() for path, pe := range i.pendingEntries { if currentTime.After(pe.deadline) { delete(i.pendingEntries, path) } } i.pendingEntriesLock.Unlock() select { case <-time.After(d): case <-i.stopCh: return } } }
package cache import ( "context" "errors" "time" "fmt" "github.com/contentsquare/chproxy/log" "github.com/redis/go-redis/v9" ) type redisTransactionRegistry struct { redisClient redis.UniversalClient // deadline specifies TTL of the record to be kept deadline time.Duration // transactionEndedDeadline specifies TTL of the record to be kept that has been ended (either completed or failed) transactionEndedDeadline time.Duration } func newRedisTransactionRegistry(redisClient redis.UniversalClient, deadline time.Duration, endedDeadline time.Duration) *redisTransactionRegistry { return &redisTransactionRegistry{ redisClient: redisClient, deadline: deadline, transactionEndedDeadline: endedDeadline, } } func (r *redisTransactionRegistry) Create(key *Key) error { return r.redisClient.Set(context.Background(), toTransactionKey(key), []byte{uint8(transactionCreated)}, r.deadline).Err() } func (r *redisTransactionRegistry) Complete(key *Key) error { return r.updateTransactionState(key, []byte{uint8(transactionCompleted)}) } func (r *redisTransactionRegistry) Fail(key *Key, reason string) error { b := make([]byte, 0, uint32(len(reason))+1) b = append(b, byte(transactionFailed)) b = append(b, []byte(reason)...) return r.updateTransactionState(key, b) } func (r *redisTransactionRegistry) updateTransactionState(key *Key, value []byte) error { return r.redisClient.Set(context.Background(), toTransactionKey(key), value, r.transactionEndedDeadline).Err() } func (r *redisTransactionRegistry) Status(key *Key) (TransactionStatus, error) { raw, err := r.redisClient.Get(context.Background(), toTransactionKey(key)).Bytes() if errors.Is(err, redis.Nil) { return TransactionStatus{State: transactionAbsent}, nil } if err != nil { log.Errorf("Failed to fetch transaction status from redis for key: %s", key.String()) return TransactionStatus{State: transactionAbsent}, err } if len(raw) == 0 { log.Errorf("Failed to fetch transaction status from redis raw value: %s", key.String()) return TransactionStatus{State: transactionAbsent}, err } state := TransactionState(uint8(raw[0])) var reason string if state.IsFailed() && len(raw) > 1 { reason = string(raw[1:]) } return TransactionStatus{State: state, FailReason: reason}, nil } func (r *redisTransactionRegistry) Close() error { return r.redisClient.Close() } func toTransactionKey(key *Key) string { return fmt.Sprintf("%s-transaction", key.String()) }
package cache import ( "fmt" "io" "os" ) // walkDir calls f on all the cache files in the given dir. func walkDir(dir string, f func(fi os.FileInfo)) error { // Do not use filepath.Walk, since it is inefficient // for large number of files. // See https://golang.org/pkg/path/filepath/#Walk . fd, err := os.Open(dir) if err != nil { return fmt.Errorf("cannot open %q: %w", dir, err) } defer fd.Close() for { fis, err := fd.Readdir(1024) if err != nil { if err == io.EOF { return nil } return fmt.Errorf("cannot read files in %q: %w", dir, err) } for _, fi := range fis { if fi.IsDir() { // Skip subdirectories continue } fn := fi.Name() if !cachefileRegexp.MatchString(fn) { // Skip invalid filenames continue } f(fi) } } }
package chdecompressor import ( "errors" "fmt" "io" "github.com/klauspost/compress/zstd" "github.com/pierrec/lz4" ) // Reader reads clickhouse compressed stream. // See https://github.com/yandex/ClickHouse/blob/ae8783aee3ef982b6eb7e1721dac4cb3ce73f0fe/dbms/src/IO/CompressedStream.h type Reader struct { src io.Reader data []byte scratch []byte } // NewReader returns new clickhouse compressed stream reader reading from src. func NewReader(src io.Reader) *Reader { return &Reader{ src: src, scratch: make([]byte, 16), } } // Read reads up to len(buf) bytes from clickhouse compressed stream. func (r *Reader) Read(buf []byte) (int, error) { // exhaust remaining data from previous Read() if len(r.data) == 0 { if err := r.readNextBlock(); err != nil { return 0, err } } n := copy(buf, r.data) r.data = r.data[n:] return n, nil } func (r *Reader) readNextBlock() error { // Skip checksum if _, err := io.ReadFull(r.src, r.scratch[:16]); err != nil { if errors.Is(err, io.EOF) { return io.EOF } return fmt.Errorf("cannot read checksum: %w", err) } // Read compression type if _, err := io.ReadFull(r.src, r.scratch[:1]); err != nil { return fmt.Errorf("cannot read compression type: %w", err) } compressionType := r.scratch[0] // Read compressed size compressedSize, err := r.readUint32() if err != nil { return fmt.Errorf("cannot read compressed size: %w", err) } compressedSize -= 9 // minus header length // Read decompressed size decompressedSize, err := r.readUint32() if err != nil { return fmt.Errorf("cannot read decompressed size: %w", err) } // Read compressed block block := make([]byte, compressedSize) if _, err = io.ReadFull(r.src, block); err != nil { return fmt.Errorf("cannot read compressed block: %w", err) } // Decompress block if err := r.decompressBlock(block, compressionType, decompressedSize); err != nil { return err } return nil } func (r *Reader) decompressBlock(block []byte, compressionType byte, decompressedSize uint32) error { var err error r.data = make([]byte, decompressedSize) var decoder, _ = zstd.NewReader(nil) switch compressionType { case noneType: r.data = block case lz4Type: if _, err := lz4.UncompressBlock(block, r.data); err != nil { return fmt.Errorf("cannot decompress lz4 block: %w", err) } case zstdType: r.data = r.data[:0] // Wipe the slice but keep allocated memory r.data, err = decoder.DecodeAll(block, r.data) if err != nil { return fmt.Errorf("cannot decompress zstd block: %w", err) } default: return fmt.Errorf("unknown compressionType: %X", compressionType) } return nil } const ( noneType = 0x02 lz4Type = 0x82 zstdType = 0x90 ) func (r *Reader) readUint32() (uint32, error) { b := r.scratch[:4] _, err := io.ReadFull(r.src, b) if err != nil { return 0, err } n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) return n, nil }
package clients import ( "context" "fmt" "github.com/contentsquare/chproxy/config" "github.com/redis/go-redis/v9" ) func NewRedisClient(cfg config.RedisCacheConfig) (redis.UniversalClient, error) { options := &redis.UniversalOptions{ Addrs: cfg.Addresses, Username: cfg.Username, Password: cfg.Password, MaxRetries: 7, // default value = 3, since MinRetryBackoff = 8 msec & MinRetryBackoff = 512 msec // the redis client will wait up to 1016 msec btw the 7 tries } if len(cfg.Addresses) == 1 { options.DB = cfg.DBIndex } if len(cfg.CertFile) != 0 || len(cfg.KeyFile) != 0 { tlsConfig, err := cfg.TLS.BuildTLSConfig(nil) if err != nil { return nil, err } options.TLSConfig = tlsConfig } r := redis.NewUniversalClient(options) err := r.Ping(context.Background()).Err() if err != nil { return nil, fmt.Errorf("failed to reach redis: %w", err) } return r, nil }
package config import ( "bytes" "crypto/tls" "fmt" "os" "regexp" "strings" "time" "github.com/mohae/deepcopy" "golang.org/x/crypto/acme/autocert" "gopkg.in/yaml.v2" ) var ( defaultConfig = Config{ Clusters: []Cluster{defaultCluster}, } defaultCluster = Cluster{ Scheme: "http", ClusterUsers: []ClusterUser{defaultClusterUser}, HeartBeat: defaultHeartBeat, RetryNumber: defaultRetryNumber, } defaultClusterUser = ClusterUser{ Name: "default", } defaultHeartBeat = HeartBeat{ Interval: Duration(time.Second * 5), Timeout: Duration(time.Second * 3), Request: "/ping", Response: "Ok.\n", User: "", Password: "", } defaultConnectionPool = ConnectionPool{ MaxIdleConns: 100, MaxIdleConnsPerHost: 2, } defaultExecutionTime = Duration(120 * time.Second) defaultMaxPayloadSize = ByteSize(1 << 50) defaultRetryNumber = 0 ) // Config describes server configuration, access and proxy rules type Config struct { Server Server `yaml:"server,omitempty"` Clusters []Cluster `yaml:"clusters"` Users []User `yaml:"users"` // Whether to print debug logs LogDebug bool `yaml:"log_debug,omitempty"` // Whether to ignore security warnings HackMePlease bool `yaml:"hack_me_please,omitempty"` NetworkGroups []NetworkGroups `yaml:"network_groups,omitempty"` Caches []Cache `yaml:"caches,omitempty"` ParamGroups []ParamGroup `yaml:"param_groups,omitempty"` ConnectionPool ConnectionPool `yaml:"connection_pool,omitempty"` networkReg map[string]Networks // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } // String implements the Stringer interface func (c *Config) String() string { b, err := yaml.Marshal(withoutSensitiveInfo(c)) if err != nil { panic(err) } return string(b) } func withoutSensitiveInfo(config *Config) *Config { const pswPlaceHolder = "XXX" // nolint: forcetypeassert // no need to check type, it is specified by function. c := deepcopy.Copy(config).(*Config) for i := range c.Users { c.Users[i].Password = pswPlaceHolder } for i := range c.Clusters { if len(c.Clusters[i].KillQueryUser.Name) > 0 { c.Clusters[i].KillQueryUser.Password = pswPlaceHolder } for j := range c.Clusters[i].ClusterUsers { c.Clusters[i].ClusterUsers[j].Password = pswPlaceHolder } } for i := range c.Caches { if len(c.Caches[i].Redis.Username) > 0 { c.Caches[i].Redis.Password = pswPlaceHolder } } return c } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { // set c to the defaults and then overwrite it with the input. *c = defaultConfig type plain Config if err := unmarshal((*plain)(c)); err != nil { return err } if err := c.validate(); err != nil { return err } return checkOverflow(c.XXX, "config") } func (c *Config) validate() error { if len(c.Users) == 0 { return fmt.Errorf("`users` must contain at least 1 user") } if len(c.Clusters) == 0 { return fmt.Errorf("`clusters` must contain at least 1 cluster") } if len(c.Server.HTTP.ListenAddr) == 0 && len(c.Server.HTTPS.ListenAddr) == 0 { return fmt.Errorf("neither HTTP nor HTTPS not configured") } if len(c.Server.HTTPS.ListenAddr) > 0 { if len(c.Server.HTTPS.Autocert.CacheDir) == 0 && len(c.Server.HTTPS.CertFile) == 0 && len(c.Server.HTTPS.KeyFile) == 0 { return fmt.Errorf("configuration `https` is missing. " + "Must be specified `https.cache_dir` for autocert " + "OR `https.key_file` and `https.cert_file` for already existing certs") } if len(c.Server.HTTPS.Autocert.CacheDir) > 0 { c.Server.HTTP.ForceAutocertHandler = true } } return nil } func (cfg *Config) setDefaults() error { var maxResponseTime time.Duration var err error for i := range cfg.Clusters { c := &cfg.Clusters[i] for j := range c.ClusterUsers { u := &c.ClusterUsers[j] cud := time.Duration(u.MaxExecutionTime + u.MaxQueueTime) if cud > maxResponseTime { maxResponseTime = cud } if u.AllowedNetworks, err = cfg.groupToNetwork(u.NetworksOrGroups); err != nil { return err } } } for i := range cfg.Users { u := &cfg.Users[i] u.setDefaults() ud := time.Duration(u.MaxExecutionTime + u.MaxQueueTime) if ud > maxResponseTime { maxResponseTime = ud } if u.AllowedNetworks, err = cfg.groupToNetwork(u.NetworksOrGroups); err != nil { return err } } for i := range cfg.Caches { c := &cfg.Caches[i] c.setDefaults() } cfg.setServerMaxResponseTime(maxResponseTime) return nil } func (cfg *Config) setServerMaxResponseTime(maxResponseTime time.Duration) { if maxResponseTime < 0 { maxResponseTime = 0 } // Give an additional minute for the maximum response time, // so the response body may be sent to the requester. maxResponseTime += time.Minute if len(cfg.Server.HTTP.ListenAddr) > 0 && cfg.Server.HTTP.WriteTimeout == 0 { cfg.Server.HTTP.WriteTimeout = Duration(maxResponseTime) } if len(cfg.Server.HTTPS.ListenAddr) > 0 && cfg.Server.HTTPS.WriteTimeout == 0 { cfg.Server.HTTPS.WriteTimeout = Duration(maxResponseTime) } } func (c *Config) groupToNetwork(src NetworksOrGroups) (Networks, error) { if len(src) == 0 { return nil, nil } dst := make(Networks, 0) for _, v := range src { group, ok := c.networkReg[v] if ok { dst = append(dst, group...) } else { ipnet, err := stringToIPnet(v) if err != nil { return nil, err } dst = append(dst, ipnet) } } return dst, nil } // Server describes configuration of proxy server // These settings are immutable and can't be reloaded without restart type Server struct { // Optional HTTP configuration HTTP HTTP `yaml:"http,omitempty"` // Optional TLS configuration HTTPS HTTPS `yaml:"https,omitempty"` // Optional metrics handler configuration Metrics Metrics `yaml:"metrics,omitempty"` // Optional Proxy configuration Proxy Proxy `yaml:"proxy,omitempty"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (s *Server) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain Server if err := unmarshal((*plain)(s)); err != nil { return err } return checkOverflow(s.XXX, "server") } // TimeoutCfg contains configurable http.Server timeouts type TimeoutCfg struct { // ReadTimeout is the maximum duration for reading the entire // request, including the body. // Default value is 1m ReadTimeout Duration `yaml:"read_timeout,omitempty"` // WriteTimeout is the maximum duration before timing out writes of the response. // Default is largest MaxExecutionTime + MaxQueueTime value from Users or Clusters WriteTimeout Duration `yaml:"write_timeout,omitempty"` // IdleTimeout is the maximum amount of time to wait for the next request. // Default is 10m IdleTimeout Duration `yaml:"idle_timeout,omitempty"` } // HTTP describes configuration for server to listen HTTP connections type HTTP struct { // TCP address to listen to for http ListenAddr string `yaml:"listen_addr"` NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"` // List of networks that access is allowed from // Each list item could be IP address or subnet mask // if omitted or zero - no limits would be applied AllowedNetworks Networks `yaml:"-"` // Whether to support Autocert handler for http-01 challenge ForceAutocertHandler bool TimeoutCfg `yaml:",inline"` // Catches all undefined fields and must be empty after parsing. XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *HTTP) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain HTTP if err := unmarshal((*plain)(c)); err != nil { return err } if err := c.validate(); err != nil { return err } return checkOverflow(c.XXX, "http") } func (c *HTTP) validate() error { if c.ReadTimeout == 0 { c.ReadTimeout = Duration(time.Minute) } if c.IdleTimeout == 0 { c.IdleTimeout = Duration(time.Minute * 10) } return nil } // TLS describes generic configuration for TLS connections, // it can be used for both HTTPS and Redis TLS. type TLS struct { // Certificate and key files for client cert authentication to the server CertFile string `yaml:"cert_file,omitempty"` KeyFile string `yaml:"key_file,omitempty"` Autocert Autocert `yaml:"autocert,omitempty"` InsecureSkipVerify bool `yaml:"insecure_skip_verify,omitempty"` } // BuildTLSConfig builds tls.Config from TLS configuration. func (c *TLS) BuildTLSConfig(acm *autocert.Manager) (*tls.Config, error) { tlsCfg := tls.Config{ PreferServerCipherSuites: true, MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{ tls.CurveP256, tls.X25519, }, InsecureSkipVerify: c.InsecureSkipVerify, // nolint: gosec } if len(c.KeyFile) > 0 && len(c.CertFile) > 0 { cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile) if err != nil { return nil, fmt.Errorf("cannot load cert for `cert_file`=%q, `key_file`=%q: %w", c.CertFile, c.KeyFile, err) } tlsCfg.Certificates = []tls.Certificate{cert} } else { if acm == nil { return nil, fmt.Errorf("autocert manager is not configured") } tlsCfg.GetCertificate = acm.GetCertificate } return &tlsCfg, nil } // HTTPS describes configuration for server to listen HTTPS connections // It can be autocert with letsencrypt // or custom certificate type HTTPS struct { // TCP address to listen to for https // Default is `:443` ListenAddr string `yaml:"listen_addr,omitempty"` // TLS configuration TLS `yaml:",inline"` NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"` // List of networks that access is allowed from // Each list item could be IP address or subnet mask // if omitted or zero - no limits would be applied AllowedNetworks Networks `yaml:"-"` TimeoutCfg `yaml:",inline"` // Catches all undefined fields and must be empty after parsing. XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *HTTPS) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain HTTPS if err := unmarshal((*plain)(c)); err != nil { return err } if err := c.validate(); err != nil { return err } return checkOverflow(c.XXX, "https") } func (c *HTTPS) validate() error { if c.ReadTimeout == 0 { c.ReadTimeout = Duration(time.Minute) } if c.IdleTimeout == 0 { c.IdleTimeout = Duration(time.Minute * 10) } if len(c.ListenAddr) == 0 { c.ListenAddr = ":443" } if err := c.validateCertConfig(); err != nil { return err } return nil } func (c *HTTPS) validateCertConfig() error { if len(c.Autocert.CacheDir) > 0 { if len(c.CertFile) > 0 || len(c.KeyFile) > 0 { return fmt.Errorf("it is forbidden to specify certificate and `https.autocert` at the same time. Choose one way") } if len(c.NetworksOrGroups) > 0 { return fmt.Errorf("`letsencrypt` specification requires https server to be without `allowed_networks` limits. " + "Otherwise, certificates will be impossible to generate") } } if len(c.CertFile) > 0 && len(c.KeyFile) == 0 { return fmt.Errorf("`https.key_file` must be specified") } if len(c.KeyFile) > 0 && len(c.CertFile) == 0 { return fmt.Errorf("`https.cert_file` must be specified") } return nil } // Autocert configuration via letsencrypt // It requires port :80 to be open // see https://community.letsencrypt.org/t/2018-01-11-update-regarding-acme-tls-sni-and-shared-hosting-infrastructure/50188 type Autocert struct { // Path to the directory where autocert certs are cached CacheDir string `yaml:"cache_dir,omitempty"` // List of host names to which proxy is allowed to respond to // see https://godoc.org/golang.org/x/crypto/acme/autocert#HostPolicy AllowedHosts []string `yaml:"allowed_hosts,omitempty"` // Catches all undefined fields and must be empty after parsing. XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *Autocert) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain Autocert if err := unmarshal((*plain)(c)); err != nil { return err } return checkOverflow(c.XXX, "autocert") } // Metrics describes configuration to access metrics endpoint type Metrics struct { NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"` // List of networks that access is allowed from // Each list item could be IP address or subnet mask // if omitted or zero - no limits would be applied AllowedNetworks Networks `yaml:"-"` // Prometheus metric namespace Namespace string `yaml:"namespace,omitempty"` // Catches all undefined fields and must be empty after parsing. XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *Metrics) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain Metrics if err := unmarshal((*plain)(c)); err != nil { return err } return checkOverflow(c.XXX, "metrics") } type Proxy struct { // Enable enables parsing proxy headers. In proxy mode, CHProxy will try to // parse the X-Forwarded-For, X-Real-IP or Forwarded header to extract the IP. If an other header is configured // in the proxy settings, CHProxy will use that header instead. Enable bool `yaml:"enable,omitempty"` // Header allows for configuring an alternative header to parse the remote IP from, e.g. // CF-Connecting-IP. If this is set, Enable must be set to true otherwise this setting // will be ignored. Header string `yaml:"header,omitempty"` // Catches all undefined fields and must be empty after parsing. XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *Proxy) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain Proxy if err := unmarshal((*plain)(c)); err != nil { return err } if !c.Enable && c.Header != "" { return fmt.Errorf("`proxy_header` cannot be set without enabling proxy settings") } return checkOverflow(c.XXX, "proxy") } // Cluster describes CH cluster configuration // The simplest configuration consists of: // // cluster description - see <remote_servers> section in CH config.xml // and users - see <users> section in CH users.xml type Cluster struct { // Name of ClickHouse cluster Name string `yaml:"name"` // Scheme: `http` or `https`; would be applied to all nodes // default value is `http` Scheme string `yaml:"scheme,omitempty"` // Nodes contains cluster nodes. // // Either Nodes or Replicas must be set, but not both. Nodes []string `yaml:"nodes,omitempty"` // Replicas contains replicas. // // Either Replicas or Nodes must be set, but not both. Replicas []Replica `yaml:"replicas,omitempty"` // ClusterUsers - list of ClickHouse users ClusterUsers []ClusterUser `yaml:"users"` // KillQueryUser - user configuration for killing timed out queries. // By default timed out queries are killed under `default` user. KillQueryUser KillQueryUser `yaml:"kill_query_user,omitempty"` // HeartBeat - user configuration for heart beat requests HeartBeat HeartBeat `yaml:"heartbeat,omitempty"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` // Retry number for query - how many times a query can retry after receiving a recoverable but failed response from Clickhouse node RetryNumber int `yaml:"retry_number,omitempty"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *Cluster) UnmarshalYAML(unmarshal func(interface{}) error) error { *c = defaultCluster type plain Cluster if err := unmarshal((*plain)(c)); err != nil { return err } if err := c.validate(); err != nil { return err } return checkOverflow(c.XXX, fmt.Sprintf("cluster %q", c.Name)) } func (c *Cluster) validate() error { if len(c.Name) == 0 { return fmt.Errorf("`cluster.name` cannot be empty") } if err := c.validateMinimumRequirements(); err != nil { return err } if c.Scheme != "http" && c.Scheme != "https" { return fmt.Errorf("`cluster.scheme` must be `http` or `https`, got %q instead for %q", c.Scheme, c.Name) } if c.HeartBeat.Interval == 0 && c.HeartBeat.Timeout == 0 && c.HeartBeat.Response == "" { return fmt.Errorf("`cluster.heartbeat` cannot be unset for %q", c.Name) } return nil } func (c *Cluster) validateMinimumRequirements() error { if len(c.Nodes) == 0 && len(c.Replicas) == 0 { return fmt.Errorf("either `cluster.nodes` or `cluster.replicas` must be set for %q", c.Name) } if len(c.Nodes) > 0 && len(c.Replicas) > 0 { return fmt.Errorf("`cluster.nodes` cannot be simultaneously set with `cluster.replicas` for %q", c.Name) } if len(c.ClusterUsers) == 0 { return fmt.Errorf("`cluster.users` must contain at least 1 user for %q", c.Name) } return nil } // Replica contains ClickHouse replica configuration. type Replica struct { // Name is replica name. Name string `yaml:"name"` // Nodes contains replica nodes. Nodes []string `yaml:"nodes"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (r *Replica) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain Replica if err := unmarshal((*plain)(r)); err != nil { return err } if len(r.Name) == 0 { return fmt.Errorf("`replica.name` cannot be empty") } if len(r.Nodes) == 0 { return fmt.Errorf("`replica.nodes` cannot be empty for %q", r.Name) } return checkOverflow(r.XXX, fmt.Sprintf("replica %q", r.Name)) } // KillQueryUser - user configuration for killing timed out queries. type KillQueryUser struct { // User name Name string `yaml:"name"` // User password to access CH with basic auth Password string `yaml:"password,omitempty"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (u *KillQueryUser) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain KillQueryUser if err := unmarshal((*plain)(u)); err != nil { return err } if len(u.Name) == 0 { return fmt.Errorf("`cluster.kill_query_user.name` must be specified") } return checkOverflow(u.XXX, "kill_query_user") } // HeartBeat - configuration for heartbeat. type HeartBeat struct { // Interval is an interval of checking // all cluster nodes for availability // if omitted or zero - interval will be set to 5s Interval Duration `yaml:"interval,omitempty"` // Timeout is a timeout of wait response from cluster nodes // if omitted or zero - interval will be set to 3s Timeout Duration `yaml:"timeout,omitempty"` // Request is a query // default value is `/ping` Request string `yaml:"request,omitempty"` // Reference response from clickhouse on health check request // default value is `Ok.\n` Response string `yaml:"response,omitempty"` // Credentials to send heartbeat requests // for anything except '/ping'. // If not specified, the first cluster user' creadentials are used User string `yaml:"user,omitempty"` Password string `yaml:"password,omitempty"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (h *HeartBeat) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain HeartBeat if err := unmarshal((*plain)(h)); err != nil { return err } return checkOverflow(h.XXX, "heartbeat") } // User describes list of allowed users // which requests will be proxied to ClickHouse type User struct { // User name Name string `yaml:"name"` // User password to access proxy with basic auth Password string `yaml:"password,omitempty"` // ToCluster is the name of cluster where requests // will be proxied ToCluster string `yaml:"to_cluster"` // ToUser is the name of cluster_user from cluster's ToCluster // whom credentials will be used for proxying request to CH ToUser string `yaml:"to_user"` // Maximum number of concurrently running queries for user // if omitted or zero - no limits would be applied MaxConcurrentQueries uint32 `yaml:"max_concurrent_queries,omitempty"` // Maximum duration of query execution for user // if omitted or zero - limit is set to 120 seconds MaxExecutionTime Duration `yaml:"max_execution_time,omitempty"` // Maximum number of requests per minute for user // if omitted or zero - no limits would be applied ReqPerMin uint32 `yaml:"requests_per_minute,omitempty"` // The burst of request packet size token bucket for user // if omitted or zero - no limits would be applied ReqPacketSizeTokensBurst ByteSize `yaml:"request_packet_size_tokens_burst,omitempty"` // The request packet size tokens produced rate per second for user // if omitted or zero - no limits would be applied ReqPacketSizeTokensRate ByteSize `yaml:"request_packet_size_tokens_rate,omitempty"` // Maximum number of queries waiting for execution in the queue // if omitted or zero - queries are executed without waiting // in the queue MaxQueueSize uint32 `yaml:"max_queue_size,omitempty"` // Maximum duration the query may wait in the queue // if omitted or zero - 10s duration is used MaxQueueTime Duration `yaml:"max_queue_time,omitempty"` NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"` // List of networks that access is allowed from // Each list item could be IP address or subnet mask // if omitted or zero - no limits would be applied AllowedNetworks Networks `yaml:"-"` // Whether to deny http connections for this user DenyHTTP bool `yaml:"deny_http,omitempty"` // Whether to deny https connections for this user DenyHTTPS bool `yaml:"deny_https,omitempty"` // Whether to allow CORS requests for this user AllowCORS bool `yaml:"allow_cors,omitempty"` // Name of Cache configuration to use for responses of this user Cache string `yaml:"cache,omitempty"` // Name of ParamGroup to use Params string `yaml:"params,omitempty"` // prefix_* IsWildcarded bool `yaml:"is_wildcarded,omitempty"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (u *User) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain User if err := unmarshal((*plain)(u)); err != nil { return err } if err := u.validate(); err != nil { return err } return checkOverflow(u.XXX, fmt.Sprintf("user %q", u.Name)) } func (u *User) validate() error { if len(u.Name) == 0 { return fmt.Errorf("`user.name` cannot be empty") } if len(u.ToUser) == 0 { return fmt.Errorf("`user.to_user` cannot be empty for %q", u.Name) } if len(u.ToCluster) == 0 { return fmt.Errorf("`user.to_cluster` cannot be empty for %q", u.Name) } if u.DenyHTTP && u.DenyHTTPS { return fmt.Errorf("`deny_http` and `deny_https` cannot be simultaneously set to `true` for %q", u.Name) } if err := u.validateWildcarded(); err != nil { return err } if err := u.validateRateLimitConfig(); err != nil { return err } return nil } func (u *User) validateWildcarded() error { if u.IsWildcarded { if s := strings.Split(u.Name, "*"); !(len(s) == 2 && (s[0] == "" || s[1] == "")) { return fmt.Errorf("user name %q marked 'is_wildcared' does not match 'prefix*' or '*suffix' or '*'", u.Name) } } return nil } func (u *User) validateRateLimitConfig() error { if u.MaxQueueTime > 0 && u.MaxQueueSize == 0 { return fmt.Errorf("`max_queue_size` must be set if `max_queue_time` is set for %q", u.Name) } if u.ReqPacketSizeTokensBurst > 0 && u.ReqPacketSizeTokensRate == 0 { return fmt.Errorf("`request_packet_size_tokens_rate` must be set if `request_packet_size_tokens_burst` is set for %q", u.Name) } return nil } func (u *User) validateSecurity(hasHTTP, hasHTTPS bool) error { if len(u.Password) == 0 { if !u.DenyHTTPS && hasHTTPS { return fmt.Errorf("https: user %q has neither password nor `allowed_networks` on `user` or `server.http` level", u.Name) } if !u.DenyHTTP && hasHTTP { return fmt.Errorf("http: user %q has neither password nor `allowed_networks` on `user` or `server.http` level", u.Name) } } if len(u.Password) > 0 && hasHTTP { return fmt.Errorf("http: user %q is allowed to connect via http, but not limited by `allowed_networks` "+ "on `user` or `server.http` level - password could be stolen", u.Name) } return nil } func (u *User) setDefaults() { if u.MaxExecutionTime == 0 { u.MaxExecutionTime = defaultExecutionTime } } // NetworkGroups describes a named Networks lists type NetworkGroups struct { // Name of the group Name string `yaml:"name"` // List of networks // Each list item could be IP address or subnet mask Networks Networks `yaml:"networks"` // Catches all undefined fields and must be empty after parsing. XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (ng *NetworkGroups) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain NetworkGroups if err := unmarshal((*plain)(ng)); err != nil { return err } if len(ng.Name) == 0 { return fmt.Errorf("`network_group.name` must be specified") } if len(ng.Networks) == 0 { return fmt.Errorf("`network_group.networks` must contain at least one network") } return checkOverflow(ng.XXX, fmt.Sprintf("network_group %q", ng.Name)) } // NetworksOrGroups is a list of strings with names of NetworkGroups // or just Networks type NetworksOrGroups []string // Cache describes configuration options for caching // responses from CH clusters type Cache struct { // Mode of cache (file_system, redis) // todo make it an enum Mode string `yaml:"mode"` // Name of configuration for further assign Name string `yaml:"name"` // Expiration period for cached response // Files which are older than expiration period will be deleted // on new request and re-cached Expire Duration `yaml:"expire,omitempty"` // Deprecated: GraceTime duration before the expired entry is deleted from the cache. // It's deprecated and in future versions it'll be replaced by user's MaxExecutionTime. // It's already the case today if value of GraceTime is omitted. GraceTime Duration `yaml:"grace_time,omitempty"` FileSystem FileSystemCacheConfig `yaml:"file_system,omitempty"` Redis RedisCacheConfig `yaml:"redis,omitempty"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` // Maximum total size of request payload for caching MaxPayloadSize ByteSize `yaml:"max_payload_size,omitempty"` // Whether a query cached by a user could be used by another user SharedWithAllUsers bool `yaml:"shared_with_all_users,omitempty"` } func (c *Cache) setDefaults() { if c.MaxPayloadSize <= 0 { c.MaxPayloadSize = defaultMaxPayloadSize } } type FileSystemCacheConfig struct { //// Path to directory where cached files will be saved Dir string `yaml:"dir"` // Maximum total size of all cached to Dir files // If size is exceeded - the oldest files in Dir will be deleted // until total size becomes normal MaxSize ByteSize `yaml:"max_size"` } type RedisCacheConfig struct { TLS `yaml:",inline"` Username string `yaml:"username,omitempty"` Password string `yaml:"password,omitempty"` Addresses []string `yaml:"addresses"` DBIndex int `yaml:"db_index,omitempty"` XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *Cache) UnmarshalYAML(unmarshal func(interface{}) error) (err error) { type plain Cache if err = unmarshal((*plain)(c)); err != nil { return err } if len(c.Name) == 0 { return fmt.Errorf("`cache.name` must be specified") } switch c.Mode { case "file_system": err = c.checkFileSystemConfig() case "redis": err = c.checkRedisConfig() default: err = fmt.Errorf("not supported cache type %v. Supported types: [file_system]", c.Mode) } if err != nil { return fmt.Errorf("failed to configure cache for %q", c.Name) } return checkOverflow(c.XXX, fmt.Sprintf("cache %q", c.Name)) } func (c *Cache) checkFileSystemConfig() error { if len(c.FileSystem.Dir) == 0 { return fmt.Errorf("`cache.filesystem.dir` must be specified for %q", c.Name) } if c.FileSystem.MaxSize <= 0 { return fmt.Errorf("`cache.filesystem.max_size` must be specified for %q", c.Name) } return nil } func (c *Cache) checkRedisConfig() error { if len(c.Redis.Addresses) == 0 { return fmt.Errorf("`cache.redis.addresses` must be specified for %q", c.Name) } return nil } // ParamGroup describes named group of GET params // for sending with each query type ParamGroup struct { // Name of configuration for further assign Name string `yaml:"name"` // Params contains a list of GET params Params []Param `yaml:"params"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (pg *ParamGroup) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain ParamGroup if err := unmarshal((*plain)(pg)); err != nil { return err } if len(pg.Name) == 0 { return fmt.Errorf("`param_group.name` must be specified") } if len(pg.Params) == 0 { return fmt.Errorf("`param_group.params` must contain at least one param") } return checkOverflow(pg.XXX, fmt.Sprintf("param_group %q", pg.Name)) } // Param describes URL param value type Param struct { // Key is a name of params Key string `yaml:"key"` // Value is a value of param Value string `yaml:"value"` } // ConnectionPool describes pool of connection with ClickHouse // settings type ConnectionPool struct { // Maximum total number of idle connections between chproxy and all ClickHouse instances MaxIdleConns int `yaml:"max_idle_conns,omitempty"` // Maximum number of idle connections between chproxy and particuler ClickHouse instance MaxIdleConnsPerHost int `yaml:"max_idle_conns_per_host,omitempty"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (cp *ConnectionPool) UnmarshalYAML(unmarshal func(interface{}) error) error { *cp = defaultConnectionPool type plain ConnectionPool if err := unmarshal((*plain)(cp)); err != nil { return err } if cp.MaxIdleConnsPerHost > cp.MaxIdleConns || cp.MaxIdleConns < 0 { return fmt.Errorf("inconsistent ConnectionPool settings") } return checkOverflow(cp.XXX, "connection_pool") } // ClusterUser describes simplest <users> configuration type ClusterUser struct { // User name in ClickHouse users.xml config Name string `yaml:"name"` // User password in ClickHouse users.xml config Password string `yaml:"password,omitempty"` // Maximum number of concurrently running queries for user // if omitted or zero - no limits would be applied MaxConcurrentQueries uint32 `yaml:"max_concurrent_queries,omitempty"` // Maximum duration of query execution for user // if omitted or zero - limit is set to 120 seconds MaxExecutionTime Duration `yaml:"max_execution_time,omitempty"` // Maximum number of requests per minute for user // if omitted or zero - no limits would be applied ReqPerMin uint32 `yaml:"requests_per_minute,omitempty"` // The burst of request packet size token bucket for user // if omitted or zero - no limits would be applied ReqPacketSizeTokensBurst ByteSize `yaml:"request_packet_size_tokens_burst,omitempty"` // The request packet size tokens produced rate for user // if omitted or zero - no limits would be applied ReqPacketSizeTokensRate ByteSize `yaml:"request_packet_size_tokens_rate,omitempty"` // Maximum number of queries waiting for execution in the queue // if omitted or zero - queries are executed without waiting // in the queue MaxQueueSize uint32 `yaml:"max_queue_size,omitempty"` // Maximum duration the query may wait in the queue // if omitted or zero - 10s duration is used MaxQueueTime Duration `yaml:"max_queue_time,omitempty"` NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"` // List of networks that access is allowed from // Each list item could be IP address or subnet mask // if omitted or zero - no limits would be applied AllowedNetworks Networks `yaml:"-"` // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (cu *ClusterUser) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain ClusterUser if err := unmarshal((*plain)(cu)); err != nil { return err } if len(cu.Name) == 0 { return fmt.Errorf("`cluster.user.name` cannot be empty") } if cu.MaxQueueTime > 0 && cu.MaxQueueSize == 0 { return fmt.Errorf("`max_queue_size` must be set if `max_queue_time` is set for %q", cu.Name) } if cu.ReqPacketSizeTokensBurst > 0 && cu.ReqPacketSizeTokensRate == 0 { return fmt.Errorf("`request_packet_size_tokens_rate` must be set if `request_packet_size_tokens_burst` is set for %q", cu.Name) } return checkOverflow(cu.XXX, fmt.Sprintf("cluster.user %q", cu.Name)) } // LoadFile loads and validates configuration from provided .yml file func LoadFile(filename string) (*Config, error) { content, err := os.ReadFile(filename) if err != nil { return nil, err } content = findAndReplacePlaceholders(content) cfg := &Config{} if err := yaml.Unmarshal(content, cfg); err != nil { return nil, err } cfg.networkReg = make(map[string]Networks, len(cfg.NetworkGroups)) for _, ng := range cfg.NetworkGroups { if _, ok := cfg.networkReg[ng.Name]; ok { return nil, fmt.Errorf("duplicate `network_groups.name` %q", ng.Name) } cfg.networkReg[ng.Name] = ng.Networks } if cfg.Server.HTTP.AllowedNetworks, err = cfg.groupToNetwork(cfg.Server.HTTP.NetworksOrGroups); err != nil { return nil, err } if cfg.Server.HTTPS.AllowedNetworks, err = cfg.groupToNetwork(cfg.Server.HTTPS.NetworksOrGroups); err != nil { return nil, err } if cfg.Server.Metrics.AllowedNetworks, err = cfg.groupToNetwork(cfg.Server.Metrics.NetworksOrGroups); err != nil { return nil, err } if err := cfg.setDefaults(); err != nil { return nil, err } if err := cfg.checkVulnerabilities(); err != nil { return nil, fmt.Errorf("security breach: %w\nSet option `hack_me_please=true` to disable security errors", err) } return cfg, nil } var envVarRegex = regexp.MustCompile(`\${([a-zA-Z_][a-zA-Z0-9_]*)}`) // findAndReplacePlaceholders finds all environment variables placeholders in the config. // Each placeholder is a string like ${VAR_NAME}. They will be replaced with the value of the // corresponding environment variable. It returns the new content with replaced placeholders. func findAndReplacePlaceholders(content []byte) []byte { for _, match := range envVarRegex.FindAllSubmatch(content, -1) { envVar := os.Getenv(string(match[1])) if envVar != "" { content = bytes.ReplaceAll(content, match[0], []byte(envVar)) } } return content } func (c Config) checkVulnerabilities() error { if c.HackMePlease { return nil } hasHTTPS := len(c.Server.HTTPS.ListenAddr) > 0 && len(c.Server.HTTPS.NetworksOrGroups) == 0 hasHTTP := len(c.Server.HTTP.ListenAddr) > 0 && len(c.Server.HTTP.NetworksOrGroups) == 0 for _, u := range c.Users { if len(u.NetworksOrGroups) != 0 { continue } if err := u.validateSecurity(hasHTTP, hasHTTPS); err != nil { return err } } return nil }
package config import ( "fmt" "math" "net" "regexp" "strconv" "strings" "time" ) // ByteSize holds size in bytes. // // May be used in yaml for parsing byte size values. type ByteSize uint64 var byteSizeRegexp = regexp.MustCompile(`^(\d+(?:\.\d+)?)\s*([KMGTP]?)B?$`) // UnmarshalYAML implements the yaml.Unmarshaler interface. func (bs *ByteSize) UnmarshalYAML(unmarshal func(interface{}) error) error { var s string if err := unmarshal(&s); err != nil { return err } s = strings.ToUpper(s) value, unit, err := parseStringParts(s) if err != nil { return err } if value <= 0 { return fmt.Errorf("byte size %q must be positive", s) } k := float64(1) switch unit { case "P": k = 1 << 50 case "T": k = 1 << 40 case "G": k = 1 << 30 case "M": k = 1 << 20 case "K": k = 1 << 10 } value *= k *bs = ByteSize(value) // check for overflow e := math.Abs(float64(*bs)-value) / value if e > 1e-6 { return fmt.Errorf("byte size %q is too big", s) } return nil } func parseStringParts(s string) (float64, string, error) { s = strings.ToUpper(s) parts := byteSizeRegexp.FindStringSubmatch(strings.TrimSpace(s)) if len(parts) < 3 { return -1, "", fmt.Errorf("cannot parse byte size %q: it must be positive float followed by optional units. For example, 1.5Gb, 3T", s) } value, err := strconv.ParseFloat(parts[1], 64) if err != nil { return -1, "", fmt.Errorf("cannot parse byte size %q: it must be positive float followed by optional units. For example, 1.5Gb, 3T; err: %w", s, err) } unit := parts[2] return value, unit, nil } // Networks is a list of IPNet entities type Networks []*net.IPNet // MarshalYAML implements yaml.Marshaler interface. // // It prettifies yaml output for Networks. func (n Networks) MarshalYAML() (interface{}, error) { var a []string for _, x := range n { a = append(a, x.String()) } return a, nil } // UnmarshalYAML implements the yaml.Unmarshaler interface. func (n *Networks) UnmarshalYAML(unmarshal func(interface{}) error) error { var s []string if err := unmarshal(&s); err != nil { return err } networks := make(Networks, len(s)) for i, s := range s { ipnet, err := stringToIPnet(s) if err != nil { return err } networks[i] = ipnet } *n = networks return nil } // Contains checks whether passed addr is in the range of networks func (n Networks) Contains(addr string) bool { if len(n) == 0 { return true } h, _, err := net.SplitHostPort(addr) if err != nil { // If we only have an IP address. This happens when the proxy middleware is enabled. h = addr } ip := net.ParseIP(h) if ip == nil { panic(fmt.Sprintf("BUG: unexpected error while parsing IP: %s", h)) } for _, ipnet := range n { if ipnet.Contains(ip) { return true } } return false } // Duration wraps time.Duration. It is used to parse the custom duration format type Duration time.Duration // UnmarshalYAML implements the yaml.Unmarshaler interface. func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error { var s string if err := unmarshal(&s); err != nil { return err } dur, err := StringToDuration(s) if err != nil { return err } *d = dur return nil } // String implements the Stringer interface. func (d Duration) String() string { factors := map[string]time.Duration{ "w": time.Hour * 24 * 7, "d": time.Hour * 24, "h": time.Hour, "m": time.Minute, "s": time.Second, "ms": time.Millisecond, "µs": time.Microsecond, "ns": 1, } var t = time.Duration(d) unit := "ns" //nolint:exhaustive // Custom duration counter that doesn't switch on the duration. switch time.Duration(0) { case t % factors["w"]: unit = "w" case t % factors["d"]: unit = "d" case t % factors["h"]: unit = "h" case t % factors["m"]: unit = "m" case t % factors["s"]: unit = "s" case t % factors["ms"]: unit = "ms" case t % factors["µs"]: unit = "µs" } return fmt.Sprintf("%d%v", t/factors[unit], unit) } // MarshalYAML implements the yaml.Marshaler interface. func (d Duration) MarshalYAML() (interface{}, error) { return d.String(), nil } // borrowed from github.com/prometheus/prometheus var durationRE = regexp.MustCompile("^([0-9]+)(w|d|h|m|s|ms|µs|ns)$") // StringToDuration parses a string into a time.Duration, // assuming that a week always has 7d, and a day always has 24h. func StringToDuration(durationStr string) (Duration, error) { n, unit, err := parseDurationParts(durationStr) if err != nil { return 0, err } return calculateDuration(n, unit) } func parseDurationParts(s string) (int, string, error) { matches := durationRE.FindStringSubmatch(s) if len(matches) != 3 { return 0, "", fmt.Errorf("not a valid duration string: %q", s) } n, err := strconv.Atoi(matches[1]) if err != nil { return 0, "", fmt.Errorf("duration too long: %q", matches[1]) } unit := matches[2] return n, unit, nil } func calculateDuration(n int, unit string) (Duration, error) { dur := time.Duration(n) switch unit { case "w": dur *= time.Hour * 24 * 7 case "d": dur *= time.Hour * 24 case "h": dur *= time.Hour case "m": dur *= time.Minute case "s": dur *= time.Second case "ms": dur *= time.Millisecond case "µs": dur *= time.Microsecond case "ns": default: return 0, fmt.Errorf("invalid time unit in duration string: %q", unit) } return Duration(dur), nil }
package config import ( "fmt" "net" "strings" ) const entireIPv4 = "0.0.0.0/0" func stringToIPnet(s string) (*net.IPNet, error) { if s == entireIPv4 { return nil, fmt.Errorf("suspicious mask specified \"0.0.0.0/0\". " + "If you want to allow all then just omit `allowed_networks` field") } ip := s if !strings.Contains(ip, `/`) { ip += "/32" } _, ipnet, err := net.ParseCIDR(ip) if err != nil { return nil, fmt.Errorf("wrong network group name or address %q: %w", s, err) } return ipnet, nil } func checkOverflow(m map[string]interface{}, ctx string) error { if len(m) > 0 { var keys []string for k := range m { keys = append(keys, k) } return fmt.Errorf("unknown fields in %s: %s", ctx, strings.Join(keys, ", ")) } return nil }
package counter import "sync/atomic" type Counter struct { value atomic.Uint32 } func (c *Counter) Store(n uint32) { c.value.Store(n) } func (c *Counter) Load() uint32 { return c.value.Load() } func (c *Counter) Dec() { c.value.Add(^uint32(0)) } func (c *Counter) Inc() uint32 { return c.value.Add(1) }
package heartbeat import ( "context" "fmt" "io" "net/http" "time" "github.com/contentsquare/chproxy/config" ) var errUnexpectedResponse = fmt.Errorf("unexpected response") type HeartBeat interface { IsHealthy(ctx context.Context, addr string) error Interval() time.Duration } type heartBeatOpts struct { defaultUser string defaultPassword string } type Option interface { apply(*heartBeatOpts) } type defaultUser struct { defaultUser string defaultPassword string } func (o defaultUser) apply(opts *heartBeatOpts) { opts.defaultUser = o.defaultUser opts.defaultPassword = o.defaultPassword } func WithDefaultUser(user, password string) Option { return defaultUser{ defaultUser: user, defaultPassword: password, } } type heartBeat struct { interval time.Duration timeout time.Duration request string response string user string password string } // User credentials are not needed const defaultEndpoint string = "/ping" func NewHeartbeat(c config.HeartBeat, options ...Option) HeartBeat { opts := &heartBeatOpts{} for _, o := range options { o.apply(opts) } newHB := &heartBeat{ interval: time.Duration(c.Interval), timeout: time.Duration(c.Timeout), request: c.Request, response: c.Response, } if c.Request != defaultEndpoint { if c.User != "" { newHB.user = c.User newHB.password = c.Password } else { newHB.user = opts.defaultUser newHB.password = opts.defaultPassword } } if newHB.request != defaultEndpoint && newHB.user == "" { panic("BUG: user is empty, no default user provided") } return newHB } func (hb *heartBeat) IsHealthy(ctx context.Context, addr string) error { req, err := http.NewRequest("GET", addr+hb.request, nil) if err != nil { return err } if hb.request != defaultEndpoint { req.SetBasicAuth(hb.user, hb.password) } ctx, cancel := context.WithTimeout(ctx, hb.timeout) defer cancel() req = req.WithContext(ctx) startTime := time.Now() resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("cannot send request in %s: %w", time.Since(startTime), err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("non-200 status code: %s", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("cannot read response in %s: %w", time.Since(startTime), err) } r := string(body) if r != hb.response { return fmt.Errorf("%w: %s", errUnexpectedResponse, r) } return nil } func (hb *heartBeat) Interval() time.Duration { return hb.interval }
package topology // TODO this is only here to avoid recursive imports. We should have a separate package for metrics. import ( "github.com/contentsquare/chproxy/config" "github.com/prometheus/client_golang/prometheus" ) var ( HostHealth *prometheus.GaugeVec HostPenalties *prometheus.CounterVec ) func initMetrics(cfg *config.Config) { namespace := cfg.Server.Metrics.Namespace HostHealth = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: namespace, Name: "host_health", Help: "Health state of hosts by clusters", }, []string{"cluster", "replica", "cluster_node"}, ) HostPenalties = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "host_penalties_total", Help: "Total number of given penalties by host", }, []string{"cluster", "replica", "cluster_node"}, ) } func RegisterMetrics(cfg *config.Config) { initMetrics(cfg) prometheus.MustRegister(HostHealth, HostPenalties) } func reportNodeHealthMetric(clusterName, replicaName, nodeName string, active bool) { label := prometheus.Labels{ "cluster": clusterName, "replica": replicaName, "cluster_node": nodeName, } if active { HostHealth.With(label).Set(1) } else { HostHealth.With(label).Set(0) } } func incrementPenaltiesMetric(clusterName, replicaName, nodeName string) { label := prometheus.Labels{ "cluster": clusterName, "replica": replicaName, "cluster_node": nodeName, } HostPenalties.With(label).Inc() }
package topology import ( "context" "net/url" "sync/atomic" "time" "github.com/contentsquare/chproxy/internal/counter" "github.com/contentsquare/chproxy/internal/heartbeat" "github.com/contentsquare/chproxy/log" ) const ( // prevents excess goroutine creating while penalizing overloaded host DefaultPenaltySize = 5 DefaultMaxSize = 300 DefaultPenaltyDuration = time.Second * 10 ) type nodeOpts struct { defaultActive bool penaltySize uint32 penaltyMaxSize uint32 penaltyDuration time.Duration } func defaultNodeOpts() nodeOpts { return nodeOpts{ penaltySize: DefaultPenaltySize, penaltyMaxSize: DefaultMaxSize, penaltyDuration: DefaultPenaltyDuration, } } type NodeOption interface { apply(*nodeOpts) } type defaultActive struct { active bool } func (o defaultActive) apply(opts *nodeOpts) { opts.defaultActive = o.active } func WithDefaultActiveState(active bool) NodeOption { return defaultActive{ active: active, } } type Node struct { // Node Address. addr *url.URL // Whether this node is alive. active atomic.Bool // Counter of currently running connections. connections counter.Counter // Counter of unsuccesfull request to decrease host priority. penalty atomic.Uint32 // Heartbeat function hb heartbeat.HeartBeat // TODO These fields are only used for labels in prometheus. We should have a different way to pass the labels. // For metrics only clusterName string replicaName string // Additional configuration options opts nodeOpts } func NewNode(addr *url.URL, hb heartbeat.HeartBeat, clusterName, replicaName string, opts ...NodeOption) *Node { nodeOpts := defaultNodeOpts() for _, opt := range opts { opt.apply(&nodeOpts) } n := &Node{ addr: addr, hb: hb, clusterName: clusterName, replicaName: replicaName, opts: nodeOpts, } if n.opts.defaultActive { n.SetIsActive(true) } return n } func (n *Node) IsActive() bool { return n.active.Load() } func (n *Node) SetIsActive(active bool) { n.active.Store(active) } // StartHeartbeat runs the heartbeat healthcheck against the node // until the done channel is closed. // If the heartbeat fails, the active status of the node is changed. func (n *Node) StartHeartbeat(done <-chan struct{}) { ctx, cancel := context.WithCancel(context.Background()) for { n.heartbeat(ctx) select { case <-done: cancel() return case <-time.After(n.hb.Interval()): } } } func (n *Node) heartbeat(ctx context.Context) { if err := n.hb.IsHealthy(ctx, n.addr.String()); err == nil { n.active.Store(true) reportNodeHealthMetric(n.clusterName, n.replicaName, n.Host(), true) } else { log.Errorf("error while health-checking %q host: %s", n.Host(), err) n.active.Store(false) reportNodeHealthMetric(n.clusterName, n.replicaName, n.Host(), false) } } // Penalize a node if a request failed to decrease it's priority. // If the penalty is already at the maximum allowed size this function // will not penalize the node further. // A function will be registered to run after the penalty duration to // increase the priority again. func (n *Node) Penalize() { penalty := n.penalty.Load() if penalty >= n.opts.penaltyMaxSize { return } incrementPenaltiesMetric(n.clusterName, n.replicaName, n.Host()) n.penalty.Add(n.opts.penaltySize) time.AfterFunc(n.opts.penaltyDuration, func() { n.penalty.Add(^uint32(n.opts.penaltySize - 1)) }) } // CurrentLoad returns the current node returns the number of open connections // plus the penalty. func (n *Node) CurrentLoad() uint32 { c := n.connections.Load() p := n.penalty.Load() return c + p } func (n *Node) CurrentConnections() uint32 { return n.connections.Load() } func (n *Node) CurrentPenalty() uint32 { return n.penalty.Load() } func (n *Node) IncrementConnections() { n.connections.Inc() } func (n *Node) DecrementConnections() { n.connections.Dec() } func (n *Node) Scheme() string { return n.addr.Scheme } func (n *Node) Host() string { return n.addr.Host } func (n *Node) ReplicaName() string { return n.replicaName } func (n *Node) String() string { return n.addr.String() }
package main import ( "errors" "fmt" "io" "net/http" "sync" "time" "github.com/contentsquare/chproxy/cache" "github.com/contentsquare/chproxy/log" "github.com/prometheus/client_golang/prometheus" ) type ResponseWriterWithCode interface { http.ResponseWriter StatusCode() int } type StatResponseWriter interface { http.ResponseWriter http.CloseNotifier StatusCode() int SetStatusCode(code int) } var _ StatResponseWriter = &statResponseWriter{} // statResponseWriter collects the amount of bytes written. // // The wrapped ResponseWriter must implement http.CloseNotifier. // // Additionally it caches response status code. type statResponseWriter struct { http.ResponseWriter statusCode int // wroteHeader tells whether the header's been written to // the original ResponseWriter wroteHeader bool bytesWritten prometheus.Counter } const ( XCacheHit = "HIT" XCacheMiss = "MISS" XCacheNA = "N/A" ) func RespondWithData(rw http.ResponseWriter, data io.Reader, metadata cache.ContentMetadata, ttl time.Duration, cacheHit string, statusCode int, labels prometheus.Labels) error { h := rw.Header() if len(metadata.Type) > 0 { h.Set("Content-Type", metadata.Type) } if len(metadata.Encoding) > 0 { h.Set("Content-Encoding", metadata.Encoding) } h.Set("Content-Length", fmt.Sprintf("%d", metadata.Length)) if ttl > 0 { expireSeconds := uint(ttl / time.Second) h.Set("Cache-Control", fmt.Sprintf("max-age=%d", expireSeconds)) } h.Set("X-Cache", cacheHit) rw.WriteHeader(statusCode) if _, err := io.Copy(rw, data); err != nil { var perr *cache.RedisCacheError if errors.As(err, &perr) { cacheCorruptedFetch.With(labels).Inc() log.Debugf("redis cache error") } log.Errorf("cannot send response to client: %s", err) return fmt.Errorf("cannot send response to client: %w", err) } return nil } func (rw *statResponseWriter) SetStatusCode(code int) { rw.statusCode = code } func (rw *statResponseWriter) StatusCode() int { if rw.statusCode == 0 { return http.StatusOK } return rw.statusCode } func (rw *statResponseWriter) Write(b []byte) (int, error) { if rw.statusCode == 0 { rw.statusCode = http.StatusOK } if !rw.wroteHeader { rw.ResponseWriter.WriteHeader(rw.statusCode) rw.wroteHeader = true } n, err := rw.ResponseWriter.Write(b) rw.bytesWritten.Add(float64(n)) return n, err } func (rw *statResponseWriter) WriteHeader(statusCode int) { // cache statusCode to keep the opportunity to change it in further rw.statusCode = statusCode } // CloseNotify implements http.CloseNotifier func (rw *statResponseWriter) CloseNotify() <-chan bool { // The rw.ResponseWriter must implement http.CloseNotifier rwc, ok := rw.ResponseWriter.(http.CloseNotifier) if !ok { panic("BUG: the wrapped ResponseWriter must implement http.CloseNotifier") } return rwc.CloseNotify() } var _ io.ReadCloser = &statReadCloser{} // statReadCloser collects the amount of bytes read. type statReadCloser struct { io.ReadCloser bytesRead prometheus.Counter } func (src *statReadCloser) Read(p []byte) (int, error) { n, err := src.ReadCloser.Read(p) src.bytesRead.Add(float64(n)) return n, err } var _ io.ReadCloser = &cachedReadCloser{} // cachedReadCloser caches the first 1Kb form the wrapped ReadCloser. type cachedReadCloser struct { io.ReadCloser // bLock protects b from concurrent access when Read and String // are called from concurrent goroutines. bLock sync.Mutex // b holds up to 1Kb of the initial data read from ReadCloser. b []byte } func (crc *cachedReadCloser) Read(p []byte) (int, error) { n, err := crc.ReadCloser.Read(p) crc.bLock.Lock() if len(crc.b) < 1024 { crc.b = append(crc.b, p[:n]...) if len(crc.b) >= 1024 { crc.b = append(crc.b[:1024], "..."...) } } crc.bLock.Unlock() // Do not cache the last read operation, since it slows down // reading large amounts of data such as large INSERT queries. return n, err } func (crc *cachedReadCloser) String() string { crc.bLock.Lock() s := string(crc.b) crc.bLock.Unlock() return s }
package log import ( "fmt" "io" "log" "os" "sync/atomic" ) var ( stdLogFlags = log.LstdFlags | log.Lshortfile | log.LUTC outputCallDepth = 2 debugLogger = log.New(os.Stderr, "DEBUG: ", stdLogFlags) infoLogger = log.New(os.Stderr, "INFO: ", stdLogFlags) errorLogger = log.New(os.Stderr, "ERROR: ", stdLogFlags) fatalLogger = log.New(os.Stderr, "FATAL: ", stdLogFlags) // NilLogger suppresses all the log messages. NilLogger = log.New(io.Discard, "", stdLogFlags) ) // SuppressOutput suppresses all output from logs if `suppress` is true // used while testing func SuppressOutput(suppress bool) { if suppress { debugLogger.SetOutput(io.Discard) infoLogger.SetOutput(io.Discard) errorLogger.SetOutput(io.Discard) } else { debugLogger.SetOutput(os.Stderr) infoLogger.SetOutput(os.Stderr) errorLogger.SetOutput(os.Stderr) } } var debug uint32 // SetDebug sets output into debug mode if true passed func SetDebug(val bool) { if val { atomic.StoreUint32(&debug, 1) } else { atomic.StoreUint32(&debug, 0) } } // Debugf prints debug message according to a format func Debugf(format string, args ...interface{}) { if atomic.LoadUint32(&debug) == 0 { return } s := fmt.Sprintf(format, args...) debugLogger.Output(outputCallDepth, s) // nolint } // Infof prints info message according to a format func Infof(format string, args ...interface{}) { s := fmt.Sprintf(format, args...) infoLogger.Output(outputCallDepth, s) // nolint } // Errorf prints warning message according to a format func Errorf(format string, args ...interface{}) { s := fmt.Sprintf(format, args...) errorLogger.Output(outputCallDepth, s) // nolint } // ErrorWithCallDepth prints err into error log using the given callDepth. func ErrorWithCallDepth(err error, callDepth int) { s := err.Error() errorLogger.Output(outputCallDepth+callDepth, s) //nolint } // Fatalf prints fatal message according to a format and exits program func Fatalf(format string, args ...interface{}) { s := fmt.Sprintf(format, args...) fatalLogger.Output(outputCallDepth, s) // nolint os.Exit(1) }
package main import ( "context" "crypto/tls" "flag" "fmt" "net" "net/http" "os" "os/signal" "strings" "sync/atomic" "syscall" "time" "github.com/contentsquare/chproxy/config" "github.com/contentsquare/chproxy/log" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/crypto/acme/autocert" ) var ( configFile = flag.String("config", "", "Proxy configuration filename") version = flag.Bool("version", false, "Prints current version and exits") enableTCP6 = flag.Bool("enableTCP6", false, "Whether to enable listening for IPv6 TCP ports. "+ "By default only IPv4 TCP ports are listened") ) var ( proxy *reverseProxy // networks allow lists allowedNetworksHTTP atomic.Value allowedNetworksHTTPS atomic.Value allowedNetworksMetrics atomic.Value proxyHandler atomic.Value ) func main() { flag.Parse() if *version { fmt.Printf("%s\n", versionString()) os.Exit(0) } log.Infof("%s", versionString()) log.Infof("Loading config: %s", *configFile) cfg, err := loadConfig() if err != nil { log.Fatalf("error while loading config: %s", err) } registerMetrics(cfg) if err = applyConfig(cfg); err != nil { log.Fatalf("error while applying config: %s", err) } configSuccess.Set(1) configSuccessTime.Set(float64(time.Now().Unix())) log.Infof("Loading config %q: successful", *configFile) setupReloadConfigWatch() server := cfg.Server if len(server.HTTP.ListenAddr) == 0 && len(server.HTTPS.ListenAddr) == 0 { panic("BUG: broken config validation - `listen_addr` is not configured") } if server.HTTP.ForceAutocertHandler { autocertManager = newAutocertManager(server.HTTPS.Autocert) } if len(server.HTTPS.ListenAddr) != 0 { go serveTLS(server.HTTPS) } if len(server.HTTP.ListenAddr) != 0 { go serve(server.HTTP) } select {} } func setupReloadConfigWatch() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGHUP) go func() { for { if <-c == syscall.SIGHUP { log.Infof("SIGHUP received. Going to reload config %s ...", *configFile) if err := reloadConfig(); err != nil { log.Errorf("error while reloading config: %s", err) continue } log.Infof("Reloading config %s: successful", *configFile) } } }() } var autocertManager *autocert.Manager func newAutocertManager(cfg config.Autocert) *autocert.Manager { if len(cfg.CacheDir) > 0 { if err := os.MkdirAll(cfg.CacheDir, 0o700); err != nil { log.Fatalf("error while creating folder %q: %s", cfg.CacheDir, err) } } var hp autocert.HostPolicy if len(cfg.AllowedHosts) != 0 { allowedHosts := make(map[string]struct{}, len(cfg.AllowedHosts)) for _, v := range cfg.AllowedHosts { allowedHosts[v] = struct{}{} } hp = func(_ context.Context, host string) error { if _, ok := allowedHosts[host]; ok { return nil } return fmt.Errorf("host %q doesn't match `host_policy` configuration", host) } } return &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(cfg.CacheDir), HostPolicy: hp, } } func newListener(listenAddr string) net.Listener { network := "tcp4" if *enableTCP6 { // Enable listening on both tcp4 and tcp6 network = "tcp" } ln, err := net.Listen(network, listenAddr) if err != nil { log.Fatalf("cannot listen for %q: %s", listenAddr, err) } return ln } func serveTLS(cfg config.HTTPS) { ln := newListener(cfg.ListenAddr) h := http.HandlerFunc(serveHTTP) tlsCfg, err := cfg.TLS.BuildTLSConfig(autocertManager) if err != nil { log.Fatalf("cannot build TLS config: %s", err) } tln := tls.NewListener(ln, tlsCfg) log.Infof("Serving https on %q", cfg.ListenAddr) if err := listenAndServe(tln, h, cfg.TimeoutCfg); err != nil { log.Fatalf("TLS server error on %q: %s", cfg.ListenAddr, err) } } func serve(cfg config.HTTP) { var h http.Handler ln := newListener(cfg.ListenAddr) h = http.HandlerFunc(serveHTTP) if cfg.ForceAutocertHandler { if autocertManager == nil { panic("BUG: autocertManager is not inited") } addr := ln.Addr().String() parts := strings.Split(addr, ":") if parts[len(parts)-1] != "80" { log.Fatalf("`letsencrypt` specification requires http server to listen on :80 port to satisfy http-01 challenge. " + "Otherwise, certificates will be impossible to generate") } h = autocertManager.HTTPHandler(h) } log.Infof("Serving http on %q", cfg.ListenAddr) if err := listenAndServe(ln, h, cfg.TimeoutCfg); err != nil { log.Fatalf("HTTP server error on %q: %s", cfg.ListenAddr, err) } } func newServer(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) *http.Server { // nolint:gosec // We already configured ReadTimeout, so no need to set ReadHeaderTimeout as well. return &http.Server{ TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), Handler: h, ReadTimeout: time.Duration(cfg.ReadTimeout), WriteTimeout: time.Duration(cfg.WriteTimeout), IdleTimeout: time.Duration(cfg.IdleTimeout), // Suppress error logging from the server, since chproxy // must handle all these errors in the code. ErrorLog: log.NilLogger, } } func listenAndServe(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) error { s := newServer(ln, h, cfg) return s.Serve(ln) } var promHandler = promhttp.Handler() //nolint:cyclop //TODO reduce complexity here. func serveHTTP(rw http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet, http.MethodPost: // Only GET and POST methods are supported. case http.MethodOptions: // This is required for CORS shit :) rw.Header().Set("Allow", "GET,POST") return default: err := fmt.Errorf("%q: unsupported method %q", r.RemoteAddr, r.Method) rw.Header().Set("Connection", "close") respondWith(rw, err, http.StatusMethodNotAllowed) return } switch r.URL.Path { case "/favicon.ico": case "/metrics": // nolint:forcetypeassert // We will cover this by tests as we control what is stored. an := allowedNetworksMetrics.Load().(*config.Networks) if !an.Contains(r.RemoteAddr) { err := fmt.Errorf("connections to /metrics are not allowed from %s", r.RemoteAddr) rw.Header().Set("Connection", "close") respondWith(rw, err, http.StatusForbidden) return } proxy.refreshCacheMetrics() promHandler.ServeHTTP(rw, r) case "/", "/query": var err error // nolint:forcetypeassert // We will cover this by tests as we control what is stored. proxyHandler := proxyHandler.Load().(*ProxyHandler) r.RemoteAddr = proxyHandler.GetRemoteAddr(r) var an *config.Networks if r.TLS != nil { // nolint:forcetypeassert // We will cover this by tests as we control what is stored. an = allowedNetworksHTTPS.Load().(*config.Networks) err = fmt.Errorf("https connections are not allowed from %s", r.RemoteAddr) } else { // nolint:forcetypeassert // We will cover this by tests as we control what is stored. an = allowedNetworksHTTP.Load().(*config.Networks) err = fmt.Errorf("http connections are not allowed from %s", r.RemoteAddr) } if !an.Contains(r.RemoteAddr) { rw.Header().Set("Connection", "close") respondWith(rw, err, http.StatusForbidden) return } proxy.ServeHTTP(rw, r) default: badRequest.Inc() err := fmt.Errorf("%q: unsupported path: %q", r.RemoteAddr, r.URL.Path) rw.Header().Set("Connection", "close") respondWith(rw, err, http.StatusBadRequest) } } func loadConfig() (*config.Config, error) { if *configFile == "" { log.Fatalf("Missing -config flag") } cfg, err := config.LoadFile(*configFile) if err != nil { return nil, fmt.Errorf("can't load config %q: %w", *configFile, err) } return cfg, nil } // a configuration parameter value that is used in proxy initialization // changed func proxyConfigChanged(cfgCp *config.ConnectionPool, rp *reverseProxy) bool { return cfgCp.MaxIdleConns != proxy.maxIdleConns || cfgCp.MaxIdleConnsPerHost != proxy.maxIdleConnsPerHost } func applyConfig(cfg *config.Config) error { if proxy == nil || proxyConfigChanged(&cfg.ConnectionPool, proxy) { proxy = newReverseProxy(&cfg.ConnectionPool) } if err := proxy.applyConfig(cfg); err != nil { return err } allowedNetworksHTTP.Store(&cfg.Server.HTTP.AllowedNetworks) allowedNetworksHTTPS.Store(&cfg.Server.HTTPS.AllowedNetworks) allowedNetworksMetrics.Store(&cfg.Server.Metrics.AllowedNetworks) proxyHandler.Store(NewProxyHandler(&cfg.Server.Proxy)) log.SetDebug(cfg.LogDebug) log.Infof("Loaded config:\n%s", cfg) return nil } func reloadConfig() error { cfg, err := loadConfig() if err != nil { return err } return applyConfig(cfg) } var ( buildTag = "unknown" buildRevision = "unknown" buildTime = "unknown" ) func versionString() string { ver := buildTag if len(ver) == 0 { ver = "unknown" } return fmt.Sprintf("chproxy ver. %s, rev. %s, built at %s", ver, buildRevision, buildTime) }
package main import ( "github.com/contentsquare/chproxy/config" "github.com/contentsquare/chproxy/internal/topology" "github.com/prometheus/client_golang/prometheus" ) var ( statusCodes *prometheus.CounterVec requestSum *prometheus.CounterVec requestSuccess *prometheus.CounterVec limitExcess *prometheus.CounterVec concurrentQueries *prometheus.GaugeVec requestQueueSize *prometheus.GaugeVec userQueueOverflow *prometheus.CounterVec clusterUserQueueOverflow *prometheus.CounterVec requestBodyBytes *prometheus.CounterVec responseBodyBytes *prometheus.CounterVec cacheFailedInsert *prometheus.CounterVec cacheCorruptedFetch *prometheus.CounterVec cacheHit *prometheus.CounterVec cacheMiss *prometheus.CounterVec cacheSize *prometheus.GaugeVec cacheItems *prometheus.GaugeVec cacheSkipped *prometheus.CounterVec requestDuration *prometheus.SummaryVec proxiedResponseDuration *prometheus.SummaryVec cachedResponseDuration *prometheus.SummaryVec canceledRequest *prometheus.CounterVec cacheHitFromConcurrentQueries *prometheus.CounterVec cacheMissFromConcurrentQueries *prometheus.CounterVec killedRequests *prometheus.CounterVec timeoutRequest *prometheus.CounterVec configSuccess prometheus.Gauge configSuccessTime prometheus.Gauge badRequest prometheus.Counter retryRequest *prometheus.CounterVec ) func initMetrics(cfg *config.Config) { namespace := cfg.Server.Metrics.Namespace statusCodes = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "status_codes_total", Help: "Distribution by status codes", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node", "code"}, ) requestSum = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "request_sum_total", Help: "Total number of sent requests", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) requestSuccess = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "request_success_total", Help: "Total number of sent success requests", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) limitExcess = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "concurrent_limit_excess_total", Help: "Total number of max_concurrent_queries excess", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) concurrentQueries = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: namespace, Name: "concurrent_queries", Help: "The number of concurrent queries at current time", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) requestQueueSize = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: namespace, Name: "request_queue_size", Help: "Request queue sizes at the current time", }, []string{"user", "cluster", "cluster_user"}, ) userQueueOverflow = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "user_queue_overflow_total", Help: "The number of overflows for per-user request queues", }, []string{"user", "cluster", "cluster_user"}, ) clusterUserQueueOverflow = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "cluster_user_queue_overflow_total", Help: "The number of overflows for per-cluster_user request queues", }, []string{"user", "cluster", "cluster_user"}, ) requestBodyBytes = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "request_body_bytes_total", Help: "The amount of bytes read from request bodies", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) responseBodyBytes = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "response_body_bytes_total", Help: "The amount of bytes written to response bodies", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) cacheFailedInsert = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "cache_insertion_failures_total", Help: "The number of insertion in the cache that didn't work out", }, []string{"cache", "user", "cluster", "cluster_user"}, ) cacheCorruptedFetch = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "cache_get_corrutpion_total", Help: "The number of time a data fetching from redis was corrupted", }, []string{"cache", "user", "cluster", "cluster_user"}, ) cacheHit = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "cache_hits_total", Help: "The amount of cache hits", }, []string{"cache", "user", "cluster", "cluster_user"}, ) cacheMiss = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "cache_miss_total", Help: "The amount of cache misses", }, []string{"cache", "user", "cluster", "cluster_user"}, ) cacheSize = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: namespace, Name: "cache_size", Help: "Cache size at the current time", }, []string{"cache"}, ) cacheItems = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: namespace, Name: "cache_items", Help: "Cache items at the current time", }, []string{"cache"}, ) cacheSkipped = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "cache_payloadsize_too_big_total", Help: "The amount of too big payloads to be cached", }, []string{"cache", "user", "cluster", "cluster_user"}, ) requestDuration = prometheus.NewSummaryVec( prometheus.SummaryOpts{ Namespace: namespace, Name: "request_duration_seconds", Help: "Request duration. Includes possible wait time in the queue", Objectives: map[float64]float64{0.5: 1e-1, 0.9: 1e-2, 0.99: 1e-3, 0.999: 1e-4, 1: 1e-5}, }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) proxiedResponseDuration = prometheus.NewSummaryVec( prometheus.SummaryOpts{ Namespace: namespace, Name: "proxied_response_duration_seconds", Help: "Response duration proxied from clickhouse", Objectives: map[float64]float64{0.5: 1e-1, 0.9: 1e-2, 0.99: 1e-3, 0.999: 1e-4, 1: 1e-5}, }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) cachedResponseDuration = prometheus.NewSummaryVec( prometheus.SummaryOpts{ Namespace: namespace, Name: "cached_response_duration_seconds", Help: "Response duration served from the cache", Objectives: map[float64]float64{0.5: 1e-1, 0.9: 1e-2, 0.99: 1e-3, 0.999: 1e-4, 1: 1e-5}, }, []string{"cache", "user", "cluster", "cluster_user"}, ) canceledRequest = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "canceled_request_total", Help: "The number of requests canceled by remote client", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) cacheHitFromConcurrentQueries = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "cache_hit_concurrent_query_total", Help: "The amount of cache hits after having awaited concurrently executed queries", }, []string{"cache", "user", "cluster", "cluster_user"}, ) cacheMissFromConcurrentQueries = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "cache_miss_concurrent_query_total", Help: "The amount of cache misses, even if previously reported as queries available in the cache, after having awaited concurrently executed queries", }, []string{"cache", "user", "cluster", "cluster_user"}, ) killedRequests = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "killed_request_total", Help: "The number of requests killed by proxy", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) timeoutRequest = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "timeout_request_total", Help: "The number of timed out requests", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) configSuccess = prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: namespace, Name: "config_last_reload_successful", Help: "Whether the last configuration reload attempt was successful.", }) configSuccessTime = prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: namespace, Name: "config_last_reload_success_timestamp_seconds", Help: "Timestamp of the last successful configuration reload.", }) badRequest = prometheus.NewCounter(prometheus.CounterOpts{ Namespace: namespace, Name: "bad_requests_total", Help: "Total number of unsupported requests", }) retryRequest = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: namespace, Name: "retry_request_total", Help: "The number of retry requests", }, []string{"user", "cluster", "cluster_user", "replica", "cluster_node"}, ) } func registerMetrics(cfg *config.Config) { topology.RegisterMetrics(cfg) initMetrics(cfg) prometheus.MustRegister(statusCodes, requestSum, requestSuccess, limitExcess, concurrentQueries, requestQueueSize, userQueueOverflow, clusterUserQueueOverflow, requestBodyBytes, responseBodyBytes, cacheFailedInsert, cacheCorruptedFetch, cacheHit, cacheMiss, cacheSize, cacheItems, cacheSkipped, requestDuration, proxiedResponseDuration, cachedResponseDuration, canceledRequest, timeoutRequest, configSuccess, configSuccessTime, badRequest, retryRequest) }
package main import ( "bytes" "context" "errors" "fmt" "io" "net" "net/http" "net/http/httputil" "net/url" "strconv" "strings" "sync" "time" "github.com/contentsquare/chproxy/cache" "github.com/contentsquare/chproxy/config" "github.com/contentsquare/chproxy/internal/topology" "github.com/contentsquare/chproxy/log" "github.com/prometheus/client_golang/prometheus" ) // tmpDir temporary path to store ongoing queries results const tmpDir = "/tmp" // failedTransactionPrefix prefix added to the failed reason for concurrent queries registry const failedTransactionPrefix = "[concurrent query failed]" type reverseProxy struct { rp *httputil.ReverseProxy // configLock serializes access to applyConfig. // It protects reload* fields. configLock sync.Mutex reloadSignal chan struct{} reloadWG sync.WaitGroup // lock protects users, clusters and caches. // RWMutex enables concurrent access to getScope. lock sync.RWMutex users map[string]*user clusters map[string]*cluster caches map[string]*cache.AsyncCache hasWildcarded bool maxIdleConns int maxIdleConnsPerHost int } func newReverseProxy(cfgCp *config.ConnectionPool) *reverseProxy { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } return dialer.DialContext(ctx, network, addr) }, ForceAttemptHTTP2: true, MaxIdleConns: cfgCp.MaxIdleConns, MaxIdleConnsPerHost: cfgCp.MaxIdleConnsPerHost, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } return &reverseProxy{ rp: &httputil.ReverseProxy{ Director: func(*http.Request) {}, Transport: transport, // Suppress error logging in ReverseProxy, since all the errors // are handled and logged in the code below. ErrorLog: log.NilLogger, }, reloadSignal: make(chan struct{}), reloadWG: sync.WaitGroup{}, maxIdleConns: cfgCp.MaxIdleConnsPerHost, maxIdleConnsPerHost: cfgCp.MaxIdleConnsPerHost, } } func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { startTime := time.Now() s, status, err := rp.getScope(req) if err != nil { q := getQuerySnippet(req) err = fmt.Errorf("%q: %w; query: %q", req.RemoteAddr, err, q) respondWith(rw, err, status) return } // WARNING: don't use s.labels before s.incQueued, // since `replica` and `cluster_node` may change inside incQueued. if err := s.incQueued(); err != nil { limitExcess.With(s.labels).Inc() q := getQuerySnippet(req) err = fmt.Errorf("%s: %w; query: %q", s, err, q) respondWith(rw, err, http.StatusTooManyRequests) return } defer s.dec() log.Debugf("%s: request start", s) requestSum.With(s.labels).Inc() if s.user.allowCORS { origin := req.Header.Get("Origin") if len(origin) == 0 { origin = "*" } rw.Header().Set("Access-Control-Allow-Origin", origin) } req.Body = &statReadCloser{ ReadCloser: req.Body, bytesRead: requestBodyBytes.With(s.labels), } srw := &statResponseWriter{ ResponseWriter: rw, bytesWritten: responseBodyBytes.With(s.labels), } req, origParams := s.decorateRequest(req) // wrap body into cachedReadCloser, so we could obtain the original // request on error. req.Body = &cachedReadCloser{ ReadCloser: req.Body, } // publish session_id if needed if s.sessionId != "" { rw.Header().Set("X-ClickHouse-Server-Session-Id", s.sessionId) } q, shouldReturnFromCache, err := shouldRespondFromCache(s, origParams, req) if err != nil { respondWith(srw, err, http.StatusBadRequest) return } if shouldReturnFromCache { rp.serveFromCache(s, srw, req, origParams, q) } else { rp.proxyRequest(s, srw, srw, req) } // It is safe calling getQuerySnippet here, since the request // has been already read in proxyRequest or serveFromCache. query := getQuerySnippet(req) if srw.statusCode == http.StatusOK { requestSuccess.With(s.labels).Inc() log.Debugf("%s: request success; query: %q; Method: %s; URL: %q", s, query, req.Method, req.URL.String()) } else { log.Debugf("%s: request failure: non-200 status code %d; query: %q; Method: %s; URL: %q", s, srw.statusCode, query, req.Method, req.URL.String()) } statusCodes.With( prometheus.Labels{ "user": s.user.name, "cluster": s.cluster.name, "cluster_user": s.clusterUser.name, "replica": s.host.ReplicaName(), "cluster_node": s.host.Host(), "code": strconv.Itoa(srw.statusCode), }, ).Inc() since := time.Since(startTime).Seconds() requestDuration.With(s.labels).Observe(since) } func shouldRespondFromCache(s *scope, origParams url.Values, req *http.Request) ([]byte, bool, error) { if s.user.cache == nil || s.user.cache.Cache == nil { return nil, false, nil } noCache := origParams.Get("no_cache") if noCache == "1" || noCache == "true" { return nil, false, nil } q, err := getFullQuery(req) if err != nil { return nil, false, fmt.Errorf("%s: cannot read query: %w", s, err) } return q, canCacheQuery(q), nil } func executeWithRetry( ctx context.Context, s *scope, maxRetry int, rp func(http.ResponseWriter, *http.Request), rw ResponseWriterWithCode, srw StatResponseWriter, req *http.Request, monitorDuration func(float64), monitorRetryRequestInc func(prometheus.Labels), ) (float64, error) { startTime := time.Now() var since float64 // keep the request body body, err := io.ReadAll(req.Body) req.Body.Close() if err != nil { since = time.Since(startTime).Seconds() return since, err } numRetry := 0 for { // update body req.Body = io.NopCloser(bytes.NewBuffer(body)) req.Body.Close() rp(rw, req) err := ctx.Err() if err != nil { since = time.Since(startTime).Seconds() return since, err } // The request has been successfully proxied. srw.SetStatusCode(rw.StatusCode()) // StatusBadGateway response is returned by http.ReverseProxy when // it cannot establish connection to remote host. if rw.StatusCode() == http.StatusBadGateway { log.Debugf("the invalid host is: %s", s.host) s.host.Penalize() // comment s.host.dec() line to avoid double increment; issue #322 // s.host.dec() s.host.SetIsActive(false) nextHost := s.cluster.getHost() // The query could be retried if it has no stickiness to a certain server if numRetry < maxRetry && nextHost.IsActive() && s.sessionId == "" { // the query execution has been failed monitorRetryRequestInc(s.labels) currentHost := s.host // decrement the current failed host counter and increment the new host // as for the end of the requests we will close the scope and in that closed scope // decrement the new host PR - https://github.com/ContentSquare/chproxy/pull/357 if currentHost != nextHost { currentHost.DecrementConnections() nextHost.IncrementConnections() } // update host s.host = nextHost req.URL.Host = s.host.Host() req.URL.Scheme = s.host.Scheme() log.Debugf("the valid host is: %s", s.host) } else { since = time.Since(startTime).Seconds() monitorDuration(since) q := getQuerySnippet(req) err1 := fmt.Errorf("%s: cannot reach %s; query: %q", s, s.host.Host(), q) respondWith(srw, err1, srw.StatusCode()) break } } else { since = time.Since(startTime).Seconds() break } numRetry++ } return since, nil } // proxyRequest proxies the given request to clickhouse and sends response // to rw. // // srw is required only for setting non-200 status codes on timeouts // or on client connection disconnects. func (rp *reverseProxy) proxyRequest(s *scope, rw ResponseWriterWithCode, srw *statResponseWriter, req *http.Request) { // wrap body into cachedReadCloser, so we could obtain the original // request on error. if _, ok := req.Body.(*cachedReadCloser); !ok { req.Body = &cachedReadCloser{ ReadCloser: req.Body, } } timeout, timeoutErrMsg := s.getTimeoutWithErrMsg() ctx := context.Background() if timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, timeout) defer cancel() } // Cancel the ctx if client closes the remote connection, // so the proxied query may be killed instantly. ctx, ctxCancel := listenToCloseNotify(ctx, rw) defer ctxCancel() req = req.WithContext(ctx) startTime := time.Now() executeDuration, err := executeWithRetry(ctx, s, s.cluster.retryNumber, rp.rp.ServeHTTP, rw, srw, req, func(duration float64) { proxiedResponseDuration.With(s.labels).Observe(duration) }, func(labels prometheus.Labels) { retryRequest.With(labels).Inc() }) switch { case err == nil: return case errors.Is(err, context.Canceled): canceledRequest.With(s.labels).Inc() q := getQuerySnippet(req) log.Debugf("%s: remote client closed the connection in %s; query: %q", s, time.Since(startTime), q) if err := s.killQuery(); err != nil { log.Errorf("%s: cannot kill query: %s; query: %q", s, err, q) } srw.statusCode = 499 // See https://httpstatuses.com/499 . case errors.Is(err, context.DeadlineExceeded): timeoutRequest.With(s.labels).Inc() // Penalize host with the timed out query, because it may be overloaded. s.host.Penalize() q := getQuerySnippet(req) log.Debugf("%s: query timeout in %f; query: %q", s, executeDuration, q) if err := s.killQuery(); err != nil { log.Errorf("%s: cannot kill query: %s; query: %q", s, err, q) } err = fmt.Errorf("%s: %w; query: %q", s, timeoutErrMsg, q) respondWith(rw, err, http.StatusGatewayTimeout) srw.statusCode = http.StatusGatewayTimeout default: panic(fmt.Sprintf("BUG: context.Context.Err() returned unexpected error: %s", err)) } } func listenToCloseNotify(ctx context.Context, rw ResponseWriterWithCode) (context.Context, context.CancelFunc) { // Cancel the ctx if client closes the remote connection, // so the proxied query may be killed instantly. ctx, ctxCancel := context.WithCancel(ctx) // rw must implement http.CloseNotifier. rwc, ok := rw.(http.CloseNotifier) if !ok { panic("BUG: the wrapped ResponseWriter must implement http.CloseNotifier") } ch := rwc.CloseNotify() go func() { select { case <-ch: ctxCancel() case <-ctx.Done(): } }() return ctx, ctxCancel } //nolint:cyclop //TODO refactor this method, most likely requires some work. func (rp *reverseProxy) serveFromCache(s *scope, srw *statResponseWriter, req *http.Request, origParams url.Values, q []byte) { labels := makeCacheLabels(s) key := newCacheKey(s, origParams, q, req) startTime := time.Now() userCache := s.user.cache // Try to serve from cache cachedData, err := userCache.Get(key) if err == nil { // The response has been successfully served from cache. defer cachedData.Data.Close() cacheHit.With(labels).Inc() cachedResponseDuration.With(labels).Observe(time.Since(startTime).Seconds()) log.Debugf("%s: cache hit", s) _ = RespondWithData(srw, cachedData.Data, cachedData.ContentMetadata, cachedData.Ttl, XCacheHit, http.StatusOK, labels) return } // Await for potential result from concurrent query transactionStatus, err := userCache.AwaitForConcurrentTransaction(key) if err != nil { // log and continue processing log.Errorf("failed to await for concurrent transaction due to: %v", err) } else { if transactionStatus.State.IsCompleted() { cachedData, err := userCache.Get(key) if err == nil { defer cachedData.Data.Close() _ = RespondWithData(srw, cachedData.Data, cachedData.ContentMetadata, cachedData.Ttl, XCacheHit, http.StatusOK, labels) cacheHitFromConcurrentQueries.With(labels).Inc() log.Debugf("%s: cache hit after awaiting concurrent query", s) return } else { cacheMissFromConcurrentQueries.With(labels).Inc() log.Debugf("%s: cache miss after awaiting concurrent query", s) } } else if transactionStatus.State.IsFailed() { respondWith(srw, fmt.Errorf(transactionStatus.FailReason), http.StatusInternalServerError) return } } // The response wasn't found in the cache. // Request it from clickhouse. tmpFileRespWriter, err := cache.NewTmpFileResponseWriter(srw, tmpDir) if err != nil { err = fmt.Errorf("%s: %w; query: %q", s, err, q) respondWith(srw, err, http.StatusInternalServerError) return } defer tmpFileRespWriter.Close() // Initialise transaction err = userCache.Create(key) if err != nil { log.Errorf("%s: %s; query: %q - failed to register transaction", s, err, q) } // proxy request and capture response along with headers to [[TmpFileResponseWriter]] rp.proxyRequest(s, tmpFileRespWriter, srw, req) contentEncoding := tmpFileRespWriter.GetCapturedContentEncoding() contentType := tmpFileRespWriter.GetCapturedContentType() contentLength, err := tmpFileRespWriter.GetCapturedContentLength() if err != nil { log.Errorf("%s: %s; query: %q - failed to get contentLength of query", s, err, q) respondWith(srw, err, http.StatusInternalServerError) return } reader, err := tmpFileRespWriter.Reader() if err != nil { log.Errorf("%s: %s; query: %q - failed to get Reader from tmp file", s, err, q) respondWith(srw, err, http.StatusInternalServerError) return } contentMetadata := cache.ContentMetadata{Length: contentLength, Encoding: contentEncoding, Type: contentType} statusCode := tmpFileRespWriter.StatusCode() if statusCode != http.StatusOK || s.canceled { // Do not cache non-200 or cancelled responses. // Restore the original status code by proxyRequest if it was set. if srw.statusCode != 0 { tmpFileRespWriter.WriteHeader(srw.statusCode) } errString, err := toString(reader) if err != nil { log.Errorf("%s failed to get error reason: %s", s, err.Error()) } errReason := fmt.Sprintf("%s %s", failedTransactionPrefix, errString) rp.completeTransaction(s, statusCode, userCache, key, q, errReason) // we need to reset the offset since the reader of tmpFileRespWriter was already // consumed in RespondWithData(...) err = tmpFileRespWriter.ResetFileOffset() if err != nil { err = fmt.Errorf("%s: %w; query: %q", s, err, q) respondWith(srw, err, http.StatusInternalServerError) return } err = RespondWithData(srw, reader, contentMetadata, 0*time.Second, XCacheMiss, statusCode, labels) if err != nil { err = fmt.Errorf("%s: %w; query: %q", s, err, q) respondWith(srw, err, http.StatusInternalServerError) } } else { // Do not cache responses greater than max payload size. if contentLength > int64(s.user.cache.MaxPayloadSize) { cacheSkipped.With(labels).Inc() log.Infof("%s: Request will not be cached. Content length (%d) is greater than max payload size (%d)", s, contentLength, s.user.cache.MaxPayloadSize) rp.completeTransaction(s, statusCode, userCache, key, q, "") err = RespondWithData(srw, reader, contentMetadata, 0*time.Second, XCacheNA, tmpFileRespWriter.StatusCode(), labels) if err != nil { err = fmt.Errorf("%s: %w; query: %q", s, err, q) respondWith(srw, err, http.StatusInternalServerError) } return } cacheMiss.With(labels).Inc() log.Debugf("%s: cache miss", s) expiration, err := userCache.Put(reader, contentMetadata, key) if err != nil { cacheFailedInsert.With(labels).Inc() log.Errorf("%s: %s; query: %q - failed to put response in the cache", s, err, q) } rp.completeTransaction(s, statusCode, userCache, key, q, "") // we need to reset the offset since the reader of tmpFileRespWriter was already // consumed in RespondWithData(...) err = tmpFileRespWriter.ResetFileOffset() if err != nil { err = fmt.Errorf("%s: %w; query: %q", s, err, q) respondWith(srw, err, http.StatusInternalServerError) return } err = RespondWithData(srw, reader, contentMetadata, expiration, XCacheMiss, statusCode, labels) if err != nil { err = fmt.Errorf("%s: %w; query: %q", s, err, q) respondWith(srw, err, http.StatusInternalServerError) return } } } func makeCacheLabels(s *scope) prometheus.Labels { // Do not store `replica` and `cluster_node` in labels, since they have // no sense for cache metrics. return prometheus.Labels{ "cache": s.user.cache.Name(), "user": s.labels["user"], "cluster": s.labels["cluster"], "cluster_user": s.labels["cluster_user"], } } func newCacheKey(s *scope, origParams url.Values, q []byte, req *http.Request) *cache.Key { var userParamsHash uint32 if s.user.params != nil { userParamsHash = s.user.params.key } queryParamsHash := calcQueryParamsHash(origParams) credHash, err := uint32(0), error(nil) if !s.user.cache.SharedWithAllUsers { credHash, err = calcCredentialHash(s.clusterUser.name, s.clusterUser.password) } if err != nil { log.Errorf("fail to calc hash on credentials for user %s", s.user.name) credHash = 0 } return cache.NewKey( skipLeadingComments(q), origParams, sortHeader(req.Header.Get("Accept-Encoding")), userParamsHash, queryParamsHash, credHash, ) } func toString(stream io.Reader) (string, error) { buf := new(bytes.Buffer) _, err := buf.ReadFrom(stream) if err != nil { return "", err } return bytes.NewBuffer(buf.Bytes()).String(), nil } // clickhouseRecoverableStatusCodes set of recoverable http responses' status codes from Clickhouse. // When such happens we mark transaction as completed and let concurrent query to hit another Clickhouse shard. // possible http error codes in clickhouse (i.e: https://github.com/ClickHouse/ClickHouse/blob/master/src/Server/HTTPHandler.cpp) var clickhouseRecoverableStatusCodes = map[int]struct{}{http.StatusServiceUnavailable: {}, http.StatusRequestTimeout: {}} func (rp *reverseProxy) completeTransaction(s *scope, statusCode int, userCache *cache.AsyncCache, key *cache.Key, q []byte, failReason string, ) { // complete successful transactions or those with empty fail reason if statusCode < 300 || failReason == "" { if err := userCache.Complete(key); err != nil { log.Errorf("%s: %s; query: %q", s, err, q) } return } if _, ok := clickhouseRecoverableStatusCodes[statusCode]; ok { if err := userCache.Complete(key); err != nil { log.Errorf("%s: %s; query: %q", s, err, q) } } else { if err := userCache.Fail(key, failReason); err != nil { log.Errorf("%s: %s; query: %q", s, err, q) } } } func calcQueryParamsHash(origParams url.Values) uint32 { queryParams := make(map[string]string) for param := range origParams { if strings.HasPrefix(param, "param_") { queryParams[param] = origParams.Get(param) } } queryParamsHash, err := calcMapHash(queryParams) if err != nil { log.Errorf("fail to calc hash for params %s; %s", origParams, err) return 0 } return queryParamsHash } // applyConfig applies the given cfg to reverseProxy. // // New config is applied only if non-nil error returned. // Otherwise old config version is kept. func (rp *reverseProxy) applyConfig(cfg *config.Config) error { // configLock protects from concurrent calls to applyConfig // by serializing such calls. // configLock shouldn't be used in other places. rp.configLock.Lock() defer rp.configLock.Unlock() clusters, err := newClusters(cfg.Clusters) if err != nil { return err } caches := make(map[string]*cache.AsyncCache, len(cfg.Caches)) defer func() { // caches is swapped with old caches from rp.caches // on successful config reload - see the end of reloadConfig. for _, tmpCache := range caches { // Speed up applyConfig by closing caches in background, // since the process of cache closing may be lengthy // due to cleaning. go tmpCache.Close() } }() // transactionsTimeout used for creation of transactions registry inside async cache. // It is set to the highest configured execution time of all users to avoid setups were users use the same cache and have configured different maxExecutionTime. // This would provoke undesired behaviour of `dogpile effect` transactionsTimeout := config.Duration(0) for _, user := range cfg.Users { if user.MaxExecutionTime > transactionsTimeout { transactionsTimeout = user.MaxExecutionTime } if user.IsWildcarded { rp.hasWildcarded = true } } if err := initTempCaches(caches, transactionsTimeout, cfg.Caches); err != nil { return err } params, err := paramsFromConfig(cfg.ParamGroups) if err != nil { return err } profile := &usersProfile{ cfg: cfg.Users, clusters: clusters, caches: caches, params: params, } users, err := profile.newUsers() if err != nil { return err } if err := validateNoWildcardedUserForHeartbeat(clusters, cfg.Clusters); err != nil { return err } // New configs have been successfully prepared. // Restart service goroutines with new configs. // Stop the previous service goroutines. close(rp.reloadSignal) rp.reloadWG.Wait() rp.reloadSignal = make(chan struct{}) rp.restartWithNewConfig(caches, clusters, users) // Substitute old configs with the new configs in rp. // All the currently running requests will continue with old configs, // while all the new requests will use new configs. rp.lock.Lock() rp.clusters = clusters rp.users = users // Swap is needed for deferred closing of old caches. // See the code above where new caches are created. caches, rp.caches = rp.caches, caches rp.lock.Unlock() return nil } func initTempCaches(caches map[string]*cache.AsyncCache, transactionsTimeout config.Duration, cfg []config.Cache) error { for _, cc := range cfg { if _, ok := caches[cc.Name]; ok { return fmt.Errorf("duplicate config for cache %q", cc.Name) } tmpCache, err := cache.NewAsyncCache(cc, time.Duration(transactionsTimeout)) if err != nil { return err } caches[cc.Name] = tmpCache } return nil } func paramsFromConfig(cfg []config.ParamGroup) (map[string]*paramsRegistry, error) { params := make(map[string]*paramsRegistry, len(cfg)) for _, p := range cfg { if _, ok := params[p.Name]; ok { return nil, fmt.Errorf("duplicate config for ParamGroups %q", p.Name) } registry, err := newParamsRegistry(p.Params) if err != nil { return nil, fmt.Errorf("cannot initialize params %q: %w", p.Name, err) } params[p.Name] = registry } return params, nil } func validateNoWildcardedUserForHeartbeat(clusters map[string]*cluster, cfg []config.Cluster) error { for c := range cfg { cfgcl := cfg[c] clname := cfgcl.Name cuname := cfgcl.ClusterUsers[0].Name heartbeat := cfg[c].HeartBeat cl := clusters[clname] cu := cl.users[cuname] if cu.isWildcarded { if heartbeat.Request != "/ping" && len(heartbeat.User) == 0 { return fmt.Errorf( "`cluster.heartbeat.user ` cannot be unset for %q because a wildcarded user cannot send heartbeat", clname, ) } } } return nil } func (rp *reverseProxy) restartWithNewConfig(caches map[string]*cache.AsyncCache, clusters map[string]*cluster, users map[string]*user) { // Reset metrics from the previous configs, which may become irrelevant // with new configs. // Counters and Summary metrics are always relevant. // Gauge metrics may become irrelevant if they may freeze at non-zero // value after config reload. topology.HostHealth.Reset() cacheSize.Reset() cacheItems.Reset() // Start service goroutines with new configs. for _, c := range clusters { for _, r := range c.replicas { for _, h := range r.hosts { rp.reloadWG.Add(1) go func(h *topology.Node) { h.StartHeartbeat(rp.reloadSignal) rp.reloadWG.Done() }(h) } } for _, cu := range c.users { rp.reloadWG.Add(1) go func(cu *clusterUser) { cu.rateLimiter.run(rp.reloadSignal) rp.reloadWG.Done() }(cu) } } for _, u := range users { rp.reloadWG.Add(1) go func(u *user) { u.rateLimiter.run(rp.reloadSignal) rp.reloadWG.Done() }(u) } } // refreshCacheMetrics refreshes cacheSize and cacheItems metrics. func (rp *reverseProxy) refreshCacheMetrics() { rp.lock.RLock() defer rp.lock.RUnlock() for _, c := range rp.caches { stats := c.Stats() labels := prometheus.Labels{ "cache": c.Name(), } cacheSize.With(labels).Set(float64(stats.Size)) cacheItems.With(labels).Set(float64(stats.Items)) } } // find user, cluster and clusterUser // in case of wildcarded user, cluster user is crafted to use original credentials func (rp *reverseProxy) getUser(name string, password string) (found bool, u *user, c *cluster, cu *clusterUser) { rp.lock.RLock() defer rp.lock.RUnlock() found = false u = rp.users[name] switch { case u != nil: found = (u.password == password) // existence of c and cu for toCluster is guaranteed by applyConfig c = rp.clusters[u.toCluster] cu = c.users[u.toUser] case name == "" || name == defaultUser: // default user can't work with the wildcarded feature for security reasons found = false case rp.hasWildcarded: // checking if we have wildcarded users and if username matches one 3 possibles patterns found, u, c, cu = rp.findWildcardedUserInformation(name, password) } return found, u, c, cu } func (rp *reverseProxy) findWildcardedUserInformation(name string, password string) (found bool, u *user, c *cluster, cu *clusterUser) { // cf a validation in config.go, the names must contains either a prefix, a suffix or a wildcard // the wildcarded user is "*" // the wildcarded user is "*[suffix]" // the wildcarded user is "[prefix]*" for _, user := range rp.users { if user.isWildcarded { s := strings.Split(user.name, "*") switch { case s[0] == "" && s[1] == "": return rp.generateWildcardedUserInformation(user, name, password) case s[0] == "": suffix := s[1] if strings.HasSuffix(name, suffix) { return rp.generateWildcardedUserInformation(user, name, password) } case s[1] == "": prefix := s[0] if strings.HasPrefix(name, prefix) { return rp.generateWildcardedUserInformation(user, name, password) } } } } return false, nil, nil, nil } func (rp *reverseProxy) generateWildcardedUserInformation(user *user, name string, password string) (found bool, u *user, c *cluster, cu *clusterUser) { found = false c = rp.clusters[user.toCluster] wildcardedCu := c.users[user.toUser] if wildcardedCu != nil { newCU := deepCopy(wildcardedCu) found = true u = user cu = newCU cu.name = name cu.password = password // TODO : improve the following behavior // the wildcarded user feature creates some side-effects on clusterUser limitations (like the max_concurrent_queries) // because of the use of a deep copy of the clusterUser. The side effect should not be too impactful since the limitation still works on user. // But we need this deep copy since we're changing the name & password of clusterUser and if we used the same instance for every call to chproxy, // it could lead to security issues since a specific query run by user A on chproxy side could trigger a query in clickhouse from user B. // Doing a clean fix would require a huge refactoring. } return } func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) { name, password := getAuth(req) sessionId := getSessionId(req) sessionTimeout := getSessionTimeout(req) var ( u *user c *cluster cu *clusterUser ) found, u, c, cu := rp.getUser(name, password) if !found { return nil, http.StatusUnauthorized, fmt.Errorf("invalid username or password for user %q", name) } if u.denyHTTP && req.TLS == nil { return nil, http.StatusForbidden, fmt.Errorf("user %q is not allowed to access via http", u.name) } if u.denyHTTPS && req.TLS != nil { return nil, http.StatusForbidden, fmt.Errorf("user %q is not allowed to access via https", u.name) } if !u.allowedNetworks.Contains(req.RemoteAddr) { return nil, http.StatusForbidden, fmt.Errorf("user %q is not allowed to access", u.name) } if !cu.allowedNetworks.Contains(req.RemoteAddr) { return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name) } s := newScope(req, u, c, cu, sessionId, sessionTimeout) q, err := getFullQuery(req) if err != nil { return nil, http.StatusBadRequest, fmt.Errorf("%s: cannot read query: %w", s, err) } s.requestPacketSize = len(q) return s, 0, nil }
package main import ( "net" "net/http" "strings" "github.com/contentsquare/chproxy/config" ) const ( xForwardedForHeader = "X-Forwarded-For" xRealIPHeader = "X-Real-Ip" forwardedHeader = "Forwarded" ) type ProxyHandler struct { proxy *config.Proxy } func NewProxyHandler(proxy *config.Proxy) *ProxyHandler { return &ProxyHandler{ proxy: proxy, } } func (m *ProxyHandler) GetRemoteAddr(r *http.Request) string { if m.proxy.Enable { var addr string if m.proxy.Header != "" { addr = r.Header.Get(m.proxy.Header) } else { addr = parseDefaultProxyHeaders(r) } if isValidAddr(addr) { return addr } } return r.RemoteAddr } // isValidAddr checks if the Addr is a valid IP or IP:port. func isValidAddr(addr string) bool { if addr == "" { return false } ip, _, err := net.SplitHostPort(addr) if err != nil { return net.ParseIP(addr) != nil } return net.ParseIP(ip) != nil } func parseDefaultProxyHeaders(r *http.Request) string { var addr string if fwd := r.Header.Get(xForwardedForHeader); fwd != "" { addr = extractFirstMatchFromIPList(fwd) } else if fwd := r.Header.Get(xRealIPHeader); fwd != "" { addr = extractFirstMatchFromIPList(fwd) } else if fwd := r.Header.Get(forwardedHeader); fwd != "" { // See: https://tools.ietf.org/html/rfc7239. addr = parseForwardedHeader(fwd) } return addr } func extractFirstMatchFromIPList(ipList string) string { if ipList == "" { return "" } s := strings.Index(ipList, ",") if s == -1 { s = len(ipList) } return ipList[:s] } func parseForwardedHeader(fwd string) string { splits := strings.Split(fwd, ";") if len(splits) == 0 { return "" } for _, split := range splits { trimmed := strings.TrimSpace(split) if strings.HasPrefix(strings.ToLower(trimmed), "for=") { forSplits := strings.Split(trimmed, ",") if len(forSplits) == 0 { return "" } addr := forSplits[0][4:] trimmedAddr := strings. NewReplacer("\"", "", "[", "", "]", ""). Replace(addr) // If IpV6, remove brackets and quotes. return trimmedAddr } } return "" }
package main import ( "context" "fmt" "io" "net" "net/http" "net/url" "regexp" "strconv" "strings" "sync/atomic" "time" "github.com/contentsquare/chproxy/cache" "github.com/contentsquare/chproxy/config" "github.com/contentsquare/chproxy/internal/heartbeat" "github.com/contentsquare/chproxy/internal/topology" "github.com/contentsquare/chproxy/log" "github.com/prometheus/client_golang/prometheus" "golang.org/x/time/rate" ) type scopeID uint64 func (sid scopeID) String() string { return fmt.Sprintf("%08X", uint64(sid)) } func newScopeID() scopeID { sid := atomic.AddUint64(&nextScopeID, 1) return scopeID(sid) } var nextScopeID = uint64(time.Now().UnixNano()) type scope struct { startTime time.Time id scopeID host *topology.Node cluster *cluster user *user clusterUser *clusterUser sessionId string sessionTimeout int remoteAddr string localAddr string // is true when KillQuery has been called canceled bool labels prometheus.Labels requestPacketSize int } func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int) *scope { h := c.getHost() if sessionId != "" { h = c.getHostSticky(sessionId) } var localAddr string if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { localAddr = addr.String() } s := &scope{ startTime: time.Now(), id: newScopeID(), host: h, cluster: c, user: u, clusterUser: cu, sessionId: sessionId, sessionTimeout: sessionTimeout, remoteAddr: req.RemoteAddr, localAddr: localAddr, labels: prometheus.Labels{ "user": u.name, "cluster": c.name, "cluster_user": cu.name, "replica": h.ReplicaName(), "cluster_node": h.Host(), }, } return s } func (s *scope) String() string { return fmt.Sprintf("[ Id: %s; User %q(%d) proxying as %q(%d) to %q(%d); RemoteAddr: %q; LocalAddr: %q; Duration: %d μs]", s.id, s.user.name, s.user.queryCounter.load(), s.clusterUser.name, s.clusterUser.queryCounter.load(), s.host.Host(), s.host.CurrentLoad(), s.remoteAddr, s.localAddr, time.Since(s.startTime).Nanoseconds()/1000.0) } //nolint:cyclop // TODO abstract user queues to reduce complexity here. func (s *scope) incQueued() error { if s.user.queueCh == nil && s.clusterUser.queueCh == nil { // Request queues in the current scope are disabled. return s.inc() } // Do not store `replica` and `cluster_node` in labels, since they have // no sense for queue metrics. labels := prometheus.Labels{ "user": s.labels["user"], "cluster": s.labels["cluster"], "cluster_user": s.labels["cluster_user"], } if s.user.queueCh != nil { select { case s.user.queueCh <- struct{}{}: defer func() { <-s.user.queueCh }() default: // Per-user request queue is full. // Give the request the last chance to run. err := s.inc() if err != nil { userQueueOverflow.With(labels).Inc() } return err } } if s.clusterUser.queueCh != nil { select { case s.clusterUser.queueCh <- struct{}{}: defer func() { <-s.clusterUser.queueCh }() default: // Per-clusterUser request queue is full. // Give the request the last chance to run. err := s.inc() if err != nil { clusterUserQueueOverflow.With(labels).Inc() } return err } } // The request has been successfully queued. queueSize := requestQueueSize.With(labels) queueSize.Inc() defer queueSize.Dec() // Try starting the request during the given duration. sleep, deadline := s.calculateQueueDeadlineAndSleep() return s.waitUntilAllowStart(sleep, deadline, labels) } func (s *scope) waitUntilAllowStart(sleep time.Duration, deadline time.Time, labels prometheus.Labels) error { for { err := s.inc() if err == nil { // The request is allowed to start. return nil } dLeft := time.Until(deadline) if dLeft <= 0 { // Give up: the request exceeded its wait time // in the queue :( return err } // The request has dLeft remaining time to wait in the queue. // Sleep for a bit and try starting it again. if sleep > dLeft { time.Sleep(dLeft) } else { time.Sleep(sleep) } var h *topology.Node // Choose new host, since the previous one may become obsolete // after sleeping. if s.sessionId == "" { h = s.cluster.getHost() } else { // if request has session_id, set same host h = s.cluster.getHostSticky(s.sessionId) } s.host = h s.labels["replica"] = h.ReplicaName() s.labels["cluster_node"] = h.Host() } } func (s *scope) calculateQueueDeadlineAndSleep() (time.Duration, time.Time) { d := s.maxQueueTime() dSleep := d / 10 if dSleep > time.Second { dSleep = time.Second } if dSleep < time.Millisecond { dSleep = time.Millisecond } deadline := time.Now().Add(d) return dSleep, deadline } func (s *scope) inc() error { uQueries := s.user.queryCounter.inc() cQueries := s.clusterUser.queryCounter.inc() var err error if s.user.maxConcurrentQueries > 0 && uQueries > s.user.maxConcurrentQueries { err = fmt.Errorf("limits for user %q are exceeded: max_concurrent_queries limit: %d", s.user.name, s.user.maxConcurrentQueries) } if s.clusterUser.maxConcurrentQueries > 0 && cQueries > s.clusterUser.maxConcurrentQueries { err = fmt.Errorf("limits for cluster user %q are exceeded: max_concurrent_queries limit: %d", s.clusterUser.name, s.clusterUser.maxConcurrentQueries) } err2 := s.checkTokenFreeRateLimiters() if err2 != nil { err = err2 } if err != nil { s.user.queryCounter.dec() s.clusterUser.queryCounter.dec() // Decrement rate limiter here, so it doesn't count requests // that didn't start due to limits overflow. s.user.rateLimiter.dec() s.clusterUser.rateLimiter.dec() return err } s.host.IncrementConnections() concurrentQueries.With(s.labels).Inc() return nil } func (s *scope) checkTokenFreeRateLimiters() error { var err error uRPM := s.user.rateLimiter.inc() cRPM := s.clusterUser.rateLimiter.inc() // int32(xRPM) > 0 check is required to detect races when RPM // is decremented on error below after per-minute zeroing // in rateLimiter.run. // These races become innocent with the given check. if s.user.reqPerMin > 0 && int32(uRPM) > 0 && uRPM > s.user.reqPerMin { err = fmt.Errorf("rate limit for user %q is exceeded: requests_per_minute limit: %d", s.user.name, s.user.reqPerMin) } if s.clusterUser.reqPerMin > 0 && int32(cRPM) > 0 && cRPM > s.clusterUser.reqPerMin { err = fmt.Errorf("rate limit for cluster user %q is exceeded: requests_per_minute limit: %d", s.clusterUser.name, s.clusterUser.reqPerMin) } err2 := s.checkTokenFreePacketSizeRateLimiters() if err2 != nil { err = err2 } return err } func (s *scope) checkTokenFreePacketSizeRateLimiters() error { var err error // reserving tokens num s.requestPacketSize if s.user.reqPacketSizeTokensBurst > 0 { tl := s.user.reqPacketSizeTokenLimiter ok := tl.AllowN(time.Now(), s.requestPacketSize) if !ok { err = fmt.Errorf("limits for user %q is exceeded: request_packet_size_tokens_burst limit: %d", s.user.name, s.user.reqPacketSizeTokensBurst) } } if s.clusterUser.reqPacketSizeTokensBurst > 0 { tl := s.clusterUser.reqPacketSizeTokenLimiter ok := tl.AllowN(time.Now(), s.requestPacketSize) if !ok { err = fmt.Errorf("limits for cluster user %q is exceeded: request_packet_size_tokens_burst limit: %d", s.clusterUser.name, s.clusterUser.reqPacketSizeTokensBurst) } } return err } func (s *scope) dec() { // There is no need in ratelimiter.dec here, since the rate limiter // is automatically zeroed every minute in rateLimiter.run. s.user.queryCounter.dec() s.clusterUser.queryCounter.dec() s.host.DecrementConnections() concurrentQueries.With(s.labels).Dec() } const killQueryTimeout = time.Second * 30 func (s *scope) killQuery() error { log.Debugf("killing the query with query_id=%s", s.id) killedRequests.With(s.labels).Inc() s.canceled = true query := fmt.Sprintf("KILL QUERY WHERE query_id = '%s'", s.id) r := strings.NewReader(query) addr := s.host.String() req, err := http.NewRequest("POST", addr, r) if err != nil { return fmt.Errorf("error while creating kill query request to %s: %w", addr, err) } ctx, cancel := context.WithTimeout(context.Background(), killQueryTimeout) defer cancel() req = req.WithContext(ctx) // send request as kill_query_user userName := s.cluster.killQueryUserName if len(userName) == 0 { userName = defaultUser } req.SetBasicAuth(userName, s.cluster.killQueryUserPassword) resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("error while executing clickhouse query %q at %q: %w", query, addr, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { responseBody, _ := io.ReadAll(resp.Body) return fmt.Errorf("unexpected status code returned from query %q at %q: %d. Response body: %q", query, addr, resp.StatusCode, responseBody) } respBody, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("cannot read response body for the query %q: %w", query, err) } log.Debugf("killed the query with query_id=%s; respBody: %q", s.id, respBody) return nil } // allowedParams contains query args allowed to be proxied. // See https://clickhouse.com/docs/en/operations/settings/ // // All the other params passed via query args are stripped before // proxying the request. This is for the sake of security. var allowedParams = []string{ "query", "database", "default_format", // if `compress=1`, CH will compress the data it sends you "compress", // if `decompress=1` , CH will decompress the same data that you pass in the POST method "decompress", // compress the result if the client over HTTP said that it understands data compressed by gzip or deflate. "enable_http_compression", // limit on the number of rows in the result "max_result_rows", // whether to count extreme values "extremes", // what to do if the volume of the result exceeds one of the limits "result_overflow_mode", // session stickiness "session_id", // session timeout "session_timeout", } // This regexp must match params needed to describe a way to use external data // @see https://clickhouse.yandex/docs/en/table_engines/external_data/ var externalDataParams = regexp.MustCompile(`(_types|_structure|_format)$`) func (s *scope) decorateRequest(req *http.Request) (*http.Request, url.Values) { // Make new params to purify URL. params := make(url.Values) // Set user params if s.user.params != nil { for _, param := range s.user.params.params { params.Set(param.Key, param.Value) } } // Keep allowed params. origParams := req.URL.Query() for _, param := range allowedParams { val := origParams.Get(param) if len(val) > 0 { params.Set(param, val) } } // Keep parametrized queries params for param := range origParams { if strings.HasPrefix(param, "param_") { params.Set(param, origParams.Get(param)) } } // Keep external_data params if req.Method == "POST" { s.decoratePostRequest(req, origParams, params) } // Set query_id as scope_id to have possibility to kill query if needed. params.Set("query_id", s.id.String()) // Set session_timeout an idle timeout for session params.Set("session_timeout", strconv.Itoa(s.sessionTimeout)) req.URL.RawQuery = params.Encode() // Rewrite possible previous Basic Auth and send request // as cluster user. req.SetBasicAuth(s.clusterUser.name, s.clusterUser.password) // Delete possible X-ClickHouse headers, // it is not allowed to use X-ClickHouse HTTP headers and other authentication methods simultaneously req.Header.Del("X-ClickHouse-User") req.Header.Del("X-ClickHouse-Key") // Send request to the chosen host from cluster. req.URL.Scheme = s.host.Scheme() req.URL.Host = s.host.Host() // Extend ua with additional info, so it may be queried // via system.query_log.http_user_agent. ua := fmt.Sprintf("RemoteAddr: %s; LocalAddr: %s; CHProxy-User: %s; CHProxy-ClusterUser: %s; %s", s.remoteAddr, s.localAddr, s.user.name, s.clusterUser.name, req.UserAgent()) req.Header.Set("User-Agent", ua) return req, origParams } func (s *scope) decoratePostRequest(req *http.Request, origParams, params url.Values) { ct := req.Header.Get("Content-Type") if strings.Contains(ct, "multipart/form-data") { for key := range origParams { if externalDataParams.MatchString(key) { params.Set(key, origParams.Get(key)) } } // disable cache for external_data queries origParams.Set("no_cache", "1") log.Debugf("external data params detected - cache will be disabled") } } func (s *scope) getTimeoutWithErrMsg() (time.Duration, error) { var ( timeout time.Duration timeoutErrMsg error ) if s.user.maxExecutionTime > 0 { timeout = s.user.maxExecutionTime timeoutErrMsg = fmt.Errorf("timeout for user %q exceeded: %v", s.user.name, timeout) } if timeout == 0 || (s.clusterUser.maxExecutionTime > 0 && s.clusterUser.maxExecutionTime < timeout) { timeout = s.clusterUser.maxExecutionTime timeoutErrMsg = fmt.Errorf("timeout for cluster user %q exceeded: %v", s.clusterUser.name, timeout) } return timeout, timeoutErrMsg } func (s *scope) maxQueueTime() time.Duration { d := s.user.maxQueueTime if d <= 0 || s.clusterUser.maxQueueTime > 0 && s.clusterUser.maxQueueTime < d { d = s.clusterUser.maxQueueTime } if d <= 0 { // Default queue time. d = 10 * time.Second } return d } type paramsRegistry struct { // key is a hashed concatenation of the params list key uint32 params []config.Param } func newParamsRegistry(params []config.Param) (*paramsRegistry, error) { if len(params) == 0 { return nil, fmt.Errorf("params can't be empty") } paramsMap := make(map[string]string, len(params)) for _, k := range params { paramsMap[k.Key] = k.Value } key, err := calcMapHash(paramsMap) if err != nil { return nil, err } return ¶msRegistry{ key: key, params: params, }, nil } type user struct { name string password string toCluster string toUser string maxConcurrentQueries uint32 queryCounter counter maxExecutionTime time.Duration reqPerMin uint32 rateLimiter rateLimiter reqPacketSizeTokenLimiter *rate.Limiter reqPacketSizeTokensBurst config.ByteSize reqPacketSizeTokensRate config.ByteSize queueCh chan struct{} maxQueueTime time.Duration allowedNetworks config.Networks denyHTTP bool denyHTTPS bool allowCORS bool isWildcarded bool cache *cache.AsyncCache params *paramsRegistry } type usersProfile struct { cfg []config.User clusters map[string]*cluster caches map[string]*cache.AsyncCache params map[string]*paramsRegistry } func (up usersProfile) newUsers() (map[string]*user, error) { users := make(map[string]*user, len(up.cfg)) for _, u := range up.cfg { if _, ok := users[u.Name]; ok { return nil, fmt.Errorf("duplicate config for user %q", u.Name) } tmpU, err := up.newUser(u) if err != nil { return nil, fmt.Errorf("cannot initialize user %q: %w", u.Name, err) } users[u.Name] = tmpU } return users, nil } func (up usersProfile) newUser(u config.User) (*user, error) { c, ok := up.clusters[u.ToCluster] if !ok { return nil, fmt.Errorf("unknown `to_cluster` %q", u.ToCluster) } var cu *clusterUser if cu, ok = c.users[u.ToUser]; !ok { return nil, fmt.Errorf("unknown `to_user` %q in cluster %q", u.ToUser, u.ToCluster) } else if u.IsWildcarded { // a wildcarded user is mapped to this cluster user // used to check if a proper user to send heartbeat exists cu.isWildcarded = true } var queueCh chan struct{} if u.MaxQueueSize > 0 { queueCh = make(chan struct{}, u.MaxQueueSize) } var cc *cache.AsyncCache if len(u.Cache) > 0 { cc = up.caches[u.Cache] if cc == nil { return nil, fmt.Errorf("unknown `cache` %q", u.Cache) } } var params *paramsRegistry if len(u.Params) > 0 { params = up.params[u.Params] if params == nil { return nil, fmt.Errorf("unknown `params` %q", u.Params) } } return &user{ name: u.Name, password: u.Password, toCluster: u.ToCluster, toUser: u.ToUser, maxConcurrentQueries: u.MaxConcurrentQueries, maxExecutionTime: time.Duration(u.MaxExecutionTime), reqPerMin: u.ReqPerMin, queueCh: queueCh, maxQueueTime: time.Duration(u.MaxQueueTime), reqPacketSizeTokenLimiter: rate.NewLimiter(rate.Limit(u.ReqPacketSizeTokensRate), int(u.ReqPacketSizeTokensBurst)), reqPacketSizeTokensBurst: u.ReqPacketSizeTokensBurst, reqPacketSizeTokensRate: u.ReqPacketSizeTokensRate, allowedNetworks: u.AllowedNetworks, denyHTTP: u.DenyHTTP, denyHTTPS: u.DenyHTTPS, allowCORS: u.AllowCORS, isWildcarded: u.IsWildcarded, cache: cc, params: params, }, nil } type clusterUser struct { name string password string maxConcurrentQueries uint32 queryCounter counter maxExecutionTime time.Duration reqPerMin uint32 rateLimiter rateLimiter queueCh chan struct{} maxQueueTime time.Duration reqPacketSizeTokenLimiter *rate.Limiter reqPacketSizeTokensBurst config.ByteSize reqPacketSizeTokensRate config.ByteSize allowedNetworks config.Networks isWildcarded bool } func deepCopy(cu *clusterUser) *clusterUser { var queueCh chan struct{} if cu.maxQueueTime > 0 { queueCh = make(chan struct{}, cu.maxQueueTime) } return &clusterUser{ name: cu.name, password: cu.password, maxConcurrentQueries: cu.maxConcurrentQueries, maxExecutionTime: time.Duration(cu.maxExecutionTime), reqPerMin: cu.reqPerMin, queueCh: queueCh, maxQueueTime: time.Duration(cu.maxQueueTime), allowedNetworks: cu.allowedNetworks, } } func newClusterUser(cu config.ClusterUser) *clusterUser { var queueCh chan struct{} if cu.MaxQueueSize > 0 { queueCh = make(chan struct{}, cu.MaxQueueSize) } return &clusterUser{ name: cu.Name, password: cu.Password, maxConcurrentQueries: cu.MaxConcurrentQueries, maxExecutionTime: time.Duration(cu.MaxExecutionTime), reqPerMin: cu.ReqPerMin, reqPacketSizeTokenLimiter: rate.NewLimiter(rate.Limit(cu.ReqPacketSizeTokensRate), int(cu.ReqPacketSizeTokensBurst)), reqPacketSizeTokensBurst: cu.ReqPacketSizeTokensBurst, reqPacketSizeTokensRate: cu.ReqPacketSizeTokensRate, queueCh: queueCh, maxQueueTime: time.Duration(cu.MaxQueueTime), allowedNetworks: cu.AllowedNetworks, } } type replica struct { cluster *cluster name string hosts []*topology.Node nextHostIdx uint32 } func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c *cluster) ([]*replica, error) { if len(nodes) > 0 { // No replicas, just flat nodes. Create default replica // containing all the nodes. r := &replica{ cluster: c, name: "default", } hosts, err := newNodes(nodes, scheme, r) if err != nil { return nil, err } r.hosts = hosts return []*replica{r}, nil } replicas := make([]*replica, len(replicasCfg)) for i, rCfg := range replicasCfg { r := &replica{ cluster: c, name: rCfg.Name, } hosts, err := newNodes(rCfg.Nodes, scheme, r) if err != nil { return nil, fmt.Errorf("cannot initialize replica %q: %w", rCfg.Name, err) } r.hosts = hosts replicas[i] = r } return replicas, nil } func newNodes(nodes []string, scheme string, r *replica) ([]*topology.Node, error) { hosts := make([]*topology.Node, len(nodes)) for i, node := range nodes { addr, err := url.Parse(fmt.Sprintf("%s://%s", scheme, node)) if err != nil { return nil, fmt.Errorf("cannot parse `node` %q with `scheme` %q: %w", node, scheme, err) } hosts[i] = topology.NewNode(addr, r.cluster.heartBeat, r.cluster.name, r.name) } return hosts, nil } func (r *replica) isActive() bool { // The replica is active if at least a single host is active. for _, h := range r.hosts { if h.IsActive() { return true } } return false } func (r *replica) load() uint32 { var reqs uint32 for _, h := range r.hosts { reqs += h.CurrentLoad() } return reqs } type cluster struct { name string replicas []*replica nextReplicaIdx uint32 users map[string]*clusterUser killQueryUserName string killQueryUserPassword string heartBeat heartbeat.HeartBeat retryNumber int } func newCluster(c config.Cluster) (*cluster, error) { clusterUsers := make(map[string]*clusterUser, len(c.ClusterUsers)) for _, cu := range c.ClusterUsers { if _, ok := clusterUsers[cu.Name]; ok { return nil, fmt.Errorf("duplicate config for cluster user %q", cu.Name) } clusterUsers[cu.Name] = newClusterUser(cu) } heartBeat := heartbeat.NewHeartbeat(c.HeartBeat, heartbeat.WithDefaultUser(c.ClusterUsers[0].Name, c.ClusterUsers[0].Password)) newC := &cluster{ name: c.Name, users: clusterUsers, killQueryUserName: c.KillQueryUser.Name, killQueryUserPassword: c.KillQueryUser.Password, heartBeat: heartBeat, retryNumber: c.RetryNumber, } replicas, err := newReplicas(c.Replicas, c.Nodes, c.Scheme, newC) if err != nil { return nil, fmt.Errorf("cannot initialize replicas: %w", err) } newC.replicas = replicas return newC, nil } func newClusters(cfg []config.Cluster) (map[string]*cluster, error) { clusters := make(map[string]*cluster, len(cfg)) for _, c := range cfg { if _, ok := clusters[c.Name]; ok { return nil, fmt.Errorf("duplicate config for cluster %q", c.Name) } tmpC, err := newCluster(c) if err != nil { return nil, fmt.Errorf("cannot initialize cluster %q: %w", c.Name, err) } clusters[c.Name] = tmpC } return clusters, nil } // getReplica returns least loaded + round-robin replica from the cluster. // // Always returns non-nil. func (c *cluster) getReplica() *replica { idx := atomic.AddUint32(&c.nextReplicaIdx, 1) n := uint32(len(c.replicas)) if n == 1 { return c.replicas[0] } idx %= n r := c.replicas[idx] reqs := r.load() // Set least priority to inactive replica. if !r.isActive() { reqs = ^uint32(0) } if reqs == 0 { return r } // Scan all the replicas for the least loaded replica. for i := uint32(1); i < n; i++ { tmpIdx := (idx + i) % n tmpR := c.replicas[tmpIdx] if !tmpR.isActive() { continue } tmpReqs := tmpR.load() if tmpReqs == 0 { return tmpR } if tmpReqs < reqs { r = tmpR reqs = tmpReqs } } // The returned replica may be inactive. This is OK, // since this means all the replicas are inactive, // so let's try proxying the request to any replica. return r } func (c *cluster) getReplicaSticky(sessionId string) *replica { idx := atomic.AddUint32(&c.nextReplicaIdx, 1) n := uint32(len(c.replicas)) if n == 1 { return c.replicas[0] } idx %= n r := c.replicas[idx] for i := uint32(1); i < n; i++ { // handling sticky session sessionId := hash(sessionId) tmpIdx := (sessionId) % n tmpRSticky := c.replicas[tmpIdx] log.Debugf("Sticky replica candidate is: %s", tmpRSticky.name) if !tmpRSticky.isActive() { log.Debugf("Sticky session replica has been picked up, but it is not available") continue } log.Debugf("Sticky session replica is: %s, session_id: %d, replica_idx: %d, max replicas in pool: %d", tmpRSticky.name, sessionId, tmpIdx, n) return tmpRSticky } // The returned replica may be inactive. This is OK, // since this means all the replicas are inactive, // so let's try proxying the request to any replica. return r } // getHostSticky returns host by stickiness from replica. // // Always returns non-nil. func (r *replica) getHostSticky(sessionId string) *topology.Node { idx := atomic.AddUint32(&r.nextHostIdx, 1) n := uint32(len(r.hosts)) if n == 1 { return r.hosts[0] } idx %= n h := r.hosts[idx] // Scan all the hosts for the least loaded host. for i := uint32(1); i < n; i++ { // handling sticky session sessionId := hash(sessionId) tmpIdx := (sessionId) % n tmpHSticky := r.hosts[tmpIdx] log.Debugf("Sticky server candidate is: %s", tmpHSticky) if !tmpHSticky.IsActive() { log.Debugf("Sticky session server has been picked up, but it is not available") continue } log.Debugf("Sticky session server is: %s, session_id: %d, server_idx: %d, max nodes in pool: %d", tmpHSticky, sessionId, tmpIdx, n) return tmpHSticky } // The returned host may be inactive. This is OK, // since this means all the hosts are inactive, // so let's try proxying the request to any host. return h } // getHost returns least loaded + round-robin host from replica. // // Always returns non-nil. func (r *replica) getHost() *topology.Node { idx := atomic.AddUint32(&r.nextHostIdx, 1) n := uint32(len(r.hosts)) if n == 1 { return r.hosts[0] } idx %= n h := r.hosts[idx] reqs := h.CurrentLoad() // Set least priority to inactive host. if !h.IsActive() { reqs = ^uint32(0) } if reqs == 0 { return h } // Scan all the hosts for the least loaded host. for i := uint32(1); i < n; i++ { tmpIdx := (idx + i) % n tmpH := r.hosts[tmpIdx] if !tmpH.IsActive() { continue } tmpReqs := tmpH.CurrentLoad() if tmpReqs == 0 { return tmpH } if tmpReqs < reqs { h = tmpH reqs = tmpReqs } } // The returned host may be inactive. This is OK, // since this means all the hosts are inactive, // so let's try proxying the request to any host. return h } // getHostSticky returns host based on stickiness from cluster. // // Always returns non-nil. func (c *cluster) getHostSticky(sessionId string) *topology.Node { r := c.getReplicaSticky(sessionId) return r.getHostSticky(sessionId) } // getHost returns least loaded + round-robin host from cluster. // // Always returns non-nil. func (c *cluster) getHost() *topology.Node { r := c.getReplica() return r.getHost() } type rateLimiter struct { counter } func (rl *rateLimiter) run(done <-chan struct{}) { for { select { case <-done: return case <-time.After(time.Minute): rl.store(0) } } } type counter struct { value uint32 } func (c *counter) store(n uint32) { atomic.StoreUint32(&c.value, n) } func (c *counter) load() uint32 { return atomic.LoadUint32(&c.value) } func (c *counter) dec() { atomic.AddUint32(&c.value, ^uint32(0)) } func (c *counter) inc() uint32 { return atomic.AddUint32(&c.value, 1) }
package main import ( "bytes" "compress/gzip" "fmt" "hash/fnv" "io" "net/http" "sort" "strconv" "strings" "github.com/contentsquare/chproxy/chdecompressor" "github.com/contentsquare/chproxy/log" ) func respondWith(rw http.ResponseWriter, err error, status int) { log.ErrorWithCallDepth(err, 1) rw.WriteHeader(status) fmt.Fprintf(rw, "%s\n", err) } var defaultUser = "default" // getAuth retrieves auth credentials from request // according to CH documentation @see "https://clickhouse.yandex/docs/en/interfaces/http/" func getAuth(req *http.Request) (string, string) { // check X-ClickHouse- headers name := req.Header.Get("X-ClickHouse-User") pass := req.Header.Get("X-ClickHouse-Key") if name != "" { return name, pass } // if header is empty - check basicAuth if name, pass, ok := req.BasicAuth(); ok { return name, pass } // if basicAuth is empty - check URL params `user` and `password` params := req.URL.Query() if name := params.Get("user"); name != "" { pass := params.Get("password") return name, pass } // if still no credentials - treat it as `default` user request return defaultUser, "" } // getSessionId retrieves session id func getSessionId(req *http.Request) string { params := req.URL.Query() sessionId := params.Get("session_id") return sessionId } // getSessionId retrieves session id func getSessionTimeout(req *http.Request) int { params := req.URL.Query() sessionTimeout, err := strconv.Atoi(params.Get("session_timeout")) if err == nil && sessionTimeout > 0 { return sessionTimeout } return 60 } // getQuerySnippet returns query snippet. // // getQuerySnippet must be called only for error reporting. func getQuerySnippet(req *http.Request) string { query := req.URL.Query().Get("query") body := getQuerySnippetFromBody(req) if len(query) != 0 && len(body) != 0 { query += "\n" } return query + body } func hash(s string) uint32 { h := fnv.New32a() h.Write([]byte(s)) return h.Sum32() } func getQuerySnippetFromBody(req *http.Request) string { if req.Body == nil { return "" } crc, ok := req.Body.(*cachedReadCloser) if !ok { crc = &cachedReadCloser{ ReadCloser: req.Body, } } // 'read' request body, so it traps into to crc. // Ignore any errors, since getQuerySnippet is called only // during error reporting. io.Copy(io.Discard, crc) // nolint data := crc.String() u := getDecompressor(req) if u == nil { return data } bs := bytes.NewBufferString(data) b, err := u.decompress(bs) if err == nil { return string(b) } // It is better to return partially decompressed data instead of an empty string. if len(b) > 0 { return string(b) } // The data failed to be decompressed. Return compressed data // instead of an empty string. return data } // getFullQuery returns full query from req. func getFullQuery(req *http.Request) ([]byte, error) { var result bytes.Buffer if req.URL.Query().Get("query") != "" { result.WriteString(req.URL.Query().Get("query")) } body, err := getFullQueryFromBody(req) if err != nil { return nil, err } if result.Len() != 0 && len(body) != 0 { result.WriteByte('\n') } result.Write(body) return result.Bytes(), nil } func getFullQueryFromBody(req *http.Request) ([]byte, error) { if req.Body == nil { return nil, nil } data, err := io.ReadAll(req.Body) if err != nil { return nil, err } // restore body for further reading req.Body = io.NopCloser(bytes.NewBuffer(data)) u := getDecompressor(req) if u == nil { return data, nil } br := bytes.NewReader(data) b, err := u.decompress(br) if err != nil { return nil, fmt.Errorf("cannot uncompress query: %w", err) } return b, nil } var cachableStatements = []string{"SELECT", "WITH"} // canCacheQuery returns true if q can be cached. func canCacheQuery(q []byte) bool { q = skipLeadingComments(q) for _, statement := range cachableStatements { if len(q) < len(statement) { continue } l := bytes.ToUpper(q[:len(statement)]) if bytes.HasPrefix(l, []byte(statement)) { return true } } return false } //nolint:cyclop // No clean way to split this. func skipLeadingComments(q []byte) []byte { for len(q) > 0 { switch q[0] { case '\t', '\n', '\v', '\f', '\r', ' ': q = q[1:] case '-': if len(q) < 2 || q[1] != '-' { return q } // skip `-- comment` n := bytes.IndexByte(q, '\n') if n < 0 { return nil } q = q[n+1:] case '/': if len(q) < 2 || q[1] != '*' { return q } // skip `/* comment */` for { n := bytes.IndexByte(q, '*') if n < 0 { return nil } q = q[n+1:] if len(q) == 0 { return nil } if q[0] == '/' { q = q[1:] break } } default: return q } } return nil } // splits header string in sorted slice func sortHeader(header string) string { h := strings.Split(header, ",") for i, v := range h { h[i] = strings.TrimSpace(v) } sort.Strings(h) return strings.Join(h, ",") } func getDecompressor(req *http.Request) decompressor { if req.Header.Get("Content-Encoding") == "gzip" { return gzipDecompressor{} } if req.URL.Query().Get("decompress") == "1" { return chDecompressor{} } return nil } type decompressor interface { decompress(r io.Reader) ([]byte, error) } type gzipDecompressor struct{} func (dc gzipDecompressor) decompress(r io.Reader) ([]byte, error) { gr, err := gzip.NewReader(r) if err != nil { return nil, fmt.Errorf("cannot ungzip query: %w", err) } return io.ReadAll(gr) } type chDecompressor struct{} func (dc chDecompressor) decompress(r io.Reader) ([]byte, error) { lr := chdecompressor.NewReader(r) return io.ReadAll(lr) } func calcMapHash(m map[string]string) (uint32, error) { if len(m) == 0 { return 0, nil } var keys []string for key := range m { keys = append(keys, key) } sort.Strings(keys) h := fnv.New32a() for _, k := range keys { str := fmt.Sprintf("%s=%s&", k, m[k]) _, err := h.Write([]byte(str)) if err != nil { return 0, err } } return h.Sum32(), nil } func calcCredentialHash(user string, pwd string) (uint32, error) { h := fnv.New32a() _, err := h.Write([]byte(user + pwd)) return h.Sum32(), err }