// Package cli provides the command-line interface for the batterdb application.
// It includes functionality for starting the server, outputting the OpenAPI specification,
// and handling various command-line options such as setting the port, enabling HTTPS,
// and persisting the database to disk.
//
// The package uses the kong library for command-line parsing and integrates with the
// handlers package to manage the server service.
package cli
import (
"context"
"errors"
"io"
"log/slog"
"net/http"
"os"
"runtime/debug"
"github.com/alecthomas/kong"
"github.com/jh125486/batterdb/handlers"
)
// Ctx represents the context for the CLI commands, including build information,
// service instance, writer for output, and a channel to handle OS signals.
type Ctx struct {
*debug.BuildInfo
service *handlers.Service
io.Writer
Stop chan os.Signal
}
//nolint:govet // Order of fields must be maintained for CLI help output.
type (
// CLI defines the structure for the command-line interface, including options
// for port, persistence, repository file, and HTTPS, as well as commands for
// starting the server and outputting the OpenAPI specification.
CLI struct {
Port int32 `short:"p" default:"1205" help:"Port to listen on."`
Store bool `short:"s" help:"Persist the database to disk."`
RepoFile string ` default:"${RepoFile}" help:"The file to persist the database to."`
Secure bool `short:"S" help:"Enable HTTPS."`
Server ServerCmd `default:"1" help:"Start the server." cmd:""`
OpenAPI OpenAPICmd `help:"Output the OpenAPI specification version." cmd:"" optional:""`
Version kong.VersionFlag `short:"v" help:"Show version."`
}
// ServerCmd represents the command to start the server.
ServerCmd struct {
}
// OpenAPICmd represents the command to output the OpenAPI specification.
OpenAPICmd struct {
Spec string `default:"3.1" help:"OpenAPI specification version." enum:"3.1,3.0.3"`
}
)
// New initializes and parses the command-line arguments.
func New(args []string, opts ...kong.Option) (*kong.Context, error) {
var cli CLI
k, err := kong.New(&cli, opts...)
if err != nil {
return nil, err
}
return k.Parse(args)
}
// Validate validates the command-line options.
func (cmd *CLI) Validate() error {
if cmd.Port < 1 || cmd.Port > 65535 {
return errors.New("port must be between 1 and 65535")
}
return nil
}
// AfterApply applies the command-line options to the context and initializes the service.
func (cmd *CLI) AfterApply(ctx *Ctx) error {
ctx.service = handlers.New(
handlers.WithBuildInfo(ctx.BuildInfo),
handlers.WithPort(cmd.Port),
handlers.WithPersistDB(cmd.Store),
handlers.WithRepoFile(cmd.RepoFile),
handlers.WithSecure(cmd.Secure),
)
return nil
}
// Run starts the server service in a goroutine and blocks until an OS signal is received,
// then it initiates a graceful shutdown of the service.
func (cmd *ServerCmd) Run(ctx *Ctx) error {
// Run the service in a goroutine so that it doesn't block.
go func() {
if err := ctx.service.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) {
slog.Error("failed to start service", slog.String("err", err.Error()))
os.Exit(1)
}
}()
// Block until we receive our signal.
<-ctx.Stop
// Begin graceful shutdown.
return ctx.service.Shutdown(context.Background())
}
// Run outputs the OpenAPI specification to the context writer.
func (cmd *OpenAPICmd) Run(ctx *Ctx) error {
_, err := ctx.Write(ctx.service.OpenAPI(cmd.Spec))
return err
}
package text
import (
"encoding"
"fmt"
"io"
"github.com/danielgtaylor/huma/v2"
)
// DefaultTextFormat is the default text formatter that can be set in the API's
// `Config.Formats` map. This is usually not needed as importing this package
// automatically adds the text format to the default formats.
//
// config := huma.Config{}
// config.Formats = map[string]huma.Format{
// "plain/text": huma.DefaultTextFormat,
// "text": huma.DefaultTextFormat,
// }
func DefaultTextFormat() huma.Format {
return huma.Format{
Marshal: func(w io.Writer, v any) error {
if m, ok := v.(encoding.TextMarshaler); ok {
b, err := m.MarshalText()
if err != nil {
return err
}
_, err = w.Write(b)
return err
}
_, err := fmt.Fprint(w, v)
return err
},
Unmarshal: func(data []byte, v any) error {
if m, ok := v.(encoding.TextUnmarshaler); ok {
return m.UnmarshalText(data)
}
return huma.Error501NotImplemented("text format not supported")
},
}
}
package yaml
import (
"io"
"github.com/danielgtaylor/huma/v2"
"gopkg.in/yaml.v3"
)
// DefaultYAMLFormat is the default YAML formatter that can be set in the API's
// `Config.Formats` map. This is usually not needed as importing this package
// automatically adds the text format to the default formats.
//
// config := huma.Config{}
// config.Formats = map[string]huma.Format{
// "application/yaml": huma.DefaultYAMLFormat,
// "yaml": huma.DefaultYAMLFormat,
// }
func DefaultYAMLFormat() huma.Format {
return huma.Format{
Marshal: func(w io.Writer, v any) error {
return yaml.NewEncoder(w).Encode(v)
},
Unmarshal: yaml.Unmarshal,
}
}
// Package handlers provides HTTP handlers for managing database operations
// within the batterdb application. The handlers include listing all databases,
// showing a specific database, creating a new database, and deleting an existing
// database.
//
// The package utilizes the huma framework for handling HTTP requests and responses,
// and interacts with the repository package to perform database operations.
package handlers
import (
"context"
"errors"
"github.com/danielgtaylor/huma/v2"
"github.com/jh125486/batterdb/repository"
)
type (
// DatabasesOutput represents the output structure for the ListDatabasesHandler.
// It contains a list of databases and the total number of databases.
DatabasesOutput struct {
Body struct {
Databases []Database `json:"databases"`
NumberOfDatabases int `json:"number_of_databases"`
}
}
// Database represents the structure of a single database, including its ID, name,
// and the number of stacks it contains.
Database struct {
ID string `json:"id"`
Name string `json:"name"`
NumberOfStacks int `json:"number_of_stacks"`
}
)
// ListDatabasesHandler handles the request to list all databases.
// It retrieves the list of databases from the repository, sorts them for stability,
// and returns the list along with the total number of databases.
func (s *Service) ListDatabasesHandler(_ context.Context, _ *struct{}) (*DatabasesOutput, error) {
out := new(DatabasesOutput)
out.Body.NumberOfDatabases = s.Repository.Len()
out.Body.Databases = make([]Database, 0, out.Body.NumberOfDatabases)
for _, db := range s.Repository.SortDatabases() {
out.Body.Databases = append(out.Body.Databases, Database{
ID: db.ID.String(),
Name: db.Name,
NumberOfStacks: db.Len(),
})
}
return out, nil
}
type (
// URLParamDatabaseID represents the URL parameter for a database ID, which can
// be either the database ID or name.
URLParamDatabaseID struct {
DatabaseID string `doc:"can be the database ID or name" path:"database"`
}
// SingleDatabaseInput represents the input structure for the ShowDatabaseHandler
// and DeleteDatabaseHandler, containing the database ID.
SingleDatabaseInput struct {
URLParamDatabaseID
}
// DatabaseOutput represents the output structure for the ShowDatabaseHandler,
// containing the details of a single database.
DatabaseOutput struct {
Body Database
}
)
// ShowDatabaseHandler handles the request to show the details of a specific database.
// It retrieves the database from the repository and returns its details.
func (s *Service) ShowDatabaseHandler(_ context.Context, input *SingleDatabaseInput) (*DatabaseOutput, error) {
db, err := s.database(input.DatabaseID)
if err != nil {
return nil, err
}
out := new(DatabaseOutput)
out.Body = Database{
ID: db.ID.String(),
Name: db.Name,
NumberOfStacks: db.Len(),
}
return out, nil
}
type (
// CreateDatabaseInput represents the REST input request for the CreateDatabaseHandler.
CreateDatabaseInput struct {
Name string `minLength:"7" query:"name" required:"true"`
}
// CreateDatabaseOutput represents the REST output response for the CreateDatabaseHandler.
CreateDatabaseOutput struct {
Body Database
}
)
// CreateDatabaseHandler handles the request to create a new database.
// It creates the database in the repository and returns its details.
// If the database already exists, it returns a conflict error.
func (s *Service) CreateDatabaseHandler(_ context.Context, input *CreateDatabaseInput) (*CreateDatabaseOutput, error) {
db, err := s.Repository.New(input.Name)
if errors.Is(err, repository.ErrAlreadyExists) {
return nil, huma.Error409Conflict("database already exists", err)
}
return &CreateDatabaseOutput{
Body: Database{
ID: db.ID.String(),
Name: db.Name,
NumberOfStacks: db.Len(),
},
}, nil
}
// DeleteDatabaseHandler handles the request to delete an existing database.
// It removes the database from the repository.
// If the database is not found, it returns a not found error.
func (s *Service) DeleteDatabaseHandler(_ context.Context, input *SingleDatabaseInput) (*struct{}, error) {
if err := s.Repository.Drop(input.DatabaseID); err != nil {
return nil, huma.Error404NotFound("database not found", err)
}
return nil, nil
}
// database retrieves a database from the repository by its ID. If the database
// is not found, it returns a not found error.
func (s *Service) database(dbID string) (*repository.Database, error) {
db, err := s.Repository.Database(dbID)
if err != nil {
return nil, huma.Error404NotFound("database not found", err)
}
return db, nil
}
// Package handlers provides HTTP handlers for the batterdb application, including
// endpoints for retrieving application status and handling ping requests.
//
// The package utilizes the Go standard library and external libraries for handling
// HTTP requests and responses, as well as for managing application status information.
package handlers
import (
"context"
"net/http"
"runtime"
"time"
"github.com/alecthomas/units"
"github.com/ccoveille/go-safecast"
)
type (
// StatusOutput represents the output structure for the StatusHandler.
// It contains detailed information about the application status.
StatusOutput struct {
Body StatusBody
}
// StatusBody represents the structure of the status information, including
// the start time, status code, version, Go version, host, memory allocation,
// runtime duration, process ID, and number of goroutines.
StatusBody struct {
StartedAt time.Time `json:"started_at" yaml:"startedAt"`
Code string `json:"status" yaml:"code"`
Version string `json:"version" yaml:"version"`
GoVersion string `json:"go_version" yaml:"goVersion"`
Host string `json:"host" yaml:"host"`
MemoryAlloc string `json:"memory_alloc" yaml:"memoryAlloc"`
RunningFor float64 `json:"running_for" yaml:"runningFor"`
PID int `json:"pid" yaml:"pid"`
NumberGoroutines int `json:"number_goroutines" yaml:"numberGoroutines"`
}
)
// StatusHandler handles the request to retrieve the application status.
// It gathers various runtime statistics and returns them in the response.
func (s *Service) StatusHandler(_ context.Context, _ *struct{}) (*StatusOutput, error) {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
allocs, err := safecast.ToInt64(mem.Alloc)
if err != nil {
return nil, err
}
out := new(StatusOutput)
out.Body.Code = http.StatusText(http.StatusOK)
out.Body.Version = s.buildInfo.Main.Version
out.Body.GoVersion = s.buildInfo.GoVersion
out.Body.Host = s.platform
out.Body.PID = s.pid
out.Body.StartedAt = s.startedAt
out.Body.RunningFor = time.Since(s.startedAt).Seconds()
out.Body.NumberGoroutines = runtime.NumGoroutine()
out.Body.MemoryAlloc = units.Base2Bytes(allocs).Round(1).String()
return out, nil
}
// PingOutput represents the output structure for the PingHandler.
// It contains a plain text response.
type PingOutput struct {
Body []byte `contentType:"text/plain"`
}
// PingHandler handles the request to check the application's availability.
// It returns a "pong" response in plain text.
func PingHandler(_ context.Context, _ *struct{}) (*PingOutput, error) {
out := new(PingOutput)
out.Body = []byte("pong")
return out, nil
}
// Package handlers provides HTTP middleware for the batterdb application,
// including logging of HTTP requests and responses.
//
// The package includes functionality to wrap HTTP handlers with additional
// behaviors, such as logging request details and response status codes.
package handlers
import (
"fmt"
"log/slog"
"net/http"
)
// loggingResponseWriter is a custom HTTP response writer that captures the status code for logging purposes.
type loggingResponseWriter struct {
http.ResponseWriter
StatusCode int
}
// WriteHeader captures the status code and calls the underlying ResponseWriter's WriteHeader method.
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.StatusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
// Write ensures that the status code is set to 200 (OK) if no status code has been set before writing the response body.
func (lrw *loggingResponseWriter) Write(b []byte) (int, error) {
if lrw.StatusCode == 0 {
lrw.StatusCode = http.StatusOK
}
return lrw.ResponseWriter.Write(b)
}
// LoggingHandler is a middleware that logs HTTP requests and their response status codes.
// It bypasses logging for WebSocket upgrade requests.
func LoggingHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if this is a WebSocket upgrade request.
if upgrade := r.Header.Get("Upgrade"); upgrade == "websocket" {
// If it is, bypass the logging and pass the request directly to the next handler.
h.ServeHTTP(w, r)
return
}
// If it's not a WebSocket upgrade request, proceed with the logging as usual.
lrw := &loggingResponseWriter{ResponseWriter: w}
h.ServeHTTP(lrw, r)
slog.Info(fmt.Sprintf("%s %v %d", r.Method, r.URL.Path, lrw.StatusCode))
})
}
// Package handlers provides the core HTTP service and middleware for the batterdb
// application. This includes setting up the HTTP server, configuring the API, handling
// secure connections, and managing the lifecycle of the service.
//
// The package utilizes the huma framework for API routing and Prometheus for metrics.
// It also supports self-signed TLS certificate generation for secure connections.
package handlers
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"log/slog"
"math/big"
"net"
"net/http"
"os"
"runtime"
"runtime/debug"
"strings"
"sync/atomic"
"time"
"github.com/alecthomas/units"
"github.com/arl/statsviz"
"github.com/ccoveille/go-safecast"
"github.com/danielgtaylor/huma/v2"
"github.com/danielgtaylor/huma/v2/adapters/humago"
_ "github.com/danielgtaylor/huma/v2/formats/cbor" // Register the CBOR format.
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/jh125486/batterdb/formats/text"
"github.com/jh125486/batterdb/formats/yaml"
"github.com/jh125486/batterdb/repository"
)
const logo = `
______ _ _ ____________
| ___ \ | | | | | _ \ ___ \
| |_/ / __ _| |_| |_ ___ _ __| | | | |_/ /
| ___ \/ _' | __| __/ _ \ '__| | | | ___ \
| |_/ / (_| | |_| || __/ | | |/ /| |_/ /
\____/ \__,_|\__|\__\___|_| |___/ \____/
`
type (
// Service represents the main service structure which holds the repository, API, server configuration,
// build information, platform details, and other service-related configurations.
Service struct {
Repository *repository.Repository
API huma.API
server *http.Server
startedAt time.Time
buildInfo *debug.BuildInfo
platform string
savefile string
port atomic.Int32
persistDB bool
secure bool
pid int
}
// Option represents a configuration option for the Service.
Option func(*Service)
)
// New creates a new instance of the Service with the provided options.
func New(opts ...Option) *Service {
// defaults.
s := &Service{
platform: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
pid: os.Getpid(),
startedAt: time.Now().UTC(),
Repository: repository.New(),
savefile: ".batterdb.gob",
}
for _, opt := range opts {
opt(s)
}
mux := http.NewServeMux()
// Crete the API with the app info, contact and formats.
s.API = humago.New(mux, config(
"BatterDB", "1.0.0", "A simple in-memory stack database.",
&huma.Contact{
Name: "Jacob Hochstetler",
URL: "https://github.com/jh125486",
Email: "jacob.hochstetler@gmail.com",
},
map[string]huma.Format{
"application/json": huma.DefaultJSONFormat,
"json": huma.DefaultJSONFormat,
"application/yaml": yaml.DefaultYAMLFormat(),
"yaml": yaml.DefaultYAMLFormat(),
"plain/text": text.DefaultTextFormat(),
"text": text.DefaultTextFormat(),
},
))
// Register Prometheus metric.
mux.Handle("/metrics", promhttp.Handler())
// Register the API routes.
s.AddRoutes(s.API)
// Register statsviz.
_ = statsviz.Register(mux)
// Create the server.
s.server = server(s.secure, mux)
return s
}
func config(title, version, description string, contact *huma.Contact, formats map[string]huma.Format) huma.Config {
schemaPrefix := "#/components/schemas/"
schemasPath := "/schemas"
registry := huma.NewMapRegistry(schemaPrefix, huma.DefaultSchemaNamer)
return huma.Config{
OpenAPI: &huma.OpenAPI{
OpenAPI: "3.1.0",
Info: &huma.Info{
Title: title,
Version: version,
Description: description,
Contact: contact,
},
Components: &huma.Components{
Schemas: registry,
},
},
OpenAPIPath: "/openapi",
DocsPath: "/docs",
SchemasPath: schemasPath,
Formats: formats,
DefaultFormat: "application/json",
CreateHooks: []func(huma.Config) huma.Config{
func(c huma.Config) huma.Config {
linkTransformer := huma.NewSchemaLinkTransformer(schemaPrefix, c.SchemasPath)
c.OnAddOperation = append(c.OnAddOperation, linkTransformer.OnAddOperation)
c.Transformers = append(c.Transformers, linkTransformer.Transform)
return c
},
},
}
}
// server creates a new HTTP server with optional TLS configuration.
func server(secure bool, mux *http.ServeMux) *http.Server {
var tlsConfig *tls.Config
if secure {
cert, err := generateSelfSignedCert()
if err != nil {
slog.Error(err.Error())
os.Exit(1)
}
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
}
return &http.Server{
Handler: LoggingHandler(mux),
TLSConfig: tlsConfig,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
MaxHeaderBytes: int(units.MiB),
}
}
// WithBuildInfo sets the build information for the Service.
func WithBuildInfo(buildInfo *debug.BuildInfo) Option {
return func(s *Service) {
s.buildInfo = buildInfo
}
}
// WithPort sets the port for the Service.
func WithPort(port int32) Option {
return func(s *Service) {
s.port.Store(port)
}
}
// WithRepoFile sets the repository file for the Service.
func WithRepoFile(repofile string) Option {
return func(s *Service) {
s.savefile = repofile
}
}
// WithSecure sets the secure flag for the Service.
func WithSecure(secure bool) Option {
return func(s *Service) {
s.secure = secure
}
}
// WithPersistDB sets the persistDB flag for the Service.
func WithPersistDB(persist bool) Option {
return func(s *Service) {
s.persistDB = persist
}
}
// AddRoutes registers the API routes for the Service.
func (s *Service) AddRoutes(api huma.API) {
s.registerMain(api)
s.registerDatabases(api)
s.registerStacks(api)
}
// Port returns the current port the Service is running on.
func (s *Service) Port() int32 { return s.port.Load() }
// Start starts the Service and listens for incoming requests.
func (s *Service) Start() error {
l, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Port()))
if err != nil {
return fmt.Errorf("failed to start listener: %w", err)
}
// Save the actual port from the listener.
port, err := safecast.ToInt32(l.Addr().(*net.TCPAddr).Port)
if err != nil {
return err
}
s.port.Store(port)
s.server.Addr = fmt.Sprintf("localhost:%d", s.port.Load())
if err := s.LoadToFile(); err != nil {
return fmt.Errorf("failed to load repository: %w", err)
}
s.loadInitMsg()
return s.serve(l)
}
// serve starts the HTTP or HTTPS server based on the secure flag.
func (s *Service) serve(l net.Listener) error {
var err error
if s.secure {
err = s.server.ServeTLS(l, "", "")
} else {
err = s.server.Serve(l)
}
if !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
// OpenAPI returns the OpenAPI spec as a string in the requested version.
func (s *Service) OpenAPI(openapi string) []byte {
switch openapi {
case "3.1":
b, _ := s.API.OpenAPI().YAML()
return b
case "3.0.3":
// Use downgrade to return OpenAPI 3.0.3 YAML since oapi-codegen doesn't
// support OpenAPI 3.1 fully yet.
b, _ := s.API.OpenAPI().DowngradeYAML()
return b
default:
return nil
}
}
// Shutdown gracefully shuts down the Service and saves the repository to file.
func (s *Service) Shutdown(ctx context.Context) error {
// Create a deadline to wait for.
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// Doesn't block if no connections, but will otherwise wait until the timeout deadline.
if err := s.server.Shutdown(ctx); err != nil {
return err
}
if err := s.SaveToFile(); err != nil {
return err
}
return nil
}
// registerMain registers the main API routes for the Service.
func (s *Service) registerMain(api huma.API) {
huma.Register(api, huma.Operation{
OperationID: "get-status",
Method: http.MethodGet,
Path: "/_status",
Summary: "Status",
Description: "Show server status.",
Tags: []string{"Main"},
}, s.StatusHandler)
huma.Register(api, huma.Operation{
OperationID: "get-ping",
Method: http.MethodGet,
Path: "/_ping",
Summary: "Ping",
Description: "Sends a ping to the server, that will answer pong if it is running.",
Tags: []string{"Main"},
}, PingHandler)
}
// registerDatabases registers the API routes for database operations.
func (s *Service) registerDatabases(api huma.API) {
huma.Register(api, huma.Operation{
OperationID: "post-database",
Method: http.MethodPost,
Path: "/databases",
Summary: "Create",
DefaultStatus: http.StatusCreated,
Description: "Create a database.",
Tags: []string{"Databases"},
}, s.CreateDatabaseHandler)
huma.Register(api, huma.Operation{
OperationID: "get-databases",
Method: http.MethodGet,
Path: "/databases",
Summary: "Databases",
Description: "Show databases.",
Tags: []string{"Databases"},
}, s.ListDatabasesHandler)
huma.Register(api, huma.Operation{
OperationID: "get-database",
Method: http.MethodGet,
Path: "/databases/{database}",
Summary: "Database",
Description: "Show a database.",
Tags: []string{"Databases"},
}, s.ShowDatabaseHandler)
huma.Register(api, huma.Operation{
OperationID: "delete-database",
Method: http.MethodDelete,
Path: "/databases/{database}",
Summary: "Delete",
Description: "Delete a database.",
Tags: []string{"Databases"},
}, s.DeleteDatabaseHandler)
}
// registerStacks registers the API routes for stack operations.
func (s *Service) registerStacks(api huma.API) {
s.registerStacksCRUD(api)
huma.Register(api, huma.Operation{
OperationID: "peek-stack",
Method: http.MethodGet,
Path: "/databases/{database}/stacks/{stack}/peek",
Summary: "Peek",
Description: "`PEEK` operation on a stack.",
Tags: []string{"Stack Operations"},
}, s.PeekDatabaseStackHandler)
huma.Register(api, huma.Operation{
OperationID: "push-stack",
Method: http.MethodPut,
Path: "/databases/{database}/stacks/{stack}",
Summary: "Push",
Description: "`PUSH` operation on a stack.",
Tags: []string{"Stack Operations"},
}, s.PushDatabaseStackHandler)
huma.Register(api, huma.Operation{
OperationID: "pop-stack",
Method: http.MethodDelete,
Path: "/databases/{database}/stacks/{stack}",
Summary: "Pop",
Description: "`POP` operation on a stack.",
Tags: []string{"Stack Operations"},
}, s.PopDatabaseStackHandler)
huma.Register(api, huma.Operation{
OperationID: "flush-stack",
Method: http.MethodDelete,
Path: "/databases/{database}/stacks/{stack}/flush",
Summary: "Flush",
Description: "`FLUSH` operation on a stack.",
Tags: []string{"Stack Operations"},
}, s.FlushDatabaseStackHandler)
}
// registerStacksCRUD registers the CRUD API routes for stacks.
func (s *Service) registerStacksCRUD(api huma.API) {
huma.Register(api, huma.Operation{
OperationID: "create-stack",
Method: http.MethodPost,
Path: "/databases/{database}/stacks",
Summary: "Create",
Description: "Create a stack from a database.",
DefaultStatus: http.StatusCreated,
Tags: []string{"Stacks"},
}, s.CreateDatabaseStackHandler)
huma.Register(api, huma.Operation{
OperationID: "get-stacks",
Method: http.MethodGet,
Path: "/databases/{database}/stacks",
Summary: "Stacks",
Description: "Show stacks of a database.",
Tags: []string{"Stacks"},
}, s.ListDatabaseStacksHandler)
huma.Register(api, huma.Operation{
OperationID: "get-stack",
Method: http.MethodGet,
Path: "/databases/{database}/stacks/{stack}",
Summary: "Stack",
Description: "Show a stack of a database.",
Tags: []string{"Stacks"},
}, s.ShowDatabaseStackHandler)
huma.Register(api, huma.Operation{
OperationID: "delete-stack",
Method: http.MethodDelete,
Path: `/databases/{database}/stacks/{stack}/nuke`,
Summary: "Delete",
Description: "Delete a stack from a database.",
Tags: []string{"Stacks"},
}, s.DeleteDatabaseStackHandler)
}
// loadInitMsg logs the initial message with service details when the service starts.
func (s *Service) loadInitMsg() {
for _, l := range strings.Split(logo, "\n") {
slog.Info(l)
}
slog.Info(fmt.Sprintf("Version: %v", s.buildInfo.Main.Version))
slog.Info(fmt.Sprintf("Go version: %v", s.buildInfo.GoVersion))
slog.Info(fmt.Sprintf("Host: %v", s.platform))
slog.Info(fmt.Sprintf("Port: %v", s.Port()))
slog.Info(fmt.Sprintf("PID: %v", s.pid))
if s.persistDB {
slog.Info(fmt.Sprintf("Loaded repo: %v", s.savefile))
slog.Info(fmt.Sprintf("Databases: %v", s.Repository.Len()))
}
baseURL := "http://" + s.server.Addr
if s.secure {
baseURL = "https://" + s.server.Addr
}
slog.Info(fmt.Sprintf("Serving: %v", baseURL))
slog.Info(fmt.Sprintf("Docs: %v/docs#/", baseURL))
slog.Info(fmt.Sprintf("Metrics: %v/metrics", baseURL))
slog.Info(fmt.Sprintf("StatsViz: %v/debug/statsviz", baseURL))
}
// SaveToFile saves the repository to a file if the persistDB flag is set.
func (s *Service) SaveToFile() error {
if !s.persistDB {
return nil
}
if err := s.Repository.Persist(s.savefile); err != nil {
return err
}
slog.Info("Repository saved to disk", slog.Int("databases", s.Repository.Len()))
return nil
}
// LoadToFile loads the repository from a file if the persistDB flag is set.
func (s *Service) LoadToFile() error {
if !s.persistDB {
return nil
}
return s.Repository.Load(s.savefile)
}
// generateSelfSignedCert generates a self-signed TLS certificate for secure connections.
func generateSelfSignedCert() (tls.Certificate, error) {
// Generate a new private key.
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}
// Create a new random serial number for the certificate.
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return tls.Certificate{}, err
}
// Create a simple certificate template.
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"github.com/jh125486"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // Valid for one year.
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
}
// Create a self-signed certificate.
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
// PEM encode the certificate and private key.
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return tls.Certificate{}, err
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
// Load the certificate and private key to create a tls.Certificate.
return tls.X509KeyPair(certPEM, keyPEM)
}
// Package handlers provides HTTP handlers for managing stacks within databases
// in the batterdb application. The handlers include operations for listing,
// creating, showing, peeking, pushing, popping, flushing, and deleting stacks.
//
// The package utilizes the huma framework for handling HTTP requests and responses,
// and interacts with the repository package to perform stack operations within databases.
package handlers
import (
"context"
"errors"
"net/http"
"time"
"github.com/danielgtaylor/huma/v2"
"github.com/jh125486/batterdb/repository"
)
type (
// StackInput represents the input structure for listing stacks in a database.
// It includes the database ID and an optional key-value query parameter.
StackInput struct {
URLParamDatabaseID
KV bool `default:"false" query:"kv"`
}
// StacksOutput represents the output structure for listing stacks in a database.
// It contains a list of stacks.
StacksOutput struct {
Body struct {
Stacks any `json:"stacks"`
}
}
// Stack represents the structure of a single stack, including its ID, name,
// size, and timestamps for creation, update, and last read.
Stack struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ReadAt time.Time `json:"read_at"`
Peek any `json:"peek"`
ID string `json:"id"`
Name string `json:"name"`
Size int `json:"size"`
}
)
// ListDatabaseStacksHandler handles the request to list all stacks in a database.
// It retrieves the stacks from the repository and returns the list.
func (s *Service) ListDatabaseStacksHandler(_ context.Context, input *StackInput) (*StacksOutput, error) {
db, err := s.Repository.Database(input.DatabaseID)
if err != nil {
return nil, huma.Error404NotFound("database not found", err)
}
out := new(StacksOutput)
if input.KV {
stacks := make(map[string]any)
for _, stack := range db.SortStacks() {
stacks[stack.Name] = stack.Peek()
}
out.Body.Stacks = stacks
return out, nil
}
stacks := make([]any, db.Len())
for i, stack := range db.SortStacks() {
stacks[i] = Stack{
ID: stack.ID.String(),
Name: stack.Name,
Peek: stack.Peek(),
Size: stack.Size(),
CreatedAt: stack.CreatedAt,
UpdatedAt: stack.UpdatedAt,
ReadAt: stack.ReadAt,
}
}
out.Body.Stacks = stacks
return out, nil
}
type (
// CreateDatabaseStackInput represents the input structure for creating a new stack in a database.
// It includes the database ID and the name of the new stack.
CreateDatabaseStackInput struct {
URLParamDatabaseID
Name string `minLength:"7" query:"name" required:"true"`
}
// StackOutput represents the output structure for operations involving a single stack.
// It contains the details of the stack.
StackOutput struct {
Body Stack `json:"stack"`
}
)
// CreateDatabaseStackHandler handles the request to create a new stack in a database.
// It creates the stack in the repository and returns its details.
// If the stack already exists, it returns a conflict error.
func (s *Service) CreateDatabaseStackHandler(_ context.Context, input *CreateDatabaseStackInput) (*StackOutput, error) {
db, err := s.Repository.Database(input.DatabaseID)
if err != nil {
return nil, huma.Error404NotFound("database not found", err)
}
stack, err := db.New(input.Name)
if errors.Is(err, repository.ErrAlreadyExists) {
return nil, huma.Error409Conflict("stack already exists", err)
}
out := new(StackOutput)
out.Body = Stack{
ID: stack.ID.String(),
Name: stack.Name,
Peek: stack.Peek(),
Size: stack.Size(),
CreatedAt: stack.CreatedAt,
UpdatedAt: stack.UpdatedAt,
ReadAt: stack.ReadAt,
}
return out, nil
}
type (
// DatabaseStackInput represents the input structure for operations involving a single stack in a database.
// It includes the database ID and the stack ID.
DatabaseStackInput struct {
URLParamDatabaseID
URLParamStackID
}
// URLParamStackID represents the URL parameter for a stack ID, which can be either the stack ID or name.
URLParamStackID struct {
StackID string `doc:"can be the stack ID or name" path:"stack"`
}
)
// ShowDatabaseStackHandler handles the request to show the details of a specific stack
// in a database. It retrieves the stack from the repository and returns its details.
func (s *Service) ShowDatabaseStackHandler(_ context.Context, input *DatabaseStackInput) (*StackOutput, error) {
_, stack, err := s.stack(input.DatabaseID, input.StackID)
if err != nil {
return nil, err
}
out := new(StackOutput)
out.Body = Stack{
ID: stack.ID.String(),
Name: stack.Name,
Peek: stack.Peek(),
Size: stack.Size(),
CreatedAt: stack.CreatedAt,
UpdatedAt: stack.UpdatedAt,
ReadAt: stack.ReadAt,
}
return out, nil
}
// StackElement represents the output structure for peeking at the top element of a stack.
// It contains the top element of the stack.
type StackElement struct {
Body struct {
Element any `json:"element"`
}
}
// PeekDatabaseStackHandler handles the request to peek at the top element of a specific stack
// in a database. It retrieves the top element from the stack and returns it.
func (s *Service) PeekDatabaseStackHandler(_ context.Context, input *DatabaseStackInput) (*StackElement, error) {
_, stack, err := s.stack(input.DatabaseID, input.StackID)
if err != nil {
return nil, err
}
out := new(StackElement)
out.Body.Element = stack.Peek()
return out, nil
}
// PushDatabaseStackElementInput represents the input structure for pushing a new element
// onto a stack in a database. It includes the database ID, stack ID, and the new element.
type PushDatabaseStackElementInput struct {
Body struct {
Element any `json:"element"`
}
DatabaseStackInput
}
// PushDatabaseStackHandler handles the request to push a new element onto a specific stack
// in a database. It adds the new element to the stack and returns the element.
func (s *Service) PushDatabaseStackHandler(_ context.Context, input *PushDatabaseStackElementInput) (*StackElement, error) {
_, stack, err := s.stack(input.DatabaseID, input.StackID)
if err != nil {
return nil, err
}
stack.Push(input.Body.Element)
out := new(StackElement)
out.Body.Element = input.Body.Element
return out, nil
}
// PopDatabaseStackElementOutput represents the output structure for popping an element
// from a stack. It contains the popped element and the status code.
type PopDatabaseStackElementOutput struct {
Body struct {
Element any `json:"element"`
}
Status int
}
// PopDatabaseStackHandler handles the request to pop an element from a specific stack
// in a database. It removes the top element from the stack and returns it. If the stack
// is empty, it returns a no content status.
func (s *Service) PopDatabaseStackHandler(_ context.Context, input *DatabaseStackInput) (*PopDatabaseStackElementOutput, error) {
_, stack, err := s.stack(input.DatabaseID, input.StackID)
if err != nil {
return nil, err
}
out := new(PopDatabaseStackElementOutput)
v := stack.Pop()
if v == nil {
out.Status = http.StatusNoContent
return out, nil
}
out.Status = http.StatusOK
out.Body.Element = v
return out, nil
}
// FlushDatabaseStackHandler handles the request to flush all elements from a specific stack
// in a database. It removes all elements from the stack and returns the details of the empty stack.
func (s *Service) FlushDatabaseStackHandler(_ context.Context, input *DatabaseStackInput) (*StackOutput, error) {
_, stack, err := s.stack(input.DatabaseID, input.StackID)
if err != nil {
return nil, err
}
stack.Flush()
out := new(StackOutput)
out.Body = Stack{
ID: stack.ID.String(),
Name: stack.Name,
Peek: stack.Peek(),
Size: stack.Size(),
CreatedAt: stack.CreatedAt,
UpdatedAt: stack.UpdatedAt,
ReadAt: stack.ReadAt,
}
return out, nil
}
// DeleteDatabaseStackHandler handles the request to delete a specific stack from a database.
// It removes the stack from the repository.
func (s *Service) DeleteDatabaseStackHandler(_ context.Context, input *DatabaseStackInput) (*struct{}, error) {
db, stack, err := s.stack(input.DatabaseID, input.StackID)
if err != nil {
return nil, err
}
if err := db.Drop(stack.ID.String()); err != nil {
return nil, err
}
return nil, nil
}
// stack retrieves a stack from the repository by its database ID and stack ID.
// If the stack or database is not found, it returns a not found error.
func (s *Service) stack(dbID, sID string) (*repository.Database, *repository.Stack, error) {
db, err := s.Repository.Database(dbID)
if err != nil {
return nil, nil, huma.Error404NotFound("database not found", err)
}
stack, err := db.Stack(sID)
if err != nil {
return nil, nil, huma.Error404NotFound("stack not found", err)
}
return db, stack, nil
}
// batterdb is a stack-based database engine.
// Databases are created with a unique name and can contain multiple stacks.
// Stacks are created within a database and can contain multiple elements.
// Docs are served at /docs.
//
// In the case of **batterdb**, this way is by pushes **_Elements_** in **_Stacks_,** so you only have access to the _Element_ on top,
// keeping the rest of them underneath.
package main
import (
"log/slog"
"os"
"os/signal"
"runtime/debug"
"syscall"
"github.com/alecthomas/kong"
"github.com/jh125486/batterdb/cli"
)
// XXX add otel
func main() {
// Listen for interrupt signals.
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
// Read build info.
info, ok := debug.ReadBuildInfo()
if !ok {
slog.Error("couldn't read build info")
os.Exit(1)
}
ctx, err := cli.New(
os.Args[1:],
kong.Name("batterdb"),
kong.Description("A simple stacked-based database 🔋."),
kong.Vars{"RepoFile": ".batterdb.gob"},
kong.Vars{"version": info.Main.Version},
kong.Bind(&cli.Ctx{
Stop: stop,
BuildInfo: info,
Writer: os.Stdout,
}),
)
if err != nil {
slog.Error(err.Error())
os.Exit(1)
}
ctx.FatalIfErrorf(ctx.Run())
}
// Package repository provides the core data structures and methods for managing
// databases and stacks in the batterdb application. It includes functionality for
// creating, retrieving, sorting, and deleting stacks within a database.
//
// The package uses UUIDs for unique identification of databases and stacks,
// and employs mutex locks for concurrent access to shared data structures.
package repository
import (
"sort"
"sync"
"time"
"github.com/google/uuid"
)
// Database represents a collection of stacks, identified by a unique UUID.
// It includes methods for managing stacks within the database.
type Database struct {
Stacks map[name]*Stack
Name string
ID uuid.UUID
mx sync.RWMutex
}
// Len returns the number of stacks in the database.
func (db *Database) Len() int {
db.mx.RLock()
defer db.mx.RUnlock()
return len(db.Stacks)
}
// SortStacks returns a sorted slice of stacks in the database, sorted by name.
func (db *Database) SortStacks() []*Stack {
db.mx.RLock()
defer db.mx.RUnlock()
stacks := make([]*Stack, 0, len(db.Stacks))
for _, stack := range db.Stacks {
stacks = append(stacks, stack)
}
sort.Slice(stacks, func(i, j int) bool {
return stacks[i].Name < stacks[j].Name
})
return stacks
}
// Stack retrieves a stack by its ID or name. If the stack is not found, it returns an error.
func (db *Database) Stack(id string) (*Stack, error) {
db.mx.RLock()
defer db.mx.RUnlock()
uid, err := uuid.Parse(id)
if err != nil {
// must be a name.
if stack, ok := db.Stacks[name(id)]; ok {
return stack, nil
}
}
for _, stack := range db.Stacks {
if stack.ID == uid {
stack.database = db
return stack, nil
}
}
return nil, ErrNotFound
}
// New creates a new stack with the given name and adds it to the database.
// If a stack with the same name already exists, it returns an error.
func (db *Database) New(n string) (*Stack, error) {
db.mx.Lock()
defer db.mx.Unlock()
if _, ok := db.Stacks[name(n)]; ok {
return nil, ErrAlreadyExists
}
t := time.Now()
stack := &Stack{
ID: uuid.New(),
Name: n,
database: db,
CreatedAt: t,
UpdatedAt: t,
ReadAt: t,
}
db.Stacks[name(n)] = stack
return stack, nil
}
// Drop removes a stack from the database by its ID or name. If the stack is not found, it returns an error.
func (db *Database) Drop(id string) error {
db.mx.Lock()
defer db.mx.Unlock()
for _, stack := range db.Stacks {
if stack.ID.String() == id || stack.Name == id {
delete(db.Stacks, name(stack.Name))
return nil
}
}
return ErrNotFound
}
// Package repository provides the core data structures and methods for managing
// databases and stacks in the batterdb application. It includes functionality for
// creating, retrieving, sorting, and deleting databases, as well as persisting
// and loading the repository state to and from a file.
//
// The package uses UUIDs for unique identification of databases and stacks,
// and employs mutex locks for concurrent access to shared data structures.
package repository
import (
"encoding/gob"
"errors"
"log/slog"
"os"
"sort"
"sync"
"github.com/google/uuid"
)
type (
// Repository represents a collection of databases.
// It includes methods for managing databases within the repository.
Repository struct {
Databases map[name]*Database
mx sync.RWMutex
}
// name represents the name of a database or stack.
name string
)
var (
// ErrNotFound is returned when a database or stack is not found.
ErrNotFound = errors.New("not found")
// ErrAlreadyExists is returned when a database or stack already exists.
ErrAlreadyExists = errors.New("already exists")
)
// New creates a new instance of Repository.
func New() *Repository {
return &Repository{
Databases: make(map[name]*Database),
}
}
// Len returns the number of databases in the repository.
func (r *Repository) Len() int {
r.mx.RLock()
defer r.mx.RUnlock()
return len(r.Databases)
}
// SortDatabases returns a sorted slice of databases in the repository, sorted by name.
func (r *Repository) SortDatabases() []*Database {
r.mx.RLock()
defer r.mx.RUnlock()
dbs := make([]*Database, 0, len(r.Databases))
for _, db := range r.Databases {
dbs = append(dbs, db)
}
sort.Slice(dbs, func(i, j int) bool {
return dbs[i].Name < dbs[j].Name
})
return dbs
}
// Database retrieves a database by its ID or name. If the database is not found, it returns an error.
func (r *Repository) Database(id string) (*Database, error) {
r.mx.RLock()
defer r.mx.RUnlock()
uid, err := uuid.Parse(id)
if err != nil {
// must be a name.
if db, ok := r.Databases[name(id)]; ok {
return db, nil
}
}
for _, db := range r.Databases {
if db.ID == uid {
return db, nil
}
}
return nil, ErrNotFound
}
// New creates a new database with the given name and adds it to the repository.
// If a database with the same name already exists, it returns an error.
func (r *Repository) New(n string) (*Database, error) {
r.mx.Lock()
defer r.mx.Unlock()
if _, ok := r.Databases[name(n)]; ok {
return nil, ErrAlreadyExists
}
db := &Database{
ID: uuid.New(),
Name: n,
Stacks: make(map[name]*Stack),
}
r.Databases[name(n)] = db
return db, nil
}
// Drop removes a database from the repository by its ID or name. If the database is not found, it returns an error.
func (r *Repository) Drop(id string) error {
r.mx.Lock()
defer r.mx.Unlock()
for _, db := range r.Databases {
if db.ID.String() == id || db.Name == id {
delete(r.Databases, name(db.Name))
return nil
}
}
return ErrNotFound
}
// Persist saves the repository state to a file.
func (r *Repository) Persist(filename string) error {
r.mx.RLock()
defer r.mx.RUnlock()
file, err := os.Create(filename)
if err != nil {
return err
}
defer func() {
_ = file.Close()
}()
return gob.NewEncoder(file).Encode(r)
}
// Load loads the repository state from a file.
func (r *Repository) Load(filename string) error {
file, err := os.Open(filename)
if err != nil {
if os.IsNotExist(err) {
// File doesn't exist yet.
slog.Info("No repository file found", slog.String("filename", filename))
return nil
}
return err
}
defer func() {
_ = file.Close()
}()
return gob.NewDecoder(file).Decode(r)
}
// Package repository provides the core data structures and methods for managing
// stacks in the batterdb application. It includes functionality for creating,
// interacting with, and managing stacks within a database.
//
// The package uses UUIDs for unique identification of stacks and employs mutex
// locks for concurrent access to shared data structures.
package repository
import (
"sync"
"time"
"github.com/google/uuid"
)
// Stack represents a stack data structure with metadata including
// creation, update, and read timestamps, and a reference to the database
// it belongs to. It supports concurrent access through a mutex lock.
type Stack struct {
CreatedAt time.Time
UpdatedAt time.Time
ReadAt time.Time
database *Database
Name string
Data []any
mx sync.RWMutex
ID uuid.UUID
}
// setUpdateTime sets the update and read timestamps of the stack to the given time.
func (s *Stack) setUpdateTime(t time.Time) {
s.setReadTime(t)
s.UpdatedAt = t
}
// setReadTime sets the read timestamp of the stack to the given time.
func (s *Stack) setReadTime(t time.Time) { s.ReadAt = t }
// Database returns the database the stack belongs to.
func (s *Stack) Database() *Database { return s.database }
// Push adds an element to the top of the stack and updates the timestamps.
func (s *Stack) Push(element any) {
s.mx.Lock()
defer s.mx.Unlock()
s.setUpdateTime(time.Now())
s.Data = append(s.Data, element)
s.UpdatedAt = time.Now()
}
// Pop removes and returns the top element of the stack. It updates the timestamps.
func (s *Stack) Pop() any {
s.mx.Lock()
defer s.mx.Unlock()
if len(s.Data) == 0 {
s.setReadTime(time.Now())
return nil
}
s.setUpdateTime(time.Now())
res := s.Data[len(s.Data)-1]
s.Data = s.Data[:len(s.Data)-1]
return res
}
// Size returns the number of elements in the stack.
func (s *Stack) Size() int {
s.mx.RLock()
defer s.mx.RUnlock()
return len(s.Data)
}
// Peek returns the top element of the stack without removing it,
// and updates the read timestamp.
func (s *Stack) Peek() any {
s.mx.RLock()
defer s.mx.RUnlock()
s.setReadTime(time.Now())
if len(s.Data) == 0 {
return nil
}
return s.Data[len(s.Data)-1]
}
// Flush removes all elements from the stack and updates the timestamps.
func (s *Stack) Flush() {
s.mx.Lock()
defer s.mx.Unlock()
s.setUpdateTime(time.Now())
s.Data = nil
}