package cache
import (
"fmt"
"time"
"github.com/contentsquare/chproxy/clients"
"github.com/contentsquare/chproxy/config"
"github.com/contentsquare/chproxy/log"
"github.com/redis/go-redis/v9"
)
// AsyncCache is a transactional cache enabled to serve the results from concurrent queries.
// When query A and B are equal, and query B arrives after query A with no more than defined deadline interval [[graceTime]],
// query B will await for the results of query B for the max time equal to:
// max_awaiting_time = graceTime - (arrivalB - arrivalA)
type AsyncCache struct {
Cache
TransactionRegistry
graceTime time.Duration
MaxPayloadSize config.ByteSize
SharedWithAllUsers bool
}
func (c *AsyncCache) Close() error {
if c.TransactionRegistry != nil {
c.TransactionRegistry.Close()
}
if c.Cache != nil {
c.Cache.Close()
}
return nil
}
func (c *AsyncCache) AwaitForConcurrentTransaction(key *Key) (TransactionStatus, error) {
startTime := time.Now()
seenState := transactionAbsent
for {
elapsedTime := time.Since(startTime)
if elapsedTime > c.graceTime {
// The entry didn't appear during deadline.
// Let the caller creating it.
return TransactionStatus{State: seenState}, nil
}
status, err := c.TransactionRegistry.Status(key)
if err != nil {
return TransactionStatus{State: seenState}, err
}
if !status.State.IsPending() {
return status, nil
}
// Wait for deadline in the hope the entry will appear
// in the cache.
//
// This should protect from thundering herd problem when
// a single slow query is executed from concurrent requests.
d := 100 * time.Millisecond
if d > c.graceTime {
d = c.graceTime
}
time.Sleep(d)
}
}
func NewAsyncCache(cfg config.Cache, maxExecutionTime time.Duration) (*AsyncCache, error) {
graceTime := time.Duration(cfg.GraceTime)
if graceTime > 0 {
log.Errorf("[DEPRECATED] detected grace time configuration %s. It will be removed in the new version",
graceTime)
}
if graceTime == 0 {
// Default grace time.
graceTime = maxExecutionTime
}
if graceTime < 0 {
// Disable protection from `dogpile effect`.
graceTime = 0
}
var cache Cache
var transaction TransactionRegistry
var err error
// transaction will be kept until we're sure there's no possible concurrent query running
transactionDeadline := 2 * graceTime
switch cfg.Mode {
case "file_system":
cache, err = newFilesSystemCache(cfg, graceTime)
transaction = newInMemoryTransactionRegistry(transactionDeadline, transactionEndedTTL)
case "redis":
var redisClient redis.UniversalClient
redisClient, err = clients.NewRedisClient(cfg.Redis)
cache = newRedisCache(redisClient, cfg)
transaction = newRedisTransactionRegistry(redisClient, transactionDeadline, transactionEndedTTL)
default:
return nil, fmt.Errorf("unknown config mode")
}
if err != nil {
return nil, err
}
maxPayloadSize := cfg.MaxPayloadSize
return &AsyncCache{
Cache: cache,
TransactionRegistry: transaction,
graceTime: graceTime,
MaxPayloadSize: maxPayloadSize,
SharedWithAllUsers: cfg.SharedWithAllUsers,
}, nil
}
package cache
import (
"fmt"
"io"
"math/rand"
"os"
"path/filepath"
"regexp"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/contentsquare/chproxy/config"
"github.com/contentsquare/chproxy/log"
)
var cachefileRegexp = regexp.MustCompile(`^[0-9a-f]{32}$`)
// fileSystemCache represents a file cache.
type fileSystemCache struct {
// name is cache name.
name string
dir string
maxSize uint64
expire time.Duration
grace time.Duration
stats Stats
wg sync.WaitGroup
stopCh chan struct{}
}
// newFilesSystemCache returns new cache for the given cfg.
func newFilesSystemCache(cfg config.Cache, graceTime time.Duration) (*fileSystemCache, error) {
if len(cfg.FileSystem.Dir) == 0 {
return nil, fmt.Errorf("`dir` cannot be empty")
}
if cfg.FileSystem.MaxSize <= 0 {
return nil, fmt.Errorf("`max_size` must be positive")
}
if cfg.Expire <= 0 {
return nil, fmt.Errorf("`expire` must be positive")
}
c := &fileSystemCache{
name: cfg.Name,
dir: cfg.FileSystem.Dir,
maxSize: uint64(cfg.FileSystem.MaxSize),
expire: time.Duration(cfg.Expire),
grace: graceTime,
stopCh: make(chan struct{}),
}
if err := os.MkdirAll(c.dir, 0700); err != nil {
return nil, fmt.Errorf("cannot create %q: %w", c.dir, err)
}
c.wg.Add(1)
go func() {
log.Debugf("cache %q: cleaner start", c.Name())
c.cleaner()
log.Debugf("cache %q: cleaner stop", c.Name())
c.wg.Done()
}()
return c, nil
}
func (f *fileSystemCache) Name() string {
return f.name
}
func (f *fileSystemCache) Close() error {
log.Debugf("cache %q: stopping", f.Name())
close(f.stopCh)
f.wg.Wait()
log.Debugf("cache %q: stopped", f.Name())
return nil
}
func (f *fileSystemCache) Stats() Stats {
var s Stats
s.Size = atomic.LoadUint64(&f.stats.Size)
s.Items = atomic.LoadUint64(&f.stats.Items)
return s
}
func (f *fileSystemCache) Get(key *Key) (*CachedData, error) {
fp := key.filePath(f.dir)
file, err := os.Open(fp)
if err != nil {
return nil, ErrMissing
}
// the file will be closed once it's read as an io.ReaderCloser
// This ReaderCloser is stored in the returned CachedData
fi, err := file.Stat()
if err != nil {
return nil, fmt.Errorf("cache %q: cannot stat %q: %w", f.Name(), fp, err)
}
mt := fi.ModTime()
age := time.Since(mt)
if age > f.expire {
// check if file exceeded expiration time + grace time
if age > f.expire+f.grace {
file.Close()
return nil, ErrMissing
}
// Serve expired file in the hope it will be substituted
// with the fresh file during deadline.
}
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to read file content from %q: %w", f.Name(), err)
}
metadata, err := decodeHeader(file)
if err != nil {
return nil, err
}
value := &CachedData{
ContentMetadata: *metadata,
Data: file,
Ttl: f.expire - age,
}
return value, nil
}
// decodeHeader decodes header from raw byte stream. Data is encoded as follows:
// length(contentType)|contentType|length(contentEncoding)|contentEncoding|length(contentLength)|contentLength|cachedData
func decodeHeader(reader io.Reader) (*ContentMetadata, error) {
contentType, err := readHeader(reader)
if err != nil {
return nil, fmt.Errorf("cannot read Content-Type from provided reader: %w", err)
}
contentEncoding, err := readHeader(reader)
if err != nil {
return nil, fmt.Errorf("cannot read Content-Encoding from provided reader: %w", err)
}
contentLengthStr, err := readHeader(reader)
if err != nil {
return nil, fmt.Errorf("cannot read Content-Encoding from provided reader: %w", err)
}
contentLength, err := strconv.Atoi(contentLengthStr)
if err != nil {
log.Errorf("found corrupted content length %s", err)
contentLength = 0
}
return &ContentMetadata{
Length: int64(contentLength),
Type: contentType,
Encoding: contentEncoding,
}, nil
}
func (f *fileSystemCache) Put(r io.Reader, contentMetadata ContentMetadata, key *Key) (time.Duration, error) {
fp := key.filePath(f.dir)
file, err := os.Create(fp)
if err != nil {
return 0, fmt.Errorf("cache %q: cannot create file: %s : %w", f.Name(), key, err)
}
if err := writeHeader(file, contentMetadata.Type); err != nil {
fn := file.Name()
return 0, fmt.Errorf("cannot write Content-Type to %q: %w", fn, err)
}
if err := writeHeader(file, contentMetadata.Encoding); err != nil {
fn := file.Name()
return 0, fmt.Errorf("cannot write Content-Encoding to %q: %w", fn, err)
}
if err := writeHeader(file, fmt.Sprintf("%d", contentMetadata.Length)); err != nil {
fn := file.Name()
return 0, fmt.Errorf("cannot write Content-Encoding to %q: %w", fn, err)
}
cnt, err := io.Copy(file, r)
if err != nil {
return 0, fmt.Errorf("cache %q: cannot write results to file: %s : %w", f.Name(), key, err)
}
atomic.AddUint64(&f.stats.Size, uint64(cnt))
atomic.AddUint64(&f.stats.Items, 1)
return f.expire, nil
}
func (f *fileSystemCache) cleaner() {
d := f.expire / 2
if d < time.Minute {
d = time.Minute
}
if d > time.Hour {
d = time.Hour
}
forceCleanCh := time.After(d)
f.clean()
for {
select {
case <-time.After(time.Second):
// Clean cache only on cache size overflow.
stats := f.Stats()
if stats.Size > f.maxSize {
f.clean()
}
case <-forceCleanCh:
// Forcibly clean cache from expired items.
f.clean()
forceCleanCh = time.After(d)
case <-f.stopCh:
return
}
}
}
func (f *fileSystemCache) fileInfoPath(fi os.FileInfo) string {
return filepath.Join(f.dir, fi.Name())
}
func (f *fileSystemCache) clean() {
currentTime := time.Now()
log.Debugf("cache %q: start cleaning dir %q", f.Name(), f.dir)
// Remove cached files after a deadline from their expiration,
// so they may be served until they are substituted with fresh files.
expire := f.expire + f.grace
// Calculate total cache size and remove expired files.
var totalSize uint64
var totalItems uint64
var removedSize uint64
var removedItems uint64
err := walkDir(f.dir, func(fi os.FileInfo) {
mt := fi.ModTime()
fs := uint64(fi.Size())
if currentTime.Sub(mt) > expire {
fn := f.fileInfoPath(fi)
err := os.Remove(fn)
if err == nil {
removedSize += fs
removedItems++
return
}
log.Errorf("cache %q: cannot remove file %q: %s", f.Name(), fn, err)
// Return skipped intentionally.
}
totalSize += fs
totalItems++
})
if err != nil {
log.Errorf("cache %q: %s", f.Name(), err)
return
}
loopsCount := 0
// Use dedicated random generator instead of global one from math/rand,
// since the global generator is slow due to locking.
//
// Seed the generator with the current time in order to randomize
// set of files to be removed below.
// nolint:gosec // not security sensitve, only used internally.
rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
for totalSize > f.maxSize && loopsCount < 3 {
// Remove some files in order to reduce cache size.
excessSize := totalSize - f.maxSize
p := int32(float64(excessSize) / float64(totalSize) * 100)
// Remove +10% over totalSize.
p += 10
err := walkDir(f.dir, func(fi os.FileInfo) {
if rnd.Int31n(100) > p {
return
}
fs := uint64(fi.Size())
fn := f.fileInfoPath(fi)
if err := os.Remove(fn); err != nil {
log.Errorf("cache %q: cannot remove file %q: %s", f.Name(), fn, err)
return
}
removedSize += fs
removedItems++
totalSize -= fs
totalItems--
})
if err != nil {
log.Errorf("cache %q: %s", f.Name(), err)
return
}
// This should protect from infinite loop.
loopsCount++
}
atomic.StoreUint64(&f.stats.Size, totalSize)
atomic.StoreUint64(&f.stats.Items, totalItems)
log.Debugf("cache %q: final size %d; final items %d; removed size %d; removed items %d",
f.Name(), totalSize, totalItems, removedSize, removedItems)
log.Debugf("cache %q: finish cleaning dir %q", f.Name(), f.dir)
}
// writeHeader encodes headers in little endian
func writeHeader(w io.Writer, s string) error {
n := uint32(len(s))
b := make([]byte, 0, n+4)
b = append(b, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
b = append(b, s...)
_, err := w.Write(b)
return err
}
// readHeader decodes headers to big endian
func readHeader(r io.Reader) (string, error) {
b := make([]byte, 4)
if _, err := io.ReadFull(r, b); err != nil {
return "", fmt.Errorf("cannot read header length: %w", err)
}
n := uint32(b[3]) | (uint32(b[2]) << 8) | (uint32(b[1]) << 16) | (uint32(b[0]) << 24)
s := make([]byte, n)
if _, err := io.ReadFull(r, s); err != nil {
return "", fmt.Errorf("cannot read header value with length %d: %w", n, err)
}
return string(s), nil
}
package cache
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/url"
"path/filepath"
)
// Version must be increased with each backward-incompatible change
// in the cache storage.
const Version = 5
// Key is the key for use in the cache.
type Key struct {
// Query must contain full request query.
Query []byte
// AcceptEncoding must contain 'Accept-Encoding' request header value.
AcceptEncoding string
// DefaultFormat must contain `default_format` query arg.
DefaultFormat string
// Database must contain `database` query arg.
Database string
// Compress must contain `compress` query arg.
Compress string
// EnableHTTPCompression must contain `enable_http_compression` query arg.
EnableHTTPCompression string
// Namespace is an optional cache namespace.
Namespace string
// MaxResultRows must contain `max_result_rows` query arg
MaxResultRows string
// Extremes must contain `extremes` query arg
Extremes string
// ResultOverflowMode must contain `result_overflow_mode` query arg
ResultOverflowMode string
// UserParamsHash must contain hashed value of users params
UserParamsHash uint32
// Version represents data encoding version number
Version int
// QueryParamsHash must contain hashed value of query params
QueryParamsHash uint32
// UserCredentialHash must contain hashed value of username & password
UserCredentialHash uint32
}
// NewKey construct cache key from provided parameters with default version number
func NewKey(query []byte, originParams url.Values, acceptEncoding string, userParamsHash uint32, queryParamsHash uint32, userCredentialHash uint32) *Key {
return &Key{
Query: query,
AcceptEncoding: acceptEncoding,
DefaultFormat: originParams.Get("default_format"),
Database: originParams.Get("database"),
Compress: originParams.Get("compress"),
EnableHTTPCompression: originParams.Get("enable_http_compression"),
Namespace: originParams.Get("cache_namespace"),
Extremes: originParams.Get("extremes"),
MaxResultRows: originParams.Get("max_result_rows"),
ResultOverflowMode: originParams.Get("result_overflow_mode"),
UserParamsHash: userParamsHash,
Version: Version,
QueryParamsHash: queryParamsHash,
UserCredentialHash: userCredentialHash,
}
}
func (k *Key) filePath(dir string) string {
return filepath.Join(dir, k.String())
}
// String returns string representation of the key.
func (k *Key) String() string {
s := fmt.Sprintf("V%d; Query=%q; AcceptEncoding=%q; DefaultFormat=%q; Database=%q; Compress=%q; EnableHTTPCompression=%q; Namespace=%q; MaxResultRows=%q; Extremes=%q; ResultOverflowMode=%q; UserParams=%d; QueryParams=%d; UserCredentialHash=%d",
k.Version, k.Query, k.AcceptEncoding, k.DefaultFormat, k.Database, k.Compress, k.EnableHTTPCompression, k.Namespace,
k.MaxResultRows, k.Extremes, k.ResultOverflowMode, k.UserParamsHash, k.QueryParamsHash, k.UserCredentialHash)
h := sha256.Sum256([]byte(s))
// The first 16 bytes of the hash should be enough
// for collision prevention :)
return hex.EncodeToString(h[:16])
}
package cache
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math/rand"
"os"
"regexp"
"strconv"
"time"
"github.com/contentsquare/chproxy/config"
"github.com/contentsquare/chproxy/log"
"github.com/redis/go-redis/v9"
)
type redisCache struct {
name string
client redis.UniversalClient
expire time.Duration
}
const getTimeout = 2 * time.Second
const removeTimeout = 1 * time.Second
const renameTimeout = 1 * time.Second
const putTimeout = 2 * time.Second
const statsTimeout = 500 * time.Millisecond
// this variable is key to select whether the result should be streamed
// from redis to the http response or if chproxy should first put the
// result from redis in a temporary files before sending it to the http response
const minTTLForRedisStreamingReader = 15 * time.Second
// tmpDir temporary path to store ongoing queries results
const tmpDir = "/tmp"
const redisTmpFilePrefix = "chproxyRedisTmp"
func newRedisCache(client redis.UniversalClient, cfg config.Cache) *redisCache {
redisCache := &redisCache{
name: cfg.Name,
expire: time.Duration(cfg.Expire),
client: client,
}
return redisCache
}
func (r *redisCache) Close() error {
return r.client.Close()
}
var usedMemoryRegexp = regexp.MustCompile(`used_memory:([0-9]+)\r\n`)
// Stats will make two calls to redis.
// First one fetches the number of keys stored in redis (DBSize)
// Second one will fetch memory info that will be parsed to fetch the used_memory
// NOTE : we can only fetch database size, not cache size
func (r *redisCache) Stats() Stats {
return Stats{
Items: r.nbOfKeys(),
Size: r.nbOfBytes(),
}
}
func (r *redisCache) nbOfKeys() uint64 {
ctx, cancelFunc := context.WithTimeout(context.Background(), statsTimeout)
defer cancelFunc()
nbOfKeys, err := r.client.DBSize(ctx).Result()
if err != nil {
log.Errorf("failed to fetch nb of keys in redis: %s", err)
}
return uint64(nbOfKeys)
}
func (r *redisCache) nbOfBytes() uint64 {
ctx, cancelFunc := context.WithTimeout(context.Background(), statsTimeout)
defer cancelFunc()
memoryInfo, err := r.client.Info(ctx, "memory").Result()
if err != nil {
log.Errorf("failed to fetch nb of bytes in redis: %s", err)
}
matches := usedMemoryRegexp.FindStringSubmatch(memoryInfo)
var cacheSize int
if len(matches) > 1 {
cacheSize, err = strconv.Atoi(matches[1])
if err != nil {
log.Errorf("failed to parse memory usage with error %s", err)
}
}
return uint64(cacheSize)
}
func (r *redisCache) Get(key *Key) (*CachedData, error) {
ctx, cancelFunc := context.WithTimeout(context.Background(), getTimeout)
defer cancelFunc()
nbBytesToFetch := int64(100 * 1024)
stringKey := key.String()
// fetching 100kBytes from redis to be sure to have the full metadata and,
// for most of the queries that fetch a few data, the cached results
val, err := r.client.GetRange(ctx, stringKey, 0, nbBytesToFetch).Result()
// errors, such as timeouts
if err != nil {
log.Errorf("failed to get key %s with error: %s", stringKey, err)
return nil, ErrMissing
}
// if key not found in cache
if len(val) == 0 {
return nil, ErrMissing
}
ttl, err := r.client.TTL(ctx, stringKey).Result()
if err != nil {
return nil, fmt.Errorf("failed to ttl of key %s with error: %w", stringKey, err)
}
b := []byte(val)
metadata, offset, err := r.decodeMetadata(b)
if err != nil {
if errors.Is(err, &RedisCacheCorruptionError{}) {
log.Errorf("an error happened while handling redis key =%s, err=%s", stringKey, err)
}
return nil, err
}
if (int64(offset) + metadata.Length) < nbBytesToFetch {
// the condition is true ony if the bytes fetched contain the metadata + the cached results
// so we extract from the remaining bytes the cached results
payload := b[offset:]
reader := &ioReaderDecorator{Reader: bytes.NewReader(payload)}
value := &CachedData{
ContentMetadata: *metadata,
Data: reader,
Ttl: ttl,
}
return value, nil
}
return r.readResultsAboveLimit(offset, stringKey, metadata, ttl)
}
func (r *redisCache) readResultsAboveLimit(offset int, stringKey string, metadata *ContentMetadata, ttl time.Duration) (*CachedData, error) {
// since the cached results in redis are too big, we can't fetch all of them because of the memory overhead.
// We will create an io.reader that will fetch redis bulk by bulk to reduce the memory usage.
redisStreamreader := newRedisStreamReader(uint64(offset), r.client, stringKey, metadata.Length)
// But before that, since the usage of the reader could take time and the object in redis could disappear btw 2 fetches
// we need to make sure the TTL will be long enough to avoid nasty side effects
// if the TTL is too short we will put all the data into a file and use it as a streamer
// nb: it would be better to retry the flow if such a failure happened but this requires a huge refactoring of proxy.go
if ttl <= minTTLForRedisStreamingReader {
fileStream, err := newFileWriterReader(tmpDir)
if err != nil {
return nil, err
}
_, err = io.Copy(fileStream, redisStreamreader)
if err != nil {
return nil, err
}
err = fileStream.resetOffset()
if err != nil {
return nil, err
}
value := &CachedData{
ContentMetadata: *metadata,
Data: fileStream,
Ttl: ttl,
}
return value, nil
}
value := &CachedData{
ContentMetadata: *metadata,
Data: &ioReaderDecorator{Reader: redisStreamreader},
Ttl: ttl,
}
return value, nil
}
// this struct is here because CachedData requires an io.ReadCloser
// but logic in the Get function generates only an io.Reader
type ioReaderDecorator struct {
io.Reader
}
func (m ioReaderDecorator) Close() error {
return nil
}
func (r *redisCache) encodeString(s string) []byte {
n := uint32(len(s))
b := make([]byte, 0, n+4)
b = append(b, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
b = append(b, s...)
return b
}
func (r *redisCache) decodeString(bytes []byte) (string, int, error) {
if len(bytes) < 4 {
return "", 0, &RedisCacheCorruptionError{}
}
b := bytes[:4]
n := uint32(b[3]) | (uint32(b[2]) << 8) | (uint32(b[1]) << 16) | (uint32(b[0]) << 24)
if len(bytes) < int(4+n) {
return "", 0, &RedisCacheCorruptionError{}
}
s := bytes[4 : 4+n]
return string(s), int(4 + n), nil
}
func (r *redisCache) encodeMetadata(contentMetadata *ContentMetadata) []byte {
cLength := contentMetadata.Length
cType := r.encodeString(contentMetadata.Type)
cEncoding := r.encodeString(contentMetadata.Encoding)
b := make([]byte, 0, len(cEncoding)+len(cType)+8)
b = append(b, byte(cLength>>56), byte(cLength>>48), byte(cLength>>40), byte(cLength>>32), byte(cLength>>24), byte(cLength>>16), byte(cLength>>8), byte(cLength))
b = append(b, cType...)
b = append(b, cEncoding...)
return b
}
func (r *redisCache) decodeMetadata(b []byte) (*ContentMetadata, int, error) {
if len(b) < 8 {
return nil, 0, &RedisCacheCorruptionError{}
}
cLength := uint64(b[7]) | (uint64(b[6]) << 8) | (uint64(b[5]) << 16) | (uint64(b[4]) << 24) | uint64(b[3])<<32 | (uint64(b[2]) << 40) | (uint64(b[1]) << 48) | (uint64(b[0]) << 56)
offset := 8
cType, sizeCType, err := r.decodeString(b[offset:])
if err != nil {
return nil, 0, err
}
offset += sizeCType
cEncoding, sizeCEncoding, err := r.decodeString(b[offset:])
if err != nil {
return nil, 0, err
}
offset += sizeCEncoding
metadata := &ContentMetadata{
Length: int64(cLength),
Type: cType,
Encoding: cEncoding,
}
return metadata, offset, nil
}
func (r *redisCache) Put(reader io.Reader, contentMetadata ContentMetadata, key *Key) (time.Duration, error) {
medatadata := r.encodeMetadata(&contentMetadata)
stringKey := key.String()
// in order to make the streaming operation atomic, chproxy streams into a temporary key (only known by the current goroutine)
// then it switches the full result to the "real" stringKey available for other goroutines
// nolint:gosec // not security sensitve, only used internally.
random := strconv.Itoa(rand.Int())
// Redis RENAME is considered to be a multikey operation. In Cluster mode, both oldkey and renamedkey must be in the same hash slot,
// Refer Redis Documentation here: https://redis.io/commands/rename/
// To solve this,we need to force the temporary key to be in the same hash slot. We can do this by adding hashtag to the
// actual part of the temporary key. When the key contains a "{...}" pattern, only the substring between the braces, "{" and "},"
// is hashed to obtain the hash slot.
// Refer the hash tags section of Redis documentation here: https://redis.io/docs/reference/cluster-spec/#hash-tags
stringKeyTmp := "{" + stringKey + "}" + random + "_tmp"
ctxSet, cancelFuncSet := context.WithTimeout(context.Background(), putTimeout)
defer cancelFuncSet()
err := r.client.Set(ctxSet, stringKeyTmp, medatadata, r.expire).Err()
if err != nil {
return 0, err
}
// we don't fetch all the reader content bulks by bulks to from redis to avoid memory issue
// if the content is big (which is the case when chproxy users are fetching a lot of data)
buffer := make([]byte, 2*1024*1024)
totalByteWrittenExpected := len(medatadata)
for {
n, err := reader.Read(buffer)
// the reader should return an err = io.EOF once it has nothing to read or at the last read call with content.
// But this is not the case with this reader so we check the condition n == 0 to exit the read loop.
// We kept the err == io.EOF in the loop in case the behavior of the reader changes
if n == 0 {
break
}
if err != nil && !errors.Is(err, io.EOF) {
return 0, err
}
ctxAppend, cancelFuncAppend := context.WithTimeout(context.Background(), putTimeout)
defer cancelFuncAppend()
totalByteWritten, err := r.client.Append(ctxAppend, stringKeyTmp, string(buffer[:n])).Result()
if err != nil {
// trying to clean redis from this partially inserted item
r.clean(stringKeyTmp)
return 0, err
}
totalByteWrittenExpected += n
if int(totalByteWritten) != totalByteWrittenExpected {
// trying to clean redis from this partially inserted item
r.clean(stringKeyTmp)
return 0, fmt.Errorf("could not stream the value into redis, only %d bytes were written instead of %d", totalByteWritten, totalByteWrittenExpected)
}
if errors.Is(err, io.EOF) {
break
}
}
// at this step we know that the item stored in stringKeyTmp is fully written
// so we can put it to its final stringKey
ctxRename, cancelFuncRename := context.WithTimeout(context.Background(), renameTimeout)
defer cancelFuncRename()
r.client.Rename(ctxRename, stringKeyTmp, stringKey)
return r.expire, nil
}
func (r *redisCache) clean(stringKey string) {
delCtx, cancelFunc := context.WithTimeout(context.Background(), removeTimeout)
defer cancelFunc()
delErr := r.client.Del(delCtx, stringKey).Err()
if delErr != nil {
log.Debugf("redis item was only partially inserted and chproxy couldn't remove the partial result because of %s", delErr)
} else {
log.Debugf("redis item was only partially inserted, chproxy was able to remove it")
}
}
func (r *redisCache) Name() string {
return r.name
}
type redisStreamReader struct {
isRedisEOF bool
redisOffset uint64 // the redisOffset that gives the beginning of the next bulk to fetch
key string // the key of the value we want to stream from redis
buffer []byte // internal buffer to store the bulks fetched from redis
bufferOffset int // the offset of the buffer that keep were the read() need to start copying data
client redis.UniversalClient // the redis client
expectedPayloadSize int // the size of the object the streamer is supposed to read.
readPayloadSize int // the size of the object currently written by the reader
}
func newRedisStreamReader(offset uint64, client redis.UniversalClient, key string, payloadSize int64) *redisStreamReader {
bufferSize := uint64(2 * 1024 * 1024)
return &redisStreamReader{
isRedisEOF: false,
redisOffset: offset,
key: key,
bufferOffset: int(bufferSize),
buffer: make([]byte, bufferSize),
client: client,
expectedPayloadSize: int(payloadSize),
}
}
func (r *redisStreamReader) Read(destBuf []byte) (n int, err error) {
// the logic is simple:
// 1) if the buffer still has data to write, it writes it into destBuf without overflowing destBuf
// 2) if the buffer only has already written data, the StreamRedis refresh the buffer with new data from redis
// 3) if the buffer only has already written data & redis has no more data to read then StreamRedis sends an EOF err
bufSize := len(r.buffer)
bytesWritten := 0
// case 3) both the buffer & redis were fully consumed, we can tell the reader to stop reading
if r.bufferOffset >= bufSize && r.isRedisEOF {
// Because of the way we fetch from redis, we need to do an extra check because we have no way
// to know if redis is really EOF or if the value was expired from cache while reading it
if r.readPayloadSize != r.expectedPayloadSize {
log.Debugf("error while fetching data from redis payload size doesn't match")
return 0, &RedisCacheError{key: r.key, readPayloadSize: r.readPayloadSize, expectedPayloadSize: r.expectedPayloadSize}
}
return 0, io.EOF
}
// case 2) the buffer only has already written data, we need to refresh it with redis datas
if r.bufferOffset >= bufSize {
if err := r.readRangeFromRedis(bufSize); err != nil {
return bytesWritten, err
}
}
// case 1) the buffer contains data to write into destBuf
if r.bufferOffset < bufSize {
bytesWritten = copy(destBuf, r.buffer[r.bufferOffset:])
r.bufferOffset += bytesWritten
r.readPayloadSize += bytesWritten
}
return bytesWritten, nil
}
func (r *redisStreamReader) readRangeFromRedis(bufSize int) error {
ctx, cancelFunc := context.WithTimeout(context.Background(), getTimeout)
defer cancelFunc()
newBuf, err := r.client.GetRange(ctx, r.key, int64(r.redisOffset), int64(r.redisOffset+uint64(bufSize))).Result()
r.redisOffset += uint64(len(newBuf))
if errors.Is(err, redis.Nil) || len(newBuf) == 0 {
r.isRedisEOF = true
}
// if redis gave less data than asked it means that it reached the end of the value
if len(newBuf) < bufSize {
r.isRedisEOF = true
}
// others errors, such as timeouts
if err != nil && !errors.Is(err, redis.Nil) {
log.Debugf("failed to get key %s with error: %s", r.key, err)
err2 := &RedisCacheError{key: r.key, readPayloadSize: r.readPayloadSize,
expectedPayloadSize: r.expectedPayloadSize, rootcause: err}
return err2
}
r.bufferOffset = 0
r.buffer = []byte(newBuf)
return nil
}
type fileWriterReader struct {
f *os.File
}
func newFileWriterReader(dir string) (*fileWriterReader, error) {
f, err := os.CreateTemp(dir, redisTmpFilePrefix)
if err != nil {
return nil, fmt.Errorf("cannot create temporary file in %q: %w", dir, err)
}
return &fileWriterReader{
f: f,
}, nil
}
func (r *fileWriterReader) Close() error {
err := r.f.Close()
if err != nil {
return err
}
return os.Remove(r.f.Name())
}
func (r *fileWriterReader) Read(destBuf []byte) (n int, err error) {
return r.f.Read(destBuf)
}
func (r *fileWriterReader) Write(p []byte) (n int, err error) {
return r.f.Write(p)
}
func (r *fileWriterReader) resetOffset() error {
if _, err := r.f.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("cannot reset offset in: %w", err)
}
return nil
}
type RedisCacheError struct {
key string
readPayloadSize int
expectedPayloadSize int
rootcause error
}
func (e *RedisCacheError) Error() string {
errorMsg := fmt.Sprintf("error while reading cached result in redis for key %s, only %d bytes of %d were fetched",
e.key, e.readPayloadSize, e.expectedPayloadSize)
if e.rootcause != nil {
errorMsg = fmt.Sprintf("%s, root cause:%s", errorMsg, e.rootcause)
}
return errorMsg
}
type RedisCacheCorruptionError struct {
}
func (e *RedisCacheCorruptionError) Error() string {
return "chproxy can't decode the cached result from redis, it seems to have been corrupted"
}
package cache
import (
"bufio"
"fmt"
"io"
"net/http"
"os"
"github.com/contentsquare/chproxy/log"
)
// TmpFileResponseWriter caches Clickhouse response.
// the http header are kept in memory
type TmpFileResponseWriter struct {
http.ResponseWriter // the original response writer
contentLength int64
contentType string
contentEncoding string
headersCaptured bool
statusCode int
tmpFile *os.File // temporary file for response streaming
bw *bufio.Writer // buffered writer for the temporary file
}
func NewTmpFileResponseWriter(rw http.ResponseWriter, dir string) (*TmpFileResponseWriter, error) {
_, ok := rw.(http.CloseNotifier)
if !ok {
return nil, fmt.Errorf("the response writer does not implement http.CloseNotifier")
}
f, err := os.CreateTemp(dir, "tmp")
if err != nil {
return nil, fmt.Errorf("cannot create temporary file in %q: %w", dir, err)
}
return &TmpFileResponseWriter{
ResponseWriter: rw,
tmpFile: f,
bw: bufio.NewWriter(f),
}, nil
}
func (rw *TmpFileResponseWriter) Close() error {
rw.tmpFile.Close()
return os.Remove(rw.tmpFile.Name())
}
func (rw *TmpFileResponseWriter) GetFile() (*os.File, error) {
if err := rw.bw.Flush(); err != nil {
fn := rw.tmpFile.Name()
errTmp := rw.tmpFile.Close()
if errTmp != nil {
log.Errorf("cannot close tmpFile: %s, error: %s", fn, errTmp)
}
errTmp = os.Remove(fn)
if errTmp != nil {
log.Errorf("cannot remove tmpFile: %s, error: %s", fn, errTmp)
}
return nil, fmt.Errorf("cannot flush data into %q: %w", fn, err)
}
return rw.tmpFile, nil
}
func (rw *TmpFileResponseWriter) Reader() (io.Reader, error) {
f, err := rw.GetFile()
if err != nil {
return nil, fmt.Errorf("cannot open tmp file: %w", err)
}
return f, nil
}
func (rw *TmpFileResponseWriter) ResetFileOffset() error {
data, err := rw.GetFile()
if err != nil {
return err
}
if _, err := data.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("cannot reset offset in: %w", err)
}
return nil
}
func (rw *TmpFileResponseWriter) captureHeaders() error {
if rw.headersCaptured {
return nil
}
rw.headersCaptured = true
h := rw.Header()
ct := h.Get("Content-Type")
ce := h.Get("Content-Encoding")
rw.contentEncoding = ce
rw.contentType = ct
// nb: the Content-Length http header is not set by CH so we can't get it
return nil
}
func (rw *TmpFileResponseWriter) GetCapturedContentType() string {
return rw.contentType
}
func (rw *TmpFileResponseWriter) GetCapturedContentLength() (int64, error) {
if rw.contentLength == 0 {
// Determine Content-Length looking at the file
data, err := rw.GetFile()
if err != nil {
return 0, fmt.Errorf("GetCapturedContentLength: cannot open tmp file: %w", err)
}
end, err := data.Seek(0, io.SeekEnd)
if err != nil {
return 0, fmt.Errorf("GetCapturedContentLength: cannot determine the last position in: %w", err)
}
if err := rw.ResetFileOffset(); err != nil {
return 0, err
}
return end - 0, nil
}
return rw.contentLength, nil
}
func (rw *TmpFileResponseWriter) GetCapturedContentEncoding() string {
return rw.contentEncoding
}
// CloseNotify implements http.CloseNotifier
func (rw *TmpFileResponseWriter) CloseNotify() <-chan bool {
// nolint:forcetypeassert // it is guaranteed by NewTmpFileResponseWriter
// The rw.FSResponseWriter must implement http.CloseNotifier.
return rw.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// WriteHeader captures response status code.
func (rw *TmpFileResponseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
// Do not call rw.ClickhouseResponseWriter.WriteHeader here
// It will be called explicitly in Finalize / Unregister.
}
// StatusCode returns captured status code from WriteHeader.
func (rw *TmpFileResponseWriter) StatusCode() int {
if rw.statusCode == 0 {
return http.StatusOK
}
return rw.statusCode
}
// Write writes b into rw.
func (rw *TmpFileResponseWriter) Write(b []byte) (int, error) {
if err := rw.captureHeaders(); err != nil {
return 0, err
}
return rw.bw.Write(b)
}
package cache
import (
"io"
"time"
)
// TransactionRegistry is a registry of ongoing queries identified by Key.
type TransactionRegistry interface {
io.Closer
// Create creates a new transaction record
Create(key *Key) error
// Complete completes a transaction for given key
Complete(key *Key) error
// Fail fails a transaction for given key
Fail(key *Key, reason string) error
// Status checks the status of the transaction
Status(key *Key) (TransactionStatus, error)
}
// transactionEndedTTL amount of time transaction record is kept after being updated
const transactionEndedTTL = 500 * time.Millisecond
type TransactionStatus struct {
State TransactionState
FailReason string // filled in only if state of transaction is transactionFailed
}
type TransactionState uint8
const (
transactionCreated TransactionState = 0
transactionCompleted TransactionState = 1
transactionFailed TransactionState = 2
transactionAbsent TransactionState = 3
)
func (t *TransactionState) IsAbsent() bool {
if t != nil {
return *t == transactionAbsent
}
return false
}
func (t *TransactionState) IsFailed() bool {
if t != nil {
return *t == transactionFailed
}
return false
}
func (t *TransactionState) IsCompleted() bool {
if t != nil {
return *t == transactionCompleted
}
return false
}
func (t *TransactionState) IsPending() bool {
if t != nil {
return *t == transactionCreated
}
return false
}
package cache
import (
"sync"
"time"
"github.com/contentsquare/chproxy/log"
)
type pendingEntry struct {
deadline time.Time
state TransactionState
failedReason string
}
type inMemoryTransactionRegistry struct {
pendingEntriesLock sync.Mutex
pendingEntries map[string]pendingEntry
deadline time.Duration
transactionEndedDeadline time.Duration
stopCh chan struct{}
wg sync.WaitGroup
}
func newInMemoryTransactionRegistry(deadline, transactionEndedDeadline time.Duration) *inMemoryTransactionRegistry {
transaction := &inMemoryTransactionRegistry{
pendingEntriesLock: sync.Mutex{},
pendingEntries: make(map[string]pendingEntry),
deadline: deadline,
transactionEndedDeadline: transactionEndedDeadline,
stopCh: make(chan struct{}),
}
transaction.wg.Add(1)
go func() {
log.Debugf("inmem transaction: cleaner start")
transaction.pendingEntriesCleaner()
transaction.wg.Done()
log.Debugf("inmem transaction: cleaner stop")
}()
return transaction
}
func (i *inMemoryTransactionRegistry) Create(key *Key) error {
i.pendingEntriesLock.Lock()
defer i.pendingEntriesLock.Unlock()
k := key.String()
_, exists := i.pendingEntries[k]
if !exists {
i.pendingEntries[k] = pendingEntry{
deadline: time.Now().Add(i.deadline),
state: transactionCreated,
}
}
return nil
}
func (i *inMemoryTransactionRegistry) Complete(key *Key) error {
i.updateTransactionState(key, transactionCompleted, "")
return nil
}
func (i *inMemoryTransactionRegistry) Fail(key *Key, reason string) error {
i.updateTransactionState(key, transactionFailed, reason)
return nil
}
func (i *inMemoryTransactionRegistry) updateTransactionState(key *Key, state TransactionState, failReason string) {
i.pendingEntriesLock.Lock()
defer i.pendingEntriesLock.Unlock()
k := key.String()
if entry, ok := i.pendingEntries[k]; ok {
entry.state = state
entry.failedReason = failReason
entry.deadline = time.Now().Add(i.transactionEndedDeadline)
i.pendingEntries[k] = entry
} else {
log.Errorf("[attempt to complete transaction] entry not found for key: %s, registering new entry with %v status", key.String(), state)
i.pendingEntries[k] = pendingEntry{
deadline: time.Now().Add(i.transactionEndedDeadline),
state: state,
failedReason: failReason,
}
}
}
func (i *inMemoryTransactionRegistry) Status(key *Key) (TransactionStatus, error) {
i.pendingEntriesLock.Lock()
defer i.pendingEntriesLock.Unlock()
k := key.String()
if entry, ok := i.pendingEntries[k]; ok {
return TransactionStatus{State: entry.state, FailReason: entry.failedReason}, nil
}
return TransactionStatus{State: transactionAbsent}, nil
}
func (i *inMemoryTransactionRegistry) Close() error {
close(i.stopCh)
i.wg.Wait()
return nil
}
func (i *inMemoryTransactionRegistry) pendingEntriesCleaner() {
d := i.deadline
if d < 100*time.Millisecond {
d = 100 * time.Millisecond
}
if d > time.Second {
d = time.Second
}
for {
currentTime := time.Now()
// Clear outdated pending entries, since they may remain here
// forever if unregisterPendingEntry call is missing.
i.pendingEntriesLock.Lock()
for path, pe := range i.pendingEntries {
if currentTime.After(pe.deadline) {
delete(i.pendingEntries, path)
}
}
i.pendingEntriesLock.Unlock()
select {
case <-time.After(d):
case <-i.stopCh:
return
}
}
}
package cache
import (
"context"
"errors"
"time"
"fmt"
"github.com/contentsquare/chproxy/log"
"github.com/redis/go-redis/v9"
)
type redisTransactionRegistry struct {
redisClient redis.UniversalClient
// deadline specifies TTL of the record to be kept
deadline time.Duration
// transactionEndedDeadline specifies TTL of the record to be kept that has been ended (either completed or failed)
transactionEndedDeadline time.Duration
}
func newRedisTransactionRegistry(redisClient redis.UniversalClient, deadline time.Duration,
endedDeadline time.Duration) *redisTransactionRegistry {
return &redisTransactionRegistry{
redisClient: redisClient,
deadline: deadline,
transactionEndedDeadline: endedDeadline,
}
}
func (r *redisTransactionRegistry) Create(key *Key) error {
return r.redisClient.Set(context.Background(), toTransactionKey(key),
[]byte{uint8(transactionCreated)}, r.deadline).Err()
}
func (r *redisTransactionRegistry) Complete(key *Key) error {
return r.updateTransactionState(key, []byte{uint8(transactionCompleted)})
}
func (r *redisTransactionRegistry) Fail(key *Key, reason string) error {
b := make([]byte, 0, uint32(len(reason))+1)
b = append(b, byte(transactionFailed))
b = append(b, []byte(reason)...)
return r.updateTransactionState(key, b)
}
func (r *redisTransactionRegistry) updateTransactionState(key *Key, value []byte) error {
return r.redisClient.Set(context.Background(), toTransactionKey(key), value, r.transactionEndedDeadline).Err()
}
func (r *redisTransactionRegistry) Status(key *Key) (TransactionStatus, error) {
raw, err := r.redisClient.Get(context.Background(), toTransactionKey(key)).Bytes()
if errors.Is(err, redis.Nil) {
return TransactionStatus{State: transactionAbsent}, nil
}
if err != nil {
log.Errorf("Failed to fetch transaction status from redis for key: %s", key.String())
return TransactionStatus{State: transactionAbsent}, err
}
if len(raw) == 0 {
log.Errorf("Failed to fetch transaction status from redis raw value: %s", key.String())
return TransactionStatus{State: transactionAbsent}, err
}
state := TransactionState(uint8(raw[0]))
var reason string
if state.IsFailed() && len(raw) > 1 {
reason = string(raw[1:])
}
return TransactionStatus{State: state, FailReason: reason}, nil
}
func (r *redisTransactionRegistry) Close() error {
return r.redisClient.Close()
}
func toTransactionKey(key *Key) string {
return fmt.Sprintf("%s-transaction", key.String())
}
package cache
import (
"fmt"
"io"
"os"
)
// walkDir calls f on all the cache files in the given dir.
func walkDir(dir string, f func(fi os.FileInfo)) error {
// Do not use filepath.Walk, since it is inefficient
// for large number of files.
// See https://golang.org/pkg/path/filepath/#Walk .
fd, err := os.Open(dir)
if err != nil {
return fmt.Errorf("cannot open %q: %w", dir, err)
}
defer fd.Close()
for {
fis, err := fd.Readdir(1024)
if err != nil {
if err == io.EOF {
return nil
}
return fmt.Errorf("cannot read files in %q: %w", dir, err)
}
for _, fi := range fis {
if fi.IsDir() {
// Skip subdirectories
continue
}
fn := fi.Name()
if !cachefileRegexp.MatchString(fn) {
// Skip invalid filenames
continue
}
f(fi)
}
}
}
package chdecompressor
import (
"errors"
"fmt"
"io"
"github.com/klauspost/compress/zstd"
"github.com/pierrec/lz4"
)
// Reader reads clickhouse compressed stream.
// See https://github.com/yandex/ClickHouse/blob/ae8783aee3ef982b6eb7e1721dac4cb3ce73f0fe/dbms/src/IO/CompressedStream.h
type Reader struct {
src io.Reader
data []byte
scratch []byte
}
// NewReader returns new clickhouse compressed stream reader reading from src.
func NewReader(src io.Reader) *Reader {
return &Reader{
src: src,
scratch: make([]byte, 16),
}
}
// Read reads up to len(buf) bytes from clickhouse compressed stream.
func (r *Reader) Read(buf []byte) (int, error) {
// exhaust remaining data from previous Read()
if len(r.data) == 0 {
if err := r.readNextBlock(); err != nil {
return 0, err
}
}
n := copy(buf, r.data)
r.data = r.data[n:]
return n, nil
}
func (r *Reader) readNextBlock() error {
// Skip checksum
if _, err := io.ReadFull(r.src, r.scratch[:16]); err != nil {
if errors.Is(err, io.EOF) {
return io.EOF
}
return fmt.Errorf("cannot read checksum: %w", err)
}
// Read compression type
if _, err := io.ReadFull(r.src, r.scratch[:1]); err != nil {
return fmt.Errorf("cannot read compression type: %w", err)
}
compressionType := r.scratch[0]
// Read compressed size
compressedSize, err := r.readUint32()
if err != nil {
return fmt.Errorf("cannot read compressed size: %w", err)
}
compressedSize -= 9 // minus header length
// Read decompressed size
decompressedSize, err := r.readUint32()
if err != nil {
return fmt.Errorf("cannot read decompressed size: %w", err)
}
// Read compressed block
block := make([]byte, compressedSize)
if _, err = io.ReadFull(r.src, block); err != nil {
return fmt.Errorf("cannot read compressed block: %w", err)
}
// Decompress block
if err := r.decompressBlock(block, compressionType, decompressedSize); err != nil {
return err
}
return nil
}
func (r *Reader) decompressBlock(block []byte, compressionType byte, decompressedSize uint32) error {
var err error
r.data = make([]byte, decompressedSize)
var decoder, _ = zstd.NewReader(nil)
switch compressionType {
case noneType:
r.data = block
case lz4Type:
if _, err := lz4.UncompressBlock(block, r.data); err != nil {
return fmt.Errorf("cannot decompress lz4 block: %w", err)
}
case zstdType:
r.data = r.data[:0] // Wipe the slice but keep allocated memory
r.data, err = decoder.DecodeAll(block, r.data)
if err != nil {
return fmt.Errorf("cannot decompress zstd block: %w", err)
}
default:
return fmt.Errorf("unknown compressionType: %X", compressionType)
}
return nil
}
const (
noneType = 0x02
lz4Type = 0x82
zstdType = 0x90
)
func (r *Reader) readUint32() (uint32, error) {
b := r.scratch[:4]
_, err := io.ReadFull(r.src, b)
if err != nil {
return 0, err
}
n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
return n, nil
}
package clients
import (
"context"
"fmt"
"github.com/contentsquare/chproxy/config"
"github.com/redis/go-redis/v9"
)
func NewRedisClient(cfg config.RedisCacheConfig) (redis.UniversalClient, error) {
options := &redis.UniversalOptions{
Addrs: cfg.Addresses,
Username: cfg.Username,
Password: cfg.Password,
MaxRetries: 7, // default value = 3, since MinRetryBackoff = 8 msec & MinRetryBackoff = 512 msec
// the redis client will wait up to 1016 msec btw the 7 tries
}
if len(cfg.Addresses) == 1 {
options.DB = cfg.DBIndex
}
if len(cfg.CertFile) != 0 || len(cfg.KeyFile) != 0 {
tlsConfig, err := cfg.TLS.BuildTLSConfig(nil)
if err != nil {
return nil, err
}
options.TLSConfig = tlsConfig
}
r := redis.NewUniversalClient(options)
err := r.Ping(context.Background()).Err()
if err != nil {
return nil, fmt.Errorf("failed to reach redis: %w", err)
}
return r, nil
}
package config
import (
"bytes"
"crypto/tls"
"fmt"
"os"
"regexp"
"strings"
"time"
"github.com/mohae/deepcopy"
"golang.org/x/crypto/acme/autocert"
"gopkg.in/yaml.v2"
)
var (
defaultConfig = Config{
Clusters: []Cluster{defaultCluster},
}
defaultCluster = Cluster{
Scheme: "http",
ClusterUsers: []ClusterUser{defaultClusterUser},
HeartBeat: defaultHeartBeat,
RetryNumber: defaultRetryNumber,
}
defaultClusterUser = ClusterUser{
Name: "default",
}
defaultHeartBeat = HeartBeat{
Interval: Duration(time.Second * 5),
Timeout: Duration(time.Second * 3),
Request: "/ping",
Response: "Ok.\n",
User: "",
Password: "",
}
defaultConnectionPool = ConnectionPool{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 2,
}
defaultExecutionTime = Duration(120 * time.Second)
defaultMaxPayloadSize = ByteSize(1 << 50)
defaultRetryNumber = 0
)
// Config describes server configuration, access and proxy rules
type Config struct {
Server Server `yaml:"server,omitempty"`
Clusters []Cluster `yaml:"clusters"`
Users []User `yaml:"users"`
// Whether to print debug logs
LogDebug bool `yaml:"log_debug,omitempty"`
// Whether to ignore security warnings
HackMePlease bool `yaml:"hack_me_please,omitempty"`
NetworkGroups []NetworkGroups `yaml:"network_groups,omitempty"`
Caches []Cache `yaml:"caches,omitempty"`
ParamGroups []ParamGroup `yaml:"param_groups,omitempty"`
ConnectionPool ConnectionPool `yaml:"connection_pool,omitempty"`
networkReg map[string]Networks
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
// String implements the Stringer interface
func (c *Config) String() string {
b, err := yaml.Marshal(withoutSensitiveInfo(c))
if err != nil {
panic(err)
}
return string(b)
}
func withoutSensitiveInfo(config *Config) *Config {
const pswPlaceHolder = "XXX"
// nolint: forcetypeassert // no need to check type, it is specified by function.
c := deepcopy.Copy(config).(*Config)
for i := range c.Users {
c.Users[i].Password = pswPlaceHolder
}
for i := range c.Clusters {
if len(c.Clusters[i].KillQueryUser.Name) > 0 {
c.Clusters[i].KillQueryUser.Password = pswPlaceHolder
}
for j := range c.Clusters[i].ClusterUsers {
c.Clusters[i].ClusterUsers[j].Password = pswPlaceHolder
}
}
for i := range c.Caches {
if len(c.Caches[i].Redis.Username) > 0 {
c.Caches[i].Redis.Password = pswPlaceHolder
}
}
return c
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
// set c to the defaults and then overwrite it with the input.
*c = defaultConfig
type plain Config
if err := unmarshal((*plain)(c)); err != nil {
return err
}
if err := c.validate(); err != nil {
return err
}
return checkOverflow(c.XXX, "config")
}
func (c *Config) validate() error {
if len(c.Users) == 0 {
return fmt.Errorf("`users` must contain at least 1 user")
}
if len(c.Clusters) == 0 {
return fmt.Errorf("`clusters` must contain at least 1 cluster")
}
if len(c.Server.HTTP.ListenAddr) == 0 && len(c.Server.HTTPS.ListenAddr) == 0 {
return fmt.Errorf("neither HTTP nor HTTPS not configured")
}
if len(c.Server.HTTPS.ListenAddr) > 0 {
if len(c.Server.HTTPS.Autocert.CacheDir) == 0 && len(c.Server.HTTPS.CertFile) == 0 && len(c.Server.HTTPS.KeyFile) == 0 {
return fmt.Errorf("configuration `https` is missing. " +
"Must be specified `https.cache_dir` for autocert " +
"OR `https.key_file` and `https.cert_file` for already existing certs")
}
if len(c.Server.HTTPS.Autocert.CacheDir) > 0 {
c.Server.HTTP.ForceAutocertHandler = true
}
}
return nil
}
func (cfg *Config) setDefaults() error {
var maxResponseTime time.Duration
var err error
for i := range cfg.Clusters {
c := &cfg.Clusters[i]
for j := range c.ClusterUsers {
u := &c.ClusterUsers[j]
cud := time.Duration(u.MaxExecutionTime + u.MaxQueueTime)
if cud > maxResponseTime {
maxResponseTime = cud
}
if u.AllowedNetworks, err = cfg.groupToNetwork(u.NetworksOrGroups); err != nil {
return err
}
}
}
for i := range cfg.Users {
u := &cfg.Users[i]
u.setDefaults()
ud := time.Duration(u.MaxExecutionTime + u.MaxQueueTime)
if ud > maxResponseTime {
maxResponseTime = ud
}
if u.AllowedNetworks, err = cfg.groupToNetwork(u.NetworksOrGroups); err != nil {
return err
}
}
for i := range cfg.Caches {
c := &cfg.Caches[i]
c.setDefaults()
}
cfg.setServerMaxResponseTime(maxResponseTime)
return nil
}
func (cfg *Config) setServerMaxResponseTime(maxResponseTime time.Duration) {
if maxResponseTime < 0 {
maxResponseTime = 0
}
// Give an additional minute for the maximum response time,
// so the response body may be sent to the requester.
maxResponseTime += time.Minute
if len(cfg.Server.HTTP.ListenAddr) > 0 && cfg.Server.HTTP.WriteTimeout == 0 {
cfg.Server.HTTP.WriteTimeout = Duration(maxResponseTime)
}
if len(cfg.Server.HTTPS.ListenAddr) > 0 && cfg.Server.HTTPS.WriteTimeout == 0 {
cfg.Server.HTTPS.WriteTimeout = Duration(maxResponseTime)
}
}
func (c *Config) groupToNetwork(src NetworksOrGroups) (Networks, error) {
if len(src) == 0 {
return nil, nil
}
dst := make(Networks, 0)
for _, v := range src {
group, ok := c.networkReg[v]
if ok {
dst = append(dst, group...)
} else {
ipnet, err := stringToIPnet(v)
if err != nil {
return nil, err
}
dst = append(dst, ipnet)
}
}
return dst, nil
}
// Server describes configuration of proxy server
// These settings are immutable and can't be reloaded without restart
type Server struct {
// Optional HTTP configuration
HTTP HTTP `yaml:"http,omitempty"`
// Optional TLS configuration
HTTPS HTTPS `yaml:"https,omitempty"`
// Optional metrics handler configuration
Metrics Metrics `yaml:"metrics,omitempty"`
// Optional Proxy configuration
Proxy Proxy `yaml:"proxy,omitempty"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (s *Server) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain Server
if err := unmarshal((*plain)(s)); err != nil {
return err
}
return checkOverflow(s.XXX, "server")
}
// TimeoutCfg contains configurable http.Server timeouts
type TimeoutCfg struct {
// ReadTimeout is the maximum duration for reading the entire
// request, including the body.
// Default value is 1m
ReadTimeout Duration `yaml:"read_timeout,omitempty"`
// WriteTimeout is the maximum duration before timing out writes of the response.
// Default is largest MaxExecutionTime + MaxQueueTime value from Users or Clusters
WriteTimeout Duration `yaml:"write_timeout,omitempty"`
// IdleTimeout is the maximum amount of time to wait for the next request.
// Default is 10m
IdleTimeout Duration `yaml:"idle_timeout,omitempty"`
}
// HTTP describes configuration for server to listen HTTP connections
type HTTP struct {
// TCP address to listen to for http
ListenAddr string `yaml:"listen_addr"`
NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
// List of networks that access is allowed from
// Each list item could be IP address or subnet mask
// if omitted or zero - no limits would be applied
AllowedNetworks Networks `yaml:"-"`
// Whether to support Autocert handler for http-01 challenge
ForceAutocertHandler bool
TimeoutCfg `yaml:",inline"`
// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *HTTP) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain HTTP
if err := unmarshal((*plain)(c)); err != nil {
return err
}
if err := c.validate(); err != nil {
return err
}
return checkOverflow(c.XXX, "http")
}
func (c *HTTP) validate() error {
if c.ReadTimeout == 0 {
c.ReadTimeout = Duration(time.Minute)
}
if c.IdleTimeout == 0 {
c.IdleTimeout = Duration(time.Minute * 10)
}
return nil
}
// TLS describes generic configuration for TLS connections,
// it can be used for both HTTPS and Redis TLS.
type TLS struct {
// Certificate and key files for client cert authentication to the server
CertFile string `yaml:"cert_file,omitempty"`
KeyFile string `yaml:"key_file,omitempty"`
Autocert Autocert `yaml:"autocert,omitempty"`
InsecureSkipVerify bool `yaml:"insecure_skip_verify,omitempty"`
}
// BuildTLSConfig builds tls.Config from TLS configuration.
func (c *TLS) BuildTLSConfig(acm *autocert.Manager) (*tls.Config, error) {
tlsCfg := tls.Config{
PreferServerCipherSuites: true,
MinVersion: tls.VersionTLS12,
CurvePreferences: []tls.CurveID{
tls.CurveP256,
tls.X25519,
},
InsecureSkipVerify: c.InsecureSkipVerify, // nolint: gosec
}
if len(c.KeyFile) > 0 && len(c.CertFile) > 0 {
cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
if err != nil {
return nil, fmt.Errorf("cannot load cert for `cert_file`=%q, `key_file`=%q: %w",
c.CertFile, c.KeyFile, err)
}
tlsCfg.Certificates = []tls.Certificate{cert}
} else {
if acm == nil {
return nil, fmt.Errorf("autocert manager is not configured")
}
tlsCfg.GetCertificate = acm.GetCertificate
}
return &tlsCfg, nil
}
// HTTPS describes configuration for server to listen HTTPS connections
// It can be autocert with letsencrypt
// or custom certificate
type HTTPS struct {
// TCP address to listen to for https
// Default is `:443`
ListenAddr string `yaml:"listen_addr,omitempty"`
// TLS configuration
TLS `yaml:",inline"`
NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
// List of networks that access is allowed from
// Each list item could be IP address or subnet mask
// if omitted or zero - no limits would be applied
AllowedNetworks Networks `yaml:"-"`
TimeoutCfg `yaml:",inline"`
// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *HTTPS) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain HTTPS
if err := unmarshal((*plain)(c)); err != nil {
return err
}
if err := c.validate(); err != nil {
return err
}
return checkOverflow(c.XXX, "https")
}
func (c *HTTPS) validate() error {
if c.ReadTimeout == 0 {
c.ReadTimeout = Duration(time.Minute)
}
if c.IdleTimeout == 0 {
c.IdleTimeout = Duration(time.Minute * 10)
}
if len(c.ListenAddr) == 0 {
c.ListenAddr = ":443"
}
if err := c.validateCertConfig(); err != nil {
return err
}
return nil
}
func (c *HTTPS) validateCertConfig() error {
if len(c.Autocert.CacheDir) > 0 {
if len(c.CertFile) > 0 || len(c.KeyFile) > 0 {
return fmt.Errorf("it is forbidden to specify certificate and `https.autocert` at the same time. Choose one way")
}
if len(c.NetworksOrGroups) > 0 {
return fmt.Errorf("`letsencrypt` specification requires https server to be without `allowed_networks` limits. " +
"Otherwise, certificates will be impossible to generate")
}
}
if len(c.CertFile) > 0 && len(c.KeyFile) == 0 {
return fmt.Errorf("`https.key_file` must be specified")
}
if len(c.KeyFile) > 0 && len(c.CertFile) == 0 {
return fmt.Errorf("`https.cert_file` must be specified")
}
return nil
}
// Autocert configuration via letsencrypt
// It requires port :80 to be open
// see https://community.letsencrypt.org/t/2018-01-11-update-regarding-acme-tls-sni-and-shared-hosting-infrastructure/50188
type Autocert struct {
// Path to the directory where autocert certs are cached
CacheDir string `yaml:"cache_dir,omitempty"`
// List of host names to which proxy is allowed to respond to
// see https://godoc.org/golang.org/x/crypto/acme/autocert#HostPolicy
AllowedHosts []string `yaml:"allowed_hosts,omitempty"`
// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Autocert) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain Autocert
if err := unmarshal((*plain)(c)); err != nil {
return err
}
return checkOverflow(c.XXX, "autocert")
}
// Metrics describes configuration to access metrics endpoint
type Metrics struct {
NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
// List of networks that access is allowed from
// Each list item could be IP address or subnet mask
// if omitted or zero - no limits would be applied
AllowedNetworks Networks `yaml:"-"`
// Prometheus metric namespace
Namespace string `yaml:"namespace,omitempty"`
// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Metrics) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain Metrics
if err := unmarshal((*plain)(c)); err != nil {
return err
}
return checkOverflow(c.XXX, "metrics")
}
type Proxy struct {
// Enable enables parsing proxy headers. In proxy mode, CHProxy will try to
// parse the X-Forwarded-For, X-Real-IP or Forwarded header to extract the IP. If an other header is configured
// in the proxy settings, CHProxy will use that header instead.
Enable bool `yaml:"enable,omitempty"`
// Header allows for configuring an alternative header to parse the remote IP from, e.g.
// CF-Connecting-IP. If this is set, Enable must be set to true otherwise this setting
// will be ignored.
Header string `yaml:"header,omitempty"`
// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Proxy) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain Proxy
if err := unmarshal((*plain)(c)); err != nil {
return err
}
if !c.Enable && c.Header != "" {
return fmt.Errorf("`proxy_header` cannot be set without enabling proxy settings")
}
return checkOverflow(c.XXX, "proxy")
}
// Cluster describes CH cluster configuration
// The simplest configuration consists of:
//
// cluster description - see <remote_servers> section in CH config.xml
// and users - see <users> section in CH users.xml
type Cluster struct {
// Name of ClickHouse cluster
Name string `yaml:"name"`
// Scheme: `http` or `https`; would be applied to all nodes
// default value is `http`
Scheme string `yaml:"scheme,omitempty"`
// Nodes contains cluster nodes.
//
// Either Nodes or Replicas must be set, but not both.
Nodes []string `yaml:"nodes,omitempty"`
// Replicas contains replicas.
//
// Either Replicas or Nodes must be set, but not both.
Replicas []Replica `yaml:"replicas,omitempty"`
// ClusterUsers - list of ClickHouse users
ClusterUsers []ClusterUser `yaml:"users"`
// KillQueryUser - user configuration for killing timed out queries.
// By default timed out queries are killed under `default` user.
KillQueryUser KillQueryUser `yaml:"kill_query_user,omitempty"`
// HeartBeat - user configuration for heart beat requests
HeartBeat HeartBeat `yaml:"heartbeat,omitempty"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
// Retry number for query - how many times a query can retry after receiving a recoverable but failed response from Clickhouse node
RetryNumber int `yaml:"retry_number,omitempty"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Cluster) UnmarshalYAML(unmarshal func(interface{}) error) error {
*c = defaultCluster
type plain Cluster
if err := unmarshal((*plain)(c)); err != nil {
return err
}
if err := c.validate(); err != nil {
return err
}
return checkOverflow(c.XXX, fmt.Sprintf("cluster %q", c.Name))
}
func (c *Cluster) validate() error {
if len(c.Name) == 0 {
return fmt.Errorf("`cluster.name` cannot be empty")
}
if err := c.validateMinimumRequirements(); err != nil {
return err
}
if c.Scheme != "http" && c.Scheme != "https" {
return fmt.Errorf("`cluster.scheme` must be `http` or `https`, got %q instead for %q", c.Scheme, c.Name)
}
if c.HeartBeat.Interval == 0 && c.HeartBeat.Timeout == 0 && c.HeartBeat.Response == "" {
return fmt.Errorf("`cluster.heartbeat` cannot be unset for %q", c.Name)
}
return nil
}
func (c *Cluster) validateMinimumRequirements() error {
if len(c.Nodes) == 0 && len(c.Replicas) == 0 {
return fmt.Errorf("either `cluster.nodes` or `cluster.replicas` must be set for %q", c.Name)
}
if len(c.Nodes) > 0 && len(c.Replicas) > 0 {
return fmt.Errorf("`cluster.nodes` cannot be simultaneously set with `cluster.replicas` for %q", c.Name)
}
if len(c.ClusterUsers) == 0 {
return fmt.Errorf("`cluster.users` must contain at least 1 user for %q", c.Name)
}
return nil
}
// Replica contains ClickHouse replica configuration.
type Replica struct {
// Name is replica name.
Name string `yaml:"name"`
// Nodes contains replica nodes.
Nodes []string `yaml:"nodes"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (r *Replica) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain Replica
if err := unmarshal((*plain)(r)); err != nil {
return err
}
if len(r.Name) == 0 {
return fmt.Errorf("`replica.name` cannot be empty")
}
if len(r.Nodes) == 0 {
return fmt.Errorf("`replica.nodes` cannot be empty for %q", r.Name)
}
return checkOverflow(r.XXX, fmt.Sprintf("replica %q", r.Name))
}
// KillQueryUser - user configuration for killing timed out queries.
type KillQueryUser struct {
// User name
Name string `yaml:"name"`
// User password to access CH with basic auth
Password string `yaml:"password,omitempty"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (u *KillQueryUser) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain KillQueryUser
if err := unmarshal((*plain)(u)); err != nil {
return err
}
if len(u.Name) == 0 {
return fmt.Errorf("`cluster.kill_query_user.name` must be specified")
}
return checkOverflow(u.XXX, "kill_query_user")
}
// HeartBeat - configuration for heartbeat.
type HeartBeat struct {
// Interval is an interval of checking
// all cluster nodes for availability
// if omitted or zero - interval will be set to 5s
Interval Duration `yaml:"interval,omitempty"`
// Timeout is a timeout of wait response from cluster nodes
// if omitted or zero - interval will be set to 3s
Timeout Duration `yaml:"timeout,omitempty"`
// Request is a query
// default value is `/ping`
Request string `yaml:"request,omitempty"`
// Reference response from clickhouse on health check request
// default value is `Ok.\n`
Response string `yaml:"response,omitempty"`
// Credentials to send heartbeat requests
// for anything except '/ping'.
// If not specified, the first cluster user' creadentials are used
User string `yaml:"user,omitempty"`
Password string `yaml:"password,omitempty"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (h *HeartBeat) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain HeartBeat
if err := unmarshal((*plain)(h)); err != nil {
return err
}
return checkOverflow(h.XXX, "heartbeat")
}
// User describes list of allowed users
// which requests will be proxied to ClickHouse
type User struct {
// User name
Name string `yaml:"name"`
// User password to access proxy with basic auth
Password string `yaml:"password,omitempty"`
// ToCluster is the name of cluster where requests
// will be proxied
ToCluster string `yaml:"to_cluster"`
// ToUser is the name of cluster_user from cluster's ToCluster
// whom credentials will be used for proxying request to CH
ToUser string `yaml:"to_user"`
// Maximum number of concurrently running queries for user
// if omitted or zero - no limits would be applied
MaxConcurrentQueries uint32 `yaml:"max_concurrent_queries,omitempty"`
// Maximum duration of query execution for user
// if omitted or zero - limit is set to 120 seconds
MaxExecutionTime Duration `yaml:"max_execution_time,omitempty"`
// Maximum number of requests per minute for user
// if omitted or zero - no limits would be applied
ReqPerMin uint32 `yaml:"requests_per_minute,omitempty"`
// The burst of request packet size token bucket for user
// if omitted or zero - no limits would be applied
ReqPacketSizeTokensBurst ByteSize `yaml:"request_packet_size_tokens_burst,omitempty"`
// The request packet size tokens produced rate per second for user
// if omitted or zero - no limits would be applied
ReqPacketSizeTokensRate ByteSize `yaml:"request_packet_size_tokens_rate,omitempty"`
// Maximum number of queries waiting for execution in the queue
// if omitted or zero - queries are executed without waiting
// in the queue
MaxQueueSize uint32 `yaml:"max_queue_size,omitempty"`
// Maximum duration the query may wait in the queue
// if omitted or zero - 10s duration is used
MaxQueueTime Duration `yaml:"max_queue_time,omitempty"`
NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
// List of networks that access is allowed from
// Each list item could be IP address or subnet mask
// if omitted or zero - no limits would be applied
AllowedNetworks Networks `yaml:"-"`
// Whether to deny http connections for this user
DenyHTTP bool `yaml:"deny_http,omitempty"`
// Whether to deny https connections for this user
DenyHTTPS bool `yaml:"deny_https,omitempty"`
// Whether to allow CORS requests for this user
AllowCORS bool `yaml:"allow_cors,omitempty"`
// Name of Cache configuration to use for responses of this user
Cache string `yaml:"cache,omitempty"`
// Name of ParamGroup to use
Params string `yaml:"params,omitempty"`
// prefix_*
IsWildcarded bool `yaml:"is_wildcarded,omitempty"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (u *User) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain User
if err := unmarshal((*plain)(u)); err != nil {
return err
}
if err := u.validate(); err != nil {
return err
}
return checkOverflow(u.XXX, fmt.Sprintf("user %q", u.Name))
}
func (u *User) validate() error {
if len(u.Name) == 0 {
return fmt.Errorf("`user.name` cannot be empty")
}
if len(u.ToUser) == 0 {
return fmt.Errorf("`user.to_user` cannot be empty for %q", u.Name)
}
if len(u.ToCluster) == 0 {
return fmt.Errorf("`user.to_cluster` cannot be empty for %q", u.Name)
}
if u.DenyHTTP && u.DenyHTTPS {
return fmt.Errorf("`deny_http` and `deny_https` cannot be simultaneously set to `true` for %q", u.Name)
}
if err := u.validateWildcarded(); err != nil {
return err
}
if err := u.validateRateLimitConfig(); err != nil {
return err
}
return nil
}
func (u *User) validateWildcarded() error {
if u.IsWildcarded {
if s := strings.Split(u.Name, "*"); !(len(s) == 2 && (s[0] == "" || s[1] == "")) {
return fmt.Errorf("user name %q marked 'is_wildcared' does not match 'prefix*' or '*suffix' or '*'", u.Name)
}
}
return nil
}
func (u *User) validateRateLimitConfig() error {
if u.MaxQueueTime > 0 && u.MaxQueueSize == 0 {
return fmt.Errorf("`max_queue_size` must be set if `max_queue_time` is set for %q", u.Name)
}
if u.ReqPacketSizeTokensBurst > 0 && u.ReqPacketSizeTokensRate == 0 {
return fmt.Errorf("`request_packet_size_tokens_rate` must be set if `request_packet_size_tokens_burst` is set for %q", u.Name)
}
return nil
}
func (u *User) validateSecurity(hasHTTP, hasHTTPS bool) error {
if len(u.Password) == 0 {
if !u.DenyHTTPS && hasHTTPS {
return fmt.Errorf("https: user %q has neither password nor `allowed_networks` on `user` or `server.http` level",
u.Name)
}
if !u.DenyHTTP && hasHTTP {
return fmt.Errorf("http: user %q has neither password nor `allowed_networks` on `user` or `server.http` level",
u.Name)
}
}
if len(u.Password) > 0 && hasHTTP {
return fmt.Errorf("http: user %q is allowed to connect via http, but not limited by `allowed_networks` "+
"on `user` or `server.http` level - password could be stolen", u.Name)
}
return nil
}
func (u *User) setDefaults() {
if u.MaxExecutionTime == 0 {
u.MaxExecutionTime = defaultExecutionTime
}
}
// NetworkGroups describes a named Networks lists
type NetworkGroups struct {
// Name of the group
Name string `yaml:"name"`
// List of networks
// Each list item could be IP address or subnet mask
Networks Networks `yaml:"networks"`
// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (ng *NetworkGroups) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain NetworkGroups
if err := unmarshal((*plain)(ng)); err != nil {
return err
}
if len(ng.Name) == 0 {
return fmt.Errorf("`network_group.name` must be specified")
}
if len(ng.Networks) == 0 {
return fmt.Errorf("`network_group.networks` must contain at least one network")
}
return checkOverflow(ng.XXX, fmt.Sprintf("network_group %q", ng.Name))
}
// NetworksOrGroups is a list of strings with names of NetworkGroups
// or just Networks
type NetworksOrGroups []string
// Cache describes configuration options for caching
// responses from CH clusters
type Cache struct {
// Mode of cache (file_system, redis)
// todo make it an enum
Mode string `yaml:"mode"`
// Name of configuration for further assign
Name string `yaml:"name"`
// Expiration period for cached response
// Files which are older than expiration period will be deleted
// on new request and re-cached
Expire Duration `yaml:"expire,omitempty"`
// Deprecated: GraceTime duration before the expired entry is deleted from the cache.
// It's deprecated and in future versions it'll be replaced by user's MaxExecutionTime.
// It's already the case today if value of GraceTime is omitted.
GraceTime Duration `yaml:"grace_time,omitempty"`
FileSystem FileSystemCacheConfig `yaml:"file_system,omitempty"`
Redis RedisCacheConfig `yaml:"redis,omitempty"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
// Maximum total size of request payload for caching
MaxPayloadSize ByteSize `yaml:"max_payload_size,omitempty"`
// Whether a query cached by a user could be used by another user
SharedWithAllUsers bool `yaml:"shared_with_all_users,omitempty"`
}
func (c *Cache) setDefaults() {
if c.MaxPayloadSize <= 0 {
c.MaxPayloadSize = defaultMaxPayloadSize
}
}
type FileSystemCacheConfig struct {
//// Path to directory where cached files will be saved
Dir string `yaml:"dir"`
// Maximum total size of all cached to Dir files
// If size is exceeded - the oldest files in Dir will be deleted
// until total size becomes normal
MaxSize ByteSize `yaml:"max_size"`
}
type RedisCacheConfig struct {
TLS `yaml:",inline"`
Username string `yaml:"username,omitempty"`
Password string `yaml:"password,omitempty"`
Addresses []string `yaml:"addresses"`
DBIndex int `yaml:"db_index,omitempty"`
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *Cache) UnmarshalYAML(unmarshal func(interface{}) error) (err error) {
type plain Cache
if err = unmarshal((*plain)(c)); err != nil {
return err
}
if len(c.Name) == 0 {
return fmt.Errorf("`cache.name` must be specified")
}
switch c.Mode {
case "file_system":
err = c.checkFileSystemConfig()
case "redis":
err = c.checkRedisConfig()
default:
err = fmt.Errorf("not supported cache type %v. Supported types: [file_system]", c.Mode)
}
if err != nil {
return fmt.Errorf("failed to configure cache for %q", c.Name)
}
return checkOverflow(c.XXX, fmt.Sprintf("cache %q", c.Name))
}
func (c *Cache) checkFileSystemConfig() error {
if len(c.FileSystem.Dir) == 0 {
return fmt.Errorf("`cache.filesystem.dir` must be specified for %q", c.Name)
}
if c.FileSystem.MaxSize <= 0 {
return fmt.Errorf("`cache.filesystem.max_size` must be specified for %q", c.Name)
}
return nil
}
func (c *Cache) checkRedisConfig() error {
if len(c.Redis.Addresses) == 0 {
return fmt.Errorf("`cache.redis.addresses` must be specified for %q", c.Name)
}
return nil
}
// ParamGroup describes named group of GET params
// for sending with each query
type ParamGroup struct {
// Name of configuration for further assign
Name string `yaml:"name"`
// Params contains a list of GET params
Params []Param `yaml:"params"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (pg *ParamGroup) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain ParamGroup
if err := unmarshal((*plain)(pg)); err != nil {
return err
}
if len(pg.Name) == 0 {
return fmt.Errorf("`param_group.name` must be specified")
}
if len(pg.Params) == 0 {
return fmt.Errorf("`param_group.params` must contain at least one param")
}
return checkOverflow(pg.XXX, fmt.Sprintf("param_group %q", pg.Name))
}
// Param describes URL param value
type Param struct {
// Key is a name of params
Key string `yaml:"key"`
// Value is a value of param
Value string `yaml:"value"`
}
// ConnectionPool describes pool of connection with ClickHouse
// settings
type ConnectionPool struct {
// Maximum total number of idle connections between chproxy and all ClickHouse instances
MaxIdleConns int `yaml:"max_idle_conns,omitempty"`
// Maximum number of idle connections between chproxy and particuler ClickHouse instance
MaxIdleConnsPerHost int `yaml:"max_idle_conns_per_host,omitempty"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (cp *ConnectionPool) UnmarshalYAML(unmarshal func(interface{}) error) error {
*cp = defaultConnectionPool
type plain ConnectionPool
if err := unmarshal((*plain)(cp)); err != nil {
return err
}
if cp.MaxIdleConnsPerHost > cp.MaxIdleConns || cp.MaxIdleConns < 0 {
return fmt.Errorf("inconsistent ConnectionPool settings")
}
return checkOverflow(cp.XXX, "connection_pool")
}
// ClusterUser describes simplest <users> configuration
type ClusterUser struct {
// User name in ClickHouse users.xml config
Name string `yaml:"name"`
// User password in ClickHouse users.xml config
Password string `yaml:"password,omitempty"`
// Maximum number of concurrently running queries for user
// if omitted or zero - no limits would be applied
MaxConcurrentQueries uint32 `yaml:"max_concurrent_queries,omitempty"`
// Maximum duration of query execution for user
// if omitted or zero - limit is set to 120 seconds
MaxExecutionTime Duration `yaml:"max_execution_time,omitempty"`
// Maximum number of requests per minute for user
// if omitted or zero - no limits would be applied
ReqPerMin uint32 `yaml:"requests_per_minute,omitempty"`
// The burst of request packet size token bucket for user
// if omitted or zero - no limits would be applied
ReqPacketSizeTokensBurst ByteSize `yaml:"request_packet_size_tokens_burst,omitempty"`
// The request packet size tokens produced rate for user
// if omitted or zero - no limits would be applied
ReqPacketSizeTokensRate ByteSize `yaml:"request_packet_size_tokens_rate,omitempty"`
// Maximum number of queries waiting for execution in the queue
// if omitted or zero - queries are executed without waiting
// in the queue
MaxQueueSize uint32 `yaml:"max_queue_size,omitempty"`
// Maximum duration the query may wait in the queue
// if omitted or zero - 10s duration is used
MaxQueueTime Duration `yaml:"max_queue_time,omitempty"`
NetworksOrGroups NetworksOrGroups `yaml:"allowed_networks,omitempty"`
// List of networks that access is allowed from
// Each list item could be IP address or subnet mask
// if omitted or zero - no limits would be applied
AllowedNetworks Networks `yaml:"-"`
// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (cu *ClusterUser) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain ClusterUser
if err := unmarshal((*plain)(cu)); err != nil {
return err
}
if len(cu.Name) == 0 {
return fmt.Errorf("`cluster.user.name` cannot be empty")
}
if cu.MaxQueueTime > 0 && cu.MaxQueueSize == 0 {
return fmt.Errorf("`max_queue_size` must be set if `max_queue_time` is set for %q", cu.Name)
}
if cu.ReqPacketSizeTokensBurst > 0 && cu.ReqPacketSizeTokensRate == 0 {
return fmt.Errorf("`request_packet_size_tokens_rate` must be set if `request_packet_size_tokens_burst` is set for %q", cu.Name)
}
return checkOverflow(cu.XXX, fmt.Sprintf("cluster.user %q", cu.Name))
}
// LoadFile loads and validates configuration from provided .yml file
func LoadFile(filename string) (*Config, error) {
content, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
content = findAndReplacePlaceholders(content)
cfg := &Config{}
if err := yaml.Unmarshal(content, cfg); err != nil {
return nil, err
}
cfg.networkReg = make(map[string]Networks, len(cfg.NetworkGroups))
for _, ng := range cfg.NetworkGroups {
if _, ok := cfg.networkReg[ng.Name]; ok {
return nil, fmt.Errorf("duplicate `network_groups.name` %q", ng.Name)
}
cfg.networkReg[ng.Name] = ng.Networks
}
if cfg.Server.HTTP.AllowedNetworks, err = cfg.groupToNetwork(cfg.Server.HTTP.NetworksOrGroups); err != nil {
return nil, err
}
if cfg.Server.HTTPS.AllowedNetworks, err = cfg.groupToNetwork(cfg.Server.HTTPS.NetworksOrGroups); err != nil {
return nil, err
}
if cfg.Server.Metrics.AllowedNetworks, err = cfg.groupToNetwork(cfg.Server.Metrics.NetworksOrGroups); err != nil {
return nil, err
}
if err := cfg.setDefaults(); err != nil {
return nil, err
}
if err := cfg.checkVulnerabilities(); err != nil {
return nil, fmt.Errorf("security breach: %w\nSet option `hack_me_please=true` to disable security errors", err)
}
return cfg, nil
}
var envVarRegex = regexp.MustCompile(`\${([a-zA-Z_][a-zA-Z0-9_]*)}`)
// findAndReplacePlaceholders finds all environment variables placeholders in the config.
// Each placeholder is a string like ${VAR_NAME}. They will be replaced with the value of the
// corresponding environment variable. It returns the new content with replaced placeholders.
func findAndReplacePlaceholders(content []byte) []byte {
for _, match := range envVarRegex.FindAllSubmatch(content, -1) {
envVar := os.Getenv(string(match[1]))
if envVar != "" {
content = bytes.ReplaceAll(content, match[0], []byte(envVar))
}
}
return content
}
func (c Config) checkVulnerabilities() error {
if c.HackMePlease {
return nil
}
hasHTTPS := len(c.Server.HTTPS.ListenAddr) > 0 && len(c.Server.HTTPS.NetworksOrGroups) == 0
hasHTTP := len(c.Server.HTTP.ListenAddr) > 0 && len(c.Server.HTTP.NetworksOrGroups) == 0
for _, u := range c.Users {
if len(u.NetworksOrGroups) != 0 {
continue
}
if err := u.validateSecurity(hasHTTP, hasHTTPS); err != nil {
return err
}
}
return nil
}
package config
import (
"fmt"
"math"
"net"
"regexp"
"strconv"
"strings"
"time"
)
// ByteSize holds size in bytes.
//
// May be used in yaml for parsing byte size values.
type ByteSize uint64
var byteSizeRegexp = regexp.MustCompile(`^(\d+(?:\.\d+)?)\s*([KMGTP]?)B?$`)
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (bs *ByteSize) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
s = strings.ToUpper(s)
value, unit, err := parseStringParts(s)
if err != nil {
return err
}
if value <= 0 {
return fmt.Errorf("byte size %q must be positive", s)
}
k := float64(1)
switch unit {
case "P":
k = 1 << 50
case "T":
k = 1 << 40
case "G":
k = 1 << 30
case "M":
k = 1 << 20
case "K":
k = 1 << 10
}
value *= k
*bs = ByteSize(value)
// check for overflow
e := math.Abs(float64(*bs)-value) / value
if e > 1e-6 {
return fmt.Errorf("byte size %q is too big", s)
}
return nil
}
func parseStringParts(s string) (float64, string, error) {
s = strings.ToUpper(s)
parts := byteSizeRegexp.FindStringSubmatch(strings.TrimSpace(s))
if len(parts) < 3 {
return -1, "", fmt.Errorf("cannot parse byte size %q: it must be positive float followed by optional units. For example, 1.5Gb, 3T", s)
}
value, err := strconv.ParseFloat(parts[1], 64)
if err != nil {
return -1, "", fmt.Errorf("cannot parse byte size %q: it must be positive float followed by optional units. For example, 1.5Gb, 3T; err: %w", s, err)
}
unit := parts[2]
return value, unit, nil
}
// Networks is a list of IPNet entities
type Networks []*net.IPNet
// MarshalYAML implements yaml.Marshaler interface.
//
// It prettifies yaml output for Networks.
func (n Networks) MarshalYAML() (interface{}, error) {
var a []string
for _, x := range n {
a = append(a, x.String())
}
return a, nil
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (n *Networks) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s []string
if err := unmarshal(&s); err != nil {
return err
}
networks := make(Networks, len(s))
for i, s := range s {
ipnet, err := stringToIPnet(s)
if err != nil {
return err
}
networks[i] = ipnet
}
*n = networks
return nil
}
// Contains checks whether passed addr is in the range of networks
func (n Networks) Contains(addr string) bool {
if len(n) == 0 {
return true
}
h, _, err := net.SplitHostPort(addr)
if err != nil {
// If we only have an IP address. This happens when the proxy middleware is enabled.
h = addr
}
ip := net.ParseIP(h)
if ip == nil {
panic(fmt.Sprintf("BUG: unexpected error while parsing IP: %s", h))
}
for _, ipnet := range n {
if ipnet.Contains(ip) {
return true
}
}
return false
}
// Duration wraps time.Duration. It is used to parse the custom duration format
type Duration time.Duration
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
dur, err := StringToDuration(s)
if err != nil {
return err
}
*d = dur
return nil
}
// String implements the Stringer interface.
func (d Duration) String() string {
factors := map[string]time.Duration{
"w": time.Hour * 24 * 7,
"d": time.Hour * 24,
"h": time.Hour,
"m": time.Minute,
"s": time.Second,
"ms": time.Millisecond,
"µs": time.Microsecond,
"ns": 1,
}
var t = time.Duration(d)
unit := "ns"
//nolint:exhaustive // Custom duration counter that doesn't switch on the duration.
switch time.Duration(0) {
case t % factors["w"]:
unit = "w"
case t % factors["d"]:
unit = "d"
case t % factors["h"]:
unit = "h"
case t % factors["m"]:
unit = "m"
case t % factors["s"]:
unit = "s"
case t % factors["ms"]:
unit = "ms"
case t % factors["µs"]:
unit = "µs"
}
return fmt.Sprintf("%d%v", t/factors[unit], unit)
}
// MarshalYAML implements the yaml.Marshaler interface.
func (d Duration) MarshalYAML() (interface{}, error) {
return d.String(), nil
}
// borrowed from github.com/prometheus/prometheus
var durationRE = regexp.MustCompile("^([0-9]+)(w|d|h|m|s|ms|µs|ns)$")
// StringToDuration parses a string into a time.Duration,
// assuming that a week always has 7d, and a day always has 24h.
func StringToDuration(durationStr string) (Duration, error) {
n, unit, err := parseDurationParts(durationStr)
if err != nil {
return 0, err
}
return calculateDuration(n, unit)
}
func parseDurationParts(s string) (int, string, error) {
matches := durationRE.FindStringSubmatch(s)
if len(matches) != 3 {
return 0, "", fmt.Errorf("not a valid duration string: %q", s)
}
n, err := strconv.Atoi(matches[1])
if err != nil {
return 0, "", fmt.Errorf("duration too long: %q", matches[1])
}
unit := matches[2]
return n, unit, nil
}
func calculateDuration(n int, unit string) (Duration, error) {
dur := time.Duration(n)
switch unit {
case "w":
dur *= time.Hour * 24 * 7
case "d":
dur *= time.Hour * 24
case "h":
dur *= time.Hour
case "m":
dur *= time.Minute
case "s":
dur *= time.Second
case "ms":
dur *= time.Millisecond
case "µs":
dur *= time.Microsecond
case "ns":
default:
return 0, fmt.Errorf("invalid time unit in duration string: %q", unit)
}
return Duration(dur), nil
}
package config
import (
"fmt"
"net"
"strings"
)
const entireIPv4 = "0.0.0.0/0"
func stringToIPnet(s string) (*net.IPNet, error) {
if s == entireIPv4 {
return nil, fmt.Errorf("suspicious mask specified \"0.0.0.0/0\". " +
"If you want to allow all then just omit `allowed_networks` field")
}
ip := s
if !strings.Contains(ip, `/`) {
ip += "/32"
}
_, ipnet, err := net.ParseCIDR(ip)
if err != nil {
return nil, fmt.Errorf("wrong network group name or address %q: %w", s, err)
}
return ipnet, nil
}
func checkOverflow(m map[string]interface{}, ctx string) error {
if len(m) > 0 {
var keys []string
for k := range m {
keys = append(keys, k)
}
return fmt.Errorf("unknown fields in %s: %s", ctx, strings.Join(keys, ", "))
}
return nil
}
package counter
import "sync/atomic"
type Counter struct {
value atomic.Uint32
}
func (c *Counter) Store(n uint32) { c.value.Store(n) }
func (c *Counter) Load() uint32 { return c.value.Load() }
func (c *Counter) Dec() { c.value.Add(^uint32(0)) }
func (c *Counter) Inc() uint32 { return c.value.Add(1) }
package heartbeat
import (
"context"
"fmt"
"io"
"net/http"
"time"
"github.com/contentsquare/chproxy/config"
)
var errUnexpectedResponse = fmt.Errorf("unexpected response")
type HeartBeat interface {
IsHealthy(ctx context.Context, addr string) error
Interval() time.Duration
}
type heartBeatOpts struct {
defaultUser string
defaultPassword string
}
type Option interface {
apply(*heartBeatOpts)
}
type defaultUser struct {
defaultUser string
defaultPassword string
}
func (o defaultUser) apply(opts *heartBeatOpts) {
opts.defaultUser = o.defaultUser
opts.defaultPassword = o.defaultPassword
}
func WithDefaultUser(user, password string) Option {
return defaultUser{
defaultUser: user,
defaultPassword: password,
}
}
type heartBeat struct {
interval time.Duration
timeout time.Duration
request string
response string
user string
password string
}
// User credentials are not needed
const defaultEndpoint string = "/ping"
func NewHeartbeat(c config.HeartBeat, options ...Option) HeartBeat {
opts := &heartBeatOpts{}
for _, o := range options {
o.apply(opts)
}
newHB := &heartBeat{
interval: time.Duration(c.Interval),
timeout: time.Duration(c.Timeout),
request: c.Request,
response: c.Response,
}
if c.Request != defaultEndpoint {
if c.User != "" {
newHB.user = c.User
newHB.password = c.Password
} else {
newHB.user = opts.defaultUser
newHB.password = opts.defaultPassword
}
}
if newHB.request != defaultEndpoint && newHB.user == "" {
panic("BUG: user is empty, no default user provided")
}
return newHB
}
func (hb *heartBeat) IsHealthy(ctx context.Context, addr string) error {
req, err := http.NewRequest("GET", addr+hb.request, nil)
if err != nil {
return err
}
if hb.request != defaultEndpoint {
req.SetBasicAuth(hb.user, hb.password)
}
ctx, cancel := context.WithTimeout(ctx, hb.timeout)
defer cancel()
req = req.WithContext(ctx)
startTime := time.Now()
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("cannot send request in %s: %w", time.Since(startTime), err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("non-200 status code: %s", resp.Status)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("cannot read response in %s: %w", time.Since(startTime), err)
}
r := string(body)
if r != hb.response {
return fmt.Errorf("%w: %s", errUnexpectedResponse, r)
}
return nil
}
func (hb *heartBeat) Interval() time.Duration {
return hb.interval
}
package topology
// TODO this is only here to avoid recursive imports. We should have a separate package for metrics.
import (
"github.com/contentsquare/chproxy/config"
"github.com/prometheus/client_golang/prometheus"
)
var (
HostHealth *prometheus.GaugeVec
HostPenalties *prometheus.CounterVec
)
func initMetrics(cfg *config.Config) {
namespace := cfg.Server.Metrics.Namespace
HostHealth = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: namespace,
Name: "host_health",
Help: "Health state of hosts by clusters",
},
[]string{"cluster", "replica", "cluster_node"},
)
HostPenalties = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "host_penalties_total",
Help: "Total number of given penalties by host",
},
[]string{"cluster", "replica", "cluster_node"},
)
}
func RegisterMetrics(cfg *config.Config) {
initMetrics(cfg)
prometheus.MustRegister(HostHealth, HostPenalties)
}
func reportNodeHealthMetric(clusterName, replicaName, nodeName string, active bool) {
label := prometheus.Labels{
"cluster": clusterName,
"replica": replicaName,
"cluster_node": nodeName,
}
if active {
HostHealth.With(label).Set(1)
} else {
HostHealth.With(label).Set(0)
}
}
func incrementPenaltiesMetric(clusterName, replicaName, nodeName string) {
label := prometheus.Labels{
"cluster": clusterName,
"replica": replicaName,
"cluster_node": nodeName,
}
HostPenalties.With(label).Inc()
}
package topology
import (
"context"
"net/url"
"sync/atomic"
"time"
"github.com/contentsquare/chproxy/internal/counter"
"github.com/contentsquare/chproxy/internal/heartbeat"
"github.com/contentsquare/chproxy/log"
)
const (
// prevents excess goroutine creating while penalizing overloaded host
DefaultPenaltySize = 5
DefaultMaxSize = 300
DefaultPenaltyDuration = time.Second * 10
)
type nodeOpts struct {
defaultActive bool
penaltySize uint32
penaltyMaxSize uint32
penaltyDuration time.Duration
}
func defaultNodeOpts() nodeOpts {
return nodeOpts{
penaltySize: DefaultPenaltySize,
penaltyMaxSize: DefaultMaxSize,
penaltyDuration: DefaultPenaltyDuration,
}
}
type NodeOption interface {
apply(*nodeOpts)
}
type defaultActive struct {
active bool
}
func (o defaultActive) apply(opts *nodeOpts) {
opts.defaultActive = o.active
}
func WithDefaultActiveState(active bool) NodeOption {
return defaultActive{
active: active,
}
}
type Node struct {
// Node Address.
addr *url.URL
// Whether this node is alive.
active atomic.Bool
// Counter of currently running connections.
connections counter.Counter
// Counter of unsuccesfull request to decrease host priority.
penalty atomic.Uint32
// Heartbeat function
hb heartbeat.HeartBeat
// TODO These fields are only used for labels in prometheus. We should have a different way to pass the labels.
// For metrics only
clusterName string
replicaName string
// Additional configuration options
opts nodeOpts
}
func NewNode(addr *url.URL, hb heartbeat.HeartBeat, clusterName, replicaName string, opts ...NodeOption) *Node {
nodeOpts := defaultNodeOpts()
for _, opt := range opts {
opt.apply(&nodeOpts)
}
n := &Node{
addr: addr,
hb: hb,
clusterName: clusterName,
replicaName: replicaName,
opts: nodeOpts,
}
if n.opts.defaultActive {
n.SetIsActive(true)
}
return n
}
func (n *Node) IsActive() bool {
return n.active.Load()
}
func (n *Node) SetIsActive(active bool) {
n.active.Store(active)
}
// StartHeartbeat runs the heartbeat healthcheck against the node
// until the done channel is closed.
// If the heartbeat fails, the active status of the node is changed.
func (n *Node) StartHeartbeat(done <-chan struct{}) {
ctx, cancel := context.WithCancel(context.Background())
for {
n.heartbeat(ctx)
select {
case <-done:
cancel()
return
case <-time.After(n.hb.Interval()):
}
}
}
func (n *Node) heartbeat(ctx context.Context) {
if err := n.hb.IsHealthy(ctx, n.addr.String()); err == nil {
n.active.Store(true)
reportNodeHealthMetric(n.clusterName, n.replicaName, n.Host(), true)
} else {
log.Errorf("error while health-checking %q host: %s", n.Host(), err)
n.active.Store(false)
reportNodeHealthMetric(n.clusterName, n.replicaName, n.Host(), false)
}
}
// Penalize a node if a request failed to decrease it's priority.
// If the penalty is already at the maximum allowed size this function
// will not penalize the node further.
// A function will be registered to run after the penalty duration to
// increase the priority again.
func (n *Node) Penalize() {
penalty := n.penalty.Load()
if penalty >= n.opts.penaltyMaxSize {
return
}
incrementPenaltiesMetric(n.clusterName, n.replicaName, n.Host())
n.penalty.Add(n.opts.penaltySize)
time.AfterFunc(n.opts.penaltyDuration, func() {
n.penalty.Add(^uint32(n.opts.penaltySize - 1))
})
}
// CurrentLoad returns the current node returns the number of open connections
// plus the penalty.
func (n *Node) CurrentLoad() uint32 {
c := n.connections.Load()
p := n.penalty.Load()
return c + p
}
func (n *Node) CurrentConnections() uint32 {
return n.connections.Load()
}
func (n *Node) CurrentPenalty() uint32 {
return n.penalty.Load()
}
func (n *Node) IncrementConnections() {
n.connections.Inc()
}
func (n *Node) DecrementConnections() {
n.connections.Dec()
}
func (n *Node) Scheme() string {
return n.addr.Scheme
}
func (n *Node) Host() string {
return n.addr.Host
}
func (n *Node) ReplicaName() string {
return n.replicaName
}
func (n *Node) String() string {
return n.addr.String()
}
package main
import (
"errors"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/contentsquare/chproxy/cache"
"github.com/contentsquare/chproxy/log"
"github.com/prometheus/client_golang/prometheus"
)
type ResponseWriterWithCode interface {
http.ResponseWriter
StatusCode() int
}
type StatResponseWriter interface {
http.ResponseWriter
http.CloseNotifier
StatusCode() int
SetStatusCode(code int)
}
var _ StatResponseWriter = &statResponseWriter{}
// statResponseWriter collects the amount of bytes written.
//
// The wrapped ResponseWriter must implement http.CloseNotifier.
//
// Additionally it caches response status code.
type statResponseWriter struct {
http.ResponseWriter
statusCode int
// wroteHeader tells whether the header's been written to
// the original ResponseWriter
wroteHeader bool
bytesWritten prometheus.Counter
}
const (
XCacheHit = "HIT"
XCacheMiss = "MISS"
XCacheNA = "N/A"
)
func RespondWithData(rw http.ResponseWriter, data io.Reader, metadata cache.ContentMetadata, ttl time.Duration, cacheHit string, statusCode int, labels prometheus.Labels) error {
h := rw.Header()
if len(metadata.Type) > 0 {
h.Set("Content-Type", metadata.Type)
}
if len(metadata.Encoding) > 0 {
h.Set("Content-Encoding", metadata.Encoding)
}
h.Set("Content-Length", fmt.Sprintf("%d", metadata.Length))
if ttl > 0 {
expireSeconds := uint(ttl / time.Second)
h.Set("Cache-Control", fmt.Sprintf("max-age=%d", expireSeconds))
}
h.Set("X-Cache", cacheHit)
rw.WriteHeader(statusCode)
if _, err := io.Copy(rw, data); err != nil {
var perr *cache.RedisCacheError
if errors.As(err, &perr) {
cacheCorruptedFetch.With(labels).Inc()
log.Debugf("redis cache error")
}
log.Errorf("cannot send response to client: %s", err)
return fmt.Errorf("cannot send response to client: %w", err)
}
return nil
}
func (rw *statResponseWriter) SetStatusCode(code int) {
rw.statusCode = code
}
func (rw *statResponseWriter) StatusCode() int {
if rw.statusCode == 0 {
return http.StatusOK
}
return rw.statusCode
}
func (rw *statResponseWriter) Write(b []byte) (int, error) {
if rw.statusCode == 0 {
rw.statusCode = http.StatusOK
}
if !rw.wroteHeader {
rw.ResponseWriter.WriteHeader(rw.statusCode)
rw.wroteHeader = true
}
n, err := rw.ResponseWriter.Write(b)
rw.bytesWritten.Add(float64(n))
return n, err
}
func (rw *statResponseWriter) WriteHeader(statusCode int) {
// cache statusCode to keep the opportunity to change it in further
rw.statusCode = statusCode
}
// CloseNotify implements http.CloseNotifier
func (rw *statResponseWriter) CloseNotify() <-chan bool {
// The rw.ResponseWriter must implement http.CloseNotifier
rwc, ok := rw.ResponseWriter.(http.CloseNotifier)
if !ok {
panic("BUG: the wrapped ResponseWriter must implement http.CloseNotifier")
}
return rwc.CloseNotify()
}
var _ io.ReadCloser = &statReadCloser{}
// statReadCloser collects the amount of bytes read.
type statReadCloser struct {
io.ReadCloser
bytesRead prometheus.Counter
}
func (src *statReadCloser) Read(p []byte) (int, error) {
n, err := src.ReadCloser.Read(p)
src.bytesRead.Add(float64(n))
return n, err
}
var _ io.ReadCloser = &cachedReadCloser{}
// cachedReadCloser caches the first 1Kb form the wrapped ReadCloser.
type cachedReadCloser struct {
io.ReadCloser
// bLock protects b from concurrent access when Read and String
// are called from concurrent goroutines.
bLock sync.Mutex
// b holds up to 1Kb of the initial data read from ReadCloser.
b []byte
}
func (crc *cachedReadCloser) Read(p []byte) (int, error) {
n, err := crc.ReadCloser.Read(p)
crc.bLock.Lock()
if len(crc.b) < 1024 {
crc.b = append(crc.b, p[:n]...)
if len(crc.b) >= 1024 {
crc.b = append(crc.b[:1024], "..."...)
}
}
crc.bLock.Unlock()
// Do not cache the last read operation, since it slows down
// reading large amounts of data such as large INSERT queries.
return n, err
}
func (crc *cachedReadCloser) String() string {
crc.bLock.Lock()
s := string(crc.b)
crc.bLock.Unlock()
return s
}
package log
import (
"fmt"
"io"
"log"
"os"
"sync/atomic"
)
var (
stdLogFlags = log.LstdFlags | log.Lshortfile | log.LUTC
outputCallDepth = 2
debugLogger = log.New(os.Stderr, "DEBUG: ", stdLogFlags)
infoLogger = log.New(os.Stderr, "INFO: ", stdLogFlags)
errorLogger = log.New(os.Stderr, "ERROR: ", stdLogFlags)
fatalLogger = log.New(os.Stderr, "FATAL: ", stdLogFlags)
// NilLogger suppresses all the log messages.
NilLogger = log.New(io.Discard, "", stdLogFlags)
)
// SuppressOutput suppresses all output from logs if `suppress` is true
// used while testing
func SuppressOutput(suppress bool) {
if suppress {
debugLogger.SetOutput(io.Discard)
infoLogger.SetOutput(io.Discard)
errorLogger.SetOutput(io.Discard)
} else {
debugLogger.SetOutput(os.Stderr)
infoLogger.SetOutput(os.Stderr)
errorLogger.SetOutput(os.Stderr)
}
}
var debug uint32
// SetDebug sets output into debug mode if true passed
func SetDebug(val bool) {
if val {
atomic.StoreUint32(&debug, 1)
} else {
atomic.StoreUint32(&debug, 0)
}
}
// Debugf prints debug message according to a format
func Debugf(format string, args ...interface{}) {
if atomic.LoadUint32(&debug) == 0 {
return
}
s := fmt.Sprintf(format, args...)
debugLogger.Output(outputCallDepth, s) // nolint
}
// Infof prints info message according to a format
func Infof(format string, args ...interface{}) {
s := fmt.Sprintf(format, args...)
infoLogger.Output(outputCallDepth, s) // nolint
}
// Errorf prints warning message according to a format
func Errorf(format string, args ...interface{}) {
s := fmt.Sprintf(format, args...)
errorLogger.Output(outputCallDepth, s) // nolint
}
// ErrorWithCallDepth prints err into error log using the given callDepth.
func ErrorWithCallDepth(err error, callDepth int) {
s := err.Error()
errorLogger.Output(outputCallDepth+callDepth, s) //nolint
}
// Fatalf prints fatal message according to a format and exits program
func Fatalf(format string, args ...interface{}) {
s := fmt.Sprintf(format, args...)
fatalLogger.Output(outputCallDepth, s) // nolint
os.Exit(1)
}
package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"strings"
"sync/atomic"
"syscall"
"time"
"github.com/contentsquare/chproxy/config"
"github.com/contentsquare/chproxy/log"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/crypto/acme/autocert"
)
var (
configFile = flag.String("config", "", "Proxy configuration filename")
version = flag.Bool("version", false, "Prints current version and exits")
enableTCP6 = flag.Bool("enableTCP6", false, "Whether to enable listening for IPv6 TCP ports. "+
"By default only IPv4 TCP ports are listened")
)
var (
proxy *reverseProxy
// networks allow lists
allowedNetworksHTTP atomic.Value
allowedNetworksHTTPS atomic.Value
allowedNetworksMetrics atomic.Value
proxyHandler atomic.Value
)
func main() {
flag.Parse()
if *version {
fmt.Printf("%s\n", versionString())
os.Exit(0)
}
log.Infof("%s", versionString())
log.Infof("Loading config: %s", *configFile)
cfg, err := loadConfig()
if err != nil {
log.Fatalf("error while loading config: %s", err)
}
registerMetrics(cfg)
if err = applyConfig(cfg); err != nil {
log.Fatalf("error while applying config: %s", err)
}
configSuccess.Set(1)
configSuccessTime.Set(float64(time.Now().Unix()))
log.Infof("Loading config %q: successful", *configFile)
setupReloadConfigWatch()
server := cfg.Server
if len(server.HTTP.ListenAddr) == 0 && len(server.HTTPS.ListenAddr) == 0 {
panic("BUG: broken config validation - `listen_addr` is not configured")
}
if server.HTTP.ForceAutocertHandler {
autocertManager = newAutocertManager(server.HTTPS.Autocert)
}
if len(server.HTTPS.ListenAddr) != 0 {
go serveTLS(server.HTTPS)
}
if len(server.HTTP.ListenAddr) != 0 {
go serve(server.HTTP)
}
select {}
}
func setupReloadConfigWatch() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGHUP)
go func() {
for {
if <-c == syscall.SIGHUP {
log.Infof("SIGHUP received. Going to reload config %s ...", *configFile)
if err := reloadConfig(); err != nil {
log.Errorf("error while reloading config: %s", err)
continue
}
log.Infof("Reloading config %s: successful", *configFile)
}
}
}()
}
var autocertManager *autocert.Manager
func newAutocertManager(cfg config.Autocert) *autocert.Manager {
if len(cfg.CacheDir) > 0 {
if err := os.MkdirAll(cfg.CacheDir, 0o700); err != nil {
log.Fatalf("error while creating folder %q: %s", cfg.CacheDir, err)
}
}
var hp autocert.HostPolicy
if len(cfg.AllowedHosts) != 0 {
allowedHosts := make(map[string]struct{}, len(cfg.AllowedHosts))
for _, v := range cfg.AllowedHosts {
allowedHosts[v] = struct{}{}
}
hp = func(_ context.Context, host string) error {
if _, ok := allowedHosts[host]; ok {
return nil
}
return fmt.Errorf("host %q doesn't match `host_policy` configuration", host)
}
}
return &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(cfg.CacheDir),
HostPolicy: hp,
}
}
func newListener(listenAddr string) net.Listener {
network := "tcp4"
if *enableTCP6 {
// Enable listening on both tcp4 and tcp6
network = "tcp"
}
ln, err := net.Listen(network, listenAddr)
if err != nil {
log.Fatalf("cannot listen for %q: %s", listenAddr, err)
}
return ln
}
func serveTLS(cfg config.HTTPS) {
ln := newListener(cfg.ListenAddr)
h := http.HandlerFunc(serveHTTP)
tlsCfg, err := cfg.TLS.BuildTLSConfig(autocertManager)
if err != nil {
log.Fatalf("cannot build TLS config: %s", err)
}
tln := tls.NewListener(ln, tlsCfg)
log.Infof("Serving https on %q", cfg.ListenAddr)
if err := listenAndServe(tln, h, cfg.TimeoutCfg); err != nil {
log.Fatalf("TLS server error on %q: %s", cfg.ListenAddr, err)
}
}
func serve(cfg config.HTTP) {
var h http.Handler
ln := newListener(cfg.ListenAddr)
h = http.HandlerFunc(serveHTTP)
if cfg.ForceAutocertHandler {
if autocertManager == nil {
panic("BUG: autocertManager is not inited")
}
addr := ln.Addr().String()
parts := strings.Split(addr, ":")
if parts[len(parts)-1] != "80" {
log.Fatalf("`letsencrypt` specification requires http server to listen on :80 port to satisfy http-01 challenge. " +
"Otherwise, certificates will be impossible to generate")
}
h = autocertManager.HTTPHandler(h)
}
log.Infof("Serving http on %q", cfg.ListenAddr)
if err := listenAndServe(ln, h, cfg.TimeoutCfg); err != nil {
log.Fatalf("HTTP server error on %q: %s", cfg.ListenAddr, err)
}
}
func newServer(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) *http.Server {
// nolint:gosec // We already configured ReadTimeout, so no need to set ReadHeaderTimeout as well.
return &http.Server{
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
Handler: h,
ReadTimeout: time.Duration(cfg.ReadTimeout),
WriteTimeout: time.Duration(cfg.WriteTimeout),
IdleTimeout: time.Duration(cfg.IdleTimeout),
// Suppress error logging from the server, since chproxy
// must handle all these errors in the code.
ErrorLog: log.NilLogger,
}
}
func listenAndServe(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) error {
s := newServer(ln, h, cfg)
return s.Serve(ln)
}
var promHandler = promhttp.Handler()
//nolint:cyclop //TODO reduce complexity here.
func serveHTTP(rw http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodPost:
// Only GET and POST methods are supported.
case http.MethodOptions:
// This is required for CORS shit :)
rw.Header().Set("Allow", "GET,POST")
return
default:
err := fmt.Errorf("%q: unsupported method %q", r.RemoteAddr, r.Method)
rw.Header().Set("Connection", "close")
respondWith(rw, err, http.StatusMethodNotAllowed)
return
}
switch r.URL.Path {
case "/favicon.ico":
case "/metrics":
// nolint:forcetypeassert // We will cover this by tests as we control what is stored.
an := allowedNetworksMetrics.Load().(*config.Networks)
if !an.Contains(r.RemoteAddr) {
err := fmt.Errorf("connections to /metrics are not allowed from %s", r.RemoteAddr)
rw.Header().Set("Connection", "close")
respondWith(rw, err, http.StatusForbidden)
return
}
proxy.refreshCacheMetrics()
promHandler.ServeHTTP(rw, r)
case "/", "/query":
var err error
// nolint:forcetypeassert // We will cover this by tests as we control what is stored.
proxyHandler := proxyHandler.Load().(*ProxyHandler)
r.RemoteAddr = proxyHandler.GetRemoteAddr(r)
var an *config.Networks
if r.TLS != nil {
// nolint:forcetypeassert // We will cover this by tests as we control what is stored.
an = allowedNetworksHTTPS.Load().(*config.Networks)
err = fmt.Errorf("https connections are not allowed from %s", r.RemoteAddr)
} else {
// nolint:forcetypeassert // We will cover this by tests as we control what is stored.
an = allowedNetworksHTTP.Load().(*config.Networks)
err = fmt.Errorf("http connections are not allowed from %s", r.RemoteAddr)
}
if !an.Contains(r.RemoteAddr) {
rw.Header().Set("Connection", "close")
respondWith(rw, err, http.StatusForbidden)
return
}
proxy.ServeHTTP(rw, r)
default:
badRequest.Inc()
err := fmt.Errorf("%q: unsupported path: %q", r.RemoteAddr, r.URL.Path)
rw.Header().Set("Connection", "close")
respondWith(rw, err, http.StatusBadRequest)
}
}
func loadConfig() (*config.Config, error) {
if *configFile == "" {
log.Fatalf("Missing -config flag")
}
cfg, err := config.LoadFile(*configFile)
if err != nil {
return nil, fmt.Errorf("can't load config %q: %w", *configFile, err)
}
return cfg, nil
}
// a configuration parameter value that is used in proxy initialization
// changed
func proxyConfigChanged(cfgCp *config.ConnectionPool, rp *reverseProxy) bool {
return cfgCp.MaxIdleConns != proxy.maxIdleConns ||
cfgCp.MaxIdleConnsPerHost != proxy.maxIdleConnsPerHost
}
func applyConfig(cfg *config.Config) error {
if proxy == nil || proxyConfigChanged(&cfg.ConnectionPool, proxy) {
proxy = newReverseProxy(&cfg.ConnectionPool)
}
if err := proxy.applyConfig(cfg); err != nil {
return err
}
allowedNetworksHTTP.Store(&cfg.Server.HTTP.AllowedNetworks)
allowedNetworksHTTPS.Store(&cfg.Server.HTTPS.AllowedNetworks)
allowedNetworksMetrics.Store(&cfg.Server.Metrics.AllowedNetworks)
proxyHandler.Store(NewProxyHandler(&cfg.Server.Proxy))
log.SetDebug(cfg.LogDebug)
log.Infof("Loaded config:\n%s", cfg)
return nil
}
func reloadConfig() error {
cfg, err := loadConfig()
if err != nil {
return err
}
return applyConfig(cfg)
}
var (
buildTag = "unknown"
buildRevision = "unknown"
buildTime = "unknown"
)
func versionString() string {
ver := buildTag
if len(ver) == 0 {
ver = "unknown"
}
return fmt.Sprintf("chproxy ver. %s, rev. %s, built at %s", ver, buildRevision, buildTime)
}
package main
import (
"github.com/contentsquare/chproxy/config"
"github.com/contentsquare/chproxy/internal/topology"
"github.com/prometheus/client_golang/prometheus"
)
var (
statusCodes *prometheus.CounterVec
requestSum *prometheus.CounterVec
requestSuccess *prometheus.CounterVec
limitExcess *prometheus.CounterVec
concurrentQueries *prometheus.GaugeVec
requestQueueSize *prometheus.GaugeVec
userQueueOverflow *prometheus.CounterVec
clusterUserQueueOverflow *prometheus.CounterVec
requestBodyBytes *prometheus.CounterVec
responseBodyBytes *prometheus.CounterVec
cacheFailedInsert *prometheus.CounterVec
cacheCorruptedFetch *prometheus.CounterVec
cacheHit *prometheus.CounterVec
cacheMiss *prometheus.CounterVec
cacheSize *prometheus.GaugeVec
cacheItems *prometheus.GaugeVec
cacheSkipped *prometheus.CounterVec
requestDuration *prometheus.SummaryVec
proxiedResponseDuration *prometheus.SummaryVec
cachedResponseDuration *prometheus.SummaryVec
canceledRequest *prometheus.CounterVec
cacheHitFromConcurrentQueries *prometheus.CounterVec
cacheMissFromConcurrentQueries *prometheus.CounterVec
killedRequests *prometheus.CounterVec
timeoutRequest *prometheus.CounterVec
configSuccess prometheus.Gauge
configSuccessTime prometheus.Gauge
badRequest prometheus.Counter
retryRequest *prometheus.CounterVec
)
func initMetrics(cfg *config.Config) {
namespace := cfg.Server.Metrics.Namespace
statusCodes = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "status_codes_total",
Help: "Distribution by status codes",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node", "code"},
)
requestSum = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "request_sum_total",
Help: "Total number of sent requests",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
requestSuccess = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "request_success_total",
Help: "Total number of sent success requests",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
limitExcess = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "concurrent_limit_excess_total",
Help: "Total number of max_concurrent_queries excess",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
concurrentQueries = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: namespace,
Name: "concurrent_queries",
Help: "The number of concurrent queries at current time",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
requestQueueSize = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: namespace,
Name: "request_queue_size",
Help: "Request queue sizes at the current time",
},
[]string{"user", "cluster", "cluster_user"},
)
userQueueOverflow = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "user_queue_overflow_total",
Help: "The number of overflows for per-user request queues",
},
[]string{"user", "cluster", "cluster_user"},
)
clusterUserQueueOverflow = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "cluster_user_queue_overflow_total",
Help: "The number of overflows for per-cluster_user request queues",
},
[]string{"user", "cluster", "cluster_user"},
)
requestBodyBytes = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "request_body_bytes_total",
Help: "The amount of bytes read from request bodies",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
responseBodyBytes = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "response_body_bytes_total",
Help: "The amount of bytes written to response bodies",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
cacheFailedInsert = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "cache_insertion_failures_total",
Help: "The number of insertion in the cache that didn't work out",
},
[]string{"cache", "user", "cluster", "cluster_user"},
)
cacheCorruptedFetch = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "cache_get_corrutpion_total",
Help: "The number of time a data fetching from redis was corrupted",
},
[]string{"cache", "user", "cluster", "cluster_user"},
)
cacheHit = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "cache_hits_total",
Help: "The amount of cache hits",
},
[]string{"cache", "user", "cluster", "cluster_user"},
)
cacheMiss = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "cache_miss_total",
Help: "The amount of cache misses",
},
[]string{"cache", "user", "cluster", "cluster_user"},
)
cacheSize = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: namespace,
Name: "cache_size",
Help: "Cache size at the current time",
},
[]string{"cache"},
)
cacheItems = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: namespace,
Name: "cache_items",
Help: "Cache items at the current time",
},
[]string{"cache"},
)
cacheSkipped = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "cache_payloadsize_too_big_total",
Help: "The amount of too big payloads to be cached",
},
[]string{"cache", "user", "cluster", "cluster_user"},
)
requestDuration = prometheus.NewSummaryVec(
prometheus.SummaryOpts{
Namespace: namespace,
Name: "request_duration_seconds",
Help: "Request duration. Includes possible wait time in the queue",
Objectives: map[float64]float64{0.5: 1e-1, 0.9: 1e-2, 0.99: 1e-3, 0.999: 1e-4, 1: 1e-5},
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
proxiedResponseDuration = prometheus.NewSummaryVec(
prometheus.SummaryOpts{
Namespace: namespace,
Name: "proxied_response_duration_seconds",
Help: "Response duration proxied from clickhouse",
Objectives: map[float64]float64{0.5: 1e-1, 0.9: 1e-2, 0.99: 1e-3, 0.999: 1e-4, 1: 1e-5},
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
cachedResponseDuration = prometheus.NewSummaryVec(
prometheus.SummaryOpts{
Namespace: namespace,
Name: "cached_response_duration_seconds",
Help: "Response duration served from the cache",
Objectives: map[float64]float64{0.5: 1e-1, 0.9: 1e-2, 0.99: 1e-3, 0.999: 1e-4, 1: 1e-5},
},
[]string{"cache", "user", "cluster", "cluster_user"},
)
canceledRequest = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "canceled_request_total",
Help: "The number of requests canceled by remote client",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
cacheHitFromConcurrentQueries = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "cache_hit_concurrent_query_total",
Help: "The amount of cache hits after having awaited concurrently executed queries",
},
[]string{"cache", "user", "cluster", "cluster_user"},
)
cacheMissFromConcurrentQueries = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "cache_miss_concurrent_query_total",
Help: "The amount of cache misses, even if previously reported as queries available in the cache, after having awaited concurrently executed queries",
},
[]string{"cache", "user", "cluster", "cluster_user"},
)
killedRequests = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "killed_request_total",
Help: "The number of requests killed by proxy",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
timeoutRequest = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "timeout_request_total",
Help: "The number of timed out requests",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
configSuccess = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Name: "config_last_reload_successful",
Help: "Whether the last configuration reload attempt was successful.",
})
configSuccessTime = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Name: "config_last_reload_success_timestamp_seconds",
Help: "Timestamp of the last successful configuration reload.",
})
badRequest = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Name: "bad_requests_total",
Help: "Total number of unsupported requests",
})
retryRequest = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Name: "retry_request_total",
Help: "The number of retry requests",
},
[]string{"user", "cluster", "cluster_user", "replica", "cluster_node"},
)
}
func registerMetrics(cfg *config.Config) {
topology.RegisterMetrics(cfg)
initMetrics(cfg)
prometheus.MustRegister(statusCodes, requestSum, requestSuccess,
limitExcess, concurrentQueries,
requestQueueSize, userQueueOverflow, clusterUserQueueOverflow,
requestBodyBytes, responseBodyBytes, cacheFailedInsert, cacheCorruptedFetch,
cacheHit, cacheMiss, cacheSize, cacheItems, cacheSkipped,
requestDuration, proxiedResponseDuration, cachedResponseDuration,
canceledRequest, timeoutRequest,
configSuccess, configSuccessTime, badRequest, retryRequest)
}
package main
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/contentsquare/chproxy/cache"
"github.com/contentsquare/chproxy/config"
"github.com/contentsquare/chproxy/internal/topology"
"github.com/contentsquare/chproxy/log"
"github.com/prometheus/client_golang/prometheus"
)
// tmpDir temporary path to store ongoing queries results
const tmpDir = "/tmp"
// failedTransactionPrefix prefix added to the failed reason for concurrent queries registry
const failedTransactionPrefix = "[concurrent query failed]"
type reverseProxy struct {
rp *httputil.ReverseProxy
// configLock serializes access to applyConfig.
// It protects reload* fields.
configLock sync.Mutex
reloadSignal chan struct{}
reloadWG sync.WaitGroup
// lock protects users, clusters and caches.
// RWMutex enables concurrent access to getScope.
lock sync.RWMutex
users map[string]*user
clusters map[string]*cluster
caches map[string]*cache.AsyncCache
hasWildcarded bool
maxIdleConns int
maxIdleConnsPerHost int
}
func newReverseProxy(cfgCp *config.ConnectionPool) *reverseProxy {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
MaxIdleConns: cfgCp.MaxIdleConns,
MaxIdleConnsPerHost: cfgCp.MaxIdleConnsPerHost,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
return &reverseProxy{
rp: &httputil.ReverseProxy{
Director: func(*http.Request) {},
Transport: transport,
// Suppress error logging in ReverseProxy, since all the errors
// are handled and logged in the code below.
ErrorLog: log.NilLogger,
},
reloadSignal: make(chan struct{}),
reloadWG: sync.WaitGroup{},
maxIdleConns: cfgCp.MaxIdleConnsPerHost,
maxIdleConnsPerHost: cfgCp.MaxIdleConnsPerHost,
}
}
func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
startTime := time.Now()
s, status, err := rp.getScope(req)
if err != nil {
q := getQuerySnippet(req)
err = fmt.Errorf("%q: %w; query: %q", req.RemoteAddr, err, q)
respondWith(rw, err, status)
return
}
// WARNING: don't use s.labels before s.incQueued,
// since `replica` and `cluster_node` may change inside incQueued.
if err := s.incQueued(); err != nil {
limitExcess.With(s.labels).Inc()
q := getQuerySnippet(req)
err = fmt.Errorf("%s: %w; query: %q", s, err, q)
respondWith(rw, err, http.StatusTooManyRequests)
return
}
defer s.dec()
log.Debugf("%s: request start", s)
requestSum.With(s.labels).Inc()
if s.user.allowCORS {
origin := req.Header.Get("Origin")
if len(origin) == 0 {
origin = "*"
}
rw.Header().Set("Access-Control-Allow-Origin", origin)
}
req.Body = &statReadCloser{
ReadCloser: req.Body,
bytesRead: requestBodyBytes.With(s.labels),
}
srw := &statResponseWriter{
ResponseWriter: rw,
bytesWritten: responseBodyBytes.With(s.labels),
}
req, origParams := s.decorateRequest(req)
// wrap body into cachedReadCloser, so we could obtain the original
// request on error.
req.Body = &cachedReadCloser{
ReadCloser: req.Body,
}
// publish session_id if needed
if s.sessionId != "" {
rw.Header().Set("X-ClickHouse-Server-Session-Id", s.sessionId)
}
q, shouldReturnFromCache, err := shouldRespondFromCache(s, origParams, req)
if err != nil {
respondWith(srw, err, http.StatusBadRequest)
return
}
if shouldReturnFromCache {
rp.serveFromCache(s, srw, req, origParams, q)
} else {
rp.proxyRequest(s, srw, srw, req)
}
// It is safe calling getQuerySnippet here, since the request
// has been already read in proxyRequest or serveFromCache.
query := getQuerySnippet(req)
if srw.statusCode == http.StatusOK {
requestSuccess.With(s.labels).Inc()
log.Debugf("%s: request success; query: %q; Method: %s; URL: %q", s, query, req.Method, req.URL.String())
} else {
log.Debugf("%s: request failure: non-200 status code %d; query: %q; Method: %s; URL: %q", s, srw.statusCode, query, req.Method, req.URL.String())
}
statusCodes.With(
prometheus.Labels{
"user": s.user.name,
"cluster": s.cluster.name,
"cluster_user": s.clusterUser.name,
"replica": s.host.ReplicaName(),
"cluster_node": s.host.Host(),
"code": strconv.Itoa(srw.statusCode),
},
).Inc()
since := time.Since(startTime).Seconds()
requestDuration.With(s.labels).Observe(since)
}
func shouldRespondFromCache(s *scope, origParams url.Values, req *http.Request) ([]byte, bool, error) {
if s.user.cache == nil || s.user.cache.Cache == nil {
return nil, false, nil
}
noCache := origParams.Get("no_cache")
if noCache == "1" || noCache == "true" {
return nil, false, nil
}
q, err := getFullQuery(req)
if err != nil {
return nil, false, fmt.Errorf("%s: cannot read query: %w", s, err)
}
return q, canCacheQuery(q), nil
}
func executeWithRetry(
ctx context.Context,
s *scope,
maxRetry int,
rp func(http.ResponseWriter, *http.Request),
rw ResponseWriterWithCode,
srw StatResponseWriter,
req *http.Request,
monitorDuration func(float64),
monitorRetryRequestInc func(prometheus.Labels),
) (float64, error) {
startTime := time.Now()
var since float64
// keep the request body
body, err := io.ReadAll(req.Body)
req.Body.Close()
if err != nil {
since = time.Since(startTime).Seconds()
return since, err
}
numRetry := 0
for {
// update body
req.Body = io.NopCloser(bytes.NewBuffer(body))
req.Body.Close()
rp(rw, req)
err := ctx.Err()
if err != nil {
since = time.Since(startTime).Seconds()
return since, err
}
// The request has been successfully proxied.
srw.SetStatusCode(rw.StatusCode())
// StatusBadGateway response is returned by http.ReverseProxy when
// it cannot establish connection to remote host.
if rw.StatusCode() == http.StatusBadGateway {
log.Debugf("the invalid host is: %s", s.host)
s.host.Penalize()
// comment s.host.dec() line to avoid double increment; issue #322
// s.host.dec()
s.host.SetIsActive(false)
nextHost := s.cluster.getHost()
// The query could be retried if it has no stickiness to a certain server
if numRetry < maxRetry && nextHost.IsActive() && s.sessionId == "" {
// the query execution has been failed
monitorRetryRequestInc(s.labels)
currentHost := s.host
// decrement the current failed host counter and increment the new host
// as for the end of the requests we will close the scope and in that closed scope
// decrement the new host PR - https://github.com/ContentSquare/chproxy/pull/357
if currentHost != nextHost {
currentHost.DecrementConnections()
nextHost.IncrementConnections()
}
// update host
s.host = nextHost
req.URL.Host = s.host.Host()
req.URL.Scheme = s.host.Scheme()
log.Debugf("the valid host is: %s", s.host)
} else {
since = time.Since(startTime).Seconds()
monitorDuration(since)
q := getQuerySnippet(req)
err1 := fmt.Errorf("%s: cannot reach %s; query: %q", s, s.host.Host(), q)
respondWith(srw, err1, srw.StatusCode())
break
}
} else {
since = time.Since(startTime).Seconds()
break
}
numRetry++
}
return since, nil
}
// proxyRequest proxies the given request to clickhouse and sends response
// to rw.
//
// srw is required only for setting non-200 status codes on timeouts
// or on client connection disconnects.
func (rp *reverseProxy) proxyRequest(s *scope, rw ResponseWriterWithCode, srw *statResponseWriter, req *http.Request) {
// wrap body into cachedReadCloser, so we could obtain the original
// request on error.
if _, ok := req.Body.(*cachedReadCloser); !ok {
req.Body = &cachedReadCloser{
ReadCloser: req.Body,
}
}
timeout, timeoutErrMsg := s.getTimeoutWithErrMsg()
ctx := context.Background()
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// Cancel the ctx if client closes the remote connection,
// so the proxied query may be killed instantly.
ctx, ctxCancel := listenToCloseNotify(ctx, rw)
defer ctxCancel()
req = req.WithContext(ctx)
startTime := time.Now()
executeDuration, err := executeWithRetry(ctx, s, s.cluster.retryNumber, rp.rp.ServeHTTP, rw, srw, req, func(duration float64) {
proxiedResponseDuration.With(s.labels).Observe(duration)
}, func(labels prometheus.Labels) { retryRequest.With(labels).Inc() })
switch {
case err == nil:
return
case errors.Is(err, context.Canceled):
canceledRequest.With(s.labels).Inc()
q := getQuerySnippet(req)
log.Debugf("%s: remote client closed the connection in %s; query: %q", s, time.Since(startTime), q)
if err := s.killQuery(); err != nil {
log.Errorf("%s: cannot kill query: %s; query: %q", s, err, q)
}
srw.statusCode = 499 // See https://httpstatuses.com/499 .
case errors.Is(err, context.DeadlineExceeded):
timeoutRequest.With(s.labels).Inc()
// Penalize host with the timed out query, because it may be overloaded.
s.host.Penalize()
q := getQuerySnippet(req)
log.Debugf("%s: query timeout in %f; query: %q", s, executeDuration, q)
if err := s.killQuery(); err != nil {
log.Errorf("%s: cannot kill query: %s; query: %q", s, err, q)
}
err = fmt.Errorf("%s: %w; query: %q", s, timeoutErrMsg, q)
respondWith(rw, err, http.StatusGatewayTimeout)
srw.statusCode = http.StatusGatewayTimeout
default:
panic(fmt.Sprintf("BUG: context.Context.Err() returned unexpected error: %s", err))
}
}
func listenToCloseNotify(ctx context.Context, rw ResponseWriterWithCode) (context.Context, context.CancelFunc) {
// Cancel the ctx if client closes the remote connection,
// so the proxied query may be killed instantly.
ctx, ctxCancel := context.WithCancel(ctx)
// rw must implement http.CloseNotifier.
rwc, ok := rw.(http.CloseNotifier)
if !ok {
panic("BUG: the wrapped ResponseWriter must implement http.CloseNotifier")
}
ch := rwc.CloseNotify()
go func() {
select {
case <-ch:
ctxCancel()
case <-ctx.Done():
}
}()
return ctx, ctxCancel
}
//nolint:cyclop //TODO refactor this method, most likely requires some work.
func (rp *reverseProxy) serveFromCache(s *scope, srw *statResponseWriter, req *http.Request, origParams url.Values, q []byte) {
labels := makeCacheLabels(s)
key := newCacheKey(s, origParams, q, req)
startTime := time.Now()
userCache := s.user.cache
// Try to serve from cache
cachedData, err := userCache.Get(key)
if err == nil {
// The response has been successfully served from cache.
defer cachedData.Data.Close()
cacheHit.With(labels).Inc()
cachedResponseDuration.With(labels).Observe(time.Since(startTime).Seconds())
log.Debugf("%s: cache hit", s)
_ = RespondWithData(srw, cachedData.Data, cachedData.ContentMetadata, cachedData.Ttl, XCacheHit, http.StatusOK, labels)
return
}
// Await for potential result from concurrent query
transactionStatus, err := userCache.AwaitForConcurrentTransaction(key)
if err != nil {
// log and continue processing
log.Errorf("failed to await for concurrent transaction due to: %v", err)
} else {
if transactionStatus.State.IsCompleted() {
cachedData, err := userCache.Get(key)
if err == nil {
defer cachedData.Data.Close()
_ = RespondWithData(srw, cachedData.Data, cachedData.ContentMetadata, cachedData.Ttl, XCacheHit, http.StatusOK, labels)
cacheHitFromConcurrentQueries.With(labels).Inc()
log.Debugf("%s: cache hit after awaiting concurrent query", s)
return
} else {
cacheMissFromConcurrentQueries.With(labels).Inc()
log.Debugf("%s: cache miss after awaiting concurrent query", s)
}
} else if transactionStatus.State.IsFailed() {
respondWith(srw, fmt.Errorf(transactionStatus.FailReason), http.StatusInternalServerError)
return
}
}
// The response wasn't found in the cache.
// Request it from clickhouse.
tmpFileRespWriter, err := cache.NewTmpFileResponseWriter(srw, tmpDir)
if err != nil {
err = fmt.Errorf("%s: %w; query: %q", s, err, q)
respondWith(srw, err, http.StatusInternalServerError)
return
}
defer tmpFileRespWriter.Close()
// Initialise transaction
err = userCache.Create(key)
if err != nil {
log.Errorf("%s: %s; query: %q - failed to register transaction", s, err, q)
}
// proxy request and capture response along with headers to [[TmpFileResponseWriter]]
rp.proxyRequest(s, tmpFileRespWriter, srw, req)
contentEncoding := tmpFileRespWriter.GetCapturedContentEncoding()
contentType := tmpFileRespWriter.GetCapturedContentType()
contentLength, err := tmpFileRespWriter.GetCapturedContentLength()
if err != nil {
log.Errorf("%s: %s; query: %q - failed to get contentLength of query", s, err, q)
respondWith(srw, err, http.StatusInternalServerError)
return
}
reader, err := tmpFileRespWriter.Reader()
if err != nil {
log.Errorf("%s: %s; query: %q - failed to get Reader from tmp file", s, err, q)
respondWith(srw, err, http.StatusInternalServerError)
return
}
contentMetadata := cache.ContentMetadata{Length: contentLength, Encoding: contentEncoding, Type: contentType}
statusCode := tmpFileRespWriter.StatusCode()
if statusCode != http.StatusOK || s.canceled {
// Do not cache non-200 or cancelled responses.
// Restore the original status code by proxyRequest if it was set.
if srw.statusCode != 0 {
tmpFileRespWriter.WriteHeader(srw.statusCode)
}
errString, err := toString(reader)
if err != nil {
log.Errorf("%s failed to get error reason: %s", s, err.Error())
}
errReason := fmt.Sprintf("%s %s", failedTransactionPrefix, errString)
rp.completeTransaction(s, statusCode, userCache, key, q, errReason)
// we need to reset the offset since the reader of tmpFileRespWriter was already
// consumed in RespondWithData(...)
err = tmpFileRespWriter.ResetFileOffset()
if err != nil {
err = fmt.Errorf("%s: %w; query: %q", s, err, q)
respondWith(srw, err, http.StatusInternalServerError)
return
}
err = RespondWithData(srw, reader, contentMetadata, 0*time.Second, XCacheMiss, statusCode, labels)
if err != nil {
err = fmt.Errorf("%s: %w; query: %q", s, err, q)
respondWith(srw, err, http.StatusInternalServerError)
}
} else {
// Do not cache responses greater than max payload size.
if contentLength > int64(s.user.cache.MaxPayloadSize) {
cacheSkipped.With(labels).Inc()
log.Infof("%s: Request will not be cached. Content length (%d) is greater than max payload size (%d)", s, contentLength, s.user.cache.MaxPayloadSize)
rp.completeTransaction(s, statusCode, userCache, key, q, "")
err = RespondWithData(srw, reader, contentMetadata, 0*time.Second, XCacheNA, tmpFileRespWriter.StatusCode(), labels)
if err != nil {
err = fmt.Errorf("%s: %w; query: %q", s, err, q)
respondWith(srw, err, http.StatusInternalServerError)
}
return
}
cacheMiss.With(labels).Inc()
log.Debugf("%s: cache miss", s)
expiration, err := userCache.Put(reader, contentMetadata, key)
if err != nil {
cacheFailedInsert.With(labels).Inc()
log.Errorf("%s: %s; query: %q - failed to put response in the cache", s, err, q)
}
rp.completeTransaction(s, statusCode, userCache, key, q, "")
// we need to reset the offset since the reader of tmpFileRespWriter was already
// consumed in RespondWithData(...)
err = tmpFileRespWriter.ResetFileOffset()
if err != nil {
err = fmt.Errorf("%s: %w; query: %q", s, err, q)
respondWith(srw, err, http.StatusInternalServerError)
return
}
err = RespondWithData(srw, reader, contentMetadata, expiration, XCacheMiss, statusCode, labels)
if err != nil {
err = fmt.Errorf("%s: %w; query: %q", s, err, q)
respondWith(srw, err, http.StatusInternalServerError)
return
}
}
}
func makeCacheLabels(s *scope) prometheus.Labels {
// Do not store `replica` and `cluster_node` in labels, since they have
// no sense for cache metrics.
return prometheus.Labels{
"cache": s.user.cache.Name(),
"user": s.labels["user"],
"cluster": s.labels["cluster"],
"cluster_user": s.labels["cluster_user"],
}
}
func newCacheKey(s *scope, origParams url.Values, q []byte, req *http.Request) *cache.Key {
var userParamsHash uint32
if s.user.params != nil {
userParamsHash = s.user.params.key
}
queryParamsHash := calcQueryParamsHash(origParams)
credHash, err := uint32(0), error(nil)
if !s.user.cache.SharedWithAllUsers {
credHash, err = calcCredentialHash(s.clusterUser.name, s.clusterUser.password)
}
if err != nil {
log.Errorf("fail to calc hash on credentials for user %s", s.user.name)
credHash = 0
}
return cache.NewKey(
skipLeadingComments(q),
origParams,
sortHeader(req.Header.Get("Accept-Encoding")),
userParamsHash,
queryParamsHash,
credHash,
)
}
func toString(stream io.Reader) (string, error) {
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(stream)
if err != nil {
return "", err
}
return bytes.NewBuffer(buf.Bytes()).String(), nil
}
// clickhouseRecoverableStatusCodes set of recoverable http responses' status codes from Clickhouse.
// When such happens we mark transaction as completed and let concurrent query to hit another Clickhouse shard.
// possible http error codes in clickhouse (i.e: https://github.com/ClickHouse/ClickHouse/blob/master/src/Server/HTTPHandler.cpp)
var clickhouseRecoverableStatusCodes = map[int]struct{}{http.StatusServiceUnavailable: {}, http.StatusRequestTimeout: {}}
func (rp *reverseProxy) completeTransaction(s *scope, statusCode int, userCache *cache.AsyncCache, key *cache.Key,
q []byte,
failReason string,
) {
// complete successful transactions or those with empty fail reason
if statusCode < 300 || failReason == "" {
if err := userCache.Complete(key); err != nil {
log.Errorf("%s: %s; query: %q", s, err, q)
}
return
}
if _, ok := clickhouseRecoverableStatusCodes[statusCode]; ok {
if err := userCache.Complete(key); err != nil {
log.Errorf("%s: %s; query: %q", s, err, q)
}
} else {
if err := userCache.Fail(key, failReason); err != nil {
log.Errorf("%s: %s; query: %q", s, err, q)
}
}
}
func calcQueryParamsHash(origParams url.Values) uint32 {
queryParams := make(map[string]string)
for param := range origParams {
if strings.HasPrefix(param, "param_") {
queryParams[param] = origParams.Get(param)
}
}
queryParamsHash, err := calcMapHash(queryParams)
if err != nil {
log.Errorf("fail to calc hash for params %s; %s", origParams, err)
return 0
}
return queryParamsHash
}
// applyConfig applies the given cfg to reverseProxy.
//
// New config is applied only if non-nil error returned.
// Otherwise old config version is kept.
func (rp *reverseProxy) applyConfig(cfg *config.Config) error {
// configLock protects from concurrent calls to applyConfig
// by serializing such calls.
// configLock shouldn't be used in other places.
rp.configLock.Lock()
defer rp.configLock.Unlock()
clusters, err := newClusters(cfg.Clusters)
if err != nil {
return err
}
caches := make(map[string]*cache.AsyncCache, len(cfg.Caches))
defer func() {
// caches is swapped with old caches from rp.caches
// on successful config reload - see the end of reloadConfig.
for _, tmpCache := range caches {
// Speed up applyConfig by closing caches in background,
// since the process of cache closing may be lengthy
// due to cleaning.
go tmpCache.Close()
}
}()
// transactionsTimeout used for creation of transactions registry inside async cache.
// It is set to the highest configured execution time of all users to avoid setups were users use the same cache and have configured different maxExecutionTime.
// This would provoke undesired behaviour of `dogpile effect`
transactionsTimeout := config.Duration(0)
for _, user := range cfg.Users {
if user.MaxExecutionTime > transactionsTimeout {
transactionsTimeout = user.MaxExecutionTime
}
if user.IsWildcarded {
rp.hasWildcarded = true
}
}
if err := initTempCaches(caches, transactionsTimeout, cfg.Caches); err != nil {
return err
}
params, err := paramsFromConfig(cfg.ParamGroups)
if err != nil {
return err
}
profile := &usersProfile{
cfg: cfg.Users,
clusters: clusters,
caches: caches,
params: params,
}
users, err := profile.newUsers()
if err != nil {
return err
}
if err := validateNoWildcardedUserForHeartbeat(clusters, cfg.Clusters); err != nil {
return err
}
// New configs have been successfully prepared.
// Restart service goroutines with new configs.
// Stop the previous service goroutines.
close(rp.reloadSignal)
rp.reloadWG.Wait()
rp.reloadSignal = make(chan struct{})
rp.restartWithNewConfig(caches, clusters, users)
// Substitute old configs with the new configs in rp.
// All the currently running requests will continue with old configs,
// while all the new requests will use new configs.
rp.lock.Lock()
rp.clusters = clusters
rp.users = users
// Swap is needed for deferred closing of old caches.
// See the code above where new caches are created.
caches, rp.caches = rp.caches, caches
rp.lock.Unlock()
return nil
}
func initTempCaches(caches map[string]*cache.AsyncCache, transactionsTimeout config.Duration, cfg []config.Cache) error {
for _, cc := range cfg {
if _, ok := caches[cc.Name]; ok {
return fmt.Errorf("duplicate config for cache %q", cc.Name)
}
tmpCache, err := cache.NewAsyncCache(cc, time.Duration(transactionsTimeout))
if err != nil {
return err
}
caches[cc.Name] = tmpCache
}
return nil
}
func paramsFromConfig(cfg []config.ParamGroup) (map[string]*paramsRegistry, error) {
params := make(map[string]*paramsRegistry, len(cfg))
for _, p := range cfg {
if _, ok := params[p.Name]; ok {
return nil, fmt.Errorf("duplicate config for ParamGroups %q", p.Name)
}
registry, err := newParamsRegistry(p.Params)
if err != nil {
return nil, fmt.Errorf("cannot initialize params %q: %w", p.Name, err)
}
params[p.Name] = registry
}
return params, nil
}
func validateNoWildcardedUserForHeartbeat(clusters map[string]*cluster, cfg []config.Cluster) error {
for c := range cfg {
cfgcl := cfg[c]
clname := cfgcl.Name
cuname := cfgcl.ClusterUsers[0].Name
heartbeat := cfg[c].HeartBeat
cl := clusters[clname]
cu := cl.users[cuname]
if cu.isWildcarded {
if heartbeat.Request != "/ping" && len(heartbeat.User) == 0 {
return fmt.Errorf(
"`cluster.heartbeat.user ` cannot be unset for %q because a wildcarded user cannot send heartbeat",
clname,
)
}
}
}
return nil
}
func (rp *reverseProxy) restartWithNewConfig(caches map[string]*cache.AsyncCache, clusters map[string]*cluster, users map[string]*user) {
// Reset metrics from the previous configs, which may become irrelevant
// with new configs.
// Counters and Summary metrics are always relevant.
// Gauge metrics may become irrelevant if they may freeze at non-zero
// value after config reload.
topology.HostHealth.Reset()
cacheSize.Reset()
cacheItems.Reset()
// Start service goroutines with new configs.
for _, c := range clusters {
for _, r := range c.replicas {
for _, h := range r.hosts {
rp.reloadWG.Add(1)
go func(h *topology.Node) {
h.StartHeartbeat(rp.reloadSignal)
rp.reloadWG.Done()
}(h)
}
}
for _, cu := range c.users {
rp.reloadWG.Add(1)
go func(cu *clusterUser) {
cu.rateLimiter.run(rp.reloadSignal)
rp.reloadWG.Done()
}(cu)
}
}
for _, u := range users {
rp.reloadWG.Add(1)
go func(u *user) {
u.rateLimiter.run(rp.reloadSignal)
rp.reloadWG.Done()
}(u)
}
}
// refreshCacheMetrics refreshes cacheSize and cacheItems metrics.
func (rp *reverseProxy) refreshCacheMetrics() {
rp.lock.RLock()
defer rp.lock.RUnlock()
for _, c := range rp.caches {
stats := c.Stats()
labels := prometheus.Labels{
"cache": c.Name(),
}
cacheSize.With(labels).Set(float64(stats.Size))
cacheItems.With(labels).Set(float64(stats.Items))
}
}
// find user, cluster and clusterUser
// in case of wildcarded user, cluster user is crafted to use original credentials
func (rp *reverseProxy) getUser(name string, password string) (found bool, u *user, c *cluster, cu *clusterUser) {
rp.lock.RLock()
defer rp.lock.RUnlock()
found = false
u = rp.users[name]
switch {
case u != nil:
found = (u.password == password)
// existence of c and cu for toCluster is guaranteed by applyConfig
c = rp.clusters[u.toCluster]
cu = c.users[u.toUser]
case name == "" || name == defaultUser:
// default user can't work with the wildcarded feature for security reasons
found = false
case rp.hasWildcarded:
// checking if we have wildcarded users and if username matches one 3 possibles patterns
found, u, c, cu = rp.findWildcardedUserInformation(name, password)
}
return found, u, c, cu
}
func (rp *reverseProxy) findWildcardedUserInformation(name string, password string) (found bool, u *user, c *cluster, cu *clusterUser) {
// cf a validation in config.go, the names must contains either a prefix, a suffix or a wildcard
// the wildcarded user is "*"
// the wildcarded user is "*[suffix]"
// the wildcarded user is "[prefix]*"
for _, user := range rp.users {
if user.isWildcarded {
s := strings.Split(user.name, "*")
switch {
case s[0] == "" && s[1] == "":
return rp.generateWildcardedUserInformation(user, name, password)
case s[0] == "":
suffix := s[1]
if strings.HasSuffix(name, suffix) {
return rp.generateWildcardedUserInformation(user, name, password)
}
case s[1] == "":
prefix := s[0]
if strings.HasPrefix(name, prefix) {
return rp.generateWildcardedUserInformation(user, name, password)
}
}
}
}
return false, nil, nil, nil
}
func (rp *reverseProxy) generateWildcardedUserInformation(user *user, name string, password string) (found bool, u *user, c *cluster, cu *clusterUser) {
found = false
c = rp.clusters[user.toCluster]
wildcardedCu := c.users[user.toUser]
if wildcardedCu != nil {
newCU := deepCopy(wildcardedCu)
found = true
u = user
cu = newCU
cu.name = name
cu.password = password
// TODO : improve the following behavior
// the wildcarded user feature creates some side-effects on clusterUser limitations (like the max_concurrent_queries)
// because of the use of a deep copy of the clusterUser. The side effect should not be too impactful since the limitation still works on user.
// But we need this deep copy since we're changing the name & password of clusterUser and if we used the same instance for every call to chproxy,
// it could lead to security issues since a specific query run by user A on chproxy side could trigger a query in clickhouse from user B.
// Doing a clean fix would require a huge refactoring.
}
return
}
func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) {
name, password := getAuth(req)
sessionId := getSessionId(req)
sessionTimeout := getSessionTimeout(req)
var (
u *user
c *cluster
cu *clusterUser
)
found, u, c, cu := rp.getUser(name, password)
if !found {
return nil, http.StatusUnauthorized, fmt.Errorf("invalid username or password for user %q", name)
}
if u.denyHTTP && req.TLS == nil {
return nil, http.StatusForbidden, fmt.Errorf("user %q is not allowed to access via http", u.name)
}
if u.denyHTTPS && req.TLS != nil {
return nil, http.StatusForbidden, fmt.Errorf("user %q is not allowed to access via https", u.name)
}
if !u.allowedNetworks.Contains(req.RemoteAddr) {
return nil, http.StatusForbidden, fmt.Errorf("user %q is not allowed to access", u.name)
}
if !cu.allowedNetworks.Contains(req.RemoteAddr) {
return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name)
}
s := newScope(req, u, c, cu, sessionId, sessionTimeout)
q, err := getFullQuery(req)
if err != nil {
return nil, http.StatusBadRequest, fmt.Errorf("%s: cannot read query: %w", s, err)
}
s.requestPacketSize = len(q)
return s, 0, nil
}
package main
import (
"net"
"net/http"
"strings"
"github.com/contentsquare/chproxy/config"
)
const (
xForwardedForHeader = "X-Forwarded-For"
xRealIPHeader = "X-Real-Ip"
forwardedHeader = "Forwarded"
)
type ProxyHandler struct {
proxy *config.Proxy
}
func NewProxyHandler(proxy *config.Proxy) *ProxyHandler {
return &ProxyHandler{
proxy: proxy,
}
}
func (m *ProxyHandler) GetRemoteAddr(r *http.Request) string {
if m.proxy.Enable {
var addr string
if m.proxy.Header != "" {
addr = r.Header.Get(m.proxy.Header)
} else {
addr = parseDefaultProxyHeaders(r)
}
if isValidAddr(addr) {
return addr
}
}
return r.RemoteAddr
}
// isValidAddr checks if the Addr is a valid IP or IP:port.
func isValidAddr(addr string) bool {
if addr == "" {
return false
}
ip, _, err := net.SplitHostPort(addr)
if err != nil {
return net.ParseIP(addr) != nil
}
return net.ParseIP(ip) != nil
}
func parseDefaultProxyHeaders(r *http.Request) string {
var addr string
if fwd := r.Header.Get(xForwardedForHeader); fwd != "" {
addr = extractFirstMatchFromIPList(fwd)
} else if fwd := r.Header.Get(xRealIPHeader); fwd != "" {
addr = extractFirstMatchFromIPList(fwd)
} else if fwd := r.Header.Get(forwardedHeader); fwd != "" {
// See: https://tools.ietf.org/html/rfc7239.
addr = parseForwardedHeader(fwd)
}
return addr
}
func extractFirstMatchFromIPList(ipList string) string {
if ipList == "" {
return ""
}
s := strings.Index(ipList, ",")
if s == -1 {
s = len(ipList)
}
return ipList[:s]
}
func parseForwardedHeader(fwd string) string {
splits := strings.Split(fwd, ";")
if len(splits) == 0 {
return ""
}
for _, split := range splits {
trimmed := strings.TrimSpace(split)
if strings.HasPrefix(strings.ToLower(trimmed), "for=") {
forSplits := strings.Split(trimmed, ",")
if len(forSplits) == 0 {
return ""
}
addr := forSplits[0][4:]
trimmedAddr := strings.
NewReplacer("\"", "", "[", "", "]", "").
Replace(addr) // If IpV6, remove brackets and quotes.
return trimmedAddr
}
}
return ""
}
package main
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/contentsquare/chproxy/cache"
"github.com/contentsquare/chproxy/config"
"github.com/contentsquare/chproxy/internal/heartbeat"
"github.com/contentsquare/chproxy/internal/topology"
"github.com/contentsquare/chproxy/log"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/time/rate"
)
type scopeID uint64
func (sid scopeID) String() string {
return fmt.Sprintf("%08X", uint64(sid))
}
func newScopeID() scopeID {
sid := atomic.AddUint64(&nextScopeID, 1)
return scopeID(sid)
}
var nextScopeID = uint64(time.Now().UnixNano())
type scope struct {
startTime time.Time
id scopeID
host *topology.Node
cluster *cluster
user *user
clusterUser *clusterUser
sessionId string
sessionTimeout int
remoteAddr string
localAddr string
// is true when KillQuery has been called
canceled bool
labels prometheus.Labels
requestPacketSize int
}
func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int) *scope {
h := c.getHost()
if sessionId != "" {
h = c.getHostSticky(sessionId)
}
var localAddr string
if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
localAddr = addr.String()
}
s := &scope{
startTime: time.Now(),
id: newScopeID(),
host: h,
cluster: c,
user: u,
clusterUser: cu,
sessionId: sessionId,
sessionTimeout: sessionTimeout,
remoteAddr: req.RemoteAddr,
localAddr: localAddr,
labels: prometheus.Labels{
"user": u.name,
"cluster": c.name,
"cluster_user": cu.name,
"replica": h.ReplicaName(),
"cluster_node": h.Host(),
},
}
return s
}
func (s *scope) String() string {
return fmt.Sprintf("[ Id: %s; User %q(%d) proxying as %q(%d) to %q(%d); RemoteAddr: %q; LocalAddr: %q; Duration: %d μs]",
s.id,
s.user.name, s.user.queryCounter.load(),
s.clusterUser.name, s.clusterUser.queryCounter.load(),
s.host.Host(), s.host.CurrentLoad(),
s.remoteAddr, s.localAddr, time.Since(s.startTime).Nanoseconds()/1000.0)
}
//nolint:cyclop // TODO abstract user queues to reduce complexity here.
func (s *scope) incQueued() error {
if s.user.queueCh == nil && s.clusterUser.queueCh == nil {
// Request queues in the current scope are disabled.
return s.inc()
}
// Do not store `replica` and `cluster_node` in labels, since they have
// no sense for queue metrics.
labels := prometheus.Labels{
"user": s.labels["user"],
"cluster": s.labels["cluster"],
"cluster_user": s.labels["cluster_user"],
}
if s.user.queueCh != nil {
select {
case s.user.queueCh <- struct{}{}:
defer func() {
<-s.user.queueCh
}()
default:
// Per-user request queue is full.
// Give the request the last chance to run.
err := s.inc()
if err != nil {
userQueueOverflow.With(labels).Inc()
}
return err
}
}
if s.clusterUser.queueCh != nil {
select {
case s.clusterUser.queueCh <- struct{}{}:
defer func() {
<-s.clusterUser.queueCh
}()
default:
// Per-clusterUser request queue is full.
// Give the request the last chance to run.
err := s.inc()
if err != nil {
clusterUserQueueOverflow.With(labels).Inc()
}
return err
}
}
// The request has been successfully queued.
queueSize := requestQueueSize.With(labels)
queueSize.Inc()
defer queueSize.Dec()
// Try starting the request during the given duration.
sleep, deadline := s.calculateQueueDeadlineAndSleep()
return s.waitUntilAllowStart(sleep, deadline, labels)
}
func (s *scope) waitUntilAllowStart(sleep time.Duration, deadline time.Time, labels prometheus.Labels) error {
for {
err := s.inc()
if err == nil {
// The request is allowed to start.
return nil
}
dLeft := time.Until(deadline)
if dLeft <= 0 {
// Give up: the request exceeded its wait time
// in the queue :(
return err
}
// The request has dLeft remaining time to wait in the queue.
// Sleep for a bit and try starting it again.
if sleep > dLeft {
time.Sleep(dLeft)
} else {
time.Sleep(sleep)
}
var h *topology.Node
// Choose new host, since the previous one may become obsolete
// after sleeping.
if s.sessionId == "" {
h = s.cluster.getHost()
} else {
// if request has session_id, set same host
h = s.cluster.getHostSticky(s.sessionId)
}
s.host = h
s.labels["replica"] = h.ReplicaName()
s.labels["cluster_node"] = h.Host()
}
}
func (s *scope) calculateQueueDeadlineAndSleep() (time.Duration, time.Time) {
d := s.maxQueueTime()
dSleep := d / 10
if dSleep > time.Second {
dSleep = time.Second
}
if dSleep < time.Millisecond {
dSleep = time.Millisecond
}
deadline := time.Now().Add(d)
return dSleep, deadline
}
func (s *scope) inc() error {
uQueries := s.user.queryCounter.inc()
cQueries := s.clusterUser.queryCounter.inc()
var err error
if s.user.maxConcurrentQueries > 0 && uQueries > s.user.maxConcurrentQueries {
err = fmt.Errorf("limits for user %q are exceeded: max_concurrent_queries limit: %d",
s.user.name, s.user.maxConcurrentQueries)
}
if s.clusterUser.maxConcurrentQueries > 0 && cQueries > s.clusterUser.maxConcurrentQueries {
err = fmt.Errorf("limits for cluster user %q are exceeded: max_concurrent_queries limit: %d",
s.clusterUser.name, s.clusterUser.maxConcurrentQueries)
}
err2 := s.checkTokenFreeRateLimiters()
if err2 != nil {
err = err2
}
if err != nil {
s.user.queryCounter.dec()
s.clusterUser.queryCounter.dec()
// Decrement rate limiter here, so it doesn't count requests
// that didn't start due to limits overflow.
s.user.rateLimiter.dec()
s.clusterUser.rateLimiter.dec()
return err
}
s.host.IncrementConnections()
concurrentQueries.With(s.labels).Inc()
return nil
}
func (s *scope) checkTokenFreeRateLimiters() error {
var err error
uRPM := s.user.rateLimiter.inc()
cRPM := s.clusterUser.rateLimiter.inc()
// int32(xRPM) > 0 check is required to detect races when RPM
// is decremented on error below after per-minute zeroing
// in rateLimiter.run.
// These races become innocent with the given check.
if s.user.reqPerMin > 0 && int32(uRPM) > 0 && uRPM > s.user.reqPerMin {
err = fmt.Errorf("rate limit for user %q is exceeded: requests_per_minute limit: %d",
s.user.name, s.user.reqPerMin)
}
if s.clusterUser.reqPerMin > 0 && int32(cRPM) > 0 && cRPM > s.clusterUser.reqPerMin {
err = fmt.Errorf("rate limit for cluster user %q is exceeded: requests_per_minute limit: %d",
s.clusterUser.name, s.clusterUser.reqPerMin)
}
err2 := s.checkTokenFreePacketSizeRateLimiters()
if err2 != nil {
err = err2
}
return err
}
func (s *scope) checkTokenFreePacketSizeRateLimiters() error {
var err error
// reserving tokens num s.requestPacketSize
if s.user.reqPacketSizeTokensBurst > 0 {
tl := s.user.reqPacketSizeTokenLimiter
ok := tl.AllowN(time.Now(), s.requestPacketSize)
if !ok {
err = fmt.Errorf("limits for user %q is exceeded: request_packet_size_tokens_burst limit: %d",
s.user.name, s.user.reqPacketSizeTokensBurst)
}
}
if s.clusterUser.reqPacketSizeTokensBurst > 0 {
tl := s.clusterUser.reqPacketSizeTokenLimiter
ok := tl.AllowN(time.Now(), s.requestPacketSize)
if !ok {
err = fmt.Errorf("limits for cluster user %q is exceeded: request_packet_size_tokens_burst limit: %d",
s.clusterUser.name, s.clusterUser.reqPacketSizeTokensBurst)
}
}
return err
}
func (s *scope) dec() {
// There is no need in ratelimiter.dec here, since the rate limiter
// is automatically zeroed every minute in rateLimiter.run.
s.user.queryCounter.dec()
s.clusterUser.queryCounter.dec()
s.host.DecrementConnections()
concurrentQueries.With(s.labels).Dec()
}
const killQueryTimeout = time.Second * 30
func (s *scope) killQuery() error {
log.Debugf("killing the query with query_id=%s", s.id)
killedRequests.With(s.labels).Inc()
s.canceled = true
query := fmt.Sprintf("KILL QUERY WHERE query_id = '%s'", s.id)
r := strings.NewReader(query)
addr := s.host.String()
req, err := http.NewRequest("POST", addr, r)
if err != nil {
return fmt.Errorf("error while creating kill query request to %s: %w", addr, err)
}
ctx, cancel := context.WithTimeout(context.Background(), killQueryTimeout)
defer cancel()
req = req.WithContext(ctx)
// send request as kill_query_user
userName := s.cluster.killQueryUserName
if len(userName) == 0 {
userName = defaultUser
}
req.SetBasicAuth(userName, s.cluster.killQueryUserPassword)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("error while executing clickhouse query %q at %q: %w", query, addr, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("unexpected status code returned from query %q at %q: %d. Response body: %q",
query, addr, resp.StatusCode, responseBody)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("cannot read response body for the query %q: %w", query, err)
}
log.Debugf("killed the query with query_id=%s; respBody: %q", s.id, respBody)
return nil
}
// allowedParams contains query args allowed to be proxied.
// See https://clickhouse.com/docs/en/operations/settings/
//
// All the other params passed via query args are stripped before
// proxying the request. This is for the sake of security.
var allowedParams = []string{
"query",
"database",
"default_format",
// if `compress=1`, CH will compress the data it sends you
"compress",
// if `decompress=1` , CH will decompress the same data that you pass in the POST method
"decompress",
// compress the result if the client over HTTP said that it understands data compressed by gzip or deflate.
"enable_http_compression",
// limit on the number of rows in the result
"max_result_rows",
// whether to count extreme values
"extremes",
// what to do if the volume of the result exceeds one of the limits
"result_overflow_mode",
// session stickiness
"session_id",
// session timeout
"session_timeout",
}
// This regexp must match params needed to describe a way to use external data
// @see https://clickhouse.yandex/docs/en/table_engines/external_data/
var externalDataParams = regexp.MustCompile(`(_types|_structure|_format)$`)
func (s *scope) decorateRequest(req *http.Request) (*http.Request, url.Values) {
// Make new params to purify URL.
params := make(url.Values)
// Set user params
if s.user.params != nil {
for _, param := range s.user.params.params {
params.Set(param.Key, param.Value)
}
}
// Keep allowed params.
origParams := req.URL.Query()
for _, param := range allowedParams {
val := origParams.Get(param)
if len(val) > 0 {
params.Set(param, val)
}
}
// Keep parametrized queries params
for param := range origParams {
if strings.HasPrefix(param, "param_") {
params.Set(param, origParams.Get(param))
}
}
// Keep external_data params
if req.Method == "POST" {
s.decoratePostRequest(req, origParams, params)
}
// Set query_id as scope_id to have possibility to kill query if needed.
params.Set("query_id", s.id.String())
// Set session_timeout an idle timeout for session
params.Set("session_timeout", strconv.Itoa(s.sessionTimeout))
req.URL.RawQuery = params.Encode()
// Rewrite possible previous Basic Auth and send request
// as cluster user.
req.SetBasicAuth(s.clusterUser.name, s.clusterUser.password)
// Delete possible X-ClickHouse headers,
// it is not allowed to use X-ClickHouse HTTP headers and other authentication methods simultaneously
req.Header.Del("X-ClickHouse-User")
req.Header.Del("X-ClickHouse-Key")
// Send request to the chosen host from cluster.
req.URL.Scheme = s.host.Scheme()
req.URL.Host = s.host.Host()
// Extend ua with additional info, so it may be queried
// via system.query_log.http_user_agent.
ua := fmt.Sprintf("RemoteAddr: %s; LocalAddr: %s; CHProxy-User: %s; CHProxy-ClusterUser: %s; %s",
s.remoteAddr, s.localAddr, s.user.name, s.clusterUser.name, req.UserAgent())
req.Header.Set("User-Agent", ua)
return req, origParams
}
func (s *scope) decoratePostRequest(req *http.Request, origParams, params url.Values) {
ct := req.Header.Get("Content-Type")
if strings.Contains(ct, "multipart/form-data") {
for key := range origParams {
if externalDataParams.MatchString(key) {
params.Set(key, origParams.Get(key))
}
}
// disable cache for external_data queries
origParams.Set("no_cache", "1")
log.Debugf("external data params detected - cache will be disabled")
}
}
func (s *scope) getTimeoutWithErrMsg() (time.Duration, error) {
var (
timeout time.Duration
timeoutErrMsg error
)
if s.user.maxExecutionTime > 0 {
timeout = s.user.maxExecutionTime
timeoutErrMsg = fmt.Errorf("timeout for user %q exceeded: %v", s.user.name, timeout)
}
if timeout == 0 || (s.clusterUser.maxExecutionTime > 0 && s.clusterUser.maxExecutionTime < timeout) {
timeout = s.clusterUser.maxExecutionTime
timeoutErrMsg = fmt.Errorf("timeout for cluster user %q exceeded: %v", s.clusterUser.name, timeout)
}
return timeout, timeoutErrMsg
}
func (s *scope) maxQueueTime() time.Duration {
d := s.user.maxQueueTime
if d <= 0 || s.clusterUser.maxQueueTime > 0 && s.clusterUser.maxQueueTime < d {
d = s.clusterUser.maxQueueTime
}
if d <= 0 {
// Default queue time.
d = 10 * time.Second
}
return d
}
type paramsRegistry struct {
// key is a hashed concatenation of the params list
key uint32
params []config.Param
}
func newParamsRegistry(params []config.Param) (*paramsRegistry, error) {
if len(params) == 0 {
return nil, fmt.Errorf("params can't be empty")
}
paramsMap := make(map[string]string, len(params))
for _, k := range params {
paramsMap[k.Key] = k.Value
}
key, err := calcMapHash(paramsMap)
if err != nil {
return nil, err
}
return ¶msRegistry{
key: key,
params: params,
}, nil
}
type user struct {
name string
password string
toCluster string
toUser string
maxConcurrentQueries uint32
queryCounter counter
maxExecutionTime time.Duration
reqPerMin uint32
rateLimiter rateLimiter
reqPacketSizeTokenLimiter *rate.Limiter
reqPacketSizeTokensBurst config.ByteSize
reqPacketSizeTokensRate config.ByteSize
queueCh chan struct{}
maxQueueTime time.Duration
allowedNetworks config.Networks
denyHTTP bool
denyHTTPS bool
allowCORS bool
isWildcarded bool
cache *cache.AsyncCache
params *paramsRegistry
}
type usersProfile struct {
cfg []config.User
clusters map[string]*cluster
caches map[string]*cache.AsyncCache
params map[string]*paramsRegistry
}
func (up usersProfile) newUsers() (map[string]*user, error) {
users := make(map[string]*user, len(up.cfg))
for _, u := range up.cfg {
if _, ok := users[u.Name]; ok {
return nil, fmt.Errorf("duplicate config for user %q", u.Name)
}
tmpU, err := up.newUser(u)
if err != nil {
return nil, fmt.Errorf("cannot initialize user %q: %w", u.Name, err)
}
users[u.Name] = tmpU
}
return users, nil
}
func (up usersProfile) newUser(u config.User) (*user, error) {
c, ok := up.clusters[u.ToCluster]
if !ok {
return nil, fmt.Errorf("unknown `to_cluster` %q", u.ToCluster)
}
var cu *clusterUser
if cu, ok = c.users[u.ToUser]; !ok {
return nil, fmt.Errorf("unknown `to_user` %q in cluster %q", u.ToUser, u.ToCluster)
} else if u.IsWildcarded {
// a wildcarded user is mapped to this cluster user
// used to check if a proper user to send heartbeat exists
cu.isWildcarded = true
}
var queueCh chan struct{}
if u.MaxQueueSize > 0 {
queueCh = make(chan struct{}, u.MaxQueueSize)
}
var cc *cache.AsyncCache
if len(u.Cache) > 0 {
cc = up.caches[u.Cache]
if cc == nil {
return nil, fmt.Errorf("unknown `cache` %q", u.Cache)
}
}
var params *paramsRegistry
if len(u.Params) > 0 {
params = up.params[u.Params]
if params == nil {
return nil, fmt.Errorf("unknown `params` %q", u.Params)
}
}
return &user{
name: u.Name,
password: u.Password,
toCluster: u.ToCluster,
toUser: u.ToUser,
maxConcurrentQueries: u.MaxConcurrentQueries,
maxExecutionTime: time.Duration(u.MaxExecutionTime),
reqPerMin: u.ReqPerMin,
queueCh: queueCh,
maxQueueTime: time.Duration(u.MaxQueueTime),
reqPacketSizeTokenLimiter: rate.NewLimiter(rate.Limit(u.ReqPacketSizeTokensRate), int(u.ReqPacketSizeTokensBurst)),
reqPacketSizeTokensBurst: u.ReqPacketSizeTokensBurst,
reqPacketSizeTokensRate: u.ReqPacketSizeTokensRate,
allowedNetworks: u.AllowedNetworks,
denyHTTP: u.DenyHTTP,
denyHTTPS: u.DenyHTTPS,
allowCORS: u.AllowCORS,
isWildcarded: u.IsWildcarded,
cache: cc,
params: params,
}, nil
}
type clusterUser struct {
name string
password string
maxConcurrentQueries uint32
queryCounter counter
maxExecutionTime time.Duration
reqPerMin uint32
rateLimiter rateLimiter
queueCh chan struct{}
maxQueueTime time.Duration
reqPacketSizeTokenLimiter *rate.Limiter
reqPacketSizeTokensBurst config.ByteSize
reqPacketSizeTokensRate config.ByteSize
allowedNetworks config.Networks
isWildcarded bool
}
func deepCopy(cu *clusterUser) *clusterUser {
var queueCh chan struct{}
if cu.maxQueueTime > 0 {
queueCh = make(chan struct{}, cu.maxQueueTime)
}
return &clusterUser{
name: cu.name,
password: cu.password,
maxConcurrentQueries: cu.maxConcurrentQueries,
maxExecutionTime: time.Duration(cu.maxExecutionTime),
reqPerMin: cu.reqPerMin,
queueCh: queueCh,
maxQueueTime: time.Duration(cu.maxQueueTime),
allowedNetworks: cu.allowedNetworks,
}
}
func newClusterUser(cu config.ClusterUser) *clusterUser {
var queueCh chan struct{}
if cu.MaxQueueSize > 0 {
queueCh = make(chan struct{}, cu.MaxQueueSize)
}
return &clusterUser{
name: cu.Name,
password: cu.Password,
maxConcurrentQueries: cu.MaxConcurrentQueries,
maxExecutionTime: time.Duration(cu.MaxExecutionTime),
reqPerMin: cu.ReqPerMin,
reqPacketSizeTokenLimiter: rate.NewLimiter(rate.Limit(cu.ReqPacketSizeTokensRate), int(cu.ReqPacketSizeTokensBurst)),
reqPacketSizeTokensBurst: cu.ReqPacketSizeTokensBurst,
reqPacketSizeTokensRate: cu.ReqPacketSizeTokensRate,
queueCh: queueCh,
maxQueueTime: time.Duration(cu.MaxQueueTime),
allowedNetworks: cu.AllowedNetworks,
}
}
type replica struct {
cluster *cluster
name string
hosts []*topology.Node
nextHostIdx uint32
}
func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c *cluster) ([]*replica, error) {
if len(nodes) > 0 {
// No replicas, just flat nodes. Create default replica
// containing all the nodes.
r := &replica{
cluster: c,
name: "default",
}
hosts, err := newNodes(nodes, scheme, r)
if err != nil {
return nil, err
}
r.hosts = hosts
return []*replica{r}, nil
}
replicas := make([]*replica, len(replicasCfg))
for i, rCfg := range replicasCfg {
r := &replica{
cluster: c,
name: rCfg.Name,
}
hosts, err := newNodes(rCfg.Nodes, scheme, r)
if err != nil {
return nil, fmt.Errorf("cannot initialize replica %q: %w", rCfg.Name, err)
}
r.hosts = hosts
replicas[i] = r
}
return replicas, nil
}
func newNodes(nodes []string, scheme string, r *replica) ([]*topology.Node, error) {
hosts := make([]*topology.Node, len(nodes))
for i, node := range nodes {
addr, err := url.Parse(fmt.Sprintf("%s://%s", scheme, node))
if err != nil {
return nil, fmt.Errorf("cannot parse `node` %q with `scheme` %q: %w", node, scheme, err)
}
hosts[i] = topology.NewNode(addr, r.cluster.heartBeat, r.cluster.name, r.name)
}
return hosts, nil
}
func (r *replica) isActive() bool {
// The replica is active if at least a single host is active.
for _, h := range r.hosts {
if h.IsActive() {
return true
}
}
return false
}
func (r *replica) load() uint32 {
var reqs uint32
for _, h := range r.hosts {
reqs += h.CurrentLoad()
}
return reqs
}
type cluster struct {
name string
replicas []*replica
nextReplicaIdx uint32
users map[string]*clusterUser
killQueryUserName string
killQueryUserPassword string
heartBeat heartbeat.HeartBeat
retryNumber int
}
func newCluster(c config.Cluster) (*cluster, error) {
clusterUsers := make(map[string]*clusterUser, len(c.ClusterUsers))
for _, cu := range c.ClusterUsers {
if _, ok := clusterUsers[cu.Name]; ok {
return nil, fmt.Errorf("duplicate config for cluster user %q", cu.Name)
}
clusterUsers[cu.Name] = newClusterUser(cu)
}
heartBeat := heartbeat.NewHeartbeat(c.HeartBeat, heartbeat.WithDefaultUser(c.ClusterUsers[0].Name, c.ClusterUsers[0].Password))
newC := &cluster{
name: c.Name,
users: clusterUsers,
killQueryUserName: c.KillQueryUser.Name,
killQueryUserPassword: c.KillQueryUser.Password,
heartBeat: heartBeat,
retryNumber: c.RetryNumber,
}
replicas, err := newReplicas(c.Replicas, c.Nodes, c.Scheme, newC)
if err != nil {
return nil, fmt.Errorf("cannot initialize replicas: %w", err)
}
newC.replicas = replicas
return newC, nil
}
func newClusters(cfg []config.Cluster) (map[string]*cluster, error) {
clusters := make(map[string]*cluster, len(cfg))
for _, c := range cfg {
if _, ok := clusters[c.Name]; ok {
return nil, fmt.Errorf("duplicate config for cluster %q", c.Name)
}
tmpC, err := newCluster(c)
if err != nil {
return nil, fmt.Errorf("cannot initialize cluster %q: %w", c.Name, err)
}
clusters[c.Name] = tmpC
}
return clusters, nil
}
// getReplica returns least loaded + round-robin replica from the cluster.
//
// Always returns non-nil.
func (c *cluster) getReplica() *replica {
idx := atomic.AddUint32(&c.nextReplicaIdx, 1)
n := uint32(len(c.replicas))
if n == 1 {
return c.replicas[0]
}
idx %= n
r := c.replicas[idx]
reqs := r.load()
// Set least priority to inactive replica.
if !r.isActive() {
reqs = ^uint32(0)
}
if reqs == 0 {
return r
}
// Scan all the replicas for the least loaded replica.
for i := uint32(1); i < n; i++ {
tmpIdx := (idx + i) % n
tmpR := c.replicas[tmpIdx]
if !tmpR.isActive() {
continue
}
tmpReqs := tmpR.load()
if tmpReqs == 0 {
return tmpR
}
if tmpReqs < reqs {
r = tmpR
reqs = tmpReqs
}
}
// The returned replica may be inactive. This is OK,
// since this means all the replicas are inactive,
// so let's try proxying the request to any replica.
return r
}
func (c *cluster) getReplicaSticky(sessionId string) *replica {
idx := atomic.AddUint32(&c.nextReplicaIdx, 1)
n := uint32(len(c.replicas))
if n == 1 {
return c.replicas[0]
}
idx %= n
r := c.replicas[idx]
for i := uint32(1); i < n; i++ {
// handling sticky session
sessionId := hash(sessionId)
tmpIdx := (sessionId) % n
tmpRSticky := c.replicas[tmpIdx]
log.Debugf("Sticky replica candidate is: %s", tmpRSticky.name)
if !tmpRSticky.isActive() {
log.Debugf("Sticky session replica has been picked up, but it is not available")
continue
}
log.Debugf("Sticky session replica is: %s, session_id: %d, replica_idx: %d, max replicas in pool: %d", tmpRSticky.name, sessionId, tmpIdx, n)
return tmpRSticky
}
// The returned replica may be inactive. This is OK,
// since this means all the replicas are inactive,
// so let's try proxying the request to any replica.
return r
}
// getHostSticky returns host by stickiness from replica.
//
// Always returns non-nil.
func (r *replica) getHostSticky(sessionId string) *topology.Node {
idx := atomic.AddUint32(&r.nextHostIdx, 1)
n := uint32(len(r.hosts))
if n == 1 {
return r.hosts[0]
}
idx %= n
h := r.hosts[idx]
// Scan all the hosts for the least loaded host.
for i := uint32(1); i < n; i++ {
// handling sticky session
sessionId := hash(sessionId)
tmpIdx := (sessionId) % n
tmpHSticky := r.hosts[tmpIdx]
log.Debugf("Sticky server candidate is: %s", tmpHSticky)
if !tmpHSticky.IsActive() {
log.Debugf("Sticky session server has been picked up, but it is not available")
continue
}
log.Debugf("Sticky session server is: %s, session_id: %d, server_idx: %d, max nodes in pool: %d", tmpHSticky, sessionId, tmpIdx, n)
return tmpHSticky
}
// The returned host may be inactive. This is OK,
// since this means all the hosts are inactive,
// so let's try proxying the request to any host.
return h
}
// getHost returns least loaded + round-robin host from replica.
//
// Always returns non-nil.
func (r *replica) getHost() *topology.Node {
idx := atomic.AddUint32(&r.nextHostIdx, 1)
n := uint32(len(r.hosts))
if n == 1 {
return r.hosts[0]
}
idx %= n
h := r.hosts[idx]
reqs := h.CurrentLoad()
// Set least priority to inactive host.
if !h.IsActive() {
reqs = ^uint32(0)
}
if reqs == 0 {
return h
}
// Scan all the hosts for the least loaded host.
for i := uint32(1); i < n; i++ {
tmpIdx := (idx + i) % n
tmpH := r.hosts[tmpIdx]
if !tmpH.IsActive() {
continue
}
tmpReqs := tmpH.CurrentLoad()
if tmpReqs == 0 {
return tmpH
}
if tmpReqs < reqs {
h = tmpH
reqs = tmpReqs
}
}
// The returned host may be inactive. This is OK,
// since this means all the hosts are inactive,
// so let's try proxying the request to any host.
return h
}
// getHostSticky returns host based on stickiness from cluster.
//
// Always returns non-nil.
func (c *cluster) getHostSticky(sessionId string) *topology.Node {
r := c.getReplicaSticky(sessionId)
return r.getHostSticky(sessionId)
}
// getHost returns least loaded + round-robin host from cluster.
//
// Always returns non-nil.
func (c *cluster) getHost() *topology.Node {
r := c.getReplica()
return r.getHost()
}
type rateLimiter struct {
counter
}
func (rl *rateLimiter) run(done <-chan struct{}) {
for {
select {
case <-done:
return
case <-time.After(time.Minute):
rl.store(0)
}
}
}
type counter struct {
value uint32
}
func (c *counter) store(n uint32) { atomic.StoreUint32(&c.value, n) }
func (c *counter) load() uint32 { return atomic.LoadUint32(&c.value) }
func (c *counter) dec() { atomic.AddUint32(&c.value, ^uint32(0)) }
func (c *counter) inc() uint32 { return atomic.AddUint32(&c.value, 1) }
package main
import (
"bytes"
"compress/gzip"
"fmt"
"hash/fnv"
"io"
"net/http"
"sort"
"strconv"
"strings"
"github.com/contentsquare/chproxy/chdecompressor"
"github.com/contentsquare/chproxy/log"
)
func respondWith(rw http.ResponseWriter, err error, status int) {
log.ErrorWithCallDepth(err, 1)
rw.WriteHeader(status)
fmt.Fprintf(rw, "%s\n", err)
}
var defaultUser = "default"
// getAuth retrieves auth credentials from request
// according to CH documentation @see "https://clickhouse.yandex/docs/en/interfaces/http/"
func getAuth(req *http.Request) (string, string) {
// check X-ClickHouse- headers
name := req.Header.Get("X-ClickHouse-User")
pass := req.Header.Get("X-ClickHouse-Key")
if name != "" {
return name, pass
}
// if header is empty - check basicAuth
if name, pass, ok := req.BasicAuth(); ok {
return name, pass
}
// if basicAuth is empty - check URL params `user` and `password`
params := req.URL.Query()
if name := params.Get("user"); name != "" {
pass := params.Get("password")
return name, pass
}
// if still no credentials - treat it as `default` user request
return defaultUser, ""
}
// getSessionId retrieves session id
func getSessionId(req *http.Request) string {
params := req.URL.Query()
sessionId := params.Get("session_id")
return sessionId
}
// getSessionId retrieves session id
func getSessionTimeout(req *http.Request) int {
params := req.URL.Query()
sessionTimeout, err := strconv.Atoi(params.Get("session_timeout"))
if err == nil && sessionTimeout > 0 {
return sessionTimeout
}
return 60
}
// getQuerySnippet returns query snippet.
//
// getQuerySnippet must be called only for error reporting.
func getQuerySnippet(req *http.Request) string {
query := req.URL.Query().Get("query")
body := getQuerySnippetFromBody(req)
if len(query) != 0 && len(body) != 0 {
query += "\n"
}
return query + body
}
func hash(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}
func getQuerySnippetFromBody(req *http.Request) string {
if req.Body == nil {
return ""
}
crc, ok := req.Body.(*cachedReadCloser)
if !ok {
crc = &cachedReadCloser{
ReadCloser: req.Body,
}
}
// 'read' request body, so it traps into to crc.
// Ignore any errors, since getQuerySnippet is called only
// during error reporting.
io.Copy(io.Discard, crc) // nolint
data := crc.String()
u := getDecompressor(req)
if u == nil {
return data
}
bs := bytes.NewBufferString(data)
b, err := u.decompress(bs)
if err == nil {
return string(b)
}
// It is better to return partially decompressed data instead of an empty string.
if len(b) > 0 {
return string(b)
}
// The data failed to be decompressed. Return compressed data
// instead of an empty string.
return data
}
// getFullQuery returns full query from req.
func getFullQuery(req *http.Request) ([]byte, error) {
var result bytes.Buffer
if req.URL.Query().Get("query") != "" {
result.WriteString(req.URL.Query().Get("query"))
}
body, err := getFullQueryFromBody(req)
if err != nil {
return nil, err
}
if result.Len() != 0 && len(body) != 0 {
result.WriteByte('\n')
}
result.Write(body)
return result.Bytes(), nil
}
func getFullQueryFromBody(req *http.Request) ([]byte, error) {
if req.Body == nil {
return nil, nil
}
data, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
// restore body for further reading
req.Body = io.NopCloser(bytes.NewBuffer(data))
u := getDecompressor(req)
if u == nil {
return data, nil
}
br := bytes.NewReader(data)
b, err := u.decompress(br)
if err != nil {
return nil, fmt.Errorf("cannot uncompress query: %w", err)
}
return b, nil
}
var cachableStatements = []string{"SELECT", "WITH"}
// canCacheQuery returns true if q can be cached.
func canCacheQuery(q []byte) bool {
q = skipLeadingComments(q)
for _, statement := range cachableStatements {
if len(q) < len(statement) {
continue
}
l := bytes.ToUpper(q[:len(statement)])
if bytes.HasPrefix(l, []byte(statement)) {
return true
}
}
return false
}
//nolint:cyclop // No clean way to split this.
func skipLeadingComments(q []byte) []byte {
for len(q) > 0 {
switch q[0] {
case '\t', '\n', '\v', '\f', '\r', ' ':
q = q[1:]
case '-':
if len(q) < 2 || q[1] != '-' {
return q
}
// skip `-- comment`
n := bytes.IndexByte(q, '\n')
if n < 0 {
return nil
}
q = q[n+1:]
case '/':
if len(q) < 2 || q[1] != '*' {
return q
}
// skip `/* comment */`
for {
n := bytes.IndexByte(q, '*')
if n < 0 {
return nil
}
q = q[n+1:]
if len(q) == 0 {
return nil
}
if q[0] == '/' {
q = q[1:]
break
}
}
default:
return q
}
}
return nil
}
// splits header string in sorted slice
func sortHeader(header string) string {
h := strings.Split(header, ",")
for i, v := range h {
h[i] = strings.TrimSpace(v)
}
sort.Strings(h)
return strings.Join(h, ",")
}
func getDecompressor(req *http.Request) decompressor {
if req.Header.Get("Content-Encoding") == "gzip" {
return gzipDecompressor{}
}
if req.URL.Query().Get("decompress") == "1" {
return chDecompressor{}
}
return nil
}
type decompressor interface {
decompress(r io.Reader) ([]byte, error)
}
type gzipDecompressor struct{}
func (dc gzipDecompressor) decompress(r io.Reader) ([]byte, error) {
gr, err := gzip.NewReader(r)
if err != nil {
return nil, fmt.Errorf("cannot ungzip query: %w", err)
}
return io.ReadAll(gr)
}
type chDecompressor struct{}
func (dc chDecompressor) decompress(r io.Reader) ([]byte, error) {
lr := chdecompressor.NewReader(r)
return io.ReadAll(lr)
}
func calcMapHash(m map[string]string) (uint32, error) {
if len(m) == 0 {
return 0, nil
}
var keys []string
for key := range m {
keys = append(keys, key)
}
sort.Strings(keys)
h := fnv.New32a()
for _, k := range keys {
str := fmt.Sprintf("%s=%s&", k, m[k])
_, err := h.Write([]byte(str))
if err != nil {
return 0, err
}
}
return h.Sum32(), nil
}
func calcCredentialHash(user string, pwd string) (uint32, error) {
h := fnv.New32a()
_, err := h.Write([]byte(user + pwd))
return h.Sum32(), err
}