Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions internal/sync/errgroup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package sync

// ErrGroup provides a way to run functions concurrently and collect the first error.
//
// It is conceptually similar to golang.org/x/sync/errgroup.Group but adapted to the
// workflow scheduler and Context. It cancels the derived Context when the first function
// returns a non-nil error. Wait waits for all functions to finish and returns the first
// error that was observed. If the Context passed to Wait is canceled before completion,
// Wait returns that context error instead.
type ErrGroup interface {
// Go starts the given function in a new workflow coroutine.
// The started coroutine receives the group's derived Context, which is canceled when the
// first function returns a non-nil error.
Go(f func(Context) error)

// Wait waits for all launched functions to complete. It returns the first non-nil error
// returned by any function. If the provided ctx is canceled before completion, the context
// error is returned.
Wait(ctx Context) error
}

type errGroup struct {
// count of running functions
n int

// future that gets set when the count drops to zero
done SettableFuture[struct{}]

// first error encountered
firstErr error

// cancel the derived context
cancel CancelFunc

// context associated with this group (child of parent)
ctx Context

// track if Wait was called to detect certain misuses (optional)
waiting bool

// coroutine creator captured from the parent context when the group is created
creator CoroutineCreator
}

// WithErrGroup creates a child Context and an ErrGroup. The returned Context is canceled
// automatically when any function started with g.Go returns a non-nil error.
func WithErrGroup(parent Context) (Context, ErrGroup) {
ctx, cancel := WithCancel(parent)
cs := getCoState(parent)
return ctx, &errGroup{
done: NewFuture[struct{}](),
cancel: cancel,
ctx: ctx,
creator: cs.creator,
}
}

func (g *errGroup) Go(f func(Context) error) {
g.n += 1

g.creator.NewCoroutine(g.ctx, func(ctx Context) error {
// Execute user function
if err := f(ctx); err != nil {
if g.firstErr == nil {
g.firstErr = err
// cancel group context on first error
if g.cancel != nil {
g.cancel()
}
}
}

g.n -= 1
if g.n < 0 {
panic("negative ErrGroup counter")
}
if g.n == 0 {
g.done.Set(struct{}{}, nil)
}

return nil
})
}

func (g *errGroup) Wait(ctx Context) error {
g.waiting = true

if g.n == 0 {
return g.firstErr
}

if _, err := g.done.Get(ctx); err != nil {
return err
}
return g.firstErr
}
71 changes: 71 additions & 0 deletions internal/sync/errgroup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package sync

import (
"errors"
"testing"

"github.com/stretchr/testify/require"
)

func Test_ErrGroup_Success(t *testing.T) {
s := NewScheduler()
ctx := Background()

s.NewCoroutine(ctx, func(ctx Context) error {
gctx, g := WithErrGroup(ctx)

g.Go(func(ctx Context) error { return nil })
g.Go(func(ctx Context) error { return nil })

err := g.Wait(gctx)
require.NoError(t, err)
return nil
})

err := s.Execute()
require.NoError(t, err)
require.Equal(t, 0, s.RunningCoroutines())
}

func Test_ErrGroup_FirstError(t *testing.T) {
s := NewScheduler()
ctx := Background()

s.NewCoroutine(ctx, func(ctx Context) error {
gctx, g := WithErrGroup(ctx)

e1 := errors.New("boom")

g.Go(func(ctx Context) error { return e1 })
g.Go(func(ctx Context) error { return nil })

err := g.Wait(gctx)
require.Equal(t, e1, err)
return nil
})

err := s.Execute()
require.NoError(t, err)
}

func Test_ErrGroup_MultipleErrors_FirstWins(t *testing.T) {
s := NewScheduler()
ctx := Background()

s.NewCoroutine(ctx, func(ctx Context) error {
gctx, g := WithErrGroup(ctx)

e1 := errors.New("first")
e2 := errors.New("second")

g.Go(func(ctx Context) error { return e1 })
g.Go(func(ctx Context) error { return e2 })

err := g.Wait(gctx)
require.Equal(t, e1, err)
return nil
})

err := s.Execute()
require.NoError(t, err)
}
124 changes: 124 additions & 0 deletions samples/concurrent-errgroup/concurrent_errgroup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package main

