package mdk
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
)
// ActorType represents the type of entity performing an action.
type ActorType string
const (
ActorHuman ActorType = "HUMAN"
ActorAIAgent ActorType = "AI_AGENT"
ActorSystem ActorType = "SYSTEM"
)
// JSONMap is a custom type for map[string]string that implements GORM/SQL scanner/valuer.
type JSONMap map[string]string
func (m JSONMap) Value() (driver.Value, error) {
if len(m) == 0 {
return nil, nil
}
return json.Marshal(m)
}
func (m *JSONMap) Scan(value interface{}) error {
if value == nil {
*m = nil
return nil
}
var bytes []byte
switch v := value.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
return json.Unmarshal(bytes, m)
}
// Actor represents the minimal interface for any security principal or identity.
type Actor interface {
GetID() string
GetType() ActorType
GetName() string
GetMetadata() map[string]string
}
// BaseActor is a simple, serializable struct that implements mdk.Actor.
type BaseActor struct {
ID string `json:"id"`
Type ActorType `json:"type"`
Name string `json:"name"`
Metadata map[string]string `json:"metadata,omitempty"`
}
var _ Actor = (*BaseActor)(nil)
func (b *BaseActor) GetID() string {
return b.ID
}
func (b *BaseActor) GetType() ActorType {
return b.Type
}
func (b *BaseActor) GetName() string {
return b.Name
}
func (b *BaseActor) GetMetadata() map[string]string {
return b.Metadata
}
type contextKey struct{}
var actorKey = contextKey{}
// WithActor stores the Actor in the context.
func WithActor(ctx context.Context, actor Actor) context.Context {
return context.WithValue(ctx, actorKey, actor)
}
// ActorFromContext retrieves the Actor from the context.
func ActorFromContext(ctx context.Context) (Actor, bool) {
actor, ok := ctx.Value(actorKey).(Actor)
return actor, ok
}
// TokenValidator defines the interface for validating authentication tokens.
type TokenValidator interface {
ValidateToken(ctx context.Context, token string) (Actor, error)
}
package mdk
import (
"context"
"errors"
"sync"
"time"
"gorm.io/gorm"
)
// EventBusCloser extends EventBus to support graceful shutdowns.
type EventBusCloser interface {
EventBus
Close() error
}
// Locker defines the interface for distributed locking.
type Locker interface {
Acquire(ctx context.Context, key string, ttl time.Duration, timeout time.Duration) (bool, error)
Release(ctx context.Context, key string) error
Close() error
}
var (
ErrLockAcquisitionTimeout = errors.New("lock acquisition timed out")
ErrLockNotHeld = errors.New("lock not held")
)
type lockContextKey string
const LockOwnerKey lockContextKey = "lock_owner"
// StateStore defines the interface for checkpointing workflow states.
type StateStore interface {
SaveState(ctx context.Context, execID string, stepID string, state string) error
GetState(ctx context.Context, execID string) (map[string]string, error)
InitializeExecution(ctx context.Context, execID string, input []byte) error
SaveInput(ctx context.Context, execID string, input []byte) error
GetInput(ctx context.Context, execID string) ([]byte, error)
SetTTL(ctx context.Context, execID string, ttl time.Duration) error
SaveStepOutput(ctx context.Context, execID string, stepID string, output []byte) error
GetStepOutput(ctx context.Context, execID string, stepID string) ([]byte, error)
ListExecutions(ctx context.Context, state string) ([]string, error)
RecordEventEmitted(ctx context.Context, execID string, eventType string) error
IsEventEmitted(ctx context.Context, execID string, eventType string) (bool, error)
}
// DialectProvider constructor signature for database dialects.
type DialectProvider func(dsn string) gorm.Dialector
// BusProvider constructor signature for event buses.
type BusProvider func(url string) (EventBusCloser, error)
// LockerProvider constructor signature for lockers.
type LockerProvider func(url string, bucketOrPrefix string) (Locker, error)
// StoreProvider constructor signature for workflow state stores.
type StoreProvider func(url string, bucketOrPrefix string) (StateStore, error)
var (
dialectsMu sync.RWMutex
dialects = make(map[string]DialectProvider)
busProvidersMu sync.RWMutex
busProviders = make(map[string]BusProvider)
lockersMu sync.RWMutex
lockers = make(map[string]LockerProvider)
storesMu sync.RWMutex
stores = make(map[string]StoreProvider)
)
// RegisterDialect registers a database dialect provider.
func RegisterDialect(name string, provider DialectProvider) {
dialectsMu.Lock()
defer dialectsMu.Unlock()
dialects[name] = provider
}
// GetDialect retrieves a database dialect provider.
func GetDialect(name string) (DialectProvider, bool) {
dialectsMu.RLock()
defer dialectsMu.RUnlock()
d, ok := dialects[name]
return d, ok
}
// RegisterEventBusProvider registers an event bus provider.
func RegisterEventBusProvider(name string, provider BusProvider) {
busProvidersMu.Lock()
defer busProvidersMu.Unlock()
busProviders[name] = provider
}
// GetEventBusProvider retrieves an event bus provider.
func GetEventBusProvider(name string) (BusProvider, bool) {
busProvidersMu.RLock()
defer busProvidersMu.RUnlock()
bp, ok := busProviders[name]
return bp, ok
}
// RegisterLocker registers a lock manager provider.
func RegisterLocker(name string, provider LockerProvider) {
lockersMu.Lock()
defer lockersMu.Unlock()
lockers[name] = provider
}
// GetLocker retrieves a lock manager provider.
func GetLocker(name string) (LockerProvider, bool) {
lockersMu.RLock()
defer lockersMu.RUnlock()
l, ok := lockers[name]
return l, ok
}
// RegisterStateStore registers a state store provider.
func RegisterStateStore(name string, provider StoreProvider) {
storesMu.Lock()
defer storesMu.Unlock()
stores[name] = provider
}
// GetStateStore retrieves a state store provider.
func GetStateStore(name string) (StoreProvider, bool) {
storesMu.RLock()
defer storesMu.RUnlock()
s, ok := stores[name]
return s, ok
}
package mdktest
import (
"context"
"fmt"
"log/slog"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/GoHyperrr/mdk"
"gorm.io/gorm"
)
// TestRuntime is a concrete, in-memory implementation of Runtime designed for unit testing.
type TestRuntime struct {
db *gorm.DB
bus *TestEventBus
workflowEngine *TestWorkflowEngine
configs map[string]any
logger *slog.Logger
modules map[string]mdk.Module
mu sync.RWMutex
}
// NewTestRuntime creates a new TestRuntime instance.
func NewTestRuntime(db *gorm.DB) *TestRuntime {
tr := &TestRuntime{
db: db,
configs: make(map[string]any),
logger: slog.Default(),
modules: make(map[string]mdk.Module),
}
tr.bus = NewTestEventBus(tr)
tr.workflowEngine = NewTestWorkflowEngine(tr)
return tr
}
func (tr *TestRuntime) DB() *gorm.DB {
return tr.db
}
func (tr *TestRuntime) Bus() mdk.EventBus {
return tr.bus
}
func (tr *TestRuntime) Workflows() mdk.WorkflowEngine {
return tr.workflowEngine
}
func (tr *TestRuntime) Config(key string) any {
tr.mu.RLock()
defer tr.mu.RUnlock()
return tr.configs[key]
}
func (tr *TestRuntime) SetConfig(key string, val any) {
tr.mu.Lock()
defer tr.mu.Unlock()
tr.configs[key] = val
}
func (tr *TestRuntime) Logger() *slog.Logger {
return tr.logger
}
func (tr *TestRuntime) SetLogger(l *slog.Logger) {
tr.logger = l
}
func (tr *TestRuntime) Module(id string) (mdk.Module, bool) {
tr.mu.RLock()
defer tr.mu.RUnlock()
m, ok := tr.modules[id]
return m, ok
}
func (tr *TestRuntime) SetModule(id string, m mdk.Module) {
tr.mu.Lock()
defer tr.mu.Unlock()
tr.modules[id] = m
}
// TestEventBus is an in-memory implementation of EventBus for testing.
type TestEventBus struct {
rt *TestRuntime
mu sync.RWMutex
handlers map[string][]mdk.EventHandler
Published []mdk.Event
}
func NewTestEventBus(rt *TestRuntime) *TestEventBus {
return &TestEventBus{
rt: rt,
handlers: make(map[string][]mdk.EventHandler),
}
}
func (teb *TestEventBus) Publish(ctx context.Context, e mdk.Event) error {
teb.mu.Lock()
if e.OccurredAt.IsZero() {
e.OccurredAt = time.Now()
}
teb.Published = append(teb.Published, e)
teb.mu.Unlock()
teb.mu.RLock()
key := e.Namespace + "." + e.Type
handlers := append([]mdk.EventHandler{}, teb.handlers[key]...)
wildcardHandlers := append([]mdk.EventHandler{}, teb.handlers[e.Namespace+".*"]...)
teb.mu.RUnlock()
for _, h := range handlers {
_ = h(ctx, e)
}
for _, h := range wildcardHandlers {
_ = h(ctx, e)
}
return nil
}
func (teb *TestEventBus) Subscribe(namespace, eventType string, handler mdk.EventHandler) (func(), error) {
teb.mu.Lock()
defer teb.mu.Unlock()
key := namespace + "." + eventType
teb.handlers[key] = append(teb.handlers[key], handler)
return func() {
teb.mu.Lock()
defer teb.mu.Unlock()
handlers := teb.handlers[key]
for i, h := range handlers {
if reflect.ValueOf(h).Pointer() == reflect.ValueOf(handler).Pointer() {
teb.handlers[key] = append(handlers[:i], handlers[i+1:]...)
break
}
}
}, nil
}
var runIDCounter int64
// TestWorkflowEngine is a simple, synchronous implementation of WorkflowEngine for unit tests.
type TestWorkflowEngine struct {
rt *TestRuntime
mu sync.RWMutex
workflows map[string]mdk.Workflow
handlers map[string]mdk.StepHandler
runs map[string]mdk.StepStatus
outputs map[string]map[string]any
}
func NewTestWorkflowEngine(rt *TestRuntime) *TestWorkflowEngine {
return &TestWorkflowEngine{
rt: rt,
workflows: make(map[string]mdk.Workflow),
handlers: make(map[string]mdk.StepHandler),
runs: make(map[string]mdk.StepStatus),
outputs: make(map[string]map[string]any),
}
}
func (twe *TestWorkflowEngine) Register(w mdk.Workflow) error {
twe.mu.Lock()
defer twe.mu.Unlock()
twe.workflows[w.ID] = w
return nil
}
func (twe *TestWorkflowEngine) RegisterHandler(name string, handler mdk.StepHandler) error {
twe.mu.Lock()
defer twe.mu.Unlock()
twe.handlers[name] = handler
return nil
}
func (twe *TestWorkflowEngine) Execute(ctx context.Context, workflowID string, input map[string]any) (string, error) {
val := atomic.AddInt64(&runIDCounter, 1)
runID := fmt.Sprintf("wf_run_%d_%d", time.Now().UnixNano(), val)
go func() {
_, _ = twe.ExecuteSync(ctx, runID, workflowID, input)
}()
return runID, nil
}
func (twe *TestWorkflowEngine) Status(ctx context.Context, runID string) (mdk.StepStatus, error) {
twe.mu.RLock()
defer twe.mu.RUnlock()
return twe.runs[runID], nil
}
func (twe *TestWorkflowEngine) Cancel(ctx context.Context, runID string) error {
twe.mu.Lock()
defer twe.mu.Unlock()
twe.runs[runID] = mdk.StepFailed
return nil
}
func (twe *TestWorkflowEngine) ExecuteSync(ctx context.Context, runID, workflowID string, input map[string]any) (map[string]any, error) {
twe.mu.RLock()
wf, ok := twe.workflows[workflowID]
twe.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("workflow not found: %s", workflowID)
}
twe.mu.Lock()
twe.runs[runID] = mdk.StepRunning
twe.mu.Unlock()
results := make(map[string]any)
for k, v := range input {
results[k] = v
}
results["input"] = input
results["_workflow_id"] = runID
completed := make(map[string]bool)
launched := make(map[string]bool)
var history []mdk.Step
for len(completed) < len(wf.Steps) {
var ready []mdk.Step
for _, step := range wf.Steps {
if launched[step.ID] {
continue
}
canRun := true
for _, dep := range step.DependsOn {
if !completed[dep] {
canRun = false
break
}
}
if canRun {
ready = append(ready, step)
}
}
if len(ready) == 0 {
twe.mu.Lock()
twe.runs[runID] = mdk.StepFailed
twe.mu.Unlock()
return results, fmt.Errorf("deadlock detected or unresolved dependencies in test workflow execution")
}
for _, step := range ready {
launched[step.ID] = true
twe.mu.RLock()
handler := twe.handlers[step.Uses]
twe.mu.RUnlock()
if handler == nil {
twe.mu.Lock()
twe.runs[runID] = mdk.StepFailed
twe.mu.Unlock()
return results, fmt.Errorf("handler not found for step %s (uses %s)", step.ID, step.Uses)
}
sCtx := mdk.StepContext{
Ctx: ctx,
Runtime: twe.rt,
WorkflowID: workflowID,
RunID: runID,
StepID: step.ID,
Input: results,
}
res := handler(sCtx)
if res.Err != nil {
twe.mu.Lock()
twe.runs[runID] = mdk.StepFailed
twe.mu.Unlock()
// Run compensations in reverse order
for i := len(history) - 1; i >= 0; i-- {
hStep := history[i]
var compensate mdk.StepHandler
if hStep.Saga != nil && hStep.Saga.Uses != "" {
twe.mu.RLock()
h := twe.handlers[hStep.Saga.Uses]
twe.mu.RUnlock()
if h != nil {
compensate = h
}
}
if compensate != nil {
sCtxComp := mdk.StepContext{
Ctx: ctx,
Runtime: twe.rt,
WorkflowID: workflowID,
RunID: runID,
StepID: hStep.ID,
Input: results,
}
_ = compensate(sCtxComp)
}
}
return results, fmt.Errorf("step %s failed: %w", step.ID, res.Err)
}
for k, v := range res.Output {
results[k] = v
}
results[step.ID] = res.Output
completed[step.ID] = true
history = append(history, step)
}
}
twe.mu.Lock()
twe.runs[runID] = mdk.StepCompleted
twe.outputs[runID] = results
twe.mu.Unlock()
return results, nil
}
// TestLineageData implements LineageData for testing.
type TestLineageData struct {
ID string
Name string
State string
Error string
StartedAt time.Time
EndedAt *time.Time
Events []mdk.Event
}
func (tld TestLineageData) GetID() string { return tld.ID }
func (tld TestLineageData) GetName() string { return tld.Name }
func (tld TestLineageData) GetState() string { return tld.State }
func (tld TestLineageData) GetError() string { return tld.Error }
func (tld TestLineageData) GetStartedAt() time.Time { return tld.StartedAt }
func (tld TestLineageData) GetEndedAt() *time.Time { return tld.EndedAt }
func (tld TestLineageData) GetEvents() []mdk.Event { return tld.Events }
// TestProjector implements Projector for testing.
type TestProjector struct {
Lineages []mdk.LineageData
}
func (tp *TestProjector) ListLineages() []mdk.LineageData {
return tp.Lineages
}
func (tp *TestProjector) QueryLineages(filter func(mdk.LineageData) bool) []mdk.LineageData {
var out []mdk.LineageData
for _, l := range tp.Lineages {
if filter(l) {
out = append(out, l)
}
}
return out
}
// ProjectorModule is a generic mock implementation of a module that provides a Projector.
type ProjectorModule struct {
ModuleID string
Proj mdk.Projector
}
// ID returns the module ID, defaulting to "core.context".
func (pm *ProjectorModule) ID() string {
if pm.ModuleID != "" {
return pm.ModuleID
}
return "core.context"
}
func (pm *ProjectorModule) Init(ctx context.Context, rt mdk.Runtime) error {
return nil
}
func (pm *ProjectorModule) Shutdown(ctx context.Context) error {
return nil
}
func (pm *ProjectorModule) Models() []any {
return nil
}
func (pm *ProjectorModule) Routes() []mdk.Route {
return nil
}
func (pm *ProjectorModule) Projector() mdk.Projector {
return pm.Proj
}
package mdk
import (
"database/sql/driver"
"encoding/json"
"errors"
)
// Metadata represents custom optional JSON metadata stored as a text/json field.
type Metadata map[string]interface{}
// Value returns the driver Value.
func (m Metadata) Value() (driver.Value, error) {
if m == nil {
return nil, nil
}
ba, err := json.Marshal(m)
if err != nil {
return nil, err
}
return string(ba), nil
}
// Scan scans value into Metadata.
func (m *Metadata) Scan(val interface{}) error {
if val == nil {
*m = nil
return nil
}
var ba []byte
switch v := val.(type) {
case []byte:
ba = v
case string:
ba = []byte(v)
default:
return errors.New("failed to scan Metadata: invalid type")
}
t := make(map[string]interface{})
if err := json.Unmarshal(ba, &t); err != nil {
return err
}
*m = t
return nil
}
package mdk
import (
"fmt"
"sync"
)
var global = ®istry{
factories: make(map[string]Factory),
}
type registry struct {
mu sync.RWMutex
factories map[string]Factory
}
// Register adds a module factory to the global registry.
// Call this inside an init() function in your module package.
//
// Example:
//
// func init() {
// mdk.Register(func() mdk.Module { return &MyModule{} })
// }
func Register(factory Factory) {
m := factory()
id := m.ID()
global.mu.Lock()
defer global.mu.Unlock()
if _, exists := global.factories[id]; exists {
panic(fmt.Sprintf("mdk: module %q already registered", id))
}
global.factories[id] = factory
}
// Registered returns a snapshot of all registered module factories.
func Registered() map[string]Factory {
global.mu.RLock()
defer global.mu.RUnlock()
out := make(map[string]Factory, len(global.factories))
for k, v := range global.factories {
out[k] = v
}
return out
}