package tasq
import (
"context"
"fmt"
"time"
)
const defaultTaskAgeLimit = 15 * time.Minute
// Cleaner is a service instance created by a Client with reference to that client
// and the task age limit parameter.
type Cleaner struct {
client *Client
taskAgeLimit time.Duration
}
// NewCleaner creates a new cleaner with a reference to the original tasq client.
func (c *Client) NewCleaner() *Cleaner {
return &Cleaner{
client: c,
taskAgeLimit: defaultTaskAgeLimit,
}
}
// WithTaskAge defines the minimum time duration that must have passed since the creation of a finished task
// in order for it to be eligible for cleanup when the Cleaner's Clean() method is called.
//
// Default value: 15 minutes.
func (c *Cleaner) WithTaskAge(taskAge time.Duration) *Cleaner {
c.taskAgeLimit = taskAge
return c
}
// Clean will initiate the removal of finished (either succeeded or failed) tasks from the tasks table
// if they have been created long enough ago for them to be eligible.
func (c *Cleaner) Clean(ctx context.Context) (int64, error) {
cleanedTaskCount, err := c.client.repository.CleanTasks(ctx, c.taskAgeLimit)
if err != nil {
return 0, fmt.Errorf("failed to clean tasks: %w", err)
}
return cleanedTaskCount, nil
}
// Package tasq provides a task queue implementation compapible with multiple repositories
package tasq
// Client wraps the tasq repository interface which is used
// by the different services to access the database.
type Client struct {
repository IRepository
}
// NewClient creates a new tasq client instance with the provided tasq.
func NewClient(repository IRepository) *Client {
return &Client{
repository: repository,
}
}
package tasq
import (
"context"
"errors"
"fmt"
"io"
"log"
"sync"
"time"
"github.com/benbjohnson/clock"
"github.com/google/uuid"
)
// Collection of consumer errors.
var (
ErrConsumerAlreadyRunning = errors.New("consumer has already been started")
ErrConsumerAlreadyStopped = errors.New("consumer has already been stopped")
ErrCouldNotActivateTasks = errors.New("a number of tasks could not be activated")
ErrCouldNotPollTasks = errors.New("could not poll tasks")
ErrCouldNotPingTasks = errors.New("could not ping tasks")
ErrTaskTypeAlreadyLearned = errors.New("task with this type already learned")
ErrTaskTypeNotFound = errors.New("task with this type not found")
ErrTaskTypeNotKnown = errors.New("task with this type is not known by this consumer")
ErrUnknownPollStrategy = errors.New("unknown poll strategy")
ErrVisibilityTimeoutTooShort = errors.New("visibility timeout must be longer than poll interval")
)
// Logger is the interface used for event logging during task consumption.
type Logger interface {
Print(v ...any)
Printf(format string, v ...any)
}
// HandlerFunc is the function signature for the handler functions that are used to process tasks.
type HandlerFunc func(task *Task) error
// PollStrategy is the label assigned to the ordering by which tasks are polled for consumption.
type PollStrategy string
// Collection of pollStrategies.
const (
PollStrategyByCreatedAt PollStrategy = "pollByCreatedAt" // Poll by oldest tasks first
PollStrategyByPriority PollStrategy = "pollByPriority" // Poll by highest priority task first
)
const (
defaultQueue = ""
defaultChannelSize = 10
defaultPollInterval = 5 * time.Second
defaultPollStrategy = PollStrategyByCreatedAt
defaultPollLimit = 10
defaultAutoDeleteOnSuccess = false
defaultMaxActiveTasks = 10
defaultVisibilityTimeout = 15 * time.Second
)
// NoopLogger discards the log messages written to it.
func NoopLogger() *log.Logger {
return log.New(io.Discard, "", 0)
}
// Consumer is a service instance created by a Client with reference to that client
// and the various parameters that define the task consumption behaviour.
type Consumer struct {
running bool
autoDeleteOnSuccess bool
channelSize int
pollLimit int
maxActiveTasks int
pollInterval time.Duration
pollStrategy PollStrategy
wg sync.WaitGroup
channel chan *func()
client *Client
clock clock.Clock
logger Logger
handlerFuncMap map[string]HandlerFunc
activeMutex sync.RWMutex
activeTasks map[uuid.UUID]struct{}
visibilityTimeout time.Duration
queues []string
stop chan struct{}
}
// NewConsumer creates a new consumer with a reference to the original tasq client
// and default consumer parameters.
func (c *Client) NewConsumer() *Consumer {
return &Consumer{
running: false,
autoDeleteOnSuccess: defaultAutoDeleteOnSuccess,
channelSize: defaultChannelSize,
pollLimit: defaultPollLimit,
maxActiveTasks: defaultMaxActiveTasks,
pollInterval: defaultPollInterval,
pollStrategy: defaultPollStrategy,
wg: sync.WaitGroup{},
channel: nil,
client: c,
clock: clock.New(),
logger: NoopLogger(),
handlerFuncMap: make(map[string]HandlerFunc),
activeMutex: sync.RWMutex{},
activeTasks: make(map[uuid.UUID]struct{}),
visibilityTimeout: defaultVisibilityTimeout,
queues: []string{defaultQueue},
stop: make(chan struct{}, 1),
}
}
// WithChannelSize sets the size of the buffered channel used for outputting the polled messages to.
//
// Default value: 10.
func (c *Consumer) WithChannelSize(channelSize int) *Consumer {
c.channelSize = channelSize
return c
}
// WithLogger sets the Logger interface that is used for event logging during task consumption.
//
// Default value: NoopLogger.
func (c *Consumer) WithLogger(logger Logger) *Consumer {
c.logger = logger
return c
}
// WithPollInterval sets the interval at which the consumer will try and poll for new tasks to be executed
// must not be greater than or equal to visibility timeout.
//
// Default value: 5 seconds.
func (c *Consumer) WithPollInterval(pollInterval time.Duration) *Consumer {
c.pollInterval = pollInterval
return c
}
// WithPollLimit sets the maximum number of messages polled from the task queue.
//
// Default value: 10.
func (c *Consumer) WithPollLimit(pollLimit int) *Consumer {
c.pollLimit = pollLimit
return c
}
// WithPollStrategy sets the ordering to be used when polling for tasks from the task queue.
//
// Default value: PollStrategyByCreatedAt.
func (c *Consumer) WithPollStrategy(pollStrategy PollStrategy) *Consumer {
c.pollStrategy = pollStrategy
return c
}
// WithAutoDeleteOnSuccess sets whether successful tasks should be automatically deleted from the task queue
// by the consumer.
//
// Default value: false.
func (c *Consumer) WithAutoDeleteOnSuccess(autoDeleteOnSuccess bool) *Consumer {
c.autoDeleteOnSuccess = autoDeleteOnSuccess
return c
}
// WithMaxActiveTasks sets the maximum number of tasks a consumer can have enqueued at the same time
// before polling for additional ones.
//
// Default value: 10.
func (c *Consumer) WithMaxActiveTasks(maxActiveTasks int) *Consumer {
c.maxActiveTasks = maxActiveTasks
return c
}
// WithVisibilityTimeout sets the duration by which each ping will extend a task's visibility timeout;
// Once this timeout is up, a consumer instance may receive the task again.
//
// Default value: 15 seconds.
func (c *Consumer) WithVisibilityTimeout(visibilityTimeout time.Duration) *Consumer {
c.visibilityTimeout = visibilityTimeout
return c
}
// WithQueues sets the queues from which the consumer may poll for tasks.
//
// Default value: empty slice of strings.
func (c *Consumer) WithQueues(queues ...string) *Consumer {
c.queues = queues
return c
}
// Learn sets a handler function for the specified taskType.
// If override is false and a handler function is already set for the specified
// taskType, it'll return an error.
func (c *Consumer) Learn(taskType string, f HandlerFunc, override bool) error {
if _, exists := c.handlerFuncMap[taskType]; exists && !override {
return fmt.Errorf("%w: %s", ErrTaskTypeAlreadyLearned, taskType)
}
c.handlerFuncMap[taskType] = f
return nil
}
// Forget removes a handler function for the specified taskType from the map of
// learned handler functions.
// If the specified taskType does not exist, it'll return an error.
func (c *Consumer) Forget(taskType string) error {
if _, exists := c.handlerFuncMap[taskType]; !exists {
return fmt.Errorf("%w: %s", ErrTaskTypeNotFound, taskType)
}
delete(c.handlerFuncMap, taskType)
return nil
}
// Start launches the go routine which manages the pinging and polling of tasks
// for the consumer, or returns an error if the consumer is not properly configured.
func (c *Consumer) Start(ctx context.Context) error {
if c.isRunning() {
return ErrConsumerAlreadyRunning
}
if c.visibilityTimeout <= c.pollInterval {
return ErrVisibilityTimeoutTooShort
}
c.setRunning(true)
c.channel = make(chan *func(), c.channelSize)
ticker := c.clock.Ticker(c.pollInterval)
go c.processLoop(ctx, ticker)
return nil
}
// Stop sends the termination signal to the consumer so it'll no longer poll for news tasks.
func (c *Consumer) Stop() error {
if !c.isRunning() {
return ErrConsumerAlreadyStopped
}
c.stop <- struct{}{}
return nil
}
// Channel returns a read-only channel where the polled jobs can be read from.
func (c *Consumer) Channel() <-chan *func() {
return c.channel
}
func (c *Consumer) isRunning() bool {
return c.running
}
func (c *Consumer) setRunning(isRunning bool) {
c.running = isRunning
}
func (c *Consumer) registerTaskStart(ctx context.Context, task *Task) {
_, err := c.client.repository.RegisterStart(ctx, task)
if err != nil {
panic(err)
}
}
func (c *Consumer) registerTaskError(ctx context.Context, task *Task, taskError error) {
_, err := c.client.repository.RegisterError(ctx, task, taskError)
if err != nil {
panic(err)
}
if task.MaxReceives > 0 && (task.ReceiveCount) >= task.MaxReceives {
c.registerTaskFail(ctx, task)
} else {
c.requeueTask(ctx, task)
}
}
func (c *Consumer) registerTaskSuccess(ctx context.Context, task *Task) {
if c.autoDeleteOnSuccess {
err := c.client.repository.DeleteTask(ctx, task, false)
if err != nil {
panic(err)
}
} else {
_, err := c.client.repository.RegisterFinish(ctx, task, StatusSuccessful)
if err != nil {
panic(err)
}
}
c.removeFromActiveTasks(task)
}
func (c *Consumer) registerTaskFail(ctx context.Context, task *Task) {
_, err := c.client.repository.RegisterFinish(ctx, task, StatusFailed)
if err != nil {
panic(err)
}
c.removeFromActiveTasks(task)
}
func (c *Consumer) requeueTask(ctx context.Context, task *Task) {
_, err := c.client.repository.RequeueTask(ctx, task)
if err != nil {
panic(err)
}
c.removeFromActiveTasks(task)
}
func (c *Consumer) getActiveTaskCount() int {
return len(c.activeTasks)
}
func (c *Consumer) removeFromActiveTasks(task *Task) {
c.activeMutex.Lock()
delete(c.activeTasks, task.ID)
c.activeMutex.Unlock()
}
func (c *Consumer) getActiveTaskIDs() []uuid.UUID {
activeTaskIDs := make([]uuid.UUID, 0, len(c.activeTasks))
for taskID := range c.activeTasks {
activeTaskIDs = append(activeTaskIDs, taskID)
}
return activeTaskIDs
}
func (c *Consumer) getKnownTaskTypes() []string {
taskTypes := make([]string, 0, len(c.handlerFuncMap))
for taskType := range c.handlerFuncMap {
taskTypes = append(taskTypes, taskType)
}
return taskTypes
}
func (c *Consumer) getPollOrdering() (Ordering, error) {
switch c.pollStrategy {
case PollStrategyByCreatedAt:
return OrderingCreatedAtFirst, nil
case PollStrategyByPriority:
return OrderingPriorityFirst, nil
default:
return -1, fmt.Errorf("%w: %s", ErrUnknownPollStrategy, c.pollStrategy)
}
}
func (c *Consumer) getPollQuantity() int {
taskCapacity := c.maxActiveTasks - len(c.activeTasks)
if c.pollLimit < taskCapacity {
return c.pollLimit
}
return taskCapacity
}
func (c *Consumer) processLoop(ctx context.Context, ticker *clock.Ticker) {
c.wg.Add(1)
defer c.wg.Done()
defer c.logger.Print("processing stopped")
defer ticker.Stop()
var (
tasks []*Task
err error
)
for {
err = c.pingActiveTasks(ctx)
if err != nil {
c.logger.Printf("error pinging active tasks: %s", err)
}
if c.isRunning() {
tasks, err = c.pollForTasks(ctx)
if err != nil {
c.logger.Printf("error polling for tasks: %s", err)
}
err = c.activateTasks(ctx, tasks)
if err != nil {
c.logger.Printf("error activating tasks: %s", err)
}
} else if c.getActiveTaskCount() == 0 {
return
}
select {
case <-c.stop:
c.setRunning(false)
close(c.channel)
case <-ticker.C:
continue
}
}
}
func (c *Consumer) pollForTasks(ctx context.Context) ([]*Task, error) {
pollOrdering, err := c.getPollOrdering()
if err != nil {
return nil, err
}
tasks, err := c.client.repository.PollTasks(ctx, c.getKnownTaskTypes(), c.queues, c.visibilityTimeout, pollOrdering, c.getPollQuantity())
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrCouldNotPollTasks, err)
}
return tasks, nil
}
func (c *Consumer) pingActiveTasks(ctx context.Context) error {
_, err := c.client.repository.PingTasks(ctx, c.getActiveTaskIDs(), c.visibilityTimeout)
if err != nil {
return fmt.Errorf("%w: %w", ErrCouldNotPingTasks, err)
}
return nil
}
func (c *Consumer) activateTasks(ctx context.Context, tasks []*Task) error {
var errors []error
for _, task := range tasks {
err := c.activateTask(ctx, task)
if err != nil {
errors = append(errors, err)
c.registerTaskFail(ctx, task)
}
}
if len(errors) > 0 {
return fmt.Errorf("%w: %v", ErrCouldNotActivateTasks, len(errors))
}
return nil
}
func (c *Consumer) activateTask(ctx context.Context, task *Task) error {
job, err := c.createJobFromTask(ctx, task)
if err != nil {
return err
}
c.activeMutex.Lock()
c.activeTasks[task.ID] = struct{}{}
c.activeMutex.Unlock()
c.channel <- job
return nil
}
func (c *Consumer) createJobFromTask(ctx context.Context, task *Task) (*func(), error) {
if handlerFunc, ok := c.handlerFuncMap[task.Type]; ok {
return c.newJob(ctx, c, handlerFunc, task), nil
}
return nil, fmt.Errorf("%w: %s", ErrTaskTypeNotKnown, task.Type)
}
func (c *Consumer) newJob(ctx context.Context, consumer *Consumer, f HandlerFunc, task *Task) *func() {
job := func() {
consumer.registerTaskStart(ctx, task)
if err := f(task); err == nil {
consumer.registerTaskSuccess(ctx, task)
} else {
consumer.registerTaskError(ctx, task, err)
}
}
return &job
}
package tasq
import (
"context"
"fmt"
)
// Inspector is a service instance created by a Client with reference to that client
// with the purpose of enabling the observation of tasks.
type Inspector struct {
client *Client
}
// NewInspector creates a new inspector with a reference to the original tasq client.
func (c *Client) NewInspector() *Inspector {
return &Inspector{
client: c,
}
}
// Count returns a the total number of tasks based on the supplied filter arguments.
func (o *Inspector) Count(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string) (int64, error) {
count, err := o.client.repository.CountTasks(ctx, taskStatuses, taskTypes, queues)
if err != nil {
return 0, fmt.Errorf("error counting tasks: %w", err)
}
return count, nil
}
// Scan returns a list of tasks based on the supplied filter arguments.
func (o *Inspector) Scan(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string, ordering Ordering, limit int) ([]*Task, error) {
tasks, err := o.client.repository.ScanTasks(ctx, taskStatuses, taskTypes, queues, ordering, limit)
if err != nil {
return nil, fmt.Errorf("error scanning tasks: %w", err)
}
return tasks, nil
}
// Purge will remove all tasks based on the supplied filter arguments.
func (o *Inspector) Purge(ctx context.Context, safeDelete bool, taskStatuses []TaskStatus, taskTypes, queues []string) (int64, error) {
count, err := o.client.repository.PurgeTasks(ctx, taskStatuses, taskTypes, queues, safeDelete)
if err != nil {
return 0, fmt.Errorf("error purging tasks: %w", err)
}
return count, nil
}
// Delete will remove the supplied tasks.
func (o *Inspector) Delete(ctx context.Context, safeDelete bool, tasks ...*Task) error {
for _, task := range tasks {
if err := o.client.repository.DeleteTask(ctx, task, safeDelete); err != nil {
return fmt.Errorf("error removing task: %w", err)
}
}
return nil
}
package tasq
import (
"context"
"fmt"
)
// Producer is a service instance created by a Client with reference to that client
// with the purpose of enabling the submission of new tasks.
type Producer struct {
client *Client
}
// NewProducer creates a new consumer with a reference to the original tasq client.
func (c *Client) NewProducer() *Producer {
return &Producer{
client: c,
}
}
// Submit constructs and submits a new task to the queue based on the supplied arguments.
func (p *Producer) Submit(ctx context.Context, taskType string, taskArgs any, queue string, priority int16, maxReceives int32) (*Task, error) {
newTask, err := NewTask(taskType, taskArgs, queue, priority, maxReceives)
if err != nil {
return nil, fmt.Errorf("error creating task: %w", err)
}
return p.SubmitTask(ctx, newTask)
}
// SubmitTask submits an existing task struct to the queue based on the supplied arguments.
func (p *Producer) SubmitTask(ctx context.Context, task *Task) (*Task, error) {
submittedTask, err := p.client.repository.SubmitTask(ctx, task)
if err != nil {
return nil, fmt.Errorf("error submitting task: %w", err)
}
return submittedTask, nil
}
// Package mysql provides the implementation of a tasq repository in MySQL
package mysql
import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
"strings"
"text/template"
"time"
_ "github.com/go-sql-driver/mysql" // import mysql driver
"github.com/google/uuid"
"github.com/greencoda/tasq"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)
const driverName = "mysql"
var (
errUnexpectedDataSourceType = errors.New("unexpected dataSource type")
errFailedToBeginTx = errors.New("failed to begin transaction")
errFailedToCommitTx = errors.New("failed to commit transaction")
errFailedToExecuteSelect = errors.New("failed to execute select query")
errFailedToExecuteUpdate = errors.New("failed to execute update query")
errFailedToExecuteDelete = errors.New("failed to execute delete query")
errFailedToExecuteInsert = errors.New("failed to execute insert query")
errFailedToExecuteCreateTable = errors.New("failed to execute create table query")
errFailedGetRowsAffected = errors.New("failed to get rows affected by query")
)
// Repository implements the menthods necessary for tasq to work in MySQL.
type Repository struct {
db *sqlx.DB
tableName string
}
// NewRepository creates a new MySQL Repository instance.
func NewRepository(dataSource any, prefix string) (*Repository, error) {
switch d := dataSource.(type) {
case string:
return newRepositoryFromDSN(d, prefix)
case *sql.DB:
return newRepositoryFromDB(d, prefix)
}
return nil, fmt.Errorf("%w: %T", errUnexpectedDataSourceType, dataSource)
}
func newRepositoryFromDSN(dsn string, prefix string) (*Repository, error) {
dbx, err := sqlx.Open(driverName, dsn)
if err != nil {
return nil, fmt.Errorf("failed to open DB from dsn: %w", err)
}
return &Repository{
db: dbx,
tableName: tableName(prefix),
}, nil
}
func newRepositoryFromDB(db *sql.DB, prefix string) (*Repository, error) {
dbx := sqlx.NewDb(db, driverName)
return &Repository{
db: dbx,
tableName: tableName(prefix),
}, nil
}
// Migrate prepares the database by adding the tasks table.
func (d *Repository) Migrate(ctx context.Context) error {
if err := d.migrateTable(ctx); err != nil {
return err
}
return nil
}
// PingTasks pings a list of tasks by their ID
// and extends their invisibility timestamp with the supplied timeout parameter.
func (d *Repository) PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*tasq.Task, error) {
if len(taskIDs) == 0 {
return []*tasq.Task{}, nil
}
const (
updatePingedTasksSQLTemplate = `UPDATE
{{.tableName}}
SET
visible_at = :visibleAt
WHERE
id IN (:pingedTaskIDs);`
selectPingedTasksSQLTemplate = `SELECT
*
FROM
{{.tableName}}
WHERE
id IN (:pingedTaskIDs);`
)
tx, err := d.db.Beginx()
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToBeginTx, err)
}
defer rollback(tx)
var (
pingTime = time.Now()
updatePingedTasksQuery, updatePingedTasksArgs = d.getQueryWithTableName(updatePingedTasksSQLTemplate, map[string]any{
"visibleAt": timeToString(pingTime.Add(visibilityTimeout)),
"pingedTaskIDs": taskIDs,
})
)
_, err = tx.ExecContext(ctx, updatePingedTasksQuery, updatePingedTasksArgs...)
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
var (
pingedMySQLTasks []*mySQLTask
selectPingedTasksQuery, selectPingedTasksArgs = d.getQueryWithTableName(selectPingedTasksSQLTemplate, map[string]any{
"pingedTaskIDs": taskIDs,
})
)
err = tx.SelectContext(ctx, &pingedMySQLTasks, selectPingedTasksQuery, selectPingedTasksArgs...)
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
err = tx.Commit()
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToCommitTx, err)
}
return mySQLTasksToTasks(pingedMySQLTasks), nil
}
// PollTasks polls for available tasks matching supplied the parameters
// and sets their invisibility the supplied timeout parameter to the future.
func (d *Repository) PollTasks(ctx context.Context, types, queues []string, visibilityTimeout time.Duration, ordering tasq.Ordering, pollLimit int) ([]*tasq.Task, error) {
if pollLimit == 0 {
return []*tasq.Task{}, nil
}
const (
selectPolledTasksSQLTemplate = `SELECT
id
FROM
{{.tableName}}
WHERE
type IN (:pollTypes) AND
queue IN (:pollQueues) AND
status IN (:pollStatuses) AND
visible_at <= :pollTime
ORDER BY
:pollOrdering
LIMIT :pollLimit
FOR UPDATE SKIP LOCKED;`
updatePolledTasksSQLTemplate = `UPDATE
{{.tableName}}
SET
status = :status,
receive_count = receive_count + 1,
visible_at = :visibleAt
WHERE
id IN (:polledTaskIDs);`
selectUpdatedPolledTasksSQLTemplate = `SELECT
*
FROM
{{.tableName}}
WHERE
id IN (:polledTaskIDs);`
)
tx, err := d.db.Beginx()
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToBeginTx, err)
}
defer rollback(tx)
var (
polledTaskIDs []TaskID
pollTime = time.Now()
selectPolledTasksQuery, selectPolledTasksArgs = d.getQueryWithTableName(selectPolledTasksSQLTemplate, map[string]any{
"pollTypes": types,
"pollQueues": queues,
"pollStatuses": tasq.GetTaskStatuses(tasq.OpenTasks),
"pollTime": timeToString(pollTime),
"pollOrdering": getOrderingDirectives(ordering),
"pollLimit": pollLimit,
})
)
err = tx.SelectContext(ctx, &polledTaskIDs, selectPolledTasksQuery, selectPolledTasksArgs...)
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
if len(polledTaskIDs) == 0 {
return []*tasq.Task{}, nil
}
updatePolledTasksQuery, updatePolledTasksArgs := d.getQueryWithTableName(updatePolledTasksSQLTemplate, map[string]any{
"status": tasq.StatusEnqueued,
"visibleAt": timeToString(pollTime.Add(visibilityTimeout)),
"polledTaskIDs": polledTaskIDs,
})
_, err = tx.ExecContext(ctx, updatePolledTasksQuery, updatePolledTasksArgs...)
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
var (
polledMySQLTasks []*mySQLTask
selectUpdatedTasksQuery, selectUpdatedTasksArgs = d.getQueryWithTableName(selectUpdatedPolledTasksSQLTemplate, map[string]any{
"polledTaskIDs": polledTaskIDs,
})
)
err = tx.SelectContext(ctx, &polledMySQLTasks, selectUpdatedTasksQuery, selectUpdatedTasksArgs...)
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
err = tx.Commit()
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToCommitTx, err)
}
return mySQLTasksToTasks(polledMySQLTasks), nil
}
// CleanTasks removes finished tasks from the queue
// if their creation date is past the supplied duration.
func (d *Repository) CleanTasks(ctx context.Context, cleanAge time.Duration) (int64, error) {
const cleanTasksSQLTemplate = `DELETE FROM
{{.tableName}}
WHERE
status IN (:statuses) AND
created_at <= :cleanAt;`
var (
cleanTime = time.Now()
cleanTasksQuery, cleanTasksArgs = d.getQueryWithTableName(cleanTasksSQLTemplate, map[string]any{
"statuses": tasq.GetTaskStatuses(tasq.FinishedTasks),
"cleanAt": timeToString(cleanTime.Add(-cleanAge)),
})
)
result, err := d.db.ExecContext(ctx, cleanTasksQuery, cleanTasksArgs...)
if err != nil {
return 0, fmt.Errorf("%w: %w", errFailedToExecuteDelete, err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("%w: %w", errFailedGetRowsAffected, err)
}
return rowsAffected, nil
}
// RegisterStart marks a task as started with the 'in progress' status
// and records the time of start.
func (d *Repository) RegisterStart(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
const (
updateTaskSQLTemplate = `UPDATE
{{.tableName}}
SET
status = :status,
started_at = :startTime
WHERE
id = :taskID;`
selectUpdatedTaskSQLTemplate = `SELECT *
FROM
{{.tableName}}
WHERE
id = :taskID;`
)
tx, err := d.db.Beginx()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToBeginTx, err)
}
defer rollback(tx)
var (
mySQLTask = newFromTask(task)
startTime = time.Now()
updateTaskQuery, updateTaskArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{
"status": tasq.StatusInProgress,
"startTime": timeToString(startTime),
"taskID": mySQLTask.ID,
})
)
_, err = tx.ExecContext(ctx, updateTaskQuery, updateTaskArgs...)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
selectUpdatedTaskQuery, selectUpdatedTaskArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{
"taskID": mySQLTask.ID,
})
err = tx.QueryRowxContext(ctx, selectUpdatedTaskQuery, selectUpdatedTaskArgs...).
StructScan(mySQLTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
err = tx.Commit()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToCommitTx, err)
}
return mySQLTask.toTask(), nil
}
// RegisterError records an error message on the task as last error.
func (d *Repository) RegisterError(ctx context.Context, task *tasq.Task, errTask error) (*tasq.Task, error) {
const (
updateTaskSQLTemplate = `UPDATE
{{.tableName}}
SET
last_error = :errorMessage
WHERE
id = :taskID;`
selectUpdatedTaskSQLTemplate = `SELECT *
FROM
{{.tableName}}
WHERE
id = :taskID;`
)
tx, err := d.db.Beginx()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToBeginTx, err)
}
defer rollback(tx)
var (
mySQLTask = newFromTask(task)
updateTaskQuery, updateTaskArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{
"errorMessage": errTask.Error(),
"taskID": mySQLTask.ID,
})
)
_, err = tx.ExecContext(ctx, updateTaskQuery, updateTaskArgs...)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
selectUpdatedTaskQuery, selectUpdatedTaskArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{
"taskID": mySQLTask.ID,
})
err = tx.QueryRowxContext(ctx, selectUpdatedTaskQuery, selectUpdatedTaskArgs...).
StructScan(mySQLTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
err = tx.Commit()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToCommitTx, err)
}
return mySQLTask.toTask(), nil
}
// RegisterFinish marks a task as finished with the supplied status
// and records the time of finish.
func (d *Repository) RegisterFinish(ctx context.Context, task *tasq.Task, finishStatus tasq.TaskStatus) (*tasq.Task, error) {
const (
updateTaskSQLTemplate = `UPDATE
{{.tableName}}
SET
status = :status,
finished_at = :finishTime
WHERE
id = :taskID;`
selectUpdatedTaskSQLTemplate = `SELECT *
FROM
{{.tableName}}
WHERE
id = :taskID;`
)
tx, err := d.db.Beginx()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToBeginTx, err)
}
defer rollback(tx)
var (
mySQLTask = newFromTask(task)
finishTime = time.Now()
updateTasksQuery, updateTasksArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{
"status": finishStatus,
"finishTime": timeToString(finishTime),
"taskID": mySQLTask.ID,
})
)
_, err = tx.ExecContext(ctx, updateTasksQuery, updateTasksArgs...)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
selectUpdatedTasksQuery, selectUpdatedTasksArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{
"taskID": mySQLTask.ID,
})
err = tx.QueryRowxContext(ctx, selectUpdatedTasksQuery, selectUpdatedTasksArgs...).
StructScan(mySQLTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
err = tx.Commit()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToCommitTx, err)
}
return mySQLTask.toTask(), nil
}
// SubmitTask adds the supplied task to the queue.
func (d *Repository) SubmitTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
const (
insertTaskSQLTemplate = `INSERT INTO
{{.tableName}}
(id, type, args, queue, priority, status, max_receives, created_at, visible_at)
VALUES
(:id, :type, :args, :queue, :priority, :status, :maxReceives, :createdAt, :visibleAt);`
selectInsertedTaskSQLTemplate = `SELECT *
FROM
{{.tableName}}
WHERE
id = :taskID;`
)
tx, err := d.db.Beginx()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToBeginTx, err)
}
defer rollback(tx)
var (
mySQLTask = newFromTask(task)
insertTaskQuery, insertTaskArgs = d.getQueryWithTableName(insertTaskSQLTemplate, map[string]any{
"id": mySQLTask.ID,
"type": mySQLTask.Type,
"args": mySQLTask.Args,
"queue": mySQLTask.Queue,
"priority": mySQLTask.Priority,
"status": mySQLTask.Status,
"maxReceives": mySQLTask.MaxReceives,
"createdAt": mySQLTask.CreatedAt,
"visibleAt": mySQLTask.VisibleAt,
})
)
_, err = tx.ExecContext(ctx, insertTaskQuery, insertTaskArgs...)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteInsert, err)
}
selectInsertedTaskQuery, selectInsertedTaskArgs := d.getQueryWithTableName(selectInsertedTaskSQLTemplate, map[string]any{
"taskID": mySQLTask.ID,
})
err = tx.QueryRowxContext(ctx, selectInsertedTaskQuery, selectInsertedTaskArgs...).
StructScan(mySQLTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
err = tx.Commit()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToCommitTx, err)
}
return mySQLTask.toTask(), nil
}
// DeleteTask removes the supplied task from the queue.
func (d *Repository) DeleteTask(ctx context.Context, task *tasq.Task, safeDelete bool) error {
var (
mySQLTask = newFromTask(task)
conditions = []string{
`id = :taskID`,
}
parameters = map[string]any{
"taskID": mySQLTask.ID,
}
)
if safeDelete {
d.applySafeDeleteConditions(&conditions, ¶meters)
}
deleteTaskSQLTemplate := `DELETE FROM {{.tableName}} WHERE ` + strings.Join(conditions, ` AND `) + `;`
deleteTaskQuery, deleteTaskArgs := d.getQueryWithTableName(deleteTaskSQLTemplate, parameters)
_, err := d.db.ExecContext(ctx, deleteTaskQuery, deleteTaskArgs...)
if err != nil {
return fmt.Errorf("%w: %w", errFailedToExecuteDelete, err)
}
return nil
}
// RequeueTask marks a task as new, so it can be picked up again.
func (d *Repository) RequeueTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
const (
updateTaskSQLTemplate = `UPDATE
{{.tableName}}
SET
status = :status
WHERE
id = :taskID;`
selectUpdatedTaskSQLTemplate = `SELECT *
FROM
{{.tableName}}
WHERE
id = :taskID;`
)
tx, err := d.db.Beginx()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToBeginTx, err)
}
defer rollback(tx)
var (
mySQLTask = newFromTask(task)
updateTaskQuery, updateTaskArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{
"status": tasq.StatusNew,
"taskID": mySQLTask.ID,
})
)
_, err = tx.ExecContext(ctx, updateTaskQuery, updateTaskArgs...)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
selectUpdatedTaskQuery, selectUpdatedTaskArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{
"taskID": mySQLTask.ID,
})
err = tx.QueryRowxContext(ctx, selectUpdatedTaskQuery, selectUpdatedTaskArgs...).
StructScan(mySQLTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
err = tx.Commit()
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToCommitTx, err)
}
return mySQLTask.toTask(), err
}
// CountTasks returns the number of tasks in the queue based on the supplied filters.
func (d *Repository) CountTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (int64, error) {
var (
count int64
selectTaskCountQuery, selectTaskCountArgs = d.getQueryWithTableName(
d.buildCountSQLTemplate(taskStatuses, taskTypes, queues),
)
)
err := d.db.GetContext(ctx, &count, selectTaskCountQuery, selectTaskCountArgs...)
if err != nil {
return 0, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
return count, nil
}
// ScanTasks returns a list of tasks in the queue based on the supplied filters.
func (d *Repository) ScanTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) ([]*tasq.Task, error) {
var (
scannedTasks []*mySQLTask
selectScannedTasksQuery, selectScannedTasksArgs = d.getQueryWithTableName(
d.buildScanSQLTemplate(taskStatuses, taskTypes, queues, ordering, scanLimit),
)
)
err := d.db.SelectContext(ctx, &scannedTasks, selectScannedTasksQuery, selectScannedTasksArgs...)
if err != nil {
return []*tasq.Task{}, fmt.Errorf("%w: %w", errFailedToExecuteSelect, err)
}
return mySQLTasksToTasks(scannedTasks), nil
}
// PurgeTasks removes all tasks from the queue based on the supplied filters.
func (d *Repository) PurgeTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (int64, error) {
selectPurgedTasksQuery, selectPurgedTasksArgs := d.getQueryWithTableName(
d.buildPurgeSQLTemplate(taskStatuses, taskTypes, queues, safeDelete),
)
result, err := d.db.ExecContext(ctx, selectPurgedTasksQuery, selectPurgedTasksArgs...)
if err != nil {
return 0, fmt.Errorf("%w: %w", errFailedToExecuteDelete, err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("%w: %w", errFailedGetRowsAffected, err)
}
return rowsAffected, nil
}
func (d *Repository) buildCountSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (string, map[string]any) {
var (
conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
sqlTemplate = `SELECT COUNT(*) FROM {{.tableName}}`
)
if len(conditions) > 0 {
sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
}
return sqlTemplate + `;`, parameters
}
func (d *Repository) buildScanSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) (string, map[string]any) {
var (
conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
sqlTemplate = `SELECT * FROM {{.tableName}}`
)
if len(conditions) > 0 {
sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
}
sqlTemplate += ` ORDER BY :scanOrdering LIMIT :limit;`
parameters["scanOrdering"] = pq.Array(getOrderingDirectives(ordering))
parameters["limit"] = scanLimit
return sqlTemplate + `;`, parameters
}
func (d *Repository) buildPurgeSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (string, map[string]any) {
var (
conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
sqlTemplate = `DELETE FROM {{.tableName}}`
)
if safeDelete {
d.applySafeDeleteConditions(&conditions, ¶meters)
}
if len(conditions) > 0 {
sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
}
return sqlTemplate + `;`, parameters
}
func (d *Repository) applySafeDeleteConditions(conditions *[]string, parameters *map[string]any) {
*conditions = append(*conditions, `(
(
visible_at <= :visibleAt
) OR (
status IN (:statuses) AND
visible_at > :visibleAt
)
)`)
(*parameters)["statuses"] = []tasq.TaskStatus{tasq.StatusNew}
(*parameters)["visibleAt"] = time.Now()
}
func (d *Repository) buildFilterConditions(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) ([]string, map[string]any) {
var (
conditions []string
parameters = make(map[string]any)
)
if len(taskStatuses) > 0 {
conditions = append(conditions, `status IN (:filterStatuses)`)
parameters["filterStatuses"] = taskStatuses
}
if len(taskTypes) > 0 {
conditions = append(conditions, `type IN (:filterTypes)`)
parameters["filterTypes"] = taskTypes
}
if len(queues) > 0 {
conditions = append(conditions, `queue IN (:filterQueues)`)
parameters["filterQueues"] = queues
}
return conditions, parameters
}
func (d *Repository) getQueryWithTableName(sqlTemplate string, args ...any) (string, []any) {
query := interpolateSQL(sqlTemplate, map[string]any{
"tableName": d.tableName,
})
query, args, err := sqlx.Named(query, args)
if err != nil {
panic(err)
}
query, args, err = sqlx.In(query, args...)
if err != nil {
panic(err)
}
return d.db.Rebind(query), args
}
func (d *Repository) migrateTable(ctx context.Context) error {
const sqlTemplate = `CREATE TABLE IF NOT EXISTS {{.tableName}} (
id binary(16) NOT NULL,
type text NOT NULL,
args longblob NOT NULL,
queue text NOT NULL,
priority smallint NOT NULL,
status enum({{.enumValues}}) NOT NULL,
receive_count int NOT NULL DEFAULT '0',
max_receives int NOT NULL DEFAULT '0',
last_error text,
created_at datetime(6) NOT NULL DEFAULT '0001-01-01 00:00:00.000000',
started_at datetime(6),
finished_at datetime(6),
visible_at datetime(6) NOT NULL DEFAULT '0001-01-01 00:00:00.000000',
PRIMARY KEY (id)
);`
query := interpolateSQL(sqlTemplate, map[string]any{
"tableName": d.tableName,
"enumValues": sliceToMySQLValueList(tasq.GetTaskStatuses(tasq.AllTasks)),
})
_, err := d.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("%w: %w", errFailedToExecuteCreateTable, err)
}
return nil
}
func getOrderingDirectives(ordering tasq.Ordering) []string {
var (
OrderingCreatedAtFirst = []string{"created_at ASC", "priority DESC"}
OrderingPriorityFirst = []string{"priority DESC", "created_at ASC"}
)
if orderingDirectives, ok := map[tasq.Ordering][]string{
tasq.OrderingCreatedAtFirst: OrderingCreatedAtFirst,
tasq.OrderingPriorityFirst: OrderingPriorityFirst,
}[ordering]; ok {
return orderingDirectives
}
return OrderingCreatedAtFirst
}
func rollback(tx *sqlx.Tx) {
if err := tx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) {
panic(err)
}
}
func sliceToMySQLValueList[T any](slice []T) string {
stringSlice := make([]string, 0, len(slice))
for _, s := range slice {
stringSlice = append(stringSlice, fmt.Sprint(s))
}
return fmt.Sprintf(`"%s"`, strings.Join(stringSlice, `", "`))
}
func tableName(prefix string) string {
const tableName = "tasks"
if len(prefix) > 0 {
return prefix + "_" + tableName
}
return tableName
}
func interpolateSQL(sql string, params map[string]any) string {
template, err := template.New("sql").Parse(sql)
if err != nil {
panic(err)
}
var outputBuffer bytes.Buffer
err = template.Execute(&outputBuffer, params)
if err != nil {
panic(err)
}
return outputBuffer.String()
}
package mysql
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"github.com/greencoda/tasq"
)
const (
idLength = 16
timeFormat = "2006-01-02 15:04:05.999999"
)
var (
errIncorrectLength = errors.New("Scan: MySQLTaskID is of incorrect length")
errUnableToScan = errors.New("Scan: unable to scan type into MySQLTaskID")
)
// TaskID represents the types used to manage conversion of UUID
// to MySQL's binary(16) format.
type TaskID [idLength]byte
// Scan implements sql.Scanner so TaskIDs can be read from MySQL transparently.
func (i *TaskID) Scan(src any) error {
switch src := src.(type) {
case nil:
return nil
case []byte:
if len(src) == 0 {
return nil
}
if len(src) != idLength {
return fmt.Errorf("%w: %v", errIncorrectLength, len(src))
}
copy((*i)[:], src)
default:
return fmt.Errorf("%w: %T", errUnableToScan, src)
}
return nil
}
// Value implements sql.Valuer so that TaskIDs can be written to MySQL
// transparently.
func (i TaskID) Value() (driver.Value, error) {
return i[:], nil
}
type mySQLTask struct {
ID TaskID `db:"id"`
Type string `db:"type"`
Args []byte `db:"args"`
Queue string `db:"queue"`
Priority int16 `db:"priority"`
Status tasq.TaskStatus `db:"status"`
ReceiveCount int32 `db:"receive_count"`
MaxReceives int32 `db:"max_receives"`
LastError sql.NullString `db:"last_error"`
CreatedAt string `db:"created_at"`
StartedAt sql.NullString `db:"started_at"`
FinishedAt sql.NullString `db:"finished_at"`
VisibleAt string `db:"visible_at"`
}
func newFromTask(task *tasq.Task) *mySQLTask {
return &mySQLTask{
ID: TaskID(task.ID),
Type: task.Type,
Args: task.Args,
Queue: task.Queue,
Priority: task.Priority,
Status: task.Status,
ReceiveCount: task.ReceiveCount,
MaxReceives: task.MaxReceives,
LastError: stringToSQLNullString(task.LastError),
CreatedAt: timeToString(task.CreatedAt),
StartedAt: timeToSQLNullString(task.StartedAt),
FinishedAt: timeToSQLNullString(task.FinishedAt),
VisibleAt: timeToString(task.VisibleAt),
}
}
func (t *mySQLTask) toTask() *tasq.Task {
return &tasq.Task{
ID: uuid.UUID(t.ID),
Type: t.Type,
Args: t.Args,
Queue: t.Queue,
Priority: t.Priority,
Status: t.Status,
ReceiveCount: t.ReceiveCount,
MaxReceives: t.MaxReceives,
LastError: parseNullableString(t.LastError),
CreatedAt: parseTime(t.CreatedAt),
StartedAt: parseNullableTime(t.StartedAt),
FinishedAt: parseNullableTime(t.FinishedAt),
VisibleAt: parseTime(t.VisibleAt),
}
}
func mySQLTasksToTasks(mySQLTasks []*mySQLTask) []*tasq.Task {
tasks := make([]*tasq.Task, len(mySQLTasks))
for i, mySQLTask := range mySQLTasks {
tasks[i] = mySQLTask.toTask()
}
return tasks
}
func stringToSQLNullString(input *string) sql.NullString {
if input == nil {
return sql.NullString{
String: "",
Valid: false,
}
}
return sql.NullString{
String: *input,
Valid: true,
}
}
func timeToString(input time.Time) string {
return input.Format(timeFormat)
}
func timeToSQLNullString(input *time.Time) sql.NullString {
if input == nil {
return sql.NullString{
String: "",
Valid: false,
}
}
return sql.NullString{
String: input.Format(timeFormat),
Valid: true,
}
}
func parseNullableString(input sql.NullString) *string {
if !input.Valid {
return nil
}
return &input.String
}
func parseTime(input string) time.Time {
parsedTime, err := time.Parse(timeFormat, input)
if err != nil {
return time.Time{}
}
return parsedTime
}
func parseNullableTime(input sql.NullString) *time.Time {
if !input.Valid {
return nil
}
parsedTime := parseTime(input.String)
return &parsedTime
}
// Package postgres provides the implementation of a tasq repository in PostgreSQL
package postgres
import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
"strings"
"text/template"
"time"
"github.com/google/uuid"
"github.com/greencoda/tasq"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)
const driverName = "postgres"
var (
errUnexpectedDataSourceType = errors.New("unexpected dataSource type")
errFailedToExecuteUpdate = errors.New("failed to execute update query")
errFailedToExecuteDelete = errors.New("failed to execute delete query")
errFailedToExecuteInsert = errors.New("failed to execute insert query")
errFailedToExecuteCreateTable = errors.New("failed to execute create table query")
errFailedToExecuteCreateType = errors.New("failed to execute create type query")
)
// Repository implements the menthods necessary for tasq to work in PostgreSQL.
type Repository struct {
db *sqlx.DB
statusTypeName string
tableName string
}
// NewRepository creates a new PostgreSQL Repository instance.
func NewRepository(dataSource any, prefix string) (*Repository, error) {
switch d := dataSource.(type) {
case string:
return newRepositoryFromDSN(d, prefix)
case *sql.DB:
return newRepositoryFromDB(d, prefix)
}
return nil, fmt.Errorf("%w: %T", errUnexpectedDataSourceType, dataSource)
}
func newRepositoryFromDSN(dsn string, prefix string) (*Repository, error) {
dbx, _ := sqlx.Open(driverName, dsn)
return &Repository{
db: dbx,
statusTypeName: statusTypeName(prefix),
tableName: tableName(prefix),
}, nil
}
func newRepositoryFromDB(db *sql.DB, prefix string) (*Repository, error) {
dbx := sqlx.NewDb(db, driverName)
return &Repository{
db: dbx,
statusTypeName: statusTypeName(prefix),
tableName: tableName(prefix),
}, nil
}
// Migrate prepares the database with the task status type
// and by adding the tasks table.
func (d *Repository) Migrate(ctx context.Context) error {
err := d.migrateStatus(ctx)
if err != nil {
return err
}
err = d.migrateTable(ctx)
if err != nil {
return err
}
return nil
}
// PingTasks pings a list of tasks by their ID
// and extends their invisibility timestamp with the supplied timeout parameter.
func (d *Repository) PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*tasq.Task, error) {
if len(taskIDs) == 0 {
return []*tasq.Task{}, nil
}
var (
pingedTasks []*postgresTask
pingTime = time.Now()
sqlTemplate = `UPDATE
{{.tableName}}
SET
"visible_at" = :visibleAt
WHERE
"id" = ANY(:pingedTaskIDs)
RETURNING id;`
stmt = d.prepareWithTableName(sqlTemplate)
)
defer d.closeStmt(stmt)
err := stmt.SelectContext(ctx, &pingedTasks, map[string]any{
"visibleAt": pingTime.Add(visibilityTimeout),
"pingedTaskIDs": pq.Array(taskIDs),
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return []*tasq.Task{}, fmt.Errorf("failed to update tasks: %w", err)
}
return postgresTasksToTasks(pingedTasks), nil
}
// PollTasks polls for available tasks matching supplied the parameters
// and sets their invisibility the supplied timeout parameter to the future.
func (d *Repository) PollTasks(ctx context.Context, types, queues []string, visibilityTimeout time.Duration, ordering tasq.Ordering, pollLimit int) ([]*tasq.Task, error) {
if pollLimit == 0 {
return []*tasq.Task{}, nil
}
var (
polledTasks []*postgresTask
pollTime = time.Now()
sqlTemplate = `UPDATE {{.tableName}} SET
"status" = :status,
"receive_count" = "receive_count" + 1,
"visible_at" = :visibleAt
WHERE
"id" IN (
SELECT
"id" FROM {{.tableName}}
WHERE
"type" = ANY(:pollTypes) AND
"queue" = ANY(:pollQueues) AND
"status" = ANY(:pollStatuses) AND
"visible_at" <= :pollTime
ORDER BY
:pollOrdering
LIMIT :pollLimit
FOR UPDATE )
RETURNING *;`
stmt = d.prepareWithTableName(sqlTemplate)
)
defer d.closeStmt(stmt)
err := stmt.SelectContext(ctx, &polledTasks, map[string]any{
"status": tasq.StatusEnqueued,
"visibleAt": pollTime.Add(visibilityTimeout),
"pollTypes": pq.Array(types),
"pollQueues": pq.Array(queues),
"pollStatuses": pq.Array(tasq.GetTaskStatuses(tasq.OpenTasks)),
"pollTime": pollTime,
"pollOrdering": pq.Array(getOrderingDirectives(ordering)),
"pollLimit": pollLimit,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return []*tasq.Task{}, fmt.Errorf("failed to update tasks: %w", err)
}
return postgresTasksToTasks(polledTasks), nil
}
// CleanTasks removes finished tasks from the queue
// if their creation date is past the supplied duration.
func (d *Repository) CleanTasks(ctx context.Context, cleanAge time.Duration) (int64, error) {
var (
cleanTime = time.Now()
sqlTemplate = `DELETE FROM {{.tableName}}
WHERE
"status" = ANY(:statuses) AND
"created_at" <= :cleanAt;`
stmt = d.prepareWithTableName(sqlTemplate)
)
defer d.closeStmt(stmt)
result, err := stmt.ExecContext(ctx, map[string]any{
"statuses": pq.Array(tasq.GetTaskStatuses(tasq.FinishedTasks)),
"cleanAt": cleanTime.Add(-cleanAge),
})
if err != nil {
return 0, fmt.Errorf("failed to delete tasks: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get number of affected rows: %w", err)
}
return rowsAffected, nil
}
// RegisterStart marks a task as started with the 'in progress' status
// and records the time of start.
func (d *Repository) RegisterStart(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
var (
updatedTask = new(postgresTask)
startTime = time.Now()
sqlTemplate = `UPDATE {{.tableName}} SET
"status" = :status,
"started_at" = :startTime
WHERE
"id" = :taskID
RETURNING *;`
stmt = d.prepareWithTableName(sqlTemplate)
)
defer d.closeStmt(stmt)
err := stmt.
QueryRowContext(ctx, map[string]any{
"status": tasq.StatusInProgress,
"startTime": startTime,
"taskID": task.ID,
}).
StructScan(updatedTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
return updatedTask.toTask(), nil
}
// RegisterError records an error message on the task as last error.
func (d *Repository) RegisterError(ctx context.Context, task *tasq.Task, errTask error) (*tasq.Task, error) {
var (
updatedTask = new(postgresTask)
sqlTemplate = `UPDATE {{.tableName}} SET
"last_error" = :errorMessage
WHERE
"id" = :taskID
RETURNING *;`
stmt = d.prepareWithTableName(sqlTemplate)
)
defer d.closeStmt(stmt)
err := stmt.
QueryRowContext(ctx, map[string]any{
"errorMessage": errTask.Error(),
"taskID": task.ID,
}).
StructScan(updatedTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
return updatedTask.toTask(), nil
}
// RegisterFinish marks a task as finished with the supplied status
// and records the time of finish.
func (d *Repository) RegisterFinish(ctx context.Context, task *tasq.Task, finishStatus tasq.TaskStatus) (*tasq.Task, error) {
var (
updatedTask = new(postgresTask)
finishTime = time.Now()
sqlTemplate = `UPDATE {{.tableName}} SET
"status" = :status,
"finished_at" = :finishTime
WHERE
"id" = :taskID
RETURNING *;`
stmt = d.prepareWithTableName(sqlTemplate)
)
defer d.closeStmt(stmt)
err := stmt.
QueryRowContext(ctx, map[string]any{
"status": finishStatus,
"finishTime": finishTime,
"taskID": task.ID,
}).
StructScan(updatedTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
return updatedTask.toTask(), nil
}
// SubmitTask adds the supplied task to the queue.
func (d *Repository) SubmitTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
var (
postgresTask = newFromTask(task)
sqlTemplate = `INSERT INTO {{.tableName}}
(id, type, args, queue, priority, status, max_receives, created_at, visible_at)
VALUES
(:id, :type, :args, :queue, :priority, :status, :maxReceives, :createdAt, :visibleAt)
RETURNING *;`
stmt = d.prepareWithTableName(sqlTemplate)
)
defer d.closeStmt(stmt)
err := stmt.
QueryRowContext(ctx, map[string]any{
"id": postgresTask.ID,
"type": postgresTask.Type,
"args": postgresTask.Args,
"queue": postgresTask.Queue,
"priority": postgresTask.Priority,
"status": postgresTask.Status,
"maxReceives": postgresTask.MaxReceives,
"createdAt": postgresTask.CreatedAt,
"visibleAt": postgresTask.VisibleAt,
}).
StructScan(postgresTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteInsert, err)
}
return postgresTask.toTask(), nil
}
// DeleteTask removes the supplied task from the queue.
func (d *Repository) DeleteTask(ctx context.Context, task *tasq.Task, safeDelete bool) error {
var (
conditions = []string{
`"id" = :taskID`,
}
parameters = map[string]any{
"taskID": task.ID,
}
)
if safeDelete {
d.applySafeDeleteConditions(&conditions, ¶meters)
}
sqlTemplate := `DELETE FROM {{.tableName}} WHERE ` + strings.Join(conditions, ` AND `) + `;`
_, err := d.prepareWithTableName(sqlTemplate).ExecContext(ctx, parameters)
if err != nil {
return fmt.Errorf("%w: %w", errFailedToExecuteDelete, err)
}
return nil
}
// RequeueTask marks a task as new, so it can be picked up again.
func (d *Repository) RequeueTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
var (
updatedTask = new(postgresTask)
sqlTemplate = `UPDATE {{.tableName}} SET
"status" = :status
WHERE
"id" = :taskID
RETURNING *;`
stmt = d.prepareWithTableName(sqlTemplate)
)
err := stmt.
QueryRowContext(ctx, map[string]any{
"status": tasq.StatusNew,
"taskID": task.ID,
}).
StructScan(updatedTask)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedToExecuteUpdate, err)
}
return updatedTask.toTask(), err
}
// CountTasks returns the number of tasks in the queue based on the supplied filters.
func (d *Repository) CountTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (int64, error) {
var (
count int64
sqlTemplate, parameters = d.buildCountSQLTemplate(taskStatuses, taskTypes, queues)
stmt = d.prepareWithTableName(sqlTemplate)
)
err := stmt.GetContext(ctx, &count, parameters)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return 0, fmt.Errorf("failed to count tasks: %w", err)
}
return count, nil
}
// ScanTasks returns a list of tasks in the queue based on the supplied filters.
func (d *Repository) ScanTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) ([]*tasq.Task, error) {
var (
scannedTasks []*postgresTask
sqlTemplate, parameters = d.buildScanSQLTemplate(taskStatuses, taskTypes, queues, ordering, scanLimit)
stmt = d.prepareWithTableName(sqlTemplate)
)
err := stmt.SelectContext(ctx, &scannedTasks, parameters)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return []*tasq.Task{}, fmt.Errorf("failed to scan tasks: %w", err)
}
return postgresTasksToTasks(scannedTasks), nil
}
// PurgeTasks removes all tasks from the queue based on the supplied filters.
func (d *Repository) PurgeTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (int64, error) {
var (
sqlTemplate, parameters = d.buildPurgeSQLTemplate(taskStatuses, taskTypes, queues, safeDelete)
stmt = d.prepareWithTableName(sqlTemplate)
)
result, err := stmt.ExecContext(ctx, parameters)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return 0, fmt.Errorf("failed to purge tasks: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get number of affected rows: %w", err)
}
return rowsAffected, nil
}
func (d *Repository) buildCountSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (string, map[string]any) {
var (
conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
sqlTemplate = `SELECT COUNT(*) FROM {{.tableName}}`
)
if len(conditions) > 0 {
sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
}
return sqlTemplate, parameters
}
func (d *Repository) buildScanSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) (string, map[string]any) {
var (
conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
sqlTemplate = `SELECT * FROM {{.tableName}}`
)
if len(conditions) > 0 {
sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
}
sqlTemplate += ` ORDER BY :scanOrdering LIMIT :limit;`
parameters["scanOrdering"] = pq.Array(getOrderingDirectives(ordering))
parameters["limit"] = scanLimit
return sqlTemplate, parameters
}
func (d *Repository) buildPurgeSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (string, map[string]any) {
var (
conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
sqlTemplate = `DELETE FROM {{.tableName}}`
)
if safeDelete {
d.applySafeDeleteConditions(&conditions, ¶meters)
}
if len(conditions) > 0 {
sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
}
return sqlTemplate + `;`, parameters
}
func (d *Repository) applySafeDeleteConditions(conditions *[]string, parameters *map[string]any) {
*conditions = append(*conditions, `(
(
"visible_at" <= :visibleAt
) OR (
"status" = ANY(:statuses) AND
"visible_at" > :visibleAt
)
)`)
(*parameters)["statuses"] = pq.Array([]tasq.TaskStatus{tasq.StatusNew})
(*parameters)["visibleAt"] = time.Now()
}
func (d *Repository) buildFilterConditions(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) ([]string, map[string]any) {
var (
conditions []string
parameters = make(map[string]any)
)
if len(taskStatuses) > 0 {
conditions = append(conditions, `"status" = ANY(:filterStatuses)`)
parameters["filterStatuses"] = pq.Array(taskStatuses)
}
if len(taskTypes) > 0 {
conditions = append(conditions, `"type" = ANY(:filterTypes)`)
parameters["filterTypes"] = pq.Array(taskTypes)
}
if len(queues) > 0 {
conditions = append(conditions, `"queue" = ANY(:filterQueues)`)
parameters["filterQueues"] = pq.Array(queues)
}
return conditions, parameters
}
func (d *Repository) migrateStatus(ctx context.Context) error {
var (
sqlTemplate = `DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = '{{.statusTypeName}}') THEN
CREATE TYPE {{.statusTypeName}} AS ENUM ({{.enumValues}});
END IF;
END$$;`
query = interpolateSQL(sqlTemplate, map[string]any{
"statusTypeName": d.statusTypeName,
"enumValues": sliceToPostgreSQLValueList(tasq.GetTaskStatuses(tasq.AllTasks)),
})
)
_, err := d.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("%w: %w", errFailedToExecuteCreateType, err)
}
return nil
}
func (d *Repository) migrateTable(ctx context.Context) error {
const sqlTemplate = `CREATE TABLE IF NOT EXISTS {{.tableName}} (
"id" UUID NOT NULL PRIMARY KEY,
"type" TEXT NOT NULL,
"args" BYTEA NOT NULL,
"queue" TEXT NOT NULL,
"priority" SMALLINT NOT NULL,
"status" {{.statusTypeName}} NOT NULL,
"receive_count" INTEGER NOT NULL DEFAULT 0,
"max_receives" INTEGER NOT NULL DEFAULT 0,
"last_error" TEXT,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT '0001-01-01 00:00:00.000000',
"started_at" TIMESTAMPTZ,
"finished_at" TIMESTAMPTZ,
"visible_at" TIMESTAMPTZ NOT NULL DEFAULT '0001-01-01 00:00:00.000000'
);`
query := interpolateSQL(sqlTemplate, map[string]any{
"tableName": d.tableName,
"statusTypeName": d.statusTypeName,
})
_, err := d.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("%w: %w", errFailedToExecuteCreateTable, err)
}
return nil
}
func (d *Repository) prepareWithTableName(sqlTemplate string) *sqlx.NamedStmt {
query := interpolateSQL(sqlTemplate, map[string]any{
"tableName": d.tableName,
})
namedStmt, err := d.db.PrepareNamed(query)
if err != nil {
panic(err)
}
return namedStmt
}
type closeableStmt interface {
Close() error
}
func (d *Repository) closeStmt(stmt closeableStmt) {
if err := stmt.Close(); err != nil {
panic(err)
}
}
func getOrderingDirectives(ordering tasq.Ordering) []string {
var (
OrderingCreatedAtFirst = []string{"created_at ASC", "priority DESC"}
OrderingPriorityFirst = []string{"priority DESC", "created_at ASC"}
)
if orderingDirectives, ok := map[tasq.Ordering][]string{
tasq.OrderingCreatedAtFirst: OrderingCreatedAtFirst,
tasq.OrderingPriorityFirst: OrderingPriorityFirst,
}[ordering]; ok {
return orderingDirectives
}
return OrderingCreatedAtFirst
}
func sliceToPostgreSQLValueList[T any](slice []T) string {
stringSlice := make([]string, 0, len(slice))
for _, s := range slice {
stringSlice = append(stringSlice, fmt.Sprint(s))
}
return fmt.Sprintf("'%s'", strings.Join(stringSlice, "','"))
}
func statusTypeName(prefix string) string {
const statusTypeName = "task_status"
if len(prefix) > 0 {
return prefix + "_" + statusTypeName
}
return statusTypeName
}
func tableName(prefix string) string {
const tableName = "tasks"
if len(prefix) > 0 {
return prefix + "_" + tableName
}
return tableName
}
func interpolateSQL(sql string, params map[string]any) string {
template, err := template.New("sql").Parse(sql)
if err != nil {
panic(err)
}
var outputBuffer bytes.Buffer
err = template.Execute(&outputBuffer, params)
if err != nil {
panic(err)
}
return outputBuffer.String()
}
package postgres
import (
"database/sql"
"time"
"github.com/google/uuid"
"github.com/greencoda/tasq"
)
type postgresTask struct {
ID uuid.UUID `db:"id"`
Type string `db:"type"`
Args []byte `db:"args"`
Queue string `db:"queue"`
Priority int16 `db:"priority"`
Status tasq.TaskStatus `db:"status"`
ReceiveCount int32 `db:"receive_count"`
MaxReceives int32 `db:"max_receives"`
LastError sql.NullString `db:"last_error"`
CreatedAt time.Time `db:"created_at"`
StartedAt *time.Time `db:"started_at"`
FinishedAt *time.Time `db:"finished_at"`
VisibleAt time.Time `db:"visible_at"`
}
func newFromTask(task *tasq.Task) *postgresTask {
return &postgresTask{
ID: task.ID,
Type: task.Type,
Args: task.Args,
Queue: task.Queue,
Priority: task.Priority,
Status: task.Status,
ReceiveCount: task.ReceiveCount,
MaxReceives: task.MaxReceives,
LastError: stringToSQLNullString(task.LastError),
CreatedAt: task.CreatedAt,
StartedAt: task.StartedAt,
FinishedAt: task.FinishedAt,
VisibleAt: task.VisibleAt,
}
}
func (t *postgresTask) toTask() *tasq.Task {
return &tasq.Task{
ID: t.ID,
Type: t.Type,
Args: t.Args,
Queue: t.Queue,
Priority: t.Priority,
Status: t.Status,
ReceiveCount: t.ReceiveCount,
MaxReceives: t.MaxReceives,
LastError: parseNullableString(t.LastError),
CreatedAt: t.CreatedAt,
StartedAt: t.StartedAt,
FinishedAt: t.FinishedAt,
VisibleAt: t.VisibleAt,
}
}
func postgresTasksToTasks(postgresTasks []*postgresTask) []*tasq.Task {
tasks := make([]*tasq.Task, len(postgresTasks))
for i, postgresTask := range postgresTasks {
tasks[i] = postgresTask.toTask()
}
return tasks
}
func stringToSQLNullString(input *string) sql.NullString {
if input == nil {
return sql.NullString{
String: "",
Valid: false,
}
}
return sql.NullString{
String: *input,
Valid: true,
}
}
func parseNullableString(input sql.NullString) *string {
if !input.Valid {
return nil
}
return &input.String
}
package tasq
import (
"bytes"
"encoding/gob"
"fmt"
"time"
"github.com/google/uuid"
)
// TaskStatus is an enum type describing the status a task is currently in.
type TaskStatus string
// The collection of possible task statuses.
const (
StatusNew TaskStatus = "NEW"
StatusEnqueued TaskStatus = "ENQUEUED"
StatusInProgress TaskStatus = "IN_PROGRESS"
StatusSuccessful TaskStatus = "SUCCESSFUL"
StatusFailed TaskStatus = "FAILED"
)
// TaskStatusGroup is an enum type describing the key used in the
// map of TaskStatuses which groups them for different purposes.
type TaskStatusGroup int
// The collection of possible task status groupings.
const (
AllTasks TaskStatusGroup = iota
OpenTasks
FinishedTasks
)
// GetTaskStatuses returns a slice of TaskStatuses based on the TaskStatusGroup
// passed as an argument.
func GetTaskStatuses(taskStatusGroup TaskStatusGroup) []TaskStatus {
if selected, ok := map[TaskStatusGroup][]TaskStatus{
AllTasks: {
StatusNew,
StatusEnqueued,
StatusInProgress,
StatusSuccessful,
StatusFailed,
},
OpenTasks: {
StatusNew,
StatusEnqueued,
StatusInProgress,
},
FinishedTasks: {
StatusSuccessful,
StatusFailed,
},
}[taskStatusGroup]; ok {
return selected
}
return nil
}
// Task is the struct used to represent an atomic task managed by tasq.
type Task struct {
ID uuid.UUID
Type string
Args []byte
Queue string
Priority int16
Status TaskStatus
ReceiveCount int32
MaxReceives int32
LastError *string
CreatedAt time.Time
StartedAt *time.Time
FinishedAt *time.Time
VisibleAt time.Time
}
// NewTask creates a new Task struct based on the supplied arguments required to define it.
func NewTask(taskType string, taskArgs any, queue string, priority int16, maxReceives int32) (*Task, error) {
taskID, err := uuid.NewRandom()
if err != nil {
return nil, fmt.Errorf("failed to generate new task ID: %w", err)
}
encodedArgs, err := encodeTaskArgs(taskArgs)
if err != nil {
return nil, err
}
return &Task{
ID: taskID,
Type: taskType,
Args: encodedArgs,
Queue: queue,
Priority: priority,
Status: StatusNew,
ReceiveCount: 0,
MaxReceives: maxReceives,
LastError: nil,
CreatedAt: time.Now(),
StartedAt: nil,
FinishedAt: nil,
VisibleAt: time.Time{},
}, nil
}
// IsLastReceive returns true if the task has reached its maximum number of receives.
func (t *Task) IsLastReceive() bool {
return t.ReceiveCount >= t.MaxReceives
}
// SetVisibility sets the time at which the task will become visible again.
func (t *Task) SetVisibility(visibleAt time.Time) {
t.VisibleAt = visibleAt
}
// UnmarshalArgs decodes the task arguments into the passed target interface.
func (t *Task) UnmarshalArgs(target any) error {
var (
buffer = bytes.NewBuffer(t.Args)
decoder = gob.NewDecoder(buffer)
)
if err := decoder.Decode(target); err != nil {
return fmt.Errorf("failed to decode task arguments: %w", err)
}
return nil
}
func encodeTaskArgs(taskArgs any) ([]byte, error) {
var (
buffer bytes.Buffer
encoder = gob.NewEncoder(&buffer)
)
err := encoder.Encode(taskArgs)
if err != nil {
return []byte{}, fmt.Errorf("failed to encode task arguments: %w", err)
}
return buffer.Bytes(), nil
}