diff --git a/fn/goroutine_manager.go b/fn/goroutine_manager.go new file mode 100644 index 00000000000..8c9ad8b2d0a --- /dev/null +++ b/fn/goroutine_manager.go @@ -0,0 +1,77 @@ +package fn + +import ( + "context" + "errors" + "sync" +) + +// ErrStopping is returned when trying to add a new goroutine while stopping. +var ErrStopping = errors.New("can not add goroutine, stopping") + +// GoroutineManager is used to launch goroutines until context expires or the +// manager is stopped. The Stop method blocks until all started goroutines stop. +type GoroutineManager struct { + wg sync.WaitGroup + mu sync.Mutex + ctx context.Context + cancel func() +} + +// NewGoroutineManager constructs and returns a new instance of +// GoroutineManager. +func NewGoroutineManager(ctx context.Context) *GoroutineManager { + ctx, cancel := context.WithCancel(ctx) + + return &GoroutineManager{ + ctx: ctx, + cancel: cancel, + } +} + +// Go starts a new goroutine if the manager is not stopping. +func (g *GoroutineManager) Go(f func(ctx context.Context)) error { + // Calling wg.Add(1) and wg.Wait() when wg's counter is 0 is a race + // condition, since it is not clear should Wait() block or not. This + // kind of race condition is detected by Go runtime and results in a + // crash if running with `-race`. To prevent this, whole Go method is + // protected with a mutex. The call to wg.Wait() inside Stop() can still + // run in parallel with Go, but in that case g.ctx is in expired state, + // because cancel() was called in Stop, so Go returns before wg.Add(1) + // call. + g.mu.Lock() + defer g.mu.Unlock() + + if g.ctx.Err() != nil { + return ErrStopping + } + + g.wg.Add(1) + go func() { + defer g.wg.Done() + f(g.ctx) + }() + + return nil +} + +// Stop prevents new goroutines from being added and waits for all running +// goroutines to finish. +func (g *GoroutineManager) Stop() { + g.mu.Lock() + g.cancel() + g.mu.Unlock() + + // Wait for all goroutines to finish. Note that this wg.Wait() call is + // safe, since it can't run in parallel with wg.Add(1) call in Go, since + // we just cancelled the context and even if Go call starts running here + // after acquiring the mutex, it would see that the context has expired + // and return ErrStopping instead of calling wg.Add(1). + g.wg.Wait() +} + +// Done returns a channel which is closed when either the context passed to +// NewGoroutineManager expires or when Stop is called. +func (g *GoroutineManager) Done() <-chan struct{} { + return g.ctx.Done() +} diff --git a/fn/goroutine_manager_test.go b/fn/goroutine_manager_test.go new file mode 100644 index 00000000000..d06a62b4a25 --- /dev/null +++ b/fn/goroutine_manager_test.go @@ -0,0 +1,121 @@ +package fn + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// TestGoroutineManager tests that the GoroutineManager starts goroutines until +// ctx expires. It also makes sure it fails to start new goroutines after the +// context expired and the GoroutineManager is in the process of waiting for +// already started goroutines in the Stop method. +func TestGoroutineManager(t *testing.T) { + t.Parallel() + + m := NewGoroutineManager(context.Background()) + + taskChan := make(chan struct{}) + + require.NoError(t, m.Go(func(ctx context.Context) { + <-taskChan + })) + + t1 := time.Now() + + // Close taskChan in 1s, causing the goroutine to stop. + time.AfterFunc(time.Second, func() { + close(taskChan) + }) + + m.Stop() + stopDelay := time.Since(t1) + + // Make sure Stop was waiting for the goroutine to stop. + require.Greater(t, stopDelay, time.Second) + + // Make sure new goroutines do not start after Stop. + require.ErrorIs(t, m.Go(func(ctx context.Context) {}), ErrStopping) + + // When Stop() is called, the internal context expires and m.Done() is + // closed. Test this. + select { + case <-m.Done(): + default: + t.Errorf("Done() channel must be closed at this point") + } +} + +// TestGoroutineManagerContextExpires tests the effect of context expiry. +func TestGoroutineManagerContextExpires(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + m := NewGoroutineManager(ctx) + + require.NoError(t, m.Go(func(ctx context.Context) { + <-ctx.Done() + })) + + // The Done channel of the manager should not be closed, so the + // following call must block. + select { + case <-m.Done(): + t.Errorf("Done() channel must not be closed at this point") + default: + } + + cancel() + + // The Done channel of the manager should be closed, so the following + // call must not block. + select { + case <-m.Done(): + default: + t.Errorf("Done() channel must be closed at this point") + } + + // Make sure new goroutines do not start after context expiry. + require.ErrorIs(t, m.Go(func(ctx context.Context) {}), ErrStopping) + + // Stop will wait for all goroutines to stop. + m.Stop() +} + +// TestGoroutineManagerStress starts many goroutines while calling Stop. It +// is needed to make sure the GoroutineManager does not crash if this happen. +// If the mutex was not used, it would crash because of a race condition between +// wg.Add(1) and wg.Wait(). +func TestGoroutineManagerStress(t *testing.T) { + t.Parallel() + + m := NewGoroutineManager(context.Background()) + + stopChan := make(chan struct{}) + + time.AfterFunc(1*time.Millisecond, func() { + m.Stop() + close(stopChan) + }) + + // Starts 100 goroutines sequentially. Sequential order is needed to + // keep wg.counter low (0 or 1) to increase probability of race + // condition to be caught if it exists. If mutex is removed in the + // implementation, this test crashes under `-race`. + for i := 0; i < 100; i++ { + taskChan := make(chan struct{}) + err := m.Go(func(ctx context.Context) { + close(taskChan) + }) + // If goroutine was started, wait for its completion. + if err == nil { + <-taskChan + } + } + + // Wait for Stop to complete. + <-stopChan +}