import (
"context"
"log"
"os"
"os/signal"
"time"

"github.com/cschleiden/go-workflows/backend"
"github.com/cschleiden/go-workflows/client"
"github.com/cschleiden/go-workflows/samples"
"github.com/cschleiden/go-workflows/worker"
"github.com/cschleiden/go-workflows/workflow"
"github.com/google/uuid"
)

func main() {
ctx := context.Background()

b := samples.GetBackend("concurrent-errgroup", true)

// Run worker
go RunWorker(ctx, b)

// Start workflow via client
c := client.New(b)

startWorkflow(ctx, c)

c2 := make(chan os.Signal, 1)
signal.Notify(c2, os.Interrupt)
<-c2
}

func startWorkflow(ctx context.Context, c *client.Client) {
wf, err := c.CreateWorkflowInstance(ctx, client.WorkflowInstanceOptions{
InstanceID: uuid.NewString(),
}, WorkflowErrGroup, "Hello world")
if err != nil {
panic("could not start workflow")
}

log.Println("Started workflow", wf.InstanceID)
}

func RunWorker(ctx context.Context, mb backend.Backend) {
w := worker.New(mb, nil)

w.RegisterWorkflow(WorkflowErrGroup)

w.RegisterActivity(Activity1)
w.RegisterActivity(Activity2)

if err := w.Start(ctx); err != nil {
panic("could not start worker")
}
}

// WorkflowErrGroup demonstrates running two concurrent branches using the workflow-native
// error group. If any branch returns an error, the group's context is canceled and the
// first error is returned from Wait.
func WorkflowErrGroup(ctx workflow.Context, msg string) (string, error) {
logger := workflow.Logger(ctx)
logger.Debug("Entering WorkflowErrGroup")
logger.Debug("\tWorkflow instance input:", "msg", msg)

defer func() {
logger.Debug("Leaving WorkflowErrGroup")
}()

gctx, g := workflow.WithErrGroup(ctx)

g.Go(func(ctx workflow.Context) error {
a1 := workflow.ExecuteActivity[int](ctx, workflow.DefaultActivityOptions, Activity1, 35, 12)
r, err := a1.Get(ctx)
if err != nil {
return err
}

logger.Debug("A1 result", "r", r)
return nil
})

g.Go(func(ctx workflow.Context) error {
a2 := workflow.ExecuteActivity[int](ctx, workflow.DefaultActivityOptions, Activity2)
r, err := a2.Get(ctx)
if err != nil {
return err
}

logger.Debug("A2 result", "r", r)
return nil
})

// Wait for both goroutines to finish and return the first error, if any
if err := g.Wait(gctx); err != nil {
return "", err
}

return "result", nil
}

func Activity1(ctx context.Context, a, b int) (int, error) {
log.Println("Entering Activity1")

defer func() {
log.Println("Leaving Activity1")
}()

return a + b, nil
}

func Activity2(ctx context.Context) (int, error) {
log.Println("Entering Activity2")

time.Sleep(5 * time.Second)

defer func() {
log.Println("Leaving Activity2")
}()

return 12, nil
}
6 changes: 6 additions & 0 deletions workflow/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ import (
type (
Context = sync.Context
WaitGroup = sync.WaitGroup
ErrGroup = sync.ErrGroup
)

// NewWaitGroup creates a new WaitGroup instance.
func NewWaitGroup() WaitGroup {
return sync.NewWaitGroup()
}

// WithErrGroup creates a child context and errgroup for running workflow goroutines that return errors.
func WithErrGroup(ctx Context) (Context, ErrGroup) {
return sync.WithErrGroup(ctx)
}

// Go spawns a workflow "goroutine".
func Go(ctx Context, f func(ctx Context)) {
sync.Go(ctx, f)
Expand Down