package main
import (
"errors"
"flag"
"fmt"
"os"
"github.com/lucmq/go-shelve/shelve"
)
type Shelf = shelve.Shelf[string, string]
var exitOnError = true
var exit = os.Exit
func main() {
if err := run(); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "run failed: %v\n", err)
if exitOnError {
exit(1)
}
}
}
func run() error {
flag.Usage = printUsage
storePath := flag.String("path", ".store", "Path to the shelve store")
codecName := flag.String("codec", "json", "value serialization format: gob, json, or text")
flag.Parse()
args := flag.Args()
if len(args) < 1 {
printUsage()
return nil
}
command := args[0]
commandArgs := args[1:]
codec, err := getCodec(*codecName)
if err != nil {
return fmt.Errorf("get codec: %w", err)
}
// Open the shelve store
store, err := shelve.Open[string, string](
*storePath,
shelve.WithCodec(codec),
)
if err != nil {
return fmt.Errorf("open store: %w", err)
}
defer store.Close()
// Execute the appropriate command
switch command {
case "put":
return handlePut(store, commandArgs)
case "get":
return handleGet(store, commandArgs)
case "has":
return handleHas(store, commandArgs)
case "delete":
return handleDelete(store, commandArgs)
case "len":
return handleLen(store)
case "items":
return handleItems(store, "items", commandArgs)
case "keys":
return handleItems(store, "keys", commandArgs)
case "values":
return handleItems(store, "values", commandArgs)
default:
return fmt.Errorf("unknown command: %s", command)
}
}
func getCodec(name string) (shelve.Codec, error) {
switch name {
case "gob":
return shelve.GobCodec(), nil
case "json":
return shelve.JSONCodec(), nil
case "text":
return shelve.TextCodec(), nil
default:
return nil, fmt.Errorf("unsupported codec: %s", name)
}
}
// Put key-value pairs.
func handlePut(store *Shelf, args []string) error {
if len(args) < 2 || len(args)%2 != 0 {
return errors.New("usage: shelve put <key> <value> [<key> <value> ...]")
}
for i := 0; i < len(args); i += 2 {
key := args[i]
value := args[i+1]
if err := store.Put(key, value); err != nil {
return fmt.Errorf("put key-value pair (%s, %s): %w", key, value, err)
}
}
fmt.Println("OK")
return nil
}
// Get value by key.
func handleGet(store *Shelf, args []string) error {
if len(args) < 1 {
return errors.New("usage: shelve get <key>")
}
key := args[0]
value, _, err := store.Get(key)
if err != nil {
return fmt.Errorf("get key: %w", err)
}
fmt.Println(value)
return nil
}
// Check if a key exists.
func handleHas(store *Shelf, args []string) error {
if len(args) < 1 {
return errors.New("usage: shelve has <key>")
}
key := args[0]
ok, err := store.Has(key)
if err != nil {
return fmt.Errorf("check key existence: %w", err)
}
if ok {
fmt.Println("true")
} else {
fmt.Println("false")
}
return nil
}
// Delete a key.
func handleDelete(store *Shelf, args []string) error {
if len(args) < 1 {
return errors.New("usage: shelve delete <key>")
}
key := args[0]
if err := store.Delete(key); err != nil {
return fmt.Errorf("delete key: %w", err)
}
fmt.Println("OK")
return nil
}
// Get total number of keys.
func handleLen(store *Shelf) error {
count := store.Len()
if count == -1 {
return errors.New("failed to get length")
}
fmt.Println(count)
return nil
}
// List items, keys, or values with optional filters.
func handleItems(store *Shelf, mode string, args []string) error {
fs := flag.NewFlagSet(mode, flag.ContinueOnError)
start := fs.String("start", "", "Inclusive start key (Asc: k ≥ start, Desc: k ≤ start)")
end := fs.String("end", "", "Exclusive end key (Asc: k < end, Desc: k > end)")
limit := fs.Int("limit", shelve.All, "Maximum number of items")
desc := fs.Bool("desc", false, "Iterate in descending order")
if err := fs.Parse(args); err != nil {
return fmt.Errorf("parse flags: %w", err)
}
order := shelve.Asc
if *desc {
order = shelve.Desc
}
switch mode {
case "items":
return printItems(store, start, end, order, *limit)
case "keys":
return printKeys(store, start, end, order, *limit)
case "values":
return printValues(store, start, end, order, *limit)
default:
return fmt.Errorf("invalid mode: %s", mode)
}
}
// Helper: Print key-value pairs.
func printItems(store *Shelf, start, end *string, order, limit int) error {
return store.Items(start, limit, order, func(key, value string) (bool, error) {
if *end != "" && key >= *end {
return false, nil
}
fmt.Println(key, value)
return true, nil
})
}
// Helper: Print keys only.
func printKeys(store *Shelf, start, end *string, order, limit int) error {
return store.Keys(start, limit, order, func(key, _ string) (bool, error) {
if *end != "" && key >= *end {
return false, nil
}
fmt.Println(key)
return true, nil
})
}
// Helper: Print values only.
func printValues(store *Shelf, start, end *string, order, limit int) error {
return store.Items(start, limit, order, func(key, value string) (bool, error) {
if *end != "" && key >= *end {
return false, nil
}
fmt.Println(value)
return true, nil
})
}
func printUsage() {
fmt.Println(`shelve is a CLI tool for managing a shelve key-value store.
Usage:
shelve [options] <command> [arguments]
The commands are:
put store one or more key-value pairs
get retrieve the value of a key
has check if a key exists
delete remove a key
len count total keys in the store
items list key-value pairs
keys list only the keys
values list only the values
Options:
`)
flag.PrintDefaults()
}
// Package sdb offers a simple key-value database that can be utilized with the
// go-shelve project.
//
// It should be suitable for a wide range of applications, but the driver
// directory (go-shelve/driver) provides additional options for configuring the
// Shelf with other supported databases from the Go ecosystem.
//
// # DB Records
//
// In sdb, each database record is represented by a distinct file stored in a
// bucket, which is a corresponding filesystem directory. The number of
// documents stored in each bucket is unlimited, and modern filesystems should
// be able to handle large buckets without significantly affecting performance.
//
// Each file record's name is "base32hex" encoding of the key, which preserves
// lexical sort order [1]. Keys are limited to 128 characters. The record file
// is stored as binary data. With this design, Users do not need to worry about
// hitting the maximum filename length or storing keys with forbidden
// characters.
//
// # Cache
//
// The sdb database uses a memory-based cache to speed up operations. By
// default, the cache size is unlimited, but it can be configured to a fixed
// size or disabled altogether.
//
// The cache's design, albeit simple, can enhance the performance of "DB.Get"
// and "DB.Items" to more than 1 million reads per second on standard hardware.
//
// # Atomicity
//
// New records are written atomically to the key-value store. With a
// file-per-record design, sdb achieves this by using atomic file writes, which
// consist of creating a temporary file and then renaming it [2].
//
// This ensures that the database's methods are always performed with one
// atomic operation, significantly simplifying the recovery process.
//
// Currently, the only data that can become inconsistent is the count of stored
// records, but if this happens, it is detected and corrected at the DB
// initialization.
//
// As an optimization, records might be written directly without needing a
// temporary file if the data fits in a single sector since a single-sector
// write can be assumed to be atomic on some systems [3] [4].
//
// # Durability
//
// By default, sdb leverages the filesystem cache to speed up the database
// writes. This is generally suitable for most applications for which sdb is
// intended, as modern hardware can offer sufficient protection against
// crashes and ensure durability.
//
// For the highest level of durability, the WithSynchronousWrites option makes
// the database synchronize data to persistent storage on each write.
//
// # Notes
//
// [1] https://datatracker.ietf.org/doc/html/rfc4648#section-7
//
// [2] On Windows, additional configuration is involved.
//
// [3] https://stackoverflow.com/questions/2009063/are-disk-sector-writes-atomic
//
// [4] https://web.cs.ucla.edu/classes/spring07/cs111-2/scribe/lecture14.html
package sdb
import (
"encoding/base32"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sync"
"time"
"unsafe"
"github.com/lucmq/go-shelve/sdb/internal"
)
const (
// Asc and Desc can be used with the DB.Items method to make the
// iteration order ascending or descending respectively.
//
// They are just syntactic sugar to make the iteration order more
// explicit.
Asc = 1
// Desc is the opposite of Asc.
Desc = -1
)
const (
// DefaultCacheSize is the default size of the cache used to speed up the
// database operations. A value of -1 represents an unlimited cache.
DefaultCacheSize = -1
// MaxKeyLength is the maximum size of a key.
MaxKeyLength = 128
// metadataSyncInterval is the interval at which the metadata is synced to
// disk.
metadataSyncInterval = 1 * time.Minute
)
const (
dataDirectory = "data"
metadataDirectory = "meta"
metadataFilename = "meta.gob"
)
const version = "1.2"
var (
// ErrKeyTooLarge is returned when a key exceeds the maximum length.
ErrKeyTooLarge = errors.New("key exceeds maximum length")
// ErrDatabaseClosed is returned when the database is closed.
ErrDatabaseClosed = errors.New("database is closed")
)
// Yield is a function called when iterating over key-value pairs in the
// database. If Yield returns false or an error, the iteration stops.
type Yield = func(key, value []byte) (bool, error)
// DB represents a database, which is created with the Open function.
//
// Client applications must call DB.Close() when done with the database.
//
// A DB is safe for concurrent use by multiple goroutines.
type DB struct {
mu sync.RWMutex
path string
metadata metadata
metadataStore *metadataStore
shards []shard
cache internal.Cache[cacheEntry]
fs fileSystem
closed bool
// Controls the background sync loop.
done chan struct{}
wg sync.WaitGroup
maxFilesPerShard int64
syncWrites bool
// autoSync enables the background sync loop. Can be removed if a WAL
// is adopted for consistency, since the WAL would handle the sync
// loop unnecessary.
autoSync bool
syncInterval time.Duration
}
// cacheEntry represents an entry in the cache.
type cacheEntry = []byte
// Open opens the database at the given path. If the path does not exist, it is
// created.
//
// Client applications must call DB.Close() when done with the database.
func Open(path string, options ...Option) (*DB, error) {
db := DB{
path: path,
metadata: makeMetadata(),
shards: []shard{{maxKey: sentinelDir}},
cache: internal.NewCache[cacheEntry](-1),
fs: &osFS{},
done: make(chan struct{}),
maxFilesPerShard: defaultMaxFilesPerShard,
syncWrites: false,
autoSync: true,
syncInterval: metadataSyncInterval,
}
// Apply options.
for _, option := range options {
option(&db)
}
if err := initializeDatabase(&db); err != nil {
return nil, fmt.Errorf("initialize database: %w", err)
}
// Start the background loop if autoSync is enabled.
if db.autoSync {
db.wg.Add(1)
go syncMetadata(&db)
}
return &db, nil
}
// Close synchronizes and closes the database. Users must ensure no pending
// operations are in progress before calling Close().
//
// Example:
//
// var wg sync.WaitGroup
// db, _ := sdb.Open("path")
//
// // Start concurrent writes
// for i := 0; i < 10; i++ {
// wg.Add(1)
// go func(i int) {
// defer wg.Done()
// db.Put([]byte(fmt.Sprintf("key-%d", i)), []byte("value"))
// }(i)
// }
//
// wg.Wait() // Ensure all writes are done
// db.Close() // Safe to close now
func (db *DB) Close() error {
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
return nil
}
db.closed = true
// Signal the background goroutine to stop.
close(db.done)
db.wg.Wait()
// Final sync.
return syncInternal(db)
}
// Len returns the number of items in the database. If an error occurs, it
// returns -1.
func (db *DB) Len() int64 {
db.mu.RLock()
defer db.mu.RUnlock()
if db.closed {
return -1
}
return int64(db.metadata.TotalEntries)
}
// Sync synchronizes the database to persistent storage.
func (db *DB) Sync() error {
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
return ErrDatabaseClosed
}
return syncInternal(db)
}
// Has reports whether a key exists in the database.
func (db *DB) Has(key []byte) (bool, error) {
db.mu.RLock()
defer db.mu.RUnlock()
if db.closed {
return false, ErrDatabaseClosed
}
_, ok := cacheGet(db, key)
if ok {
return true, nil
}
path, _ := keyPath(db, key)
_, err := fs.Stat(db.fs, path)
if err != nil && !os.IsNotExist(err) {
return false, fmt.Errorf("stat: %w", err)
}
return !os.IsNotExist(err), nil
}
// Get retrieves the value associated with a key from the database. If the key
// is not found, it returns nil.
func (db *DB) Get(key []byte) ([]byte, error) {
db.mu.RLock()
defer db.mu.RUnlock()
if db.closed {
return nil, ErrDatabaseClosed
}
v, ok := cacheGet(db, key)
if ok {
return v, nil
}
path, _ := keyPath(db, key)
value, err := fs.ReadFile(db.fs, path)
if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("read file: %w", err)
}
if os.IsNotExist(err) {
return nil, nil
}
return value, err
}
// Put adds a key-value pair to the database. If the key already exists, it
// overwrites the existing value.
//
// It returns an error if the key is greater than [MaxKeyLength].
func (db *DB) Put(key, value []byte) error {
if err := prepareForMutation(db); err != nil {
return fmt.Errorf("prepare for mutation: %w", err)
}
if len(key) > MaxKeyLength {
return ErrKeyTooLarge
}
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
return ErrDatabaseClosed
}
path, shardID := keyPath(db, key)
sh := &db.shards[shardID]
updated, err := putPath(db, path, value)
if err != nil {
return fmt.Errorf("put path: %w", err)
}
if !updated {
sh.count++
db.metadata.TotalEntries++
}
db.metadata.Generation++
if int64(sh.count) > db.maxFilesPerShard {
if err = db.splitShard(shardID); err != nil {
return fmt.Errorf("split shard: %w", err)
}
}
// Cache aside
db.cache.Put(string(key), value)
return nil
}
func putPath(db *DB, path string, value []byte) (updated bool, err error) {
_, err = fs.Stat(db.fs, path)
if err != nil && !os.IsNotExist(err) {
return false, fmt.Errorf("stat: %w", err)
}
if err == nil {
updated = true
}
writer := newAtomicWriter(db.fs, db.syncWrites)
err = writer.WriteFile(path, value, !updated)
return updated, err
}
// Delete removes a key-value pair from the database.
func (db *DB) Delete(key []byte) error {
if err := prepareForMutation(db); err != nil {
return fmt.Errorf("prepare for mutation: %w", err)
}
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
return ErrDatabaseClosed
}
path, shardID := keyPath(db, key)
sh := &db.shards[shardID]
var deleted bool
err := db.fs.Remove(path)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove: %w", err)
}
if err == nil {
deleted = true
}
if deleted {
sh.count--
db.metadata.TotalEntries--
}
db.metadata.Generation++
db.cache.Delete(string(key))
return nil
}
// Items iterates over key-value pairs in the database, invoking fn(k, v)
// for each pair. Iteration stops early if fn returns false.
//
// start is the first key to include in the iteration (inclusive).
// If start is nil or empty, iteration begins at the logical extremity
// determined by order.
//
// order controls the traversal direction:
//
// Asc (value +1) – ascending lexical order
// Desc (value –1) – descending lexical order
//
// Keys are streamed in that order until Yield returns false or an
// error occurs.
//
// This operation acquires a read lock each time a database record is read
// and holds it for the duration of the fn callback. Implementations that
// require faster lock release should copy the key-value pair and return
// from the callback as quickly as possible.
//
// The user-provided fn(k, v) must not modify the database within the same
// goroutine as the iteration, as this would cause a deadlock.
func (db *DB) Items(start []byte, order int, fn Yield) error {
db.mu.RLock()
defer db.mu.RUnlock()
if db.closed {
return ErrDatabaseClosed
}
n := len(db.shards)
asc := order == Asc
encStart := encodeKey(start)
// Pick initial shard (Use the db.shards slice to prune the
// search space).
idx := 0
if len(start) != 0 {
idx = db.shardForKey(encStart)
} else if !asc {
idx = n - 1
}
step := 1
stop := n
if !asc {
step = -1
stop = -1
}
for k := idx; k != stop; k += step {
sh := db.shards[k]
dir := filepath.Join(db.path, dataDirectory, sh.maxKey)
keep, err := streamDir(db.fs, dir, encStart, order, func(filename string) (bool, error) {
return handleFileWithLock(db, dir, filename, fn)
})
if err != nil {
return err
}
if !keep {
return nil
}
}
return nil
}
func handleFileWithLock(db *DB, dir, name string, fn Yield) (bool, error) {
key, err := decodeKey(name)
if err != nil {
return false, fmt.Errorf("decode key: %w", err)
}
// Use the cache (but do not cache aside while iterating) because that would
// result in a lot of cache turnover with keys that might not be needed to be
// cached.
value, ok := cacheGet(db, key)
if ok {
return fn(key, value)
}
// Read from the disk.
v, err := fs.ReadFile(db.fs, filepath.Join(dir, name))
if errors.Is(err, os.ErrNotExist) {
// Deleted while iterating? Ignore.
return true, nil
}
if err != nil {
return false, fmt.Errorf("read key-value: %w", err)
}
return fn(key, v)
}
// Helpers
func keyPath(db *DB, key []byte) (path string, shardID int) {
base := encodeKey(key)
i := db.shardForKey(base)
dir := db.shardPath(i)
return filepath.Join(dir, base), i
}
func encodeKey(key []byte) string {
return base32.HexEncoding.EncodeToString(key)
}
func decodeKey(key string) ([]byte, error) {
return base32.HexEncoding.DecodeString(key)
}
func cacheGet(db *DB, key []byte) (cacheEntry, bool) {
s := unsafe.String(&key[0], len(key))
return db.cache.Get(s)
}
// prepareForMutation ensures we have enough information saved in persistent
// storage to be able to recover the database in the event of an error.
//
// Before each mutation, we compare the database generation value with the
// checkpoint. If they are equal, we increase generation and sync the metadata.
// Different values for generation and checkpoint indicates that the database
// has pending state to be synced to persistent storage.
//
// The I/O done by this function should be amortized between many mutations.
func prepareForMutation(db *DB) error {
ok := db.mu.TryLock()
if !ok {
return nil
}
defer db.mu.Unlock()
if db.closed {
return ErrDatabaseClosed
}
if db.metadata.Generation != db.metadata.Checkpoint {
// Already drifted
return nil
}
// Mark as loaded
db.metadata.Generation = db.metadata.Checkpoint + 1
// Sync the metadata
return db.metadataStore.Save(db.metadata)
}
func syncInternal(db *DB) error {
// Mark as consistent
db.metadata.Checkpoint = db.metadata.Generation
return db.metadataStore.Save(db.metadata)
}
// syncMetadata periodically syncs the metadata to persistent storage.
//
// Note: This is only done to decrease the chance of a recovery triggered
// in the initialization due to a user forgetting to call DB.Close() or a
// system crash. The database doesn't really depend on this mechanism and
// errors here can be ignored.
func syncMetadata(db *DB) {
defer db.wg.Done()
ticker := time.NewTicker(db.syncInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
_ = db.Sync()
case <-db.done:
// The channel is closed in Close(); exit the goroutine.
return
}
}
}
package sdb
import (
"fmt"
"io"
"io/fs"
"math/rand/v2"
"os"
"path/filepath"
"runtime"
"sort"
"time"
)
var (
defaultDiskSectorSize = 4096
defaultPermissions = os.FileMode(0600)
defaultDirPermissions = os.FileMode(0700)
)
// fileSystem abstracts the subset of os/fs operations that SDB needs.
//
// Implementations are expected to wrap a concrete filesystem (e.g. the real
// os filesystem, an in-memory mock, or an overlay that injects faults for
// testing). All methods MUST be safe for concurrent use by multiple
// goroutines.
//
// Notes:
//
// - Rename MUST be atomic: it should guarantee to either replace the target
// file entirely, or not change either the destination or the source.
type fileSystem interface {
fs.FS
OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error)
Remove(name string) error
Rename(oldpath, newpath string) error
MkdirAll(path string, perm fs.FileMode) error
}
// OS Filesystem
// osFS is the implementation of fileSystem that delegates every call to the
// standard library’s os package. The zero value is ready to use.
type osFS struct{}
// Compile-time interface check.
var _ fileSystem = (*osFS)(nil)
func (*osFS) Open(name string) (fs.File, error) {
return os.Open(name)
}
func (*osFS) OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error) {
return os.OpenFile(name, flag, perm)
}
func (*osFS) Stat(name string) (fs.FileInfo, error) {
return os.Stat(name)
}
func (*osFS) ReadFile(name string) ([]byte, error) {
return os.ReadFile(name)
}
func (*osFS) Remove(name string) error {
return os.Remove(name)
}
func (*osFS) Rename(oldpath, newpath string) error {
return renameFile(oldpath, newpath)
}
func (*osFS) MkdirAll(path string, perm fs.FileMode) error {
return os.MkdirAll(path, perm)
}
// Atomic Writer
// The main object of atomicWrite is to protect against incomplete writes.
// When used together with O_SYNC, atomicWrite also provides some additional
// durability guarantees.
type atomicWriter struct {
fs fileSystem
syncWrites bool
diskSectorSize int
perm os.FileMode
}
func newAtomicWriter(fsys fileSystem, syncWrites bool) *atomicWriter {
// Note: If we decide to ask the host system for the disk sector size,
// we can use the go `init` function for that and keep this constructor
// cleaner, without the need to return an error and also, without the
// need to query the os multiple times.
diskSectorSize := defaultDiskSectorSize
return &atomicWriter{
fs: fsys,
syncWrites: syncWrites,
diskSectorSize: diskSectorSize,
perm: defaultPermissions,
}
}
func (w *atomicWriter) flag(excl bool) int {
flag := os.O_WRONLY | os.O_CREATE | os.O_TRUNC
if w.syncWrites {
flag |= os.O_SYNC
}
if excl {
flag |= os.O_EXCL
}
return flag
}
func (w *atomicWriter) WriteFile(path string, data []byte, excl bool) (err error) {
defer func() {
// Sync the parent directory for more durability guarantees. See:
// - https://lwn.net/Articles/457667/#:~:text=When%20should%20you%20Fsync
if err == nil && w.syncWrites {
err = syncFile(w.fs, filepath.Dir(path))
}
}()
if runtime.GOOS == "linux" && len(data) <= w.diskSectorSize {
// Optimization: Write directly if the data fits in a single sector,
// since a single-sector write can be assumed to be atomic. See:
//
// - https://stackoverflow.com/questions/2009063/are-disk-sector-writes-atomic
// - https://web.cs.ucla.edu/classes/spring07/cs111-2/scribe/lecture14.html
//
// This optimization assumes that the host supports atomic writes to a
// disk sector.
return w._writeFile(path, data, excl)
}
tmpPath := makeTempPath(path)
// w.writeFile will sync, if configured to do so.
err = w._writeFile(tmpPath, data, excl)
if err != nil {
return fmt.Errorf("write: %w", err)
}
return renameFile(tmpPath, path)
}
// writeFile writes data to the named file, creating it if necessary.
// If the file does not exist, WriteFile creates it with permissions perm (before umask);
// otherwise writeFile truncates it before writing, without changing permissions.
// Since writeFile requires multiple system calls to complete, a failure mid-operation
// can leave the file in a partially written state.
func (w *atomicWriter) _writeFile(name string, data []byte, excl bool) error {
// Adapted from `os.WriteFile()`
f, err := w.fs.OpenFile(name, w.flag(excl), w.perm)
if err != nil {
return err
}
_, err = f.(io.Writer).Write(data)
if err1 := f.Close(); err1 != nil && err == nil {
err = err1
}
return err
}
// Utilities
func mkdirs(fs fileSystem, paths []string, perm os.FileMode) error {
for _, path := range paths {
if err := fs.MkdirAll(path, perm); err != nil {
return fmt.Errorf("MkdirAll: %w", err)
}
}
return nil
}
func streamDir(fs fileSystem, dir, start string, order int, fn func(filename string) (bool, error)) (bool, error) {
asc := order > Desc
needFilter := start != ""
filenames, err := readDir(fs, dir, order)
if err != nil {
return false, fmt.Errorf("readDir: %w", err)
}
for _, name := range filenames {
if needFilter {
if asc && name < start {
continue // still before the start
}
if !asc && name > start {
continue // still before the start (descending case)
}
needFilter = false // boundary crossed -- stop filtering
}
keep, err := fn(name)
if err != nil {
return false, fmt.Errorf("fn: %w", err)
}
if !keep {
return false, nil
}
}
return true, nil
}
func readDir(fsys fileSystem, dir string, order int) ([]string, error) {
names, err := readdirnames(fsys, dir)
if err != nil {
return nil, fmt.Errorf("readdirnames: %w", err)
}
sort.Slice(names, func(i, j int) bool {
if order > Desc {
return names[i] < names[j]
}
return names[i] > names[j]
})
return names, nil
}
func readdirnames(fsys fs.FS, name string) ([]string, error) {
f, err := fsys.Open(name)
if err != nil {
return nil, fmt.Errorf("open: %w", err)
}
defer f.Close()
type dirReader interface{ Readdirnames(n int) ([]string, error) }
dir, _ := f.(dirReader)
return dir.Readdirnames(-1)
}
// countRegularFiles walks the directory tree rooted at path and returns the
// number of regular (non-directory) files it finds.
func countRegularFiles(fsys fileSystem, path string) (uint64, error) {
var count uint64
err := fs.WalkDir(fsys, path, func(_ string, d fs.DirEntry, err error) error {
if d != nil && d.Type().IsRegular() {
count++
}
// propagate I/O or permission errors
return err
})
return count, err
}
// Helpers
func makeTempPath(path string) string {
tmpBase := fmt.Sprintf(
"%s-%d-%d",
filepath.Base(path),
rand.Uint32(),
time.Now().UnixNano(),
)
tmpPath := filepath.Join(os.TempDir(), tmpBase)
return tmpPath
}
func syncFile(fsys fileSystem, path string) error {
f, err := fsys.Open(path)
if err != nil {
return fmt.Errorf("open: %w", err)
}
type syncer interface{ Sync() error }
ff := f.(syncer)
err = ff.Sync()
if err1 := f.Close(); err1 != nil && err == nil {
err = err1
}
return err
}
//go:build !windows
package sdb
import "os"
// renameFile atomically replaces the destination file or directory with the
// source. It is guaranteed to either replace the target file entirely, or not
// change either file.
func renameFile(oldpath, newpath string) error {
return os.Rename(oldpath, newpath)
}
package sdb
import (
"fmt"
"io/fs"
"os"
"path/filepath"
)
func initializeDatabase(db *DB) error {
db.metadataStore = newMetadataStore(db.fs, db.path)
// Check if the database already exists
fi, err := fs.Stat(db.fs, db.path)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("stat path: %w", err)
}
if os.IsNotExist(err) {
return createDatabaseStorage(db)
}
// Check permissions (validate DB folder)
if !fi.IsDir() {
return fmt.Errorf("path is not a directory")
} else if fi.Mode().Perm()&0700 != 0700 {
return fmt.Errorf("path permissions are not 0700")
}
return loadDatabase(db)
}
func createDatabaseStorage(db *DB) error {
paths := []string{
db.path,
filepath.Join(db.path, dataDirectory),
filepath.Join(db.path, dataDirectory, sentinelDir),
filepath.Join(db.path, metadataDirectory),
}
if err := mkdirs(db.fs, paths, defaultDirPermissions); err != nil {
return fmt.Errorf("create directories: %w", err)
}
// Sync the database
err := db.Sync()
if err != nil {
return fmt.Errorf("sync: %w", err)
}
return nil
}
func loadDatabase(db *DB) error {
// Load the metadata
meta, err := db.metadataStore.Load()
if err != nil {
return fmt.Errorf("load metadata: %w", err)
}
db.metadata = meta
// Load the shards
if err = db.loadShards(); err != nil {
return fmt.Errorf("load shards: %w", err)
}
// Check the DB consistency and possibly recover from a corrupted
// state
return sanityCheck(db)
}
// Check version, totalBuckets and the generations. Recover from a corrupted
// database if the generation checkpoint doesn't match the current generation.
func sanityCheck(db *DB) error {
if err := db.metadata.Validate(); err != nil {
return err
}
// Check generations
if db.metadata.Generation != db.metadata.Checkpoint {
return recoverDatabase(db)
}
return nil
}
package internal
import (
"sync/atomic"
)
// TKey is the key type used by the cache.
type TKey = string
// Cache is a generic cache interface.
type Cache[TValue any] interface {
// Get retrieves a value from the cache based on the provided key. It
// returns the value and a boolean indicating whether the value was
// found in the cache.
Get(key TKey) (TValue, bool)
// Put adds a new key-value pair to the cache.
Put(key TKey, value TValue)
// Delete removes a value from the cache based on the provided key.
Delete(key TKey)
}
// Default Cache
// DefaultCache is the default implementation of the Cache interface. It is not
// safe for concurrent use, as it meant to be embedded in code that does the
// concurrency control.
type DefaultCache[TValue any] struct {
cache Cache[TValue]
hits atomic.Int64 // Atomic, since it's mutated by DefaultCache.Get.
misses atomic.Int64
}
// Assert DefaultCache implements Cache
var _ Cache[any] = (*DefaultCache[any])(nil)
// NewCache creates a new cache based on the provided maximum length.
//
// Setting the maxLength to -1 or less will disable the eviction of elements
// from the cache. A maxLength of 0 will create a pass-through cache that
// does nothing.
func NewCache[TValue any](maxLength int) *DefaultCache[TValue] {
var c Cache[TValue]
switch {
case maxLength <= -1:
c = newUnboundedCache[TValue]()
case maxLength == 0:
c = newPassThroughCache[TValue]()
default:
c = newRandomCache[TValue](maxLength)
}
return newCacheWithBase[TValue](c)
}
func newCacheWithBase[TValue any](c Cache[TValue]) *DefaultCache[TValue] {
return &DefaultCache[TValue]{cache: c}
}
// Get retrieves a value from the cache based on the provided key. It
// returns the value and a boolean indicating whether the value was
// found in the cache.
func (c *DefaultCache[TValue]) Get(key TKey) (TValue, bool) {
v, ok := c.cache.Get(key)
if !ok {
c.misses.Add(1)
return v, false
}
c.hits.Add(1)
return v, true
}
// Put adds a new key-value pair to the cache.
func (c *DefaultCache[TValue]) Put(key TKey, value TValue) {
c.cache.Put(key, value)
}
// Delete removes a value from the cache based on the provided key.
func (c *DefaultCache[TValue]) Delete(key TKey) { c.cache.Delete(key) }
// Hits returns the number of cache hits (i.e. the number of Get calls that
// found the value in the cache).
func (c *DefaultCache[TValue]) Hits() int { return int(c.hits.Load()) }
// Misses returns the number of cache misses (i.e. the number of Get calls that
// did not find the value in the cache).
func (c *DefaultCache[TValue]) Misses() int { return int(c.misses.Load()) }
// ResetRatio resets the ratio of hits to misses.
func (c *DefaultCache[TValue]) ResetRatio() {
c.misses.Store(0)
c.hits.Store(0)
}
// Unbounded Cache
type unboundedCache[TValue any] struct {
m map[TKey]TValue
}
// Check unboundedCache implements Cache interface
var _ Cache[any] = (*unboundedCache[any])(nil)
func newUnboundedCache[TValue any]() *unboundedCache[TValue] {
return &unboundedCache[TValue]{
m: make(map[TKey]TValue),
}
}
func (c *unboundedCache[TValue]) Get(key TKey) (TValue, bool) {
v, ok := c.m[key]
return v, ok
}
func (c *unboundedCache[TValue]) Put(key TKey, value TValue) {
c.m[key] = value
}
func (c *unboundedCache[TValue]) Delete(key TKey) {
delete(c.m, key)
}
// Pass-Through Cache
// passThroughCache is a simple pass-through cache.
type passThroughCache[TValue any] struct{}
// Check passThroughCache implements Cache interface
var _ Cache[any] = (*passThroughCache[any])(nil)
// newPassThroughCache creates a new pass-through cache.
func newPassThroughCache[TValue any]() *passThroughCache[TValue] {
return &passThroughCache[TValue]{}
}
func (passThroughCache[TValue]) Get(TKey) (v TValue, ok bool) { return }
func (passThroughCache[TValue]) Put(TKey, TValue) {}
func (passThroughCache[TValue]) Delete(TKey) {}
// Random Cache
// randomCache provides a cache that evicts elements randomly.
type randomCache[TValue any] struct {
cache map[TKey]TValue
maxSize int
}
// Check randomCache implements Cache interface
var _ Cache[any] = (*randomCache[any])(nil)
// newRandomCache creates a new instance of the randomCache struct, that can
// hold up to maxSize elements.
//
// Setting the maxSize to 0 or less will disable the eviction of elements
// from the cache.
func newRandomCache[TValue any](maxSize int) *randomCache[TValue] {
return &randomCache[TValue]{
cache: make(map[TKey]TValue),
maxSize: maxSize,
}
}
func (c *randomCache[TValue]) Get(key TKey) (value TValue, ok bool) {
value, ok = c.cache[key]
return value, ok
}
func (c *randomCache[TValue]) Put(key TKey, value TValue) {
if c.maxSize <= 0 || len(c.cache) < c.maxSize {
c.cache[key] = value
return
}
// Remove any key and save the new one
for k := range c.cache {
delete(c.cache, k)
break
}
c.cache[key] = value
}
func (c *randomCache[TValue]) Delete(key TKey) {
delete(c.cache, key)
}
package sdb
import (
"bytes"
"encoding/gob"
"fmt"
"io/fs"
"path/filepath"
)
type metadata struct {
Version string
TotalEntries uint64
Generation uint64
Checkpoint uint64
}
func makeMetadata() metadata {
return metadata{
Version: version,
TotalEntries: 0,
Generation: 0,
Checkpoint: 0,
}
}
func (m *metadata) Validate() error {
if m.Version != version {
return fmt.Errorf("version mismatch: expected %s, got %s",
version, m.Version)
}
return nil
}
type metadataStore struct {
fs fileSystem
root string // absolute path to DB root
writer *atomicWriter
marshalFn func(v any) ([]byte, error)
}
func newMetadataStore(fsys fileSystem, root string) *metadataStore {
return &metadataStore{
fs: fsys,
root: root,
writer: newAtomicWriter(fsys, false),
marshalFn: gobEncode,
}
}
func (s *metadataStore) FilePath() string {
return filepath.Join(s.root, metadataDirectory, metadataFilename)
}
func (s *metadataStore) Load() (metadata, error) {
var m metadata
data, err := fs.ReadFile(s.fs, s.FilePath())
if err != nil {
return m, fmt.Errorf("read file: %w", err)
}
return s.unmarshal(data)
}
func (s *metadataStore) Save(m metadata) error {
data, err := s.marshal(m)
if err != nil {
return fmt.Errorf("marshal metadata: %w", err)
}
return s.writer.WriteFile(s.FilePath(), data, false)
}
func (s *metadataStore) marshal(m metadata) ([]byte, error) {
return s.marshalFn(m)
}
func gobEncode(v any) ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(v)
return buf.Bytes(), err
}
func (*metadataStore) unmarshal(data []byte) (metadata, error) {
dec := gob.NewDecoder(bytes.NewReader(data))
var m metadata
err := dec.Decode(&m)
return m, err
}
package sdb
import (
"time"
"github.com/lucmq/go-shelve/sdb/internal"
)
// Option is passed to the Open function to create a customized DB.
type Option func(*DB)
// WithCacheSize sets the size of the cache used by the database. A value of -1
// represents an unlimited cache and a value of 0 disables the cache. The
// default cache size is -1.
func WithCacheSize(size int64) Option {
return func(db *DB) {
db.cache = internal.NewCache[cacheEntry](int(size))
}
}
// WithSynchronousWrites enables synchronous writes to the database. By default,
// synchronous writes are disabled.
func WithSynchronousWrites(sync bool) Option {
return func(db *DB) {
db.syncWrites = sync
}
}
// withMaxFilesPerShard returns an Option that limits how many regular data
// files may reside in a single shard directory before SDB triggers a split.
//
// Currently, it is intended for tests: using a very small threshold speeds up
// exercises that need to observe shard-splitting behaviour.
func withMaxFilesPerShard(maxFilesPerShard int64) Option {
return func(db *DB) {
db.maxFilesPerShard = maxFilesPerShard
}
}
// withSyncInterval returns an Option that overrides the background metadata
// sync interval. Some tests shorten the default (one minute) to stress
// time-dependent logic.
func withSyncInterval(d time.Duration) Option {
return func(db *DB) {
db.syncInterval = d
}
}
// withFileSystem returns an Option that injects a custom fileSystem
// implementation (e.g. an in-memory or fault-injecting mock). Production code
// normally relies on the default os-backed implementation; tests can supply a
// mock to avoid touching the real disk or to verify error-handling paths.
func withFileSystem(fsys fileSystem) Option {
return func(db *DB) {
db.fs = fsys
}
}
package sdb
import (
"fmt"
"path/filepath"
)
// Recover the database from a corrupted state, detected at initialization.
func recoverDatabase(db *DB) error {
// Note: We have the following:
// - The DB design is simple, and all operations require at most one file
// mutation.
// - The metadata is stored in a single file, and all operations require
// at most one file mutation.
// - Currently, the only thing that can get corrupted is the metadata, in
// particular, the metadata.TotalEntries count.
//
// Thus, the recovery process can be limited to counting the number of
// files in the data folder and updating the metadata.
dataRoot := filepath.Join(db.path, dataDirectory)
totalItems, err := countItems(db.fs, dataRoot)
if err != nil {
return fmt.Errorf("count items: %w", err)
}
db.metadata.TotalEntries = totalItems
db.metadata.Checkpoint = db.metadata.Generation
return db.metadataStore.Save(db.metadata)
}
func countItems(fsys fileSystem, path string) (uint64, error) {
// Each database record is represented by a regular file.
return countRegularFiles(fsys, path)
}
package sdb
import (
"fmt"
"os"
"path/filepath"
"sort"
)
const (
// This limit is arbitrary and chosen to balance between:
// - the cost of creating many directories
// - the cost of walking a large number of files
defaultMaxFilesPerShard = 10_000
// The sentinelDir is a special directory that is guaranteed to have a
// higher name than any other directory. It is created by the db and is
// used to simplify the logic of sharding.
sentinelDir = "_"
)
type shard struct {
maxKey string // upper bound, inclusive
count uint32
}
// shardForKey returns the index whose *upper* bound >= encStart.
func (db *DB) shardForKey(enc string) int {
return sort.Search(len(db.shards), func(j int) bool { return enc <= db.shards[j].maxKey })
}
func (db *DB) shardPath(i int) string {
return filepath.Join(db.path, dataDirectory, db.shards[i].maxKey)
}
func (db *DB) loadShards() error {
names, err := readdirnames(db.fs, filepath.Join(db.path, dataDirectory))
if err != nil {
return fmt.Errorf("read data dir: %w", err)
}
sort.Slice(names, func(i, j int) bool { return names[i] < names[j] })
db.shards = make([]shard, len(names))
for i, name := range names {
shardEntries, err := readdirnames(
db.fs,
filepath.Join(db.path, dataDirectory, name),
)
if err != nil {
return fmt.Errorf("read shard dir: %w", err)
}
db.shards[i] = shard{
maxKey: name,
count: uint32(len(shardEntries)),
}
}
return nil
}
// splitShard splits shard `idx` in two. It moves the _lower_ half of shard
// `idx` into a freshly-created directory whose name is the *highest* key that
// stays inside that new shard (`names[mid-1]`). The original directory keeps
// the upper half unchanged.
func (db *DB) splitShard(idx int) error {
// 1. Enumerate & sort entries in the *old* directory.
oldPath := db.shardPath(idx)
files, err := readdirnames(db.fs, oldPath)
if err != nil {
return fmt.Errorf("read shard dir: %w", err)
}
sort.Slice(files, func(i, j int) bool { return files[i] < files[j] })
mid := len(files) / 2
lowerHalf := files[:mid]
newLowMax := files[mid-1]
newPath := filepath.Join(db.path, dataDirectory, newLowMax)
// 2. Create the new directory and move the lower-half files into it.
if err = db.fs.MkdirAll(newPath, defaultDirPermissions); err != nil && !os.IsExist(err) {
return fmt.Errorf("mkdir: %w", err)
}
for _, e := range lowerHalf {
if err = db.fs.Rename(
filepath.Join(oldPath, e),
filepath.Join(newPath, e),
); err != nil {
return fmt.Errorf("rename: %w", err)
}
}
// 3. Update the in-memory shard slice.
updateSplitShards(db, idx, files)
// 4. Sync the parent directory.
if db.syncWrites {
// Sync the parent directory for more durability guarantees. See:
// - https://lwn.net/Articles/457667/#:~:text=When%20should%20you%20Fsync
_ = syncFile(db.fs, newPath)
}
return nil
}
// updateSplitShards updates the in-memory shard slice after a shard has been
// split: it inserts a new shard in the middle, and updates the two Counts so
// that the next split will happen at the correct boundary.
//
// The files argument must be sorted by file name.
func updateSplitShards(db *DB, idx int, files []string) {
mid := len(files) / 2
lowerHalf := files[:mid]
upperHalf := files[mid:]
newLowMax := files[mid-1]
// make room for one more element (shift right)
db.shards = append(db.shards, shard{})
copy(db.shards[idx+1:], db.shards[idx:])
// fill the freshly-created slot
db.shards[idx] = shard{maxKey: newLowMax, count: uint32(len(lowerHalf))}
// fix the old shard’s count (it’s now the *upper* shard)
db.shards[idx+1].count = uint32(len(upperHalf))
}
package shelve
import (
"bytes"
"encoding"
"encoding/gob"
"encoding/hex"
"encoding/json"
"fmt"
"reflect"
"strconv"
)
// Codec is the interface for encoding and decoding data stored by Shelf.
//
// The go-shelve module natively supports the following codecs:
// - [GobCodec]: Returns a Codec for the [gob] format.
// - [JSONCodec]: Returns a Codec for the [json] format.
// - [TextCodec]: Returns a Codec for values that can be represented as
// plain text.
//
// Additional codecs are provided by the packages in [driver/encoding].
//
// [driver/encoding]: https://pkg.go.dev/github.com/lucmq/go-shelve/driver/encoding
type Codec interface {
// Encode returns the Codec encoding of v as a byte slice.
Encode(v any) ([]byte, error)
// Decode parses the encoded data and stores the result in the value
// pointed to by v. It is the inverse of Encode.
Decode(data []byte, v any) error
}
// GobCodec Returns a Codec for the [gob] format, a self-describing
// serialization format native to Go.
//
// Gob is a binary format and is more compact than text-based formats like
// JSON.
func GobCodec() Codec { return gobCodec{} }
// JSONCodec Returns a Codec for the JSON format.
func JSONCodec() Codec { return jsonCodec{} }
// TextCodec Returns a Codec for values that can be represented as plain text.
func TextCodec() Codec { return textCodec{} }
// Gob Codec
type gobCodec struct{}
func (gobCodec) Encode(value any) ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(value)
if err != nil {
return nil, fmt.Errorf("encode gob: %w", err)
}
return buf.Bytes(), nil
}
func (gobCodec) Decode(data []byte, value any) error {
dec := gob.NewDecoder(bytes.NewReader(data))
err := dec.Decode(value)
if err != nil {
return fmt.Errorf("decode gob: %w", err)
}
return nil
}
// Json Codec
type jsonCodec struct{}
func (jsonCodec) Encode(value any) ([]byte, error) {
return json.MarshalIndent(value, "", " ")
}
func (jsonCodec) Decode(data []byte, value any) error {
return json.Unmarshal(data, value)
}
// Text Codec
// textCodec encodes scalar values, fixed-size byte arrays, and types that
// implement encoding.TextMarshaler.
// It supports strings, booleans, integers, floats, and [N]byte arrays (encoded
// as hex).
type textCodec struct{}
func (textCodec) Encode(value any) ([]byte, error) {
switch v := value.(type) {
case string:
return []byte(v), nil
case int:
return []byte(strconv.Itoa(v)), nil
case int8:
return []byte(strconv.FormatInt(int64(v), 10)), nil
case int16:
return []byte(strconv.FormatInt(int64(v), 10)), nil
case int32:
return []byte(strconv.FormatInt(int64(v), 10)), nil
case int64:
return []byte(strconv.FormatInt(v, 10)), nil
case uint:
return []byte(strconv.FormatUint(uint64(v), 10)), nil
case uint8:
return []byte(strconv.FormatUint(uint64(v), 10)), nil
case uint16:
return []byte(strconv.FormatUint(uint64(v), 10)), nil
case uint32:
return []byte(strconv.FormatUint(uint64(v), 10)), nil
case uint64:
return []byte(strconv.FormatUint(v, 10)), nil
case float32:
return []byte(strconv.FormatFloat(float64(v), 'g', -1, 32)), nil
case float64:
return []byte(strconv.FormatFloat(v, 'g', -1, 64)), nil
case bool:
return []byte(strconv.FormatBool(v)), nil
case encoding.TextMarshaler:
return v.MarshalText()
default:
if encoded, ok := encodeFixedByteArray(value); ok {
return encoded, nil
}
return nil, fmt.Errorf("textCodec: unsupported type %T", value)
}
}
func (textCodec) Decode(data []byte, value any) error {
str := string(data)
switch v := value.(type) {
case *string:
*v = str
return nil
case *int:
i, err := strconv.Atoi(str)
*v = i
return err
case *int8:
i, err := strconv.ParseInt(str, 10, 8)
*v = int8(i)
return err
case *int16:
i, err := strconv.ParseInt(str, 10, 16)
*v = int16(i)
return err
case *int32:
i, err := strconv.ParseInt(str, 10, 32)
*v = int32(i)
return err
case *int64:
i, err := strconv.ParseInt(str, 10, 64)
*v = i
return err
case *uint:
u, err := strconv.ParseUint(str, 10, 0)
*v = uint(u)
return err
case *uint8:
u, err := strconv.ParseUint(str, 10, 8)
*v = uint8(u)
return err
case *uint16:
u, err := strconv.ParseUint(str, 10, 16)
*v = uint16(u)
return err
case *uint32:
u, err := strconv.ParseUint(str, 10, 32)
*v = uint32(u)
return err
case *uint64:
u, err := strconv.ParseUint(str, 10, 64)
*v = u
return err
case *float32:
f, err := strconv.ParseFloat(str, 32)
*v = float32(f)
return err
case *float64:
f, err := strconv.ParseFloat(str, 64)
*v = f
return err
case *bool:
b, err := strconv.ParseBool(str)
*v = b
return err
default:
if u, ok := value.(encoding.TextUnmarshaler); ok {
return u.UnmarshalText(data)
}
if err := decodeFixedByteArray(data, value); err == nil {
return nil
}
return fmt.Errorf("textCodec: unsupported decode target %T", value)
}
}
func encodeFixedByteArray(v any) ([]byte, bool) {
val := reflect.ValueOf(v)
if val.Kind() == reflect.Array && val.Type().Elem().Kind() == reflect.Uint8 {
n := val.Len()
buf := make([]byte, n)
for i := 0; i < n; i++ {
buf[i] = byte(val.Index(i).Uint())
}
return []byte(hex.EncodeToString(buf)), true
}
return nil, false
}
func decodeFixedByteArray(data []byte, out any) error {
val := reflect.ValueOf(out)
if val.Kind() != reflect.Ptr || val.Elem().Kind() != reflect.Array || val.Elem().Type().Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("unsupported decode target: %T", out)
}
arr := val.Elem()
expectedLen := arr.Len()
decoded, err := hex.DecodeString(string(data))
if err != nil {
return fmt.Errorf("hex decode failed: %w", err)
}
if len(decoded) != expectedLen {
return fmt.Errorf("invalid hex length: got %d bytes, want %d", len(decoded), expectedLen)
}
for i := 0; i < expectedLen; i++ {
arr.Index(i).SetUint(uint64(decoded[i]))
}
return nil
}
// Package shelve provides a persistent, map-like object called Shelf. It lets you
// store and retrieve Go objects directly, with the serialization and storage handled
// automatically by the Shelf. Additionally, you can customize the underlying
// key-value storage and serialization codec to better suit your application's needs.
//
// This package is inspired by the `shelve` module from the Python standard library
// and aims to provide a similar set of functionalities.
//
// By default, a Shelf serializes data using the JSON format and stores it using `sdb`
// (for "shelve-db"), a simple key-value storage created for this project. This
// database should be good enough for a broad range of applications, but the modules
// in [go-shelve/driver] provide additional options for configuring the `Shelf` with
// other databases and Codecs.
//
// [go-shelve/driver]: https://pkg.go.dev/github.com/lucmq/go-shelve/driver
package shelve
import (
"encoding"
"fmt"
"reflect"
"github.com/lucmq/go-shelve/sdb"
)
const (
// Asc and Desc can be used with the Shelf.Items method to make the
// iteration order ascending or descending respectively.
//
// They are just syntactic sugar to make the iteration order more
// explicit.
Asc = 1
// Desc is the opposite of Asc.
Desc = -1
// All can be used with the Shelf.Items method to iterate over all
// items in the database. It is the same as the -1 value.
All = -1
)
// Yield is a function called when iterating over key-value pairs in the
// Shelf. If Yield returns false or an error, the iteration stops.
type Yield[K, V any] func(key K, value V) (bool, error)
// A Shelf is a persistent, map-like object. It is used together with an
// underlying key-value storage to store Go objects directly.
//
// Stored values can be of arbitrary types, but the keys must be comparable.
//
// By default, values are encoded using the JSON codec, and keys using the
// TextCodec.
//
// For storage, the underlying database is an instance of the [sdb.DB]
// ("shelve-db") key-value store.
//
// The underlying storage and codec Shelf uses can be configured with the
// [Option] functions.
type Shelf[K comparable, V any] struct {
db DB
codec Codec
keyCodec Codec
}
// Option is passed to the Open function to create a customized Shelf.
type Option func(any)
type options struct {
DB DB
Codec Codec
KeyCodec Codec
}
// WithDatabase specifies the underlying database to use. By default, the
// [sdb.DB] ("shelve-db") key-value storage is used.
//
// The packages in [driver/db] packages provide support for others databases in
// the Go ecosystem, like [Bolt] and [Badger].
//
// [driver/db]: https://pkg.go.dev/github.com/lucmq/go-shelve/driver/db
// [Bolt]: https://pkg.go.dev/github.com/etcd-io/bbolt
// [Badger]: https://pkg.go.dev/github.com/dgraph-io/badger
func WithDatabase(db DB) Option {
return func(v any) {
opt := v.(*options)
opt.DB = db
}
}
// WithCodec specifies the Codec to use. By default, a codec for the JSON format
// is used.
//
// Additional Codecs can be found in the packages in [driver/encoding].
//
// [driver/encoding]: https://pkg.go.dev/github.com/lucmq/go-shelve/driver/encoding
func WithCodec(c Codec) Option {
return func(v any) {
opt := v.(*options)
opt.Codec = c
}
}
// WithKeyCodec specifies the Codec to use for encoding keys.
// By default, keys of type string, boolean, integer (signed or unsigned),
// float, [N]byte arrays (e.g., [12]byte), or types that implement
// [encoding.TextMarshaler] are encoded using [TextCodec].
//
// Additional Codecs can be found in the packages in [driver/encoding].
//
// [driver/encoding]: https://pkg.go.dev/github.com/lucmq/go-shelve/driver/encoding
func WithKeyCodec(c Codec) Option {
return func(v any) {
opt := v.(*options)
opt.KeyCodec = c
}
}
// Open creates a new Shelf.
//
// The path parameter specifies the filesystem path to the database files. It
// can be a directory or a regular file, depending on the underlying database
// implementation. With the default database [sdb.DB], it will point to a
// directory.
func Open[K comparable, V any](path string, opts ...Option) (
*Shelf[K, V],
error,
) {
var k K
keyCodec, err := defaultKeyCodec(k)
if err != nil {
return nil, err
}
o := options{
Codec: JSONCodec(),
KeyCodec: keyCodec,
}
for _, option := range opts {
option(&o)
}
if o.DB == nil {
db, err := sdb.Open(path)
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
o.DB = db
}
return &Shelf[K, V]{
db: o.DB,
codec: o.Codec,
keyCodec: o.KeyCodec,
}, nil
}
// Close synchronizes and closes the Shelf.
func (s *Shelf[K, V]) Close() error {
return s.db.Close()
}
// Len returns the number of items in the Shelf. It returns the number
// of items as an int64. If an error occurs, it returns -1.
func (s *Shelf[K, V]) Len() int64 {
return s.db.Len()
}
// Sync synchronizes the Shelf contents to persistent storage.
func (s *Shelf[K, V]) Sync() error {
return s.db.Sync()
}
// Has reports whether a key exists in the Shelf.
func (s *Shelf[K, V]) Has(key K) (bool, error) {
data, err := s.keyCodec.Encode(key)
if err != nil {
return false, fmt.Errorf("encode: %w", err)
}
ok, err := s.db.Has(data)
if err != nil {
return false, fmt.Errorf("has: %w", err)
}
return ok, nil
}
// Get retrieves the value associated with a key from the Shelf. If the key is
// not found, it returns nil.
func (s *Shelf[K, V]) Get(key K) (value V, ok bool, err error) {
data, err := s.keyCodec.Encode(key)
if err != nil {
return *new(V), false, fmt.Errorf("encode: %w", err)
}
vData, err := s.db.Get(data)
if err != nil {
return *new(V), false, fmt.Errorf("get: %w", err)
}
if vData == nil {
return *new(V), false, nil
}
var v V
err = s.codec.Decode(vData, &v)
return v, true, err
}
// Put adds a key-value pair to the Shelf. If the key already exists, it
// overwrites the existing value.
func (s *Shelf[K, V]) Put(key K, value V) error {
data, err := s.keyCodec.Encode(key)
if err != nil {
return fmt.Errorf("encode key: %w", err)
}
vData, err := s.codec.Encode(value)
if err != nil {
return fmt.Errorf("encode value: %w", err)
}
err = s.db.Put(data, vData)
if err != nil {
return fmt.Errorf("put: %w", err)
}
return nil
}
// Delete removes a key-value pair from the Shelf.
func (s *Shelf[K, V]) Delete(key K) error {
data, err := s.keyCodec.Encode(key)
if err != nil {
return fmt.Errorf("encode: %w", err)
}
err = s.db.Delete(data)
if err != nil {
return fmt.Errorf("delete: %w", err)
}
return nil
}
// Items iterates over key-value pairs in the Shelf, calling fn(k, v) for each
// pair in the sequence. The iteration stops early if the function fn returns
// false.
//
// The start parameter specifies the key from which the iteration should start.
// If the start parameter is nil, the iteration will begin from the first key
// in the Shelf.
//
// The n parameter specifies the maximum number of items to iterate over. If n
// is All (-1) or less, all items will be iterated.
//
// The step parameter specifies the number of items to skip between each
// iteration. A negative value for step will cause the iteration to occur in
// reverse order.
//
// When iterating over key-value pairs in a Shelf, the order of iteration may
// not be sorted. Some database implementations may ignore the start parameter
// or not support iteration in reverse order.
//
// The default Shelf database (sdb.DB) yields items in deterministic lexical
// order and honours the start key. If the exact start key does not exist:
// - Asc: position at the first key > start. If all keys < start, the
// iterator is empty.
// - Desc: position at the last key < start. If all keys > start, the
// iterator is empty.
func (s *Shelf[K, V]) Items(start *K, n, step int, fn Yield[K, V]) error {
dbFn := func(k, v []byte) (bool, error) {
var key K
var value V
err := s.keyCodec.Decode(k, &key)
if err != nil {
return false, fmt.Errorf("decode key: %w", err)
}
if len(v) != 0 {
err = s.codec.Decode(v, &value)
if err != nil {
return false, fmt.Errorf("decode value: %w", err)
}
}
return fn(key, value)
}
return s.iterate(start, n, step, dbFn)
}
// Keys iterates over all keys in the Shelf and calls the user-provided
// function fn for each key. The details of the iteration are the same as
// for [Shelf.Items].
//
// The value parameter for fn will always be the zero value for the type V.
func (s *Shelf[K, V]) Keys(start *K, n, step int, fn Yield[K, V]) error {
dbFn := func(k, _ []byte) (bool, error) {
var key K
var zero V
err := s.keyCodec.Decode(k, &key)
if err != nil {
return false, fmt.Errorf("decode: %w", err)
}
return fn(key, zero)
}
return s.iterate(start, n, step, dbFn)
}
// Values iterates over all values in the Shelf and calls the user-provided
// function fn for each value. The details of the iteration are the same as
// for [Shelf.Items].
//
// The key parameter for fn will always be the zero value for the type K.
func (s *Shelf[K, V]) Values(start *K, n, step int, fn Yield[K, V]) error {
dbFn := func(_, v []byte) (bool, error) {
var zero K
var value V
err := s.codec.Decode(v, &value)
if err != nil {
return false, fmt.Errorf("decode: %w", err)
}
return fn(zero, value)
}
return s.iterate(start, n, step, dbFn)
}
func (s *Shelf[K, V]) iterate(
start *K,
n, step int,
fn func(k, v []byte) (bool, error),
) error {
var from []byte = nil
var err error
if start != nil {
from, err = s.keyCodec.Encode(*start)
if err != nil {
return fmt.Errorf("encode start: %w", err)
}
}
var order int
if step > 0 {
order = Asc
} else if step < 0 {
order = Desc
step = -step
} else {
return nil
}
var total int
var counter = step - 1 // 0, 1, ..., step - 1
return s.db.Items(from, order, func(k, v []byte) (bool, error) {
if n > 0 && total >= n {
return false, nil
}
// Increase counter until the step is reached
if counter < step-1 {
counter++
return true, nil
}
counter = 0
total++
return fn(k, v)
})
}
// Helpers
func defaultKeyCodec(key any) (Codec, error) {
switch key.(type) {
case string,
bool,
int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
float32, float64:
return TextCodec(), nil
default:
// Support TextMarshaler types
if _, ok := key.(encoding.TextMarshaler); ok {
return TextCodec(), nil
}
// Handle [N]byte arrays via reflection
val := reflect.ValueOf(key)
if val.Kind() == reflect.Array && val.Type().Elem().Kind() == reflect.Uint8 {
return TextCodec(), nil
}
return nil, fmt.Errorf("unsupported key type %T: must explicitly set a key codec", key)
}
}