// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package sim import ( "pgregory.net/rapid" ) // BiasedBool returns a rapid generator for boolean values biased towards true with probability p. func BiasedBool(p float64) *rapid.Generator[bool] { notOne := func(v float64) bool { return v != 1 } return rapid.Custom(func(t *rapid.T) bool { return rapid.Float64Range(0, 1).Filter(notOne).Draw(t, "p") < p || p == 1.0 }) }
// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package sim import ( "cmp" "slices" "time" "github.com/addrummond/heap" "github.com/gammazero/deque" "github.com/stretchr/testify/require" "pgregory.net/rapid" ) type JobConfig struct { MinJitter time.Duration MedJitter time.Duration MaxJitter time.Duration Debug bool } func EstimateJob(t *rapid.T, plan *Plan, trialCount int, config *JobConfig) map[*Plan]*ResultRange { chk := require.New(t) estimateMap := make(map[*Plan]*ResultRange) var estimateSubjobs func(tasks []*Task) runTrials := func(plan *Plan) { estimateMapLenOrigin := len(estimateMap) estimateSubjobs(plan.RootTasks) rr := &ResultRange{} durations := make([]time.Duration, trialCount) for tr := range trialCount { r := estimateJob(t, plan, config, estimateMap, 0, "") durations[tr] = r.OverallDuration if tr == 0 { rr.MinMaxConcurrencyByPool = slices.Clone(r.MaxConcurrencyByPool) rr.MaxMaxConcurrencyByPool = slices.Clone(r.MaxConcurrencyByPool) rr.MinOverallDuration = r.OverallDuration rr.MaxOverallDuration = r.OverallDuration } else { for i := range len(r.MaxConcurrencyByPool) { rr.MinMaxConcurrencyByPool[i] = min(rr.MinMaxConcurrencyByPool[i], r.MaxConcurrencyByPool[i]) rr.MaxMaxConcurrencyByPool[i] = max(rr.MaxMaxConcurrencyByPool[i], r.MaxConcurrencyByPool[i]) } rr.MinOverallDuration = min(rr.MinOverallDuration, r.OverallDuration) rr.MaxOverallDuration = max(rr.MaxOverallDuration, r.OverallDuration) } } slices.Sort(durations) rr.MedOverallDuration = durations[len(durations)/2] estimateMap[plan] = rr chk.Equal(1+plan.SubplanCount, len(estimateMap)-estimateMapLenOrigin) } estimateSubjobs = func(tasks []*Task) { for _, task := range tasks { for _, subjob := range task.Subjobs { runTrials(subjob) } estimateSubjobs(task.Children) } } runTrials(plan) return estimateMap } func estimateJob( t *rapid.T, plan *Plan, config *JobConfig, subjobEstimates map[*Plan]*ResultRange, simTimeOrigin time.Duration, indent string, ) *Result { simTime := simTimeOrigin chk := require.New(t) var eventHeap heap.Heap[taskEvent, heap.Min] if config.MinJitter < 0 { panic("MinJitter may not be less than zero") } if config.MedJitter < config.MinJitter { panic("MedJitter may not be less than MinJitter") } if config.MaxJitter < config.MedJitter { panic("MaxJitter may not be less than MedJitter") } jitter := func() time.Duration { return time.Duration(biasedInt64( int64(config.MinJitter), int64(config.MedJitter), int64(config.MaxJitter), ).Draw(t, "jitterNoise")) } poolCount := len(plan.Config.ConcurrencyLimits) waitersByPool := make([][]func(), poolCount) concurrencyByPool := make([]int, poolCount) maxConcurrencyByPool := make([]int64, poolCount) var startTask func(task *Task) var scatterTask func(task *Task, then func()) scatterTask = func(task *Task, then func()) { pool := task.Pool concurrency := &concurrencyByPool[pool] if *concurrency < plan.Config.ConcurrencyLimits[pool] { *concurrency++ maxConcurrencyByPool[pool] = max(maxConcurrencyByPool[pool], int64(*concurrency)) if config.Debug { t.Logf("%v%s %v launching on pool %d (concurrency now %d)", simTime, indent, task, pool, *concurrency) } startTask(task) then() } else { waiters := &waitersByPool[pool] if config.Debug { t.Logf("%v%s %v blocked on pool %d along with %d others", simTime, indent, task, task.Pool, len(*waiters)) } *waiters = append(*waiters, func() { scatterTask(task, then) }) } } var scatterRootTask func(i int) scatterRootTask = func(i int) { scatterNext := func() { next := i + 1 if next < len(plan.RootTasks) { scatterRootTask(next) } } scatterTask(plan.RootTasks[i], scatterNext) } var endTask func(task *Task) startTask = func(task *Task) { endTime := simTime + jitter() if config.Debug { t.Logf("%v%s starting %v at %v", simTime, indent, task, endTime) } endTime += task.SelfTimes[0] for step, subjobPlan := range task.Subjobs { if config.Debug { t.Logf("%v%s estimating %v subjob %d of %d", simTime, indent, task, step+1, len(task.Subjobs)) } e := subjobEstimates[subjobPlan] endTime += time.Duration(biasedInt64( int64(e.MinOverallDuration), int64(e.MedOverallDuration), int64(e.MaxOverallDuration), ).Draw(t, "subjobDuration")) endTime += task.SelfTimes[step+1] } heap.PushOrderable(&eventHeap, taskEvent{ Time: endTime, Func: func() { endTask(task) }, }) if config.Debug { t.Logf("%v%s scheduled %v end at %v", simTime, indent, task, endTime) } } var activeGatherThreadCount int var gatherQueue deque.Deque[*Task] var postGather func(task *Task) endTask = func(task *Task) { pool := task.Pool concurrency := &concurrencyByPool[pool] *concurrency-- if config.Debug { t.Logf("%v%s %v released pool %d", simTime, indent, task, pool) } waiters := &waitersByPool[pool] if len(*waiters) > 0 { wi := rapid.IntRange(0, len(*waiters)-1).Draw(t, "waiter") waiter := (*waiters)[wi] *waiters = slices.Delete(*waiters, wi, wi+1) waiter() } chk.GreaterOrEqual(simTime, task.PathDurationAtTaskEnd) postGather(task) } var startGather func(task *Task) postGather = func(task *Task) { if activeGatherThreadCount < plan.Config.MaxGatherThreadCount { activeGatherThreadCount++ startGather(task) } else { if config.Debug { t.Logf("%v%s queuing %v gather (already %d active gather threads)", simTime, indent, task, activeGatherThreadCount) } gatherQueue.PushBack(task) } } var advanceGather func(task *Task, step int) startGather = func(task *Task) { advanceGather(task, 0) } var endGather func(task *Task) advanceGather = func(task *Task, step int) { endTime := simTime + jitter() if config.Debug && step == 0 { t.Logf("%v%s starting %v gather at %v", simTime, indent, task, endTime) } endTime += task.GatherTimes[step] if step < len(task.Children) { heap.PushOrderable(&eventHeap, taskEvent{ Time: endTime, Func: func() { scatterTask(task.Children[step], func() { advanceGather(task, step+1) }) }, }) } else { heap.PushOrderable(&eventHeap, taskEvent{ Time: endTime, Func: func() { endGather(task) }, }) } } stoppedGatherThreads := 0 endGather = func(task *Task) { // Gather complete, start next if needed if task.ReturnErrorFromGather { // Do not start next gather nor decrement activeGatherThreadCount if config.Debug { t.Logf("%v%s %v gather returns error, stopping", simTime, indent, task) } stoppedGatherThreads++ } else { if config.Debug { t.Logf("%v%s %v gather complete", simTime, indent, task) } if gatherQueue.Len() == 0 { activeGatherThreadCount-- chk.GreaterOrEqual(activeGatherThreadCount, 0) } else { startGather(gatherQueue.PopFront()) } } } scatterRootTask(0) var concurrentEvents []taskEvent for stoppedGatherThreads < plan.Config.MaxGatherThreadCount { event, ok := heap.PopOrderable(&eventHeap) if !ok { break } concurrentEvents = concurrentEvents[:0] for { concurrentEvents = append(concurrentEvents, event) event, ok = heap.Peek(&eventHeap) if !ok || event.Time != concurrentEvents[0].Time { break } _, _ = heap.PopOrderable(&eventHeap) } if config.Debug && len(concurrentEvents) > 1 { t.Logf("%v%s have %d concurrent events", simTime, indent, len(concurrentEvents)) } if len(concurrentEvents) > 1 { concurrentEvents = rapid.Permutation(concurrentEvents).Draw(t, "concurrentEvents") } for _, event := range concurrentEvents { simTime = event.Time event.Func() } } if config.Debug { for pool, waiters := range waitersByPool { t.Logf("%d waiters remain in pool %d", len(waiters), pool) } t.Logf("%d gathers in queue", gatherQueue.Len()) } result := &Result{ MaxConcurrencyByPool: maxConcurrencyByPool, OverallDuration: simTime - simTimeOrigin, } chk.GreaterOrEqual(result.OverallDuration, plan.MaxPathDuration) if config.Debug { t.Logf("%v %v estimate done: %v", simTime, plan, *result) } return result } type taskEvent struct { Time time.Duration Func func() } func (a *taskEvent) Cmp(b *taskEvent) int { return cmp.Compare(a.Time, b.Time) }
// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package sim import ( "fmt" "slices" "time" "github.com/stretchr/testify/require" "pgregory.net/rapid" ) type Plan struct { ID int SubplanCount int Config PlanConfig PathCount int TaskCount int SubjobTaskCount int MaxPathDuration time.Duration RootTasks []*Task } type PlanConfig struct { ConcurrencyLimits []int MaxGatherThreadCount int SubjobProbability float64 MaxSubjobCount int MaxSubjobDepth int MaxPathCount int MaxPathLength int MinNewIntermediateChildProbability float64 MaxNewIntermediateChildProbability float64 MinTaskDuration time.Duration MedTaskDuration time.Duration MaxTaskDuration time.Duration TaskErrorProbability float64 MinGatherDuration time.Duration MedGatherDuration time.Duration MaxGatherDuration time.Duration GatherErrorProbability float64 OverallDurationBudget time.Duration } var DefaultPlanConfig = PlanConfig{ ConcurrencyLimits: []int{1}, MaxGatherThreadCount: 1, SubjobProbability: 0.1, MaxSubjobCount: 2, MaxSubjobDepth: 3, MaxPathCount: 100, MaxPathLength: 10, MinNewIntermediateChildProbability: 0, MaxNewIntermediateChildProbability: 0.1, MinTaskDuration: 10 * time.Microsecond, MedTaskDuration: 10_000 * time.Microsecond, MaxTaskDuration: 200_000 * time.Microsecond, TaskErrorProbability: 0.0, MinGatherDuration: 10 * time.Microsecond, MedGatherDuration: 100 * time.Microsecond, MaxGatherDuration: 1000 * time.Microsecond, GatherErrorProbability: 0.0, OverallDurationBudget: 500 * time.Millisecond, } var nextIDs idCounters func NewPlanConfig(t *rapid.T) *PlanConfig { config := &PlanConfig{} *config = DefaultPlanConfig config.ConcurrencyLimits = rapid.SliceOfN( rapid.Custom(func(t *rapid.T) int { mpc := 10 // max pool concurrency // bias toward pool size of 2 return 2 + rapid.IntRange(-1, mpc-2).Draw(t, "concurrencyLimit") }), 1, 3).Draw(t, "concurrencyLimits") //config.MaxGatherThreadCount = rapid.IntRange(1, 3).Draw(t, "maxGatherThreadCount") return config } // NewPlan creates a hierarchy of simulated tasks for testing. func NewPlan(t *rapid.T, config *PlanConfig) *Plan { plan := newPlan(t, config, &nextIDs, 0) t.Logf("NewPlan: %#v", plan) return plan } func newPlan(t *rapid.T, config *PlanConfig, nextIDs *idCounters, depth int) *Plan { plan := &Plan{ ID: nextIDs.Plan, Config: *config, } //t.Logf("NewPlanConfig: %+v", plan.Config) nextIDs.Plan++ nextIDsOrigin := *nextIDs config = &plan.Config // use copy from here on out chk := require.New(t) var minConcurrencyLimit int chk.Greater(len(config.ConcurrencyLimits), 0) for i, limit := range config.ConcurrencyLimits { chk.Greater(limit, 0) if i == 0 || limit < minConcurrencyLimit { minConcurrencyLimit = limit } } chk.Greater(config.MedTaskDuration, time.Duration(0)) chk.LessOrEqual(config.MinTaskDuration, config.MedTaskDuration) chk.LessOrEqual(config.MedTaskDuration, config.MaxTaskDuration) chk.Greater(config.MedGatherDuration, time.Duration(0)) chk.LessOrEqual(config.MinGatherDuration, config.MedGatherDuration) chk.LessOrEqual(config.MedGatherDuration, config.MaxGatherDuration) minStepDuration := config.MinTaskDuration + config.MinGatherDuration medStepDuration := config.MedTaskDuration + config.MedGatherDuration maxStepDuration := config.MaxTaskDuration + config.MaxGatherDuration maxPathLength := int64(config.MaxPathLength) chk.Greater(maxPathLength, int64(0)) if minStepDuration > 0 { maxPathLength = min(int64(config.OverallDurationBudget/minStepDuration), maxPathLength) } medPathLength := int64(1) if medStepDuration > 0 { medPathLength = min(int64(config.OverallDurationBudget/medStepDuration), maxPathLength) } medPathDuration := time.Duration(medPathLength) * medStepDuration budgetedPaths := int64(minConcurrencyLimit) if medPathDuration > 0 { budgetedPaths = max(1, min( int64(config.OverallDurationBudget)*int64(minConcurrencyLimit)/int64(medPathDuration), int64(config.MaxPathCount), )) } plan.PathCount = int(biasedInt64(1, max(1, budgetedPaths/2), budgetedPaths).Draw(t, "pathCount")) newIntermediateChildProbability := rapid.Float64Range( config.MinNewIntermediateChildProbability, config.MaxNewIntermediateChildProbability, ).Draw(t, "newIntermediateChildProbability") createNewIntermediateChild := func() bool { return BiasedBool(newIntermediateChildProbability).Draw(t, "createNewIntermediateChild") } newTask := func(plan *Plan, parent *Task, pathDurationBudget time.Duration) *Task { task := &Task{ ID: nextIDs.Task, Pool: rapid.IntRange(0, len(config.ConcurrencyLimits)-1).Draw(t, "poolIndex"), Parent: parent, ReturnErrorFromTask: BiasedBool(config.TaskErrorProbability).Draw(t, "returnErrorFromTask"), ReturnErrorFromGather: BiasedBool(config.GatherErrorProbability).Draw(t, "returnErrorFromGather"), } nextIDs.Task++ parent.Children = append(parent.Children, task) totalSelfTime := min( time.Duration(biasedInt64( int64(config.MinTaskDuration), int64(config.MedTaskDuration), int64(config.MaxTaskDuration), ).Draw(t, "totalSelfTime")), pathDurationBudget, ) task.PathDurationAtTaskEnd = parent.PathDurationAtTaskEnd + totalSelfTime totalGatherTime := min( time.Duration(biasedInt64( int64(config.MinGatherDuration), int64(config.MedGatherDuration), int64(config.MaxGatherDuration), ).Draw(t, "totalGatherTime")), pathDurationBudget-totalSelfTime, ) // Maybe add subjobs if config.MaxSubjobDepth > 0 { subjobDurationBudget := pathDurationBudget - totalSelfTime - totalGatherTime for subjobDurationBudget >= medStepDuration && BiasedBool(config.SubjobProbability).Draw(t, "addSubjob") { subjobConfig := plan.Config subjobConfig.MaxSubjobDepth-- subjobConfig.OverallDurationBudget = time.Duration( biasedInt64( int64(medStepDuration), int64(max(medStepDuration, subjobDurationBudget/2)), int64(subjobDurationBudget), ).Draw(t, "subjobDurationBudget"), ) subjobPlan := newPlan(t, &subjobConfig, nextIDs, depth+1) task.Subjobs = append(task.Subjobs, subjobPlan) subjobDurationBudget -= subjobPlan.MaxPathDuration plan.SubjobTaskCount += subjobPlan.TaskCount + subjobPlan.SubjobTaskCount } } // Interleave self time and subjobs task.SelfTimes = make([]time.Duration, len(task.Subjobs)+1) for i := range len(task.SelfTimes) - 1 { st := time.Duration( biasedInt64( 0, int64(totalSelfTime)/int64(len(task.SelfTimes)-i), int64(totalSelfTime), ).Draw(t, "selfTimeChunk"), ) task.SelfTimes[i] = st totalSelfTime -= st } task.SelfTimes[len(task.SelfTimes)-1] = totalSelfTime // GatherTimes will be interleaved with Children later, after Children // has been populated. For now, just record the total gather time. task.GatherTimes = []time.Duration{totalGatherTime} return task } var addPath func(parent *Task, maxSteps int, pathDurationBudget time.Duration) time.Duration addPath = func(parent *Task, maxSteps int, pathDurationBudget time.Duration) time.Duration { if maxSteps <= 0 || pathDurationBudget <= 0 { return parent.PathDurationAtTaskEnd + parent.GatherDuration() } var child *Task if len(parent.Children) == 0 || pathDurationBudget <= medStepDuration || createNewIntermediateChild() { child = newTask(plan, parent, pathDurationBudget) } else { // Avoid rapid.SampledFrom because it will print the long-form (%#v) // representation of the task in the log. child = parent.Children[rapid.IntRange(0, len(parent.Children)-1).Draw(t, "child")] } return addPath(child, maxSteps-1, pathDurationBudget-child.TaskDuration()) } var rootTask Task for range plan.PathCount { // Decide on a duration (length) for this particular path pathDurationBudget := time.Duration(biasedInt64( int64(medStepDuration), int64(min(maxStepDuration, config.OverallDurationBudget)), int64(max(maxStepDuration, config.OverallDurationBudget)), ).Draw(t, "pathDurationBudget")) pathDuration := addPath(&rootTask, config.MaxPathLength, pathDurationBudget) plan.MaxPathDuration = max(plan.MaxPathDuration, pathDuration) } var populateChildGatherTimes func(parent *Task) populateChildGatherTimes = func(parent *Task) { for _, child := range parent.Children { totalGatherTime := child.GatherTimes[0] child.GatherTimes = slices.Grow(child.GatherTimes[:0], len(child.Children)+1)[:len(child.Children)+1] for i := range len(child.Children) { gt := time.Duration( biasedInt64( 0, int64(totalGatherTime)/int64(len(child.GatherTimes)-i), int64(totalGatherTime), ).Draw(t, "gatherTimeChunk"), ) child.GatherTimes[i] = gt totalGatherTime -= gt } child.GatherTimes[len(child.GatherTimes)-1] = totalGatherTime populateChildGatherTimes(child) } } populateChildGatherTimes(&rootTask) // Create a root GatherTimes slice of the appropriate length so that // recalculateChildPathDurations doesn't get tripped up. rootTask.GatherTimes = make([]time.Duration, len(rootTask.Children)+1) // Recalculate accurate path durations now that the gather times have been // interleaved with the children. var recalculateChildPathDurations func(parent *Task) plan.MaxPathDuration = 0 recalculateChildPathDurations = func(parent *Task) { if len(parent.Children) == 0 { plan.MaxPathDuration = max(plan.MaxPathDuration, parent.PathDurationAtTaskEnd+parent.GatherDuration()) return } var parentGatherDuration time.Duration for i, child := range parent.Children { parentGatherDuration += parent.GatherTimes[i] child.PathDurationAtTaskEnd = parent.PathDurationAtTaskEnd + parentGatherDuration + child.TaskDuration() recalculateChildPathDurations(child) } } recalculateChildPathDurations(&rootTask) t.Logf("%v: nextIDs: %#v nextIDsOrigin: %#v subjobTaskCount: %d", plan, nextIDs, nextIDsOrigin, plan.SubjobTaskCount) plan.TaskCount = nextIDs.Task - nextIDsOrigin.Task - plan.SubjobTaskCount plan.SubplanCount = nextIDs.Plan - nextIDsOrigin.Plan plan.RootTasks = rootTask.Children return plan } type idCounters struct { Plan int Task int } func biasedInt64(minVal, medVal, maxVal int64) *rapid.Generator[int64] { if medVal < minVal || maxVal < medVal { panic("invalid biasedInt64 parameters") } return rapid.Custom(func(t *rapid.T) int64 { return medVal + rapid.Int64Range(minVal-medVal, maxVal-medVal).Draw(t, "biasedInt64") }) } func (p *Plan) AppendSubplans(s []*Plan) []*Plan { for _, t := range p.RootTasks { s = p.appendTaskSubplans(s, t) } return s } func (p *Plan) appendTaskSubplans(s []*Plan, t *Task) []*Plan { for _, sp := range t.Subjobs { s = append(s, sp) s = sp.AppendSubplans(s) } for _, t := range t.Children { s = p.appendTaskSubplans(s, t) } return s } // Format implements fmt.Formatter for pretty-printing a plan. func (p *Plan) Format(f fmt.State, verb rune) { if verb != 'v' { panic("unsupported verb") } if f.Flag('#') { p.formatInternal(f, " ") } else { _, _ = fmt.Fprintf(f, "Plan#%d", p.ID) } } func (p *Plan) formatInternal(f fmt.State, indent string) { _, _ = fmt.Fprintf(f, "Plan#%d: pathCount=%d taskCount=%d maxPathDuration=%v config=%+v", p.ID, p.PathCount, p.TaskCount, p.MaxPathDuration, p.Config) for _, child := range p.RootTasks { _, _ = fmt.Fprintf(f, "\n%s", indent) child.formatInternal(f, indent+" ") } }
// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package sim import ( "slices" "time" "github.com/stretchr/testify/require" ) type Result struct { MaxConcurrencyByPool []int64 OverallDuration time.Duration } type ResultRange struct { MinMaxConcurrencyByPool []int64 MaxMaxConcurrencyByPool []int64 MinOverallDuration time.Duration MedOverallDuration time.Duration MaxOverallDuration time.Duration } func (rr *ResultRange) MergeResult(t require.TestingT, r *Result) { chk := require.New(t) chk.NotNil(rr) chk.NotNil(r) if len(rr.MinMaxConcurrencyByPool) == 0 && len(rr.MaxMaxConcurrencyByPool) == 0 { rr.MinMaxConcurrencyByPool = slices.Clone(r.MaxConcurrencyByPool) rr.MaxMaxConcurrencyByPool = slices.Clone(r.MaxConcurrencyByPool) rr.MinOverallDuration = r.OverallDuration rr.MaxOverallDuration = r.OverallDuration } else { chk.Equal(len(rr.MinMaxConcurrencyByPool), len(r.MaxConcurrencyByPool)) chk.Equal(len(rr.MaxMaxConcurrencyByPool), len(r.MaxConcurrencyByPool)) for i := range len(r.MaxConcurrencyByPool) { rr.MinMaxConcurrencyByPool[i] = min(rr.MinMaxConcurrencyByPool[i], r.MaxConcurrencyByPool[i]) rr.MaxMaxConcurrencyByPool[i] = max(rr.MaxMaxConcurrencyByPool[i], r.MaxConcurrencyByPool[i]) } rr.MinOverallDuration = min(rr.MinOverallDuration, r.OverallDuration) rr.MaxOverallDuration = max(rr.MaxOverallDuration, r.OverallDuration) } } func (rr *ResultRange) MergeRange(t require.TestingT, r *ResultRange) { chk := require.New(t) chk.Equal(len(r.MinMaxConcurrencyByPool), len(r.MaxMaxConcurrencyByPool)) if len(rr.MinMaxConcurrencyByPool) == 0 && len(rr.MaxMaxConcurrencyByPool) == 0 { rr.MinMaxConcurrencyByPool = slices.Clone(r.MinMaxConcurrencyByPool) rr.MaxMaxConcurrencyByPool = slices.Clone(r.MaxMaxConcurrencyByPool) rr.MinOverallDuration = r.MinOverallDuration rr.MaxOverallDuration = r.MaxOverallDuration } else { chk.Equal(len(rr.MinMaxConcurrencyByPool), len(r.MinMaxConcurrencyByPool)) chk.Equal(len(rr.MaxMaxConcurrencyByPool), len(r.MaxMaxConcurrencyByPool)) for i := range len(r.MinMaxConcurrencyByPool) { rr.MinMaxConcurrencyByPool[i] = min(rr.MinMaxConcurrencyByPool[i], r.MinMaxConcurrencyByPool[i]) rr.MaxMaxConcurrencyByPool[i] = max(rr.MaxMaxConcurrencyByPool[i], r.MaxMaxConcurrencyByPool[i]) } rr.MinOverallDuration = min(rr.MinOverallDuration, r.MinOverallDuration) rr.MaxOverallDuration = max(rr.MaxOverallDuration, r.MaxOverallDuration) } } func MergeResultMap(t require.TestingT, dst map[*Plan]*ResultRange, src map[*Plan]*Result) { for p, sr := range src { drr := dst[p] if drr == nil { drr = &ResultRange{} dst[p] = drr } drr.MergeResult(t, sr) } } func MergeResultRangeMap(t require.TestingT, dst, src map[*Plan]*ResultRange) { for p, srr := range src { drr := dst[p] if drr == nil { drr = &ResultRange{} dst[p] = drr } drr.MergeRange(t, srr) } }
// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package sim import ( "context" "fmt" "math" "sync" "sync/atomic" "time" "github.com/petenewcomb/psg-go" "github.com/stretchr/testify/require" ) func Run(t require.TestingT, ctx context.Context, plan *Plan, debug bool) (map[*Plan]*Result, error) { pools := make([]*psg.Pool, len(plan.Config.ConcurrencyLimits)) for i, limit := range plan.Config.ConcurrencyLimits { pools[i] = psg.NewPool(limit) } c := &controller{ Plan: plan, Pools: pools, ConcurrencyByPool: make([]atomic.Int64, len(pools)), MaxConcurrencyByPool: make([]atomicMinMaxInt64, len(pools)), ResultMap: make(map[*Plan]*Result), Debug: debug, } c.MinScatterDelay.Store(math.MaxInt64) c.MinGatherDelay.Store(math.MaxInt64) return c.Run(t, ctx) } type controller struct { Plan *Plan Pools []*psg.Pool ConcurrencyByPool []atomic.Int64 MaxConcurrencyByPool []atomicMinMaxInt64 GatheredCount atomic.Int64 ResultMapMutex sync.Mutex ResultMap map[*Plan]*Result StartTime time.Time MinScatterDelay atomicMinMaxInt64 MinGatherDelay atomicMinMaxInt64 Debug bool } func (c *controller) Run(t require.TestingT, ctx context.Context) (map[*Plan]*Result, error) { c.StartTime = time.Now() c.debugf("%v starting %v", time.Since(c.StartTime), c.Plan) job := psg.NewJob(ctx, c.Pools...) defer job.CancelAndWait() for _, task := range c.Plan.RootTasks { c.scatterTask(t, ctx, task) } chk := require.New(t) err := job.CloseAndGatherAll(ctx) overallDuration := time.Since(c.StartTime) if ge, ok := err.(expectedGatherError); ok { chk.True(ge.task.ReturnErrorFromGather) } else { chk.NoError(err) } gatheredCount := c.GatheredCount.Load() chk.Equal(int64(c.Plan.TaskCount), gatheredCount) maxConcurrencyByPool := make([]int64, len(c.MaxConcurrencyByPool)) for i := range len(maxConcurrencyByPool) { maxConcurrencyByPool[i] = c.MaxConcurrencyByPool[i].Load() } c.addResultMap(t, map[*Plan]*Result{ c.Plan: { MaxConcurrencyByPool: maxConcurrencyByPool, OverallDuration: overallDuration, }, }) c.debugf("%v ended %v with min delays scatter=%v gather=%v", overallDuration, c.Plan, time.Duration(c.MinScatterDelay.Load()), time.Duration(c.MinGatherDelay.Load())) return c.ResultMap, nil } func (c *controller) addResultMap(t require.TestingT, rm map[*Plan]*Result) { chk := require.New(t) c.ResultMapMutex.Lock() defer c.ResultMapMutex.Unlock() for p, r := range rm { chk.Nil(c.ResultMap[p]) c.ResultMap[p] = r } } func (c *controller) scatterTask(t require.TestingT, ctx context.Context, task *Task) { err := psg.Scatter( ctx, c.Pools[task.Pool], c.newTaskFunc(task, &c.ConcurrencyByPool[task.Pool]), c.newGatherFunc(t, task), ) chk := require.New(t) if ge, ok := err.(expectedGatherError); ok { chk.True(ge.task.ReturnErrorFromGather) } else { chk.NoError(err) } } type localT struct { calls []func(require.TestingT) } func (lt *localT) Errorf(format string, args ...any) { lt.calls = append(lt.calls, func(t require.TestingT) { t.Errorf(format, args...) }) } func (lt *localT) FailNow() { lt.calls = append(lt.calls, func(t require.TestingT) { t.FailNow() }) panic(lt) } func (lt *localT) DrainTo(t require.TestingT) { for _, call := range lt.calls { call(t) } } func (lt *localT) Error() string { return "localT passthrough error" } func (c *controller) newTaskFunc(task *Task, concurrency *atomic.Int64) psg.TaskFunc[*taskResult] { lt := &localT{} scatterTime := time.Now() return func(ctx context.Context) (res *taskResult, err error) { c.MinScatterDelay.UpdateMin(int64(time.Since(scatterTime))) defer func() { if r := recover(); r != nil { if lt, ok := r.(*localT); ok { err = lt } else { panic(r) } } }() chk := require.New(lt) res = &taskResult{ Task: task, ConcurrencyAtStart: concurrency.Add(1), } c.debugf("%v starting %v on pool %d, concurrency now %d", time.Since(c.StartTime), task, task.Pool, res.ConcurrencyAtStart) chk.Greater(res.ConcurrencyAtStart, int64(0)) defer func() { res.ConcurrencyAfter = concurrency.Add(-1) elapsedTime := time.Since(c.StartTime) c.debugf("%v ended %v on pool %d, concurrency now %d", elapsedTime, task, task.Pool, res.ConcurrencyAfter) chk.GreaterOrEqual(elapsedTime, task.PathDurationAtTaskEnd) res.TaskEndTime = time.Now() }() for i, d := range task.SelfTimes { if i > 0 { subjobPlan := task.Subjobs[i-1] resultMap, err := Run(lt, ctx, subjobPlan, c.Debug) chk.NoError(err) c.addResultMap(lt, resultMap) } select { case <-time.After(d): case <-ctx.Done(): return res, ctx.Err() } } if task.ReturnErrorFromTask { return res, fmt.Errorf("%v error", task) } else { return res, nil } } } func (c *controller) newGatherFunc(t require.TestingT, task *Task) psg.GatherFunc[*taskResult] { chk := require.New(t) return func(ctx context.Context, res *taskResult, err error) error { c.MinGatherDelay.UpdateMin(int64(time.Since(res.TaskEndTime))) if lt, ok := err.(*localT); ok { lt.DrainTo(t) } else if task.ReturnErrorFromTask { chk.Error(err) } else { chk.NoError(err) } pool := task.Pool gatheredCount := c.GatheredCount.Add(1) c.debugf("%v gathering %v, gathered count now %d", time.Since(c.StartTime), task, gatheredCount) chk.LessOrEqual(gatheredCount, int64(c.Plan.TaskCount)) chk.Greater(res.ConcurrencyAtStart, int64(0)) chk.LessOrEqual(res.ConcurrencyAtStart, int64(c.Plan.Config.ConcurrencyLimits[pool])) chk.GreaterOrEqual(res.ConcurrencyAfter, int64(0)) chk.Less(res.ConcurrencyAfter, int64(c.Plan.Config.ConcurrencyLimits[pool])) c.MaxConcurrencyByPool[pool].UpdateMax(res.ConcurrencyAtStart) for i, d := range task.GatherTimes { if i > 0 { child := task.Children[i-1] c.scatterTask(t, ctx, child) } select { case <-time.After(d): case <-ctx.Done(): return ctx.Err() } } if task.ReturnErrorFromGather { return expectedGatherError{task} } else { return nil } } } func (c *controller) debugf(format string, args ...interface{}) { if c.Debug { fmt.Printf(format+"\n", args...) } } // Result represents the result of executing a simulated task. type taskResult struct { Task *Task ConcurrencyAtStart int64 ConcurrencyAfter int64 TaskEndTime time.Time } type expectedGatherError struct { task *Task } func (e expectedGatherError) Error() string { return fmt.Sprintf("%v gather error", e.task) } type atomicMinMaxInt64 struct { value atomic.Int64 } func (mm *atomicMinMaxInt64) Store(x int64) { mm.value.Store(x) } func (mm *atomicMinMaxInt64) Load() int64 { return mm.value.Load() } func (mm *atomicMinMaxInt64) UpdateMax(x int64) { mm.update(x, func(a, b int64) bool { return a > b }) } func (mm *atomicMinMaxInt64) UpdateMin(x int64) { mm.update(x, func(a, b int64) bool { return a > b }) } func (mm *atomicMinMaxInt64) update(x int64, t func(a, b int64) bool) { for { old := mm.value.Load() if !t(x, old) { break } if mm.value.CompareAndSwap(old, x) { break } } }
// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package sim import ( "fmt" "time" ) // Task represents a simulated task and its gathering type Task struct { ID int Pool int Parent *Task SelfTimes []time.Duration // one more entry than Subjobs Subjobs []*Plan GatherTimes []time.Duration // one more entry than Children Children []*Task ReturnErrorFromTask bool ReturnErrorFromGather bool PathDurationAtTaskEnd time.Duration } func (t *Task) ParentGatherDuration() time.Duration { if t.Parent == nil { return 0 } var d time.Duration for i, gt := range t.Parent.GatherTimes { if t.Parent.Children[i] == t { break } d += gt } return d } func (t *Task) TaskDuration() time.Duration { var d time.Duration for _, st := range t.SelfTimes { d += st } for _, sj := range t.Subjobs { d += sj.MaxPathDuration } return d } func (t *Task) GatherDuration() time.Duration { var d time.Duration for _, gt := range t.GatherTimes { d += gt } return d } // Format implements fmt.Formatter for pretty-printing a task hierarchy. func (t *Task) Format(f fmt.State, verb rune) { if verb != 'v' { panic("unsupported verb") } if f.Flag('#') { t.formatInternal(f, "") } else { _, _ = fmt.Fprintf(f, "Task#%d", t.ID) } } func (t *Task) formatInternal(f fmt.State, indent string) { _, _ = fmt.Fprintf(f, "Task#%d: pool=%d minTaskEnd=%v", t.ID, t.Pool, t.PathDurationAtTaskEnd) t.formatSteps(f, indent+" ") } func (t *Task) formatSteps(f fmt.State, indent string) { if len(t.SelfTimes) > 0 { for i, subjob := range t.Subjobs { _, _ = fmt.Fprintf(f, "\n%s%v self time\n%s", indent, t.SelfTimes[i], indent) subjob.formatInternal(f, indent+" ") } _, _ = fmt.Fprintf(f, "\n%s%v self time", indent, t.SelfTimes[len(t.SelfTimes)-1]) } if len(t.GatherTimes) > 0 { for i, child := range t.Children { _, _ = fmt.Fprintf(f, "\n%s%v gather self time\n%s", indent, t.GatherTimes[i], indent) child.formatInternal(f, indent+" ") } _, _ = fmt.Fprintf(f, "\n%s%v gather self time", indent, t.GatherTimes[len(t.GatherTimes)-1]) } }
// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package state import ( "sync/atomic" ) type InFlightCounter struct { v atomic.Int64 } func (c *InFlightCounter) Increment() { c.v.Add(1) } func (c *InFlightCounter) IncrementIfUnder(limit int) bool { // Tentatively increment the counter and check against limit. If over limit, // remove the tentative increment and try again if we notice that another // goroutine has made room between the increment and decrement. for c.v.Add(1) > int64(limit) { // Back out tentative increment and re-check. if c.v.Add(-1) >= int64(limit) { // Still at or over limit. return false } } // Incremented counter is within limit. return true } func (c *InFlightCounter) Decrement() bool { newValue := c.v.Add(-1) if newValue < 0 { panic("there were no tasks in flight") } return newValue == 0 } func (c *InFlightCounter) GreaterThanZero() bool { return c.v.Load() > 0 }
// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package psg import ( "context" "maps" "slices" "sync" "sync/atomic" "github.com/petenewcomb/psg-go/internal/state" ) // Job represents a scatter-gather execution environment. It tracks tasks // launched with [Scatter] across a set of [Pool] instances and provides methods // for gathering their results. [Job.Cancel] and [Job.CancelAndWait] allow the // caller to terminate the environment early and ensure cleanup when the // environment is no longer needed. // // A Job must be created with [NewJob], see that function for caveats and // important details. type Job struct { ctx context.Context cancelFunc context.CancelFunc pools []*Pool inFlight state.InFlightCounter gatherChannel chan boundGatherFunc wg sync.WaitGroup closed atomic.Bool done chan struct{} } type boundGatherFunc = func(ctx context.Context) error // NewJob creates an independent scatter-gather execution environment with the // specified context and set of pools. The context passed to NewJob is used as // the root of the context that will be passed to all task functions. (See // [TaskFunc] and [Job.Cancel] for more detail.) // // Pools may not be shared across jobs. NewJob panics if it detects such // sharing, but such detection is not guaranteed to work if the same pool is // passed to NewJob calls in different goroutines. // // Each call to NewJob should typically be followed by a deferred call to // [Job.CancelAndWait] to ensure that an early exit from the calling function // does not leave any outstanding goroutines. func NewJob( ctx context.Context, pools ...*Pool, ) *Job { ctx, cancelFunc := context.WithCancel(ctx) j := &Job{ cancelFunc: cancelFunc, pools: slices.Clone(pools), gatherChannel: make(chan boundGatherFunc), done: make(chan struct{}), } j.ctx = j.makeTaskContext(ctx) for _, p := range j.pools { if p.job != nil { panic("pool was already registered") } p.job = j } return j } type taskContextMarkerType struct{} var taskContextMarkerKey any = taskContextMarkerType{} func (j *Job) makeTaskContext(ctx context.Context) context.Context { return makeJobContext(ctx, j, taskContextMarkerKey) } func (j *Job) isTaskContext(ctx context.Context) bool { return isJobContext(ctx, j, taskContextMarkerKey) } func makeJobContext[K any](ctx context.Context, j *Job, key K) context.Context { // Accumulate the jobs to which the context belongs but avoid creating a // collection unless it's needed. var newValue any switch oldValue := ctx.Value(key).(type) { case nil: newValue = j case *Job: if oldValue == j { return ctx } newValue = map[*Job]struct{}{ oldValue: {}, j: {}, } case map[*Job]struct{}: if _, ok := oldValue[j]; ok { return ctx } newValue := make(map[*Job]struct{}, len(oldValue)+1) maps.Copy(newValue, oldValue) newValue[j] = struct{}{} default: panic("unexpected job context marker value type") } return context.WithValue(ctx, key, newValue) } func isJobContext[K any](ctx context.Context, j *Job, key K) bool { switch v := ctx.Value(key).(type) { case nil: return false case *Job: return v == j case map[*Job]struct{}: _, ok := v[j] return ok default: panic("unexpected job context marker value type") } } // Cancel terminates any in-flight tasks and forfeits any ungathered results. // Outstanding calls to [Scatter], [Job.GatherOne], [Job.TryGatherOne], // [Job.GatherAll], or [Job.TryGatherAll] using the job or any of its pools will // fail with [context.Canceled] or other error returned by a [GatherFunc]. // // While Cancel always returns immediately, any running [TaskFunc] or // [GatherFunc] will delay termination of their independent goroutine or caller // until it returns. This method cancels the context passed to each [TaskFunc], // but not the context passed to each [GatherFunc]. Gather functions instead // receive the context passed to the calling [Scatter], [Job.GatherOne], // [Job.TryGatherOne], [Job.GatherAll], or [Job.TryGatherAll] function. If it is // desirable to transmit a cancelation signal to a running [GatherFunc], one // must also cancel any contexts being passed to those callers. // // Cancel is always thread-safe and calling it more than once has no additional // effect. func (j *Job) Cancel() { j.cancelFunc() } // CancelAndWait cancels like [Job.Cancel], but then blocks until any // outstanding task goroutines exit. func (j *Job) CancelAndWait() { j.Cancel() j.wg.Wait() } // GatherOne processes at most a single result from a task previously launched // in one of the [Job]'s pools via [Scatter]. It will block until a completed // task is available or the context has been canceled. See [Job.TryGatherOne] for a // non-blocking alternative. // // Returns a boolean flag indicating whether a result was processed and an error // if one occurred: // // - true, nil: a task completed and was successfully gathered // - true, non-nil: a task completed but the gather function returned a // non-nil error // - false, nil: there were no tasks in flight // - false, non-nil: the argument or job-internal context was canceled // // If all gather functions are thread-safe, then GatherOne is thread-safe and // may be called concurrently from multiple goroutines. Blocking and // non-blocking calls may also be mixed, as can calls to any of the other gather // methods. // // NOTE: If a task result is gathered, this method will call the task's // [GatherFunc] and wait until it returns. func (j *Job) GatherOne(ctx context.Context) (bool, error) { return j.gatherOne(ctx, true) } // TryGatherOne processes at most a single result from a task previously // launched in one of the [Job]'s pools via [Scatter]. Unlike [Job.GatherOne], it // will not block if a completed task is not immediately available. // // Return values are the same as GatherOne, except that false, nil means that // there were no tasks ready to gather. // // See GatherOne for additional details. func (j *Job) TryGatherOne(ctx context.Context) (bool, error) { return j.gatherOne(ctx, false) } func (j *Job) gatherOne(ctx context.Context, block bool) (bool, error) { if block { select { case gather := <-j.gatherChannel: return true, j.executeGather(ctx, gather) case <-ctx.Done(): return false, ctx.Err() case <-j.ctx.Done(): return false, j.ctx.Err() case <-j.done: return false, nil } } else { // Identical to the blocking branch above except replaces <-j.done with // a default clause. select { case gather := <-j.gatherChannel: return true, j.executeGather(ctx, gather) case <-ctx.Done(): return false, ctx.Err() case <-j.ctx.Done(): return false, j.ctx.Err() case <-j.done: return false, nil default: // There were no in-flight tasks ready to gather. return false, nil } } } // GatherAll processes all results from previously scattered tasks, continuing // until there are no more in-flight tasks or an error occurs. It will block to // wait for in-flight tasks that are not yet complete. // // Returns nil unless the context is canceled or a task's [GatherFunc] returns a // non-nil error. // // If all gather functions are thread-safe, then GatherAll is thread-safe and // can be called concurrently from multiple goroutines. In this case they will // collectively process all results, with each call handling a subset. Blocking // and non-blocking calls may also be mixed, as can calls to any of the other // gather methods. // // NOTE: This method will serially call each gathered task's [GatherFunc] and // wait until it returns. func (j *Job) GatherAll(ctx context.Context) error { return j.gatherAll(ctx, true) } // TryGatherAll processes all results from completed tasks, continuing until // there are no more immediately available or an error occurs. Unlike // [Job.GatherAll], TryGatherAll will not block to wait for in-flight tasks to // complete. // // See GatherAll for information about return values and thread safety. // // NOTE: If completed tasks are available, this method must still call each // task's [GatherFunc] and wait until it finishes processing. func (j *Job) TryGatherAll(ctx context.Context) error { return j.gatherAll(ctx, false) } func (j *Job) gatherAll(ctx context.Context, block bool) error { for { ok, err := j.gatherOne(ctx, block) if err != nil { return err } if !ok { return nil } } } func (j *Job) executeGather(ctx context.Context, gather boundGatherFunc) error { // Decrement the environment-wide in-flight counter only AFTER calling the // gather function. This ensures that the in-flight count never drops to // zero before the gather function has had a chance to scatter new tasks. defer j.decrementInFlight() return gather(ctx) } func (j *Job) decrementInFlight() { if j.inFlight.Decrement() { if j.closed.Load() { // Check again now that we know the job is already closed, in case // the job was closed after the decrement AND another increment. if !j.inFlight.GreaterThanZero() { close(j.done) } } } } // Wakes any goroutines that might be waiting on the gatherChannel. func (j *Job) wakeGatherers() { for { select { case j.gatherChannel <- nil: default: // No more waiters return } } } // Close must be called to signify that no more top-level tasks will be launched // and that [Job.GatherAll] should stop blocking to wait for more after the // results of all in-flight tasks have been gathered. See [Job.GatherAll] for // more detail. // // Close may be called from any goroutine and may safely be called more than // once. func (j *Job) Close() { j.closed.Store(true) if !j.inFlight.GreaterThanZero() { close(j.done) } } // CloseAndGatherAll closes the job via [Job.Close] and then waits for and // gathers the results of all in-flight tasks via [Job.GatherAll]. func (j *Job) CloseAndGatherAll(ctx context.Context) error { j.Close() return j.GatherAll(ctx) }
// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package psg import ( "context" "sync/atomic" "github.com/petenewcomb/psg-go/internal/state" ) // A Pool defines a virtual set of task execution slots and optionally places a // limit on its size. Use [Scatter] to launch tasks into a Pool. A Pool must be // bound to a [Job] before a task can be launched into it. // // The zero value of Pool is unbound and has a limit of zero. [NewPool] // provides a convenient way to create a new pool with a non-zero limit. type Pool struct { limit atomic.Int64 job *Job inFlight state.InFlightCounter } // Creates a new [Pool] with the given limit. See [Pool.SetLimit] for the range // of allowed values and their semantics. func NewPool(limit int) *Pool { p := &Pool{} p.limit.Store(int64(limit)) return p } // Sets the active concurrency limit for the pool. A negative value means no // limit (tasks will always be launched regardless of how many are currently // running). Zero means no new tasks will be launched (i.e., [Scatter] will block // indefinitely) until SetLimit is called with a non-zero value. SetLimit is // always thread-safe, even for a Pool in a single-threaded [Job]. func (p *Pool) SetLimit(limit int) { if p.limit.Swap(int64(limit)) == 0 && limit != 0 { j := p.job if j != nil { j.wakeGatherers() } } } func (p *Pool) launch(ctx context.Context, task boundTaskFunc, block bool) (bool, error) { j := p.job if j == nil { panic("pool not bound to a job") } if j.isTaskContext(ctx) { // Don't launch if the provided context is a task context within the // current job, since that may lead to deadlock. panic("psg.Scatter called from within TaskFunc; move call to GatherFunc instead") } // Don't launch if the provided context has been canceled. if err := ctx.Err(); err != nil { return false, err } // Don't launch if the job context has been canceled. if err := j.ctx.Err(); err != nil { return false, err } // Register the task with the job to make sure that any calls to gather will // block until the task is completed. j.inFlight.Increment() // Bookkeeping: make sure that the job-scope count incremented above gets // decremented unless the launch actually happens launched := false defer func() { if !launched { j.decrementInFlight() } }() // Apply backpressure if launching a new task would exceed the pool's // concurrency limit. for !p.incrementInFlightIfUnderLimit() { if !block { return false, nil } // Gather a result to make room to launch the new task. As long as there // wasn't an error, we don't care whether a task was actually gathered // by this call. Either way, it's time to re-check the in-flight count // for this pool. if _, err := j.GatherOne(ctx); err != nil { return false, err } } // Launch the task in a new goroutine. launched = true j.wg.Add(1) go func() { defer j.wg.Done() task(j.ctx) }() return true, nil } type boundTaskFunc func(ctx context.Context) func (p *Pool) incrementInFlightIfUnderLimit() bool { limit := p.limit.Load() switch { case limit < 0: p.inFlight.Increment() return true case limit == 0: return false default: return p.inFlight.IncrementIfUnder(int(limit)) } } func (p *Pool) postGather(gather boundGatherFunc) { // Decrement the pool's in-flight count BEFORE waiting on the gather // channel. This makes it safe for gatherFunc to call `Scatter` with this // same `Pool` instance without deadlock, as there is guaranteed to be at // least one slot available. p.inFlight.Decrement() j := p.job select { case j.gatherChannel <- gather: case <-j.ctx.Done(): } }
// Copyright (c) Peter Newcomb. All rights reserved. // Licensed under the MIT License. package psg import ( "context" ) // Scatter initiates asynchronous execution of the provided task function in a // new goroutine. After the task completes, the task's result and error will be // passed to the provided gather function within a subsequent call to Scatter or // any of the gathering methods of [Job] (i.e., [Job.GatherOne], // [Job.TryGatherOne], [Job.GatherAll], or [Job.TryGatherAll]). // // Scatter blocks to delay launch as needed to ensure compliance with the // concurrency limit for the given pool. This backpressure is applied by // gathering other tasks in the job until the a slot becomes available. The // context passed to Scatter may be used to cancel (e.g., with a timeout) both // gathering and launch, but only the context associated with the pool's job // will be passed to the task. // // WARNING: Scatter must not be called from within a TaskFunc launched the same // job as this may lead to deadlock when a concurrency limit is reached. // Instead, call Scatter from the associated GatherFunc after the TaskFunc // completes. // // Scatter will panic if the given pool is not yet associated with a job. // Scatter returns a non-nil error if the context is canceled or if a non-nil // error is returned by a gather function. If the returned error is non-nil, the // task function supplied to the call will not have been launched will therefore // also not result in a call to the supplied gather function. // // See [TaskFunc] and [GatherFunc] for important caveats and additional detail. func Scatter[T any]( ctx context.Context, pool *Pool, taskFunc TaskFunc[T], gatherFunc GatherFunc[T], ) error { _, err := scatter(ctx, pool, taskFunc, gatherFunc, true) return err } // TryScatter attempts to initiate asynchronous execution of the provided task // function in a new goroutine like [Scatter]. Unlike Scatter, TryScatter will // return instead of blocking if the given pool is already at its concurrency // limit. // // Returns (true, nil) if the task was successfully launched, (false, nil) if // the pool was at its limit, and (false, non-nil) if the task could not be // launched for any other reason. // // See Scatter for more detail about how scattering works. func TryScatter[T any]( ctx context.Context, pool *Pool, taskFunc TaskFunc[T], gatherFunc GatherFunc[T], ) (bool, error) { return scatter(ctx, pool, taskFunc, gatherFunc, false) } func scatter[T any]( ctx context.Context, pool *Pool, taskFunc TaskFunc[T], gatherFunc GatherFunc[T], block bool, ) (bool, error) { if taskFunc == nil { panic("task function must be non-nil") } if gatherFunc == nil { panic("gather function must be non-nil") } // Bind the task and gather functions together into a top-level function for // the new goroutine and hand it to the pool to launch. return pool.launch(ctx, func(ctx context.Context) { // Don't launch if the context has been canceled by the time the // goroutine starts. if ctx.Err() != nil { return } // Actually execute the task function. Since this is the top-level // function of a goroutine, if the task function panics the whole // program will terminate. The user can avoid this behavior by // recovering from the panic within the task function itself and then // returning normally with whatever results they want to pass to the // GatherFunc to represent the failure. We therefore do not defer // posting a gather to the job's channel or otherwise attempt to // maintain the integrity of the pool or overall job in case of task // panics. value, err := taskFunc(ctx) // Build the gather function, binding the supplied gatherFunc to the // result. gather := func(ctx context.Context) error { return gatherFunc(ctx, value, err) } // Post the gather to the gather channel. pool.postGather(gather) }, block) } // A TaskFunc represents a task to be executed asynchronously within the context // of a [Pool]. It returns a result of type T and an error value. The provided // context should be respected for cancellation. Any other inputs to the task // are expected to be provided by specifying the TaskFunc as a [function // literal] that references and therefore captures local variables via [lexical // closure]. // // Each TaskFunc is executed in a new goroutine spawned by the [Scatter] // function and must therefore be thread-safe. This includes access to any // captured variables. // // Also because they are executed in their own goroutines, if a TaskFunc panics, // the whole program will terminate as per [Handling panics] in The Go // Programming Language Specification. If you need to avoid this behavior, // recover from the panic within the task function itself and then return // whatever results you want to passed to the associated [GatherFunc] to // represent the failure. // // WARNING: If a TaskFunc needs to spawn new tasks, it must not call [Scatter] // directly as this would lead to deadlock when a concurrency limit is reached. // Instead, [Scatter] should be called from the associated [GatherFunc] after // the TaskFunc completes. [Scatter] attempts to recognize this situation and // panic, but this detection works only if the context passed to [Scatter] is // the one passed to the TaskFunc or is a subcontext thereof. // // A TaskFunc may however, create its own sub-[Job] within which to run // concurrent tasks. This serves a different use case: tasks created in such a // sub-job should complete or be canceled before the outer TaskFunc returns, // while tasks spawned from a GatherFunc on behalf of a TaskFunc necessarily // form a sequence (or pipeline). Both patterns can be used together as needed. // // [function literal]: https://go.dev/ref/spec#Function_literals // [lexical closure]: https://en.wikipedia.org/wiki/Closure_(computer_programming) // [Handling panics]: https://go.dev/ref/spec#Handling_panics type TaskFunc[T any] = func(context.Context) (T, error) // A GatherFunc is a function that processes the result of a completed // [TaskFunc]. It receives the result and error values from the [TaskFunc] // execution, allowing it to handle both successful and failed task executions. // // The GatherFunc is called when completed task results are processed by // [Scatter], [Job.GatherOne], [Job.TryGatherOne], [Job.GatherAll], or // [Job.TryGatherAll]. Execution of a GatherFunc will block processing of // subsequent task results, adding to backpressure. If such backpressure is // undesirable, consider launching expensive gathering logic in another // asynchronous task using [Scatter]. Unlike [TaskFunc], it is safe to call // [Scatter] from within a GatherFunc. // // If multiple goroutines may call [Scatter], [Job.GatherOne], // [Job.TryGatherOne], [Job.GatherAll], or [Job.TryGatherAll] concurrently, then // every GatherFunc used in the job must be thread-safe. type GatherFunc[T any] = func(context.Context, T, error) error