diff --git a/internal/sync/errgroup.go b/internal/sync/errgroup.go new file mode 100644 index 00000000..edaf4a02 --- /dev/null +++ b/internal/sync/errgroup.go @@ -0,0 +1,96 @@ +package sync + +// ErrGroup provides a way to run functions concurrently and collect the first error. +// +// It is conceptually similar to golang.org/x/sync/errgroup.Group but adapted to the +// workflow scheduler and Context. It cancels the derived Context when the first function +// returns a non-nil error. Wait waits for all functions to finish and returns the first +// error that was observed. If the Context passed to Wait is canceled before completion, +// Wait returns that context error instead. +type ErrGroup interface { + // Go starts the given function in a new workflow coroutine. + // The started coroutine receives the group's derived Context, which is canceled when the + // first function returns a non-nil error. + Go(f func(Context) error) + + // Wait waits for all launched functions to complete. It returns the first non-nil error + // returned by any function. If the provided ctx is canceled before completion, the context + // error is returned. + Wait(ctx Context) error +} + +type errGroup struct { + // count of running functions + n int + + // future that gets set when the count drops to zero + done SettableFuture[struct{}] + + // first error encountered + firstErr error + + // cancel the derived context + cancel CancelFunc + + // context associated with this group (child of parent) + ctx Context + + // track if Wait was called to detect certain misuses (optional) + waiting bool + + // coroutine creator captured from the parent context when the group is created + creator CoroutineCreator +} + +// WithErrGroup creates a child Context and an ErrGroup. The returned Context is canceled +// automatically when any function started with g.Go returns a non-nil error. +func WithErrGroup(parent Context) (Context, ErrGroup) { + ctx, cancel := WithCancel(parent) + cs := getCoState(parent) + return ctx, &errGroup{ + done: NewFuture[struct{}](), + cancel: cancel, + ctx: ctx, + creator: cs.creator, + } +} + +func (g *errGroup) Go(f func(Context) error) { + g.n += 1 + + g.creator.NewCoroutine(g.ctx, func(ctx Context) error { + // Execute user function + if err := f(ctx); err != nil { + if g.firstErr == nil { + g.firstErr = err + // cancel group context on first error + if g.cancel != nil { + g.cancel() + } + } + } + + g.n -= 1 + if g.n < 0 { + panic("negative ErrGroup counter") + } + if g.n == 0 { + g.done.Set(struct{}{}, nil) + } + + return nil + }) +} + +func (g *errGroup) Wait(ctx Context) error { + g.waiting = true + + if g.n == 0 { + return g.firstErr + } + + if _, err := g.done.Get(ctx); err != nil { + return err + } + return g.firstErr +} diff --git a/internal/sync/errgroup_test.go b/internal/sync/errgroup_test.go new file mode 100644 index 00000000..67e8ac03 --- /dev/null +++ b/internal/sync/errgroup_test.go @@ -0,0 +1,71 @@ +package sync + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_ErrGroup_Success(t *testing.T) { + s := NewScheduler() + ctx := Background() + + s.NewCoroutine(ctx, func(ctx Context) error { + gctx, g := WithErrGroup(ctx) + + g.Go(func(ctx Context) error { return nil }) + g.Go(func(ctx Context) error { return nil }) + + err := g.Wait(gctx) + require.NoError(t, err) + return nil + }) + + err := s.Execute() + require.NoError(t, err) + require.Equal(t, 0, s.RunningCoroutines()) +} + +func Test_ErrGroup_FirstError(t *testing.T) { + s := NewScheduler() + ctx := Background() + + s.NewCoroutine(ctx, func(ctx Context) error { + gctx, g := WithErrGroup(ctx) + + e1 := errors.New("boom") + + g.Go(func(ctx Context) error { return e1 }) + g.Go(func(ctx Context) error { return nil }) + + err := g.Wait(gctx) + require.Equal(t, e1, err) + return nil + }) + + err := s.Execute() + require.NoError(t, err) +} + +func Test_ErrGroup_MultipleErrors_FirstWins(t *testing.T) { + s := NewScheduler() + ctx := Background() + + s.NewCoroutine(ctx, func(ctx Context) error { + gctx, g := WithErrGroup(ctx) + + e1 := errors.New("first") + e2 := errors.New("second") + + g.Go(func(ctx Context) error { return e1 }) + g.Go(func(ctx Context) error { return e2 }) + + err := g.Wait(gctx) + require.Equal(t, e1, err) + return nil + }) + + err := s.Execute() + require.NoError(t, err) +} diff --git a/samples/concurrent-errgroup/concurrent_errgroup.go b/samples/concurrent-errgroup/concurrent_errgroup.go new file mode 100644 index 00000000..3803606c --- /dev/null +++ b/samples/concurrent-errgroup/concurrent_errgroup.go @@ -0,0 +1,124 @@ +package main + +import ( + "context" + "log" + "os" + "os/signal" + "time" + + "github.com/cschleiden/go-workflows/backend" + "github.com/cschleiden/go-workflows/client" + "github.com/cschleiden/go-workflows/samples" + "github.com/cschleiden/go-workflows/worker" + "github.com/cschleiden/go-workflows/workflow" + "github.com/google/uuid" +) + +func main() { + ctx := context.Background() + + b := samples.GetBackend("concurrent-errgroup", true) + + // Run worker + go RunWorker(ctx, b) + + // Start workflow via client + c := client.New(b) + + startWorkflow(ctx, c) + + c2 := make(chan os.Signal, 1) + signal.Notify(c2, os.Interrupt) + <-c2 +} + +func startWorkflow(ctx context.Context, c *client.Client) { + wf, err := c.CreateWorkflowInstance(ctx, client.WorkflowInstanceOptions{ + InstanceID: uuid.NewString(), + }, WorkflowErrGroup, "Hello world") + if err != nil { + panic("could not start workflow") + } + + log.Println("Started workflow", wf.InstanceID) +} + +func RunWorker(ctx context.Context, mb backend.Backend) { + w := worker.New(mb, nil) + + w.RegisterWorkflow(WorkflowErrGroup) + + w.RegisterActivity(Activity1) + w.RegisterActivity(Activity2) + + if err := w.Start(ctx); err != nil { + panic("could not start worker") + } +} + +// WorkflowErrGroup demonstrates running two concurrent branches using the workflow-native +// error group. If any branch returns an error, the group's context is canceled and the +// first error is returned from Wait. +func WorkflowErrGroup(ctx workflow.Context, msg string) (string, error) { + logger := workflow.Logger(ctx) + logger.Debug("Entering WorkflowErrGroup") + logger.Debug("\tWorkflow instance input:", "msg", msg) + + defer func() { + logger.Debug("Leaving WorkflowErrGroup") + }() + + gctx, g := workflow.WithErrGroup(ctx) + + g.Go(func(ctx workflow.Context) error { + a1 := workflow.ExecuteActivity[int](ctx, workflow.DefaultActivityOptions, Activity1, 35, 12) + r, err := a1.Get(ctx) + if err != nil { + return err + } + + logger.Debug("A1 result", "r", r) + return nil + }) + + g.Go(func(ctx workflow.Context) error { + a2 := workflow.ExecuteActivity[int](ctx, workflow.DefaultActivityOptions, Activity2) + r, err := a2.Get(ctx) + if err != nil { + return err + } + + logger.Debug("A2 result", "r", r) + return nil + }) + + // Wait for both goroutines to finish and return the first error, if any + if err := g.Wait(gctx); err != nil { + return "", err + } + + return "result", nil +} + +func Activity1(ctx context.Context, a, b int) (int, error) { + log.Println("Entering Activity1") + + defer func() { + log.Println("Leaving Activity1") + }() + + return a + b, nil +} + +func Activity2(ctx context.Context) (int, error) { + log.Println("Entering Activity2") + + time.Sleep(5 * time.Second) + + defer func() { + log.Println("Leaving Activity2") + }() + + return 12, nil +} diff --git a/workflow/sync.go b/workflow/sync.go index 8e2e2ac8..45ee9594 100644 --- a/workflow/sync.go +++ b/workflow/sync.go @@ -7,6 +7,7 @@ import ( type ( Context = sync.Context WaitGroup = sync.WaitGroup + ErrGroup = sync.ErrGroup ) // NewWaitGroup creates a new WaitGroup instance. @@ -14,6 +15,11 @@ func NewWaitGroup() WaitGroup { return sync.NewWaitGroup() } +// WithErrGroup creates a child context and errgroup for running workflow goroutines that return errors. +func WithErrGroup(ctx Context) (Context, ErrGroup) { + return sync.WithErrGroup(ctx) +} + // Go spawns a workflow "goroutine". func Go(ctx Context, f func(ctx Context)) { sync.Go(ctx, f)