package backlite
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"sync"
"time"
"github.com/mikestefanello/backlite/internal/query"
"github.com/mikestefanello/backlite/internal/task"
)
// now returns the current time in a way that tests can override.
var now = func() time.Time { return time.Now() }
type (
// Client is a client used to register queues and add tasks to them for execution.
Client struct {
// db stores the database to use for storing tasks.
db *sql.DB
// log is the logger.
log Logger
// queues stores the registered queues which tasks can be added to.
queues queues
// buffers is a pool of byte buffers for more efficient encoding.
buffers sync.Pool
// dispatcher is used to fetch and dispatch queued tasks to the workers for execution.
dispatcher Dispatcher
}
// ClientConfig contains configuration for the Client.
ClientConfig struct {
// DB is the open database connection used for storing tasks.
DB *sql.DB
// Logger is the logger used to log task execution.
Logger Logger
// NumWorkers is the number of goroutines to open to use for executing queued tasks concurrently.
NumWorkers int
// ReleaseAfter is the duration after which a task is released back to a queue if it has not finished executing.
// This value should be much higher than the timeout setting used for each queue and exists as a fail-safe
// just in case tasks become stuck.
ReleaseAfter time.Duration
// CleanupInterval is how often to run cleanup operations on the database in order to remove expired completed
// tasks. If omitted, no cleanup operations will be performed and the task retention duration will be ignored.
CleanupInterval time.Duration
}
// ctxKeyClient is used to store a Client in a context.
ctxKeyClient struct{}
)
// FromContext returns a Client from a context which is set for queue processor callbacks, so they can access
// the client in order to create additional tasks.
func FromContext(ctx context.Context) *Client {
if c, ok := ctx.Value(ctxKeyClient{}).(*Client); ok {
return c
}
return nil
}
// NewClient initializes a new Client
func NewClient(cfg ClientConfig) (*Client, error) {
switch {
case cfg.DB == nil:
return nil, errors.New("missing database")
case cfg.NumWorkers < 1:
return nil, errors.New("at least one worker required")
case cfg.ReleaseAfter <= 0:
return nil, errors.New("release duration must be greater than zero")
}
if cfg.Logger == nil {
cfg.Logger = &noLogger{}
}
c := &Client{
db: cfg.DB,
log: cfg.Logger,
queues: queues{registry: make(map[string]Queue)},
buffers: sync.Pool{
New: func() any {
return bytes.NewBuffer(nil)
},
},
}
c.dispatcher = &dispatcher{
client: c,
log: cfg.Logger,
numWorkers: cfg.NumWorkers,
releaseAfter: cfg.ReleaseAfter,
cleanupInterval: cfg.CleanupInterval,
}
return c, nil
}
// Register registers a new Queue so tasks can be added to it.
// This will panic if the name of the queue provided has already been registered.
func (c *Client) Register(queue Queue) {
c.queues.add(queue)
}
// Add starts an operation to add one or many tasks.
func (c *Client) Add(tasks ...Task) *TaskAddOp {
return &TaskAddOp{
client: c,
tasks: tasks,
}
}
// Start starts the dispatcher so queued tasks can automatically be executed in the background.
// To gracefully shut down the dispatcher, call Stop(), or to hard-stop, cancel the provided context.
func (c *Client) Start(ctx context.Context) {
c.dispatcher.Start(ctx)
}
// Stop attempts to gracefully shut down the dispatcher before the provided context is cancelled.
// True is returned if all workers were able to complete their tasks prior to shutting down.
func (c *Client) Stop(ctx context.Context) bool {
return c.dispatcher.Stop(ctx)
}
// Install installs the provided schema in the database.
// TODO provide migrations
func (c *Client) Install() error {
_, err := c.db.Exec(query.Schema)
return err
}
// Notify notifies the dispatcher that a new task has been added.
// This is only needed and required if you supply a database transaction when adding a task.
// See TaskAddOp.Tx().
func (c *Client) Notify() {
c.dispatcher.Notify()
}
// save saves a task add operation.
func (c *Client) save(op *TaskAddOp) error {
var commit bool
var err error
// Get a buffer for the encoding.
buf := c.buffers.Get().(*bytes.Buffer)
// Put the buffer back in the pool for re-use.
defer func() {
buf.Reset()
c.buffers.Put(buf)
}()
if op.ctx == nil {
op.ctx = context.Background()
}
// Start a transaction if one isn't provided.
if op.tx == nil {
op.tx, err = c.db.BeginTx(op.ctx, nil)
if err != nil {
return err
}
commit = true
defer func() {
if err == nil {
return
}
if err = op.tx.Rollback(); err != nil {
c.log.Error("failed to rollback task creation transaction",
"error", err,
)
}
}()
}
// Insert the tasks.
for _, t := range op.tasks {
buf.Reset()
if err = json.NewEncoder(buf).Encode(t); err != nil {
return err
}
m := task.Task{
Queue: t.Config().Name,
Task: buf.Bytes(),
WaitUntil: op.wait,
CreatedAt: now(),
}
if err = m.InsertTx(op.ctx, op.tx); err != nil {
return err
}
}
// If we created the transaction we'll commit it now.
if commit {
if err = op.tx.Commit(); err != nil {
return err
}
// Tell the dispatcher that a new task has been added.
c.Notify()
}
return nil
}
package backlite
import (
"context"
"database/sql"
"fmt"
"sync/atomic"
"time"
"github.com/mikestefanello/backlite/internal/task"
)
type (
// Dispatcher handles automatically pulling queued tasks and executing them via queue processors.
Dispatcher interface {
// Start starts the dispatcher.
Start(context.Context)
// Stop stops the dispatcher.
Stop(context.Context) bool
// Notify notifies the dispatcher that a new task has been added.
Notify()
}
// dispatcher implements Dispatcher.
dispatcher struct {
// client is the Client that this dispatcher belongs to.
client *Client
// log is the logger.
log Logger
// ctx stores the context used to start the dispatcher.
ctx context.Context
// shutdownCtx stores an internal context that is used when attempting to gracefully shut down the dispatcher.
shutdownCtx context.Context
// shutdown is the cancel function for cancelling shutdownCtx.
shutdown context.CancelFunc
// numWorkers is the amount of goroutines opened to execute tasks.
numWorkers int
// releaseAfter is the duration to reclaim a task for execution if it has not completed.
releaseAfter time.Duration
// CleanupInterval is how often to run cleanup operations on the database in order to remove expired completed
// tasks.
cleanupInterval time.Duration
// running indicates if the dispatching is currently running.
running atomic.Bool
// ticker will fetch tasks from the database if the next task is delayed.
ticker *time.Ticker
// tasks transmits tasks to the workers.
tasks chan *task.Task
// availableWorkers tracks the amount of workers available to receive a task to execute.
availableWorkers chan struct{}
// ready tells the dispatcher that fetching tasks from the database is required.
ready chan struct{}
// trigger instructs the dispatcher to fetch tasks from the database now.
trigger chan struct{}
// triggered indicates that a trigger was sent but not yet received.
// This is used to allow multiple calls to ready, which will happen whenever a task is added,
// but only 1 database fetch since that is all that is needed for the dispatcher to be aware of the
// current state of the queues.
triggered atomic.Bool
}
)
// Start starts the dispatcher.
// To hard-stop, cancel the provided context. To gracefully stop, call stop().
func (d *dispatcher) Start(ctx context.Context) {
// Abort if the dispatcher is already running.
if d.running.Load() {
return
}
d.ctx = ctx
d.shutdownCtx, d.shutdown = context.WithCancel(context.Background())
d.tasks = make(chan *task.Task, d.numWorkers)
d.ticker = time.NewTicker(time.Second)
d.ticker.Stop() // No need to tick yet
d.ready = make(chan struct{}, 1000) // Prevent blocking task creation
d.trigger = make(chan struct{}, 10) // Should never need more than 1 but just in case
d.availableWorkers = make(chan struct{}, d.numWorkers)
d.running.Store(true)
for range d.numWorkers {
go d.worker()
d.availableWorkers <- struct{}{}
}
if d.cleanupInterval > 0 {
go d.cleaner()
}
go d.triggerer()
go d.fetcher()
d.ready <- struct{}{}
d.log.Info("task dispatcher started")
}
// Stop attempts to gracefully shut down the dispatcher by blocking until either the context is cancelled or all
// workers are done with their task. If all workers are able to complete, true will be returned.
func (d *dispatcher) Stop(ctx context.Context) bool {
if !d.running.Load() {
return true
}
// Call the internal shutdown to gracefully close all goroutines.
d.shutdown()
var count int
for {
select {
case <-ctx.Done():
return false
case <-d.availableWorkers:
count++
if count == d.numWorkers {
return true
}
}
}
}
// triggerer listens to the ready channel and sends a trigger to the fetcher only when it is needed which is
// controlled by the triggered lock. This allows the dispatcher to track database fetches and when one is made,
// it can account for all incoming tasks that sent a signal to the ready channel before it, rather than fetching
// from the database every single time a new task is added.
func (d *dispatcher) triggerer() {
for {
select {
case <-d.ready:
if d.triggered.CompareAndSwap(false, true) {
d.trigger <- struct{}{}
}
case <-d.shutdownCtx.Done():
return
case <-d.ctx.Done():
return
}
}
}
// fetcher fetches tasks from the database to be executed either when the ticker ticks or when the trigger signal
// is sent by the triggerer.
func (d *dispatcher) fetcher() {
defer func() {
d.running.Store(false)
d.ticker.Stop()
close(d.tasks)
d.log.Info("shutting down dispatcher")
}()
for {
select {
case <-d.ticker.C:
d.ticker.Stop()
d.fetch()
case <-d.trigger:
d.fetch()
case <-d.shutdownCtx.Done():
return
case <-d.ctx.Done():
return
}
}
}
// worker processes incoming tasks.
func (d *dispatcher) worker() {
for {
select {
case row := <-d.tasks:
if row == nil {
break
}
d.processTask(row)
d.availableWorkers <- struct{}{}
case <-d.shutdownCtx.Done():
return
case <-d.ctx.Done():
return
}
}
}
// cleaner periodically deletes expired completed tasks from the database.
func (d *dispatcher) cleaner() {
ticker := time.NewTicker(d.cleanupInterval)
for {
select {
case <-ticker.C:
if err := task.DeleteExpiredCompleted(d.ctx, d.client.db); err != nil {
d.log.Error("failed to delete expired completed tasks",
"error", err,
)
}
case <-d.shutdownCtx.Done():
return
case <-d.ctx.Done():
ticker.Stop()
return
}
}
}
// waitForWorkers waits until at least one worker is available to execute a task and returns the number that are
// available.
func (d *dispatcher) waitForWorkers() int {
for {
if w := len(d.availableWorkers); w > 0 {
return w
}
time.Sleep(100 * time.Millisecond)
}
}
// fetch fetches tasks from the database to be executed and/or coordinate the dispatcher, so it is aware of when it
// needs to fetch again.
func (d *dispatcher) fetch() {
var err error
// If we failed at any point, we need to tell the dispatcher to try again.
defer func() {
if err != nil {
// Wait and try again.
time.Sleep(100 * time.Millisecond)
d.ready <- struct{}{}
}
}()
// Indicate that incoming task additions from this point on should trigger another fetch.
d.triggered.Store(false)
// Determine how many workers are available, so we only fetch that many tasks.
workers := d.waitForWorkers()
// Fetch tasks for each available worker plus the next upcoming task so the scheduler knows when to
// query the database again without having to continually poll.
tasks, err := task.GetScheduledTasks(
d.ctx,
d.client.db,
now().Add(-d.releaseAfter),
int(workers)+1,
)
if err != nil {
d.log.Error("fetch tasks query failed",
"error", err,
)
return
}
var next *task.Task
nextUp := func(i int) {
next = tasks[i]
tasks = tasks[:i]
}
for i := range tasks {
// Check if the workers are full.
if (i + 1) > workers {
nextUp(i)
break
}
// Check if this task is not ready yet.
if tasks[i].WaitUntil != nil {
if tasks[i].WaitUntil.After(now()) {
nextUp(i)
break
}
}
}
// Claim the tasks that are ready to be processed.
if err = tasks.Claim(d.ctx, d.client.db); err != nil {
d.log.Error("failed to claim tasks",
"error", err,
)
return
}
// Send the ready tasks to the workers.
for i := range tasks {
tasks[i].Attempts++
<-d.availableWorkers
d.tasks <- tasks[i]
}
// Adjust the schedule based on the next up task.
d.schedule(next)
}
// schedule handles scheduling the dispatcher based on the next up task provided by the fetcher.
func (d *dispatcher) schedule(t *task.Task) {
d.ticker.Stop()
if t != nil {
if t.WaitUntil == nil {
d.ready <- struct{}{}
return
}
dur := t.WaitUntil.Sub(now())
if dur < 0 {
d.ready <- struct{}{}
return
}
d.ticker.Reset(dur)
}
}
// processTask attempts to execute a given task.
func (d *dispatcher) processTask(t *task.Task) {
var err error
var ctx context.Context
var cancel context.CancelFunc
q := d.client.queues.get(t.Queue)
cfg := q.Config()
// Set a context timeout, if desired.
if cfg.Timeout > 0 {
ctx, cancel = context.WithDeadline(d.ctx, now().Add(cfg.Timeout))
defer cancel()
} else {
ctx = d.ctx
}
// Store the client in the context so the processor can use it.
// TODO include the attempt number
ctx = context.WithValue(ctx, ctxKeyClient{}, d.client)
start := now()
defer func() {
// Recover from panics from within the task processor.
if rec := recover(); rec != nil {
d.log.Error("panic processing task",
"id", t.ID,
"queue", t.Queue,
"error", rec,
)
err = fmt.Errorf("%v", rec)
}
// If panic or error, handle the task as a failure.
if err != nil {
d.taskFailure(q, t, start, time.Since(start), err)
}
}()
// Process the task.
if err = q.Process(ctx, t.Task); err == nil {
d.taskSuccess(q, t, start, time.Since(start))
}
}
// taskSuccess handles post successful execution of a given task by removing it from the task table and optionally
// retaining it in the completed tasks table if the queue settings have retention enabled.
func (d *dispatcher) taskSuccess(q Queue, t *task.Task, started time.Time, dur time.Duration) {
var tx *sql.Tx
var err error
defer func() {
if err != nil {
d.log.Error("failed to update task success",
"id", t.ID,
"queue", t.Queue,
"error", err,
)
if tx != nil {
if err := tx.Rollback(); err != nil {
d.log.Error("failed to rollback task success",
"id", t.ID,
"queue", t.Queue,
"error", err,
)
}
}
}
}()
d.log.Info("task processed",
"id", t.ID,
"queue", t.Queue,
"duration", dur,
"attempt", t.Attempts,
)
tx, err = d.client.db.Begin()
if err != nil {
return
}
err = t.DeleteTx(d.ctx, tx)
if err != nil {
return
}
if err = d.taskComplete(tx, q, t, started, dur, nil); err != nil {
return
}
err = tx.Commit()
}
// taskFailure handles post failed execution of a given task by either releasing it back to the queue, if the maximum
// amount of attempts haven't been reached, or by deleting it from the task table and optionally moving to the completed
// task table if the queue has retention enabled.
func (d *dispatcher) taskFailure(q Queue, t *task.Task, started time.Time, dur time.Duration, taskErr error) {
remaining := q.Config().MaxAttempts - t.Attempts
d.log.Error("task processing failed",
"id", t.ID,
"queue", t.Queue,
"duration", dur,
"attempt", t.Attempts,
"remaining", remaining,
)
if remaining < 1 {
var tx *sql.Tx
var err error
defer func() {
if err != nil {
d.log.Error("failed to update task failure",
"id", t.ID,
"queue", t.Queue,
"error", err,
)
if tx != nil {
if err := tx.Rollback(); err != nil {
d.log.Error("failed to rollback task failure",
"id", t.ID,
"queue", t.Queue,
"error", err,
)
}
}
}
}()
tx, err = d.client.db.Begin()
if err != nil {
return
}
err = t.DeleteTx(d.ctx, tx)
if err != nil {
return
}
if err = d.taskComplete(tx, q, t, started, dur, taskErr); err != nil {
return
}
err = tx.Commit()
} else {
t.LastExecutedAt = &started
err := t.Fail(
d.ctx,
d.client.db,
now().Add(q.Config().Backoff),
)
if err != nil {
d.log.Error("failed to update task failure",
"id", t.ID,
"queue", t.Queue,
"error", err,
)
}
d.ready <- struct{}{}
}
}
// taskComplete creates a completed task from a given task.
func (d *dispatcher) taskComplete(
tx *sql.Tx,
q Queue,
t *task.Task,
started time.Time,
dur time.Duration,
taskErr error) error {
ret := q.Config().Retention
if ret == nil {
return nil
}
if taskErr == nil && ret.OnlyFailed {
return nil
}
c := task.Completed{
ID: t.ID,
Queue: t.Queue,
Attempts: t.Attempts,
Succeeded: taskErr == nil,
LastDuration: dur,
CreatedAt: t.CreatedAt,
LastExecutedAt: started,
}
if taskErr != nil {
errStr := taskErr.Error()
c.Error = &errStr
}
if ret.Duration != 0 {
v := now().Add(ret.Duration)
c.ExpiresAt = &v
}
if ret.Data != nil {
if !ret.Data.OnlyFailed || taskErr != nil {
c.Task = t.Task
}
}
return c.InsertTx(d.ctx, tx)
}
// Notify is used by the client to notify the dispatcher that a new task was added.
func (d *dispatcher) Notify() {
if d.running.Load() {
d.ready <- struct{}{}
}
}
package query
import (
_ "embed"
"fmt"
"strings"
)
//go:embed schema.sql
var Schema string
const InsertTask = `
INSERT INTO backlite_tasks
(id, created_at, queue, task, wait_until)
VALUES (?, ?, ?, ?, ?)
`
const SelectScheduledTasks = `
SELECT
id, queue, task, attempts, wait_until, created_at, last_executed_at, null
FROM
backlite_tasks
WHERE
claimed_at IS NULL
OR claimed_at < ?
ORDER BY
wait_until ASC,
id ASC
LIMIT ?
OFFSET ?
`
const DeleteTask = `
DELETE FROM backlite_tasks
WHERE id = ?
`
const InsertCompletedTask = `
INSERT INTO backlite_tasks_completed
(id, created_at, queue, last_executed_at, attempts, last_duration_micro, succeeded, task, expires_at, error)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
const TaskFailed = `
UPDATE backlite_tasks
SET
claimed_at = NULL,
wait_until = ?,
last_executed_at = ?
WHERE id = ?
`
const DeleteExpiredCompletedTasks = `
DELETE FROM backlite_tasks_completed
WHERE
expires_at IS NOT NULL
AND expires_at <= ?
`
func ClaimTasks(count int) string {
const query = `
UPDATE backlite_tasks
SET
claimed_at = ?,
attempts = attempts + 1
WHERE id IN (%s)
`
param := strings.Repeat("?,", count)
return fmt.Sprintf(query, param[:len(param)-1])
}
package task
import (
"context"
"database/sql"
"time"
"github.com/mikestefanello/backlite/internal/query"
)
type (
// Completed is a completed task.
Completed struct {
// ID is the Task ID
ID string
// Queue is the name of the queue this Task belongs to.
Queue string
// Task is the task data.
Task []byte
// Attempts are the amount of times this Task was executed.
Attempts int
// Succeeded indicates if the Task execution was a success.
Succeeded bool
// LastDuration is the last execution duration.
LastDuration time.Duration
// ExpiresAt is when this record should be removed from the database.
// If omitted, the record should not be removed.
ExpiresAt *time.Time
// CreatedAt is when the Task was originally created.
CreatedAt time.Time
// LastExecutedAt is the last time this Task executed.
LastExecutedAt time.Time
// Error is the error message provided by the Task processor.
Error *string
}
// CompletedTasks contains multiple completed tasks.
CompletedTasks []*Completed
)
// InsertTx inserts a completed task as part of a database transaction.
func (c *Completed) InsertTx(ctx context.Context, tx *sql.Tx) error {
var expiresAt *int64
if c.ExpiresAt != nil {
v := c.ExpiresAt.UnixMilli()
expiresAt = &v
}
_, err := tx.ExecContext(
ctx,
query.InsertCompletedTask,
c.ID,
c.CreatedAt.UnixMilli(),
c.Queue,
c.LastExecutedAt.UnixMilli(),
c.Attempts,
c.LastDuration.Microseconds(),
c.Succeeded,
c.Task,
expiresAt,
c.Error,
)
return err
}
// GetCompletedTasks loads completed tasks from the database using a given query and arguments.
func GetCompletedTasks(ctx context.Context, db *sql.DB, query string, args ...any) (CompletedTasks, error) {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
tasks := make(CompletedTasks, 0)
for rows.Next() {
var task Completed
var lastExecutedAt, createdAt int64
var expiresAt *int64
err = rows.Scan(
&task.ID,
&createdAt,
&task.Queue,
&lastExecutedAt,
&task.Attempts,
&task.LastDuration,
&task.Succeeded,
&task.Task,
&expiresAt,
&task.Error,
)
if err != nil {
return nil, err
}
task.LastExecutedAt = time.UnixMilli(lastExecutedAt)
task.CreatedAt = time.UnixMilli(createdAt)
task.LastDuration *= 1000
if expiresAt != nil {
v := time.UnixMilli(*expiresAt)
task.ExpiresAt = &v
}
tasks = append(tasks, &task)
}
if err = rows.Err(); err != nil {
return nil, err
}
return tasks, nil
}
// DeleteExpiredCompleted deletes completed tasks that have an expiration date in the past.
func DeleteExpiredCompleted(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(
ctx,
query.DeleteExpiredCompletedTasks,
time.Now().UnixMilli(),
)
return err
}
package task
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/mikestefanello/backlite/internal/query"
)
// Task is a task that is queued for execution.
type Task struct {
// ID is the Task ID.
ID string
// Queue is the name of the queue this Task belongs to.
Queue string
// Task is the task data.
Task []byte
// Attempts are the amount of times this Task was executed.
Attempts int
// WaitUntil is the time the task should not be executed until.
WaitUntil *time.Time
// CreatedAt is when the Task was originally created.
CreatedAt time.Time
// LastExecutedAt is the last time this Task executed.
LastExecutedAt *time.Time
// ClaimedAt is the time this Task was claimed for execution.
ClaimedAt *time.Time
}
// InsertTx inserts a task as part of a database transaction.
func (t *Task) InsertTx(ctx context.Context, tx *sql.Tx) error {
if len(t.ID) == 0 {
// UUID is used because it's faster and more reliable than having the DB generate a random string.
// And since it's time-sortable, we avoid needing a separate index on the created time.
id, err := uuid.NewV7()
if err != nil {
return fmt.Errorf("unable to generate task ID: %w", err)
}
t.ID = id.String()
}
if t.CreatedAt.IsZero() {
t.CreatedAt = time.Now()
}
var wait *int64
if t.WaitUntil != nil {
v := t.WaitUntil.UnixMilli()
wait = &v
}
_, err := tx.ExecContext(
ctx,
query.InsertTask,
t.ID,
t.CreatedAt.UnixMilli(),
t.Queue,
t.Task,
wait,
)
return err
}
// DeleteTx deletes a task as part of a database transaction.
func (t *Task) DeleteTx(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, query.DeleteTask, t.ID)
return err
}
// Fail marks a task as failed in the database and queues it to be executed again.
func (t *Task) Fail(ctx context.Context, db *sql.DB, waitUntil time.Time) error {
_, err := db.ExecContext(
ctx,
query.TaskFailed,
waitUntil.UnixMilli(),
t.LastExecutedAt.UnixMilli(),
t.ID,
)
return err
}
package task
import (
"context"
"database/sql"
"time"
"github.com/mikestefanello/backlite/internal/query"
)
// Tasks are a slice of tasks.
type Tasks []*Task
// Claim updates a Task in the database to indicate that it has been claimed by a processor to be executed.
func (t Tasks) Claim(ctx context.Context, db *sql.DB) error {
if len(t) == 0 {
return nil
}
params := make([]any, 0, len(t)+1)
params = append(params, time.Now().UnixMilli())
for _, task := range t {
params = append(params, task.ID)
}
_, err := db.ExecContext(
ctx,
query.ClaimTasks(len(t)),
params...,
)
return err
}
// GetTasks loads tasks from the database using a given query and arguments.
func GetTasks(ctx context.Context, db *sql.DB, query string, args ...any) (Tasks, error) {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
tasks := make(Tasks, 0)
toTime := func(ms *int64) *time.Time {
if ms == nil {
return nil
}
v := time.UnixMilli(*ms)
return &v
}
for rows.Next() {
var task Task
var createdAt int64
var waitUntil, lastExecutedAt, claimedAt *int64
err = rows.Scan(
&task.ID,
&task.Queue,
&task.Task,
&task.Attempts,
&waitUntil,
&createdAt,
&lastExecutedAt,
&claimedAt,
)
if err != nil {
return nil, err
}
task.CreatedAt = time.UnixMilli(createdAt)
task.WaitUntil = toTime(waitUntil)
task.LastExecutedAt = toTime(lastExecutedAt)
task.ClaimedAt = toTime(claimedAt)
tasks = append(tasks, &task)
}
if err = rows.Err(); err != nil {
return nil, err
}
return tasks, nil
}
// GetScheduledTasks loads the tasks that are next up to be executed in order of execution time.
// It's important to note that this does not filter out tasks that are not yet ready based on their wait time.
// The deadline provided is used to include tasks that have been claimed if that given amount of time has elapsed.
func GetScheduledTasks(ctx context.Context, db *sql.DB, deadline time.Time, limit int) (Tasks, error) {
return GetScheduledTasksWithOffset(
ctx,
db,
deadline,
limit,
0,
)
}
// GetScheduledTasksWithOffset is the same as GetScheduledTasks but with an offset for paging.
func GetScheduledTasksWithOffset(
ctx context.Context,
db *sql.DB,
deadline time.Time,
limit,
offset int) (Tasks, error) {
return GetTasks(
ctx,
db,
query.SelectScheduledTasks,
deadline.UnixMilli(),
limit,
offset,
)
}
package testutil
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3"
"github.com/mikestefanello/backlite/internal/query"
"github.com/mikestefanello/backlite/internal/task"
)
func GetTasks(t *testing.T, db *sql.DB) task.Tasks {
got, err := task.GetTasks(context.Background(), db, `
SELECT
id, queue, task, attempts, wait_until, created_at, last_executed_at, claimed_at
FROM
backlite_tasks
ORDER BY
id ASC
`)
if err != nil {
t.Fatal(err)
}
return got
}
func InsertTask(t *testing.T, db *sql.DB, tk *task.Task) {
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
if err := tk.InsertTx(context.Background(), tx); err != nil {
t.Fatal(err)
}
if err := tx.Commit(); err != nil {
t.Fatal(err)
}
}
func TaskIDsExist(t *testing.T, db *sql.DB, ids []string) {
idMap := make(map[string]struct{}, len(ids))
for _, id := range ids {
idMap[id] = struct{}{}
}
for _, tc := range GetTasks(t, db) {
delete(idMap, tc.ID)
}
if len(idMap) != 0 {
t.Errorf("ids do not exist: %v", idMap)
}
}
func DeleteTasks(t *testing.T, db *sql.DB) {
_, err := db.Exec("DELETE FROM backlite_tasks")
if err != nil {
t.Fatal(err)
}
}
func GetCompletedTasks(t *testing.T, db *sql.DB) task.CompletedTasks {
got, err := task.GetCompletedTasks(context.Background(), db, `
SELECT
*
FROM
backlite_tasks_completed
ORDER BY
id ASC
`)
if err != nil {
t.Fatal(err)
}
return got
}
func DeleteCompletedTasks(t *testing.T, db *sql.DB) {
_, err := db.Exec("DELETE FROM backlite_tasks_completed")
if err != nil {
t.Fatal(err)
}
}
func InsertCompleted(t *testing.T, db *sql.DB, completed task.Completed) {
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
if err := completed.InsertTx(context.Background(), tx); err != nil {
t.Fatal(err)
}
if err := tx.Commit(); err != nil {
t.Fatal(err)
}
}
func CompleteTaskIDsExist(t *testing.T, db *sql.DB, ids []string) {
idMap := make(map[string]struct{}, len(ids))
for _, id := range ids {
idMap[id] = struct{}{}
}
for _, tc := range GetCompletedTasks(t, db) {
delete(idMap, tc.ID)
}
if len(idMap) != 0 {
t.Errorf("ids do not exist: %v", idMap)
}
}
func Equal[T comparable](t *testing.T, name string, expected, got T) {
if expected != got {
t.Errorf("%s; expected %v, got %v", name, expected, got)
}
}
func Length[T any](t *testing.T, obj []T, expectedLength int) {
if len(obj) != expectedLength {
t.Errorf("expected %d items, got %d", expectedLength, len(obj))
}
}
func IsTask(t *testing.T, expected, got task.Task) {
Equal(t, "Queue", expected.Queue, got.Queue)
Equal(t, "Attempts", expected.Attempts, got.Attempts)
Equal(t, "CreatedAt", expected.CreatedAt, got.CreatedAt)
if !bytes.Equal(expected.Task, got.Task) {
t.Error("Task bytes not equal")
}
switch {
case expected.WaitUntil == nil && got.WaitUntil == nil:
case expected.WaitUntil != nil && got.WaitUntil != nil:
Equal(t, "WaitUntil", *expected.WaitUntil, *got.WaitUntil)
default:
t.Error("WaitUntil not equal")
}
switch {
case expected.LastExecutedAt == nil && got.LastExecutedAt == nil:
case expected.LastExecutedAt != nil && got.LastExecutedAt != nil:
Equal(t, "LastExecutedAt", *expected.LastExecutedAt, *got.LastExecutedAt)
default:
t.Error("LastExecutedAt not equal")
}
switch {
case expected.ClaimedAt == nil && got.ClaimedAt == nil:
case expected.ClaimedAt != nil && got.ClaimedAt != nil:
Equal(t, "ClaimedAt", *expected.ClaimedAt, *got.ClaimedAt)
default:
t.Error("ClaimedAt not equal")
}
}
func Encode(t *testing.T, v any) []byte {
b := bytes.NewBuffer(nil)
err := json.NewEncoder(b).Encode(v)
if err != nil {
t.Fatal(err)
}
return b.Bytes()
}
func Pointer[T any](v T) *T {
return &v
}
func Wait() {
time.Sleep(100 * time.Millisecond)
}
func NewDB(t *testing.T) *sql.DB {
db, err := sql.Open("sqlite3", fmt.Sprintf("file:/%s?vfs=memdb&_timeout=1000", uuid.New().String()))
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(query.Schema)
if err != nil {
t.Fatal(err)
}
return db
}
func WaitForChan[T any](t *testing.T, signal chan T) {
select {
case <-signal:
case <-time.After(500 * time.Millisecond):
t.Error("signal not received")
}
}
package backlite
type (
// Logger is used to log operations.
Logger interface {
// Info logs info messages.
Info(message string, params ...any)
// Error logs error messages.
Error(message string, params ...any)
}
// noLogger is the default logger and will log nothing.
noLogger struct{}
)
func (n noLogger) Info(message string, params ...any) {}
func (n noLogger) Error(message string, params ...any) {}
package backlite
import (
"bytes"
"context"
"encoding/json"
"fmt"
"sync"
"time"
)
type (
// Queue represents a queue which contains tasks to be executed.
Queue interface {
// Config returns the configuration for the queue.
Config() *QueueConfig
// Process processes the Task.
Process(ctx context.Context, payload []byte) error
}
// QueueConfig is the configuration options for a queue.
QueueConfig struct {
// Name is the name of the queue and must be unique.
Name string
// MaxAttempts are the maximum number of attempts to execute this task before it's marked as completed.
MaxAttempts int
// Timeout is the duration set on the context while executing a given task.
Timeout time.Duration
// Backoff is the duration a failed task will be held in the queue until being retried.
Backoff time.Duration
// Retention dictates if and how completed tasks will be retained in the database.
// If nil, no completed tasks will be retained.
Retention *Retention
}
// Retention is the policy for how completed tasks will be retained in the database.
Retention struct {
// Duration is the amount of time to retain a task for after completion.
// If omitted, the task will be retained forever.
Duration time.Duration
// OnlyFailed indicates if only failed tasks should be retained.
OnlyFailed bool
// Data provides options for retaining Task payload data.
// If nil, no task payload data will be retained.
Data *RetainData
}
// RetainData is the policy for how Task payload data will be retained in the database after the task is complete.
RetainData struct {
// OnlyFailed indicates if Task payload data should only be retained for failed tasks.
OnlyFailed bool
}
// queue provides a type-safe implementation of Queue
queue[T Task] struct {
config *QueueConfig
processor QueueProcessor[T]
}
// QueueProcessor is a generic processor callback for a given queue to process Tasks
QueueProcessor[T Task] func(context.Context, T) error
// queues stores a registry of queues.
queues struct {
registry map[string]Queue
sync.RWMutex
}
)
// NewQueue creates a new type-safe Queue of a given Task type
func NewQueue[T Task](processor QueueProcessor[T]) Queue {
var task T
cfg := task.Config()
q := &queue[T]{
config: &cfg,
processor: processor,
}
return q
}
func (q *queue[T]) Config() *QueueConfig {
return q.config
}
func (q *queue[T]) Process(ctx context.Context, payload []byte) error {
var obj T
err := json.
NewDecoder(bytes.NewReader(payload)).
Decode(&obj)
if err != nil {
return err
}
return q.processor(ctx, obj)
}
// add adds a queue to the registry and will panic if the name has already been registered.
func (q *queues) add(queue Queue) {
if len(queue.Config().Name) == 0 {
panic("queue name is missing")
}
q.Lock()
defer q.Unlock()
if _, exists := q.registry[queue.Config().Name]; exists {
panic(fmt.Sprintf("queue '%s' already registered", queue.Config().Name))
}
q.registry[queue.Config().Name] = queue
}
// get loads a queue from the registry by name.
func (q *queues) get(name string) Queue {
q.RLock()
defer q.RUnlock()
return q.registry[name]
}
package backlite
import (
"context"
"database/sql"
"time"
)
type (
// Task represents a task that will be placed in to a queue for execution.
Task interface {
// Config returns the configuration options for the queue that this Task will be placed in.
Config() QueueConfig
}
// TaskAddOp facilitates adding Tasks to the queue.
TaskAddOp struct {
client *Client
ctx context.Context
tasks []Task
wait *time.Time
tx *sql.Tx
}
)
// Ctx sets the request context.
func (t *TaskAddOp) Ctx(ctx context.Context) *TaskAddOp {
t.ctx = ctx
return t
}
// At sets the time the task should not be executed until.
func (t *TaskAddOp) At(processAt time.Time) *TaskAddOp {
t.wait = &processAt
return t
}
// Wait instructs the task to wait a given duration before it is executed.
func (t *TaskAddOp) Wait(duration time.Duration) *TaskAddOp {
t.At(now().Add(duration))
return t
}
// Tx will include the task as part of a given database transaction.
// When using this, it is critical that after you commit the transaction that you call Notify() on the
// client so the dispatcher is aware that a new task has been created, otherwise it may not be executed.
// This is necessary because there is, unfortunately, no way for outsiders to know if or when a transaction
// is committed and since the dispatcher avoids continuous polling, it needs to know when tasks are added.
func (t *TaskAddOp) Tx(tx *sql.Tx) *TaskAddOp {
t.tx = tx
return t
}
// Save saves the task, so it can be queued for execution.
func (t *TaskAddOp) Save() error {
return t.client.save(t)
}
package ui
import (
"database/sql"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strconv"
"text/template"
"time"
"github.com/mikestefanello/backlite/internal/task"
)
// itemLimit is the limit of items to fetch from the database.
// TODO allow this to be configurable via the UI.
const itemLimit = 25
type (
// Handler handles HTTP requests for the Backlite UI.
Handler struct {
// db stores the Backlite database.
db *sql.DB
}
// templateData is a wrapper of data sent to templates for rendering.
templateData struct {
// Path is the current request URL path.
Path string
// Content is the data to render.
Content any
// Page is the page number.
Page int
}
// handleFunc is an HTTP handle func that returns an error.
handleFunc func(http.ResponseWriter, *http.Request) error
)
// NewHandler creates a new handler for the Backlite web UI.
func NewHandler(db *sql.DB) *Handler {
return &Handler{db: db}
}
// Register registers all available routes.
func (h *Handler) Register(mux *http.ServeMux) *http.ServeMux {
mux.HandleFunc("GET /", handle(h.Running))
mux.HandleFunc("GET /upcoming", handle(h.Upcoming))
mux.HandleFunc("GET /succeeded", handle(h.Succeeded))
mux.HandleFunc("GET /failed", handle(h.Failed))
mux.HandleFunc("GET /task/{task}", handle(h.Task))
mux.HandleFunc("GET /completed/{task}", handle(h.TaskCompleted))
return mux
}
func handle(hf handleFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := hf(w, r); err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, err)
log.Println(err)
}
}
}
// Running renders the running tasks.
func (h *Handler) Running(w http.ResponseWriter, req *http.Request) error {
tasks, err := task.GetTasks(
req.Context(),
h.db,
selectRunningTasks,
itemLimit,
getOffset(req.URL),
)
if err != nil {
return err
}
return h.render(req, w, tmplTasksRunning, tasks)
}
// Upcoming renders the upcoming tasks.
func (h *Handler) Upcoming(w http.ResponseWriter, req *http.Request) error {
tasks, err := task.GetScheduledTasksWithOffset(
req.Context(),
h.db,
time.Now().Add(-time.Hour), // TODO use actual time from the client
itemLimit,
getOffset(req.URL),
)
if err != nil {
return err
}
return h.render(req, w, tmplTasksUpcoming, tasks)
}
// Succeeded renders the completed tasks that have succeeded.
func (h *Handler) Succeeded(w http.ResponseWriter, req *http.Request) error {
tasks, err := task.GetCompletedTasks(
req.Context(),
h.db,
selectCompletedTasks,
1,
itemLimit,
getOffset(req.URL),
)
if err != nil {
return err
}
return h.render(req, w, tmplTasksCompleted, tasks)
}
// Failed renders the completed tasks that have failed.
func (h *Handler) Failed(w http.ResponseWriter, req *http.Request) error {
tasks, err := task.GetCompletedTasks(
req.Context(),
h.db,
selectCompletedTasks,
0,
itemLimit,
getOffset(req.URL),
)
if err != nil {
return err
}
return h.render(req, w, tmplTasksCompleted, tasks)
}
// Task renders a task.
func (h *Handler) Task(w http.ResponseWriter, req *http.Request) error {
id := req.PathValue("task")
tasks, err := task.GetTasks(req.Context(), h.db, selectTask, id)
if err != nil {
return err
}
if len(tasks) > 0 {
return h.render(req, w, tmplTask, tasks[0])
}
// If no task found, try the same ID as a completed task.
return h.TaskCompleted(w, req)
}
// TaskCompleted renders a completed task.
func (h *Handler) TaskCompleted(w http.ResponseWriter, req *http.Request) error {
var t *task.Completed
id := req.PathValue("task")
tasks, err := task.GetCompletedTasks(req.Context(), h.db, selectCompletedTask, id)
if err != nil {
return err
}
if len(tasks) > 0 {
t = tasks[0]
}
return h.render(req, w, tmplTaskCompleted, t)
}
func (h *Handler) render(req *http.Request, w io.Writer, tmpl *template.Template, data any) error {
return tmpl.ExecuteTemplate(w, "layout.gohtml", templateData{
Path: req.URL.Path,
Content: data,
Page: getPage(req.URL),
})
}
func getPage(u *url.URL) int {
if p := u.Query().Get("page"); p != "" {
if page, err := strconv.Atoi(p); err == nil {
if page > 0 {
return page
}
}
}
return 1
}
func getOffset(u *url.URL) int {
return (getPage(u) - 1) * itemLimit
}
package ui
import (
"embed"
"fmt"
"text/template"
"time"
)
//go:embed templates/*.gohtml
var templates embed.FS
var (
tmplTasksRunning = mustParse("running")
tmplTasksUpcoming = mustParse("upcoming")
tmplTasksCompleted = mustParse("completed_tasks")
tmplTask = mustParse("task")
tmplTaskCompleted = mustParse("completed_task")
)
func mustParse(page string) *template.Template {
t, err := template.
New("layout.gohtml").
Funcs(
template.FuncMap{
"bytestring": bytestring,
"datetime": datetime,
"add": add,
}).
ParseFS(
templates,
"templates/layout.gohtml",
fmt.Sprintf("templates/%s.gohtml", page),
)
if err != nil {
panic(err)
}
return t
}
func bytestring(b []byte) string {
return string(b)
}
func datetime(t time.Time) string {
return t.Local().Format("02 Jan 2006 15:04:05 MST")
}
func add(a, b int) int {
return a + b
}