diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 35ffd304..f793ce8c 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -86,6 +86,46 @@ jobs: ${{ github.workspace }}/report.xml if: always() + test_valkey: + runs-on: ubuntu-latest + needs: build + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: 1.24 + check-latest: true + cache: true + + - name: Start Valkey (Docker) + run: | + docker run -d --name valkey -p 6379:6379 valkey/valkey:latest valkey-server --requirepass ValkeyPassw0rd + + - name: Wait for Valkey readiness + run: | + for i in {1..60}; do + if docker exec valkey valkey-cli -a RedisPassw0rd PING | grep -q PONG; then + echo "Valkey is ready"; exit 0; fi; + sleep 1; + done + echo "Valkey did not become ready in time"; + docker logs valkey || true + exit 1 + + - name: Tests (valkey backend, integration) + run: | + go test -tags=valkey_integration -timeout 240s -race -count 1 -v github.com/cschleiden/go-workflows/backend/valkey 2>&1 | go tool go-junit-report -set-exit-code -iocopy -out "${{ github.workspace }}/report.xml" + + - name: Test Summary + uses: test-summary/action@v2 + with: + paths: | + ${{ github.workspace }}/report.xml + if: always() + test_sqlite: runs-on: ubuntu-latest needs: build diff --git a/backend/valkey/activity.go b/backend/valkey/activity.go new file mode 100644 index 00000000..3e740f18 --- /dev/null +++ b/backend/valkey/activity.go @@ -0,0 +1,80 @@ +package valkey + +import ( + "context" + "fmt" + + "github.com/cschleiden/go-workflows/backend" + "github.com/cschleiden/go-workflows/backend/history" + "github.com/cschleiden/go-workflows/workflow" +) + +func (vb *valkeyBackend) PrepareActivityQueues(ctx context.Context, queues []workflow.Queue) error { + return vb.activityQueue.Prepare(ctx, vb.client, queues) +} + +func (vb *valkeyBackend) GetActivityTask(ctx context.Context, queues []workflow.Queue) (*backend.ActivityTask, error) { + activityTask, err := vb.activityQueue.Dequeue(ctx, vb.client, queues, vb.options.ActivityLockTimeout, vb.options.BlockTimeout) + if err != nil { + return nil, err + } + + if activityTask == nil { + return nil, nil + } + + return &backend.ActivityTask{ + WorkflowInstance: activityTask.Data.Instance, + Queue: workflow.Queue(activityTask.Data.Queue), + ID: activityTask.TaskID, + ActivityID: activityTask.Data.ID, + Event: activityTask.Data.Event, + }, nil +} + +func (vb *valkeyBackend) ExtendActivityTask(ctx context.Context, task *backend.ActivityTask) error { + if err := vb.activityQueue.Extend(ctx, vb.client, task.Queue, task.ID); err != nil { + return err + } + + return nil +} + +func (vb *valkeyBackend) CompleteActivityTask(ctx context.Context, task *backend.ActivityTask, result *history.Event) error { + instance, err := readInstance(ctx, vb.client, vb.keys.instanceKey(task.WorkflowInstance)) + if err != nil { + return err + } + + eventData, payload, err := marshalEvent(result) + if err != nil { + return err + } + + activityQueueKeys := vb.activityQueue.Keys(task.Queue) + workflowQueueKeys := vb.workflowQueue.Keys(workflow.Queue(instance.Queue)) + + err = completeActivityTaskScript.Exec(ctx, vb.client, []string{ + activityQueueKeys.SetKey, + activityQueueKeys.StreamKey, + vb.keys.pendingEventsKey(task.WorkflowInstance), + vb.keys.payloadKey(task.WorkflowInstance), + vb.workflowQueue.queueSetKey, + workflowQueueKeys.SetKey, + workflowQueueKeys.StreamKey, + }, []string{ + task.ID, + vb.activityQueue.groupName, + result.ID, + eventData, + payload, + vb.workflowQueue.groupName, + instanceSegment(task.WorkflowInstance), + }).Error() + + if err != nil { + return fmt.Errorf("completing activity task: %w", err) + } + + return nil +} diff --git a/backend/valkey/delete.go b/backend/valkey/delete.go new file mode 100644 index 00000000..9e23a450 --- /dev/null +++ b/backend/valkey/delete.go @@ -0,0 +1,31 @@ +package valkey + +import ( + "context" + "fmt" + + "github.com/cschleiden/go-workflows/core" +) + +// deleteInstance deletes an instance from Valkey. It does not attempt to remove any future events or pending +// workflow tasks. It's assumed that the instance is in the finished state. +// +// Note: might want to revisit this in the future if we want to support removing hung instances. +func (vb *valkeyBackend) deleteInstance(ctx context.Context, instance *core.WorkflowInstance) error { + err := deleteInstanceScript.Exec(ctx, vb.client, []string{ + vb.keys.instanceKey(instance), + vb.keys.pendingEventsKey(instance), + vb.keys.historyKey(instance), + vb.keys.payloadKey(instance), + vb.keys.activeInstanceExecutionKey(instance.InstanceID), + vb.keys.instancesByCreation(), + }, []string{ + instanceSegment(instance), + }).Error() + + if err != nil { + return fmt.Errorf("failed to delete instance: %w", err) + } + + return nil +} diff --git a/backend/valkey/diagnostics.go b/backend/valkey/diagnostics.go new file mode 100644 index 00000000..2a27e845 --- /dev/null +++ b/backend/valkey/diagnostics.go @@ -0,0 +1,100 @@ +package valkey + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/cschleiden/go-workflows/core" + "github.com/cschleiden/go-workflows/diag" + "github.com/cschleiden/go-workflows/internal/log" +) + +var _ diag.Backend = (*valkeyBackend)(nil) + +func (vb *valkeyBackend) GetWorkflowInstances(ctx context.Context, afterInstanceID, afterExecutionID string, count int) ([]*diag.WorkflowInstanceRef, error) { + zrangeCmd := vb.client.B().Zrange().Key(vb.keys.instancesByCreation()).Min("+inf").Max("-inf").Byscore().Rev().Limit(0, int64(count)) + if afterInstanceID != "" { + afterSegmentID := instanceSegment(core.NewWorkflowInstance(afterInstanceID, afterExecutionID)) + scores, err := vb.client.Do(ctx, vb.client.B().Zmscore().Key(vb.keys.instancesByCreation()).Member(afterSegmentID).Build()).AsFloatSlice() + if err != nil { + return nil, fmt.Errorf("getting instance score for %v: %w", afterSegmentID, err) + } + + if len(scores) == 0 || scores[0] == 0 { + vb.Options().Logger.Error("could not find instance %v", + log.NamespaceKey+".valkey.afterInstanceID", afterInstanceID, + log.NamespaceKey+".valkey.afterExecutionID", afterExecutionID, + ) + return nil, nil + } + + zrangeCmd = vb.client.B().Zrange().Key(vb.keys.instancesByCreation()).Min("+inf").Max(fmt.Sprintf("(%f", scores[0])).Byscore().Rev().Limit(0, int64(count)) + } + + instanceSegments, err := vb.client.Do(ctx, zrangeCmd.Build()).AsStrSlice() + if err != nil { + return nil, fmt.Errorf("getting instances: %w", err) + } + + if len(instanceSegments) == 0 { + return nil, nil + } + + instanceKeys := make([]string, 0) + for _, r := range instanceSegments { + instanceKeys = append(instanceKeys, vb.keys.instanceKeyFromSegment(r)) + } + + cmd := vb.client.B().Mget().Key(instanceKeys...) + instances, err := vb.client.Do(ctx, cmd.Build()).AsStrSlice() + if err != nil { + return nil, fmt.Errorf("getting instances: %w", err) + } + + instanceRefs := make([]*diag.WorkflowInstanceRef, 0, len(instances)) + for _, instance := range instances { + if instance == "" { + continue + } + + var state instanceState + if err := json.Unmarshal([]byte(instance), &state); err != nil { + return nil, fmt.Errorf("unmarshaling instance state: %w", err) + } + + instanceRefs = append(instanceRefs, &diag.WorkflowInstanceRef{ + Instance: state.Instance, + CreatedAt: state.CreatedAt, + CompletedAt: state.CompletedAt, + State: state.State, + Queue: state.Queue, + }) + } + + return instanceRefs, nil +} + +func (vb *valkeyBackend) GetWorkflowInstance(ctx context.Context, instance *core.WorkflowInstance) (*diag.WorkflowInstanceRef, error) { + instanceState, err := readInstance(ctx, vb.client, vb.keys.instanceKey(instance)) + if err != nil { + return nil, err + } + + return mapWorkflowInstance(instanceState), nil +} + +func (vb *valkeyBackend) GetWorkflowTree(ctx context.Context, instance *core.WorkflowInstance) (*diag.WorkflowInstanceTree, error) { + itb := diag.NewInstanceTreeBuilder(vb) + return itb.BuildWorkflowInstanceTree(ctx, instance) +} + +func mapWorkflowInstance(instance *instanceState) *diag.WorkflowInstanceRef { + return &diag.WorkflowInstanceRef{ + Instance: instance.Instance, + CreatedAt: instance.CreatedAt, + CompletedAt: instance.CompletedAt, + State: instance.State, + Queue: instance.Queue, + } +} diff --git a/backend/valkey/diagnostics_test.go b/backend/valkey/diagnostics_test.go new file mode 100644 index 00000000..998bfba6 --- /dev/null +++ b/backend/valkey/diagnostics_test.go @@ -0,0 +1,117 @@ +package valkey + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/cschleiden/go-workflows/backend" + "github.com/cschleiden/go-workflows/backend/history" + "github.com/cschleiden/go-workflows/backend/test" + "github.com/cschleiden/go-workflows/client" + "github.com/cschleiden/go-workflows/diag" + "github.com/stretchr/testify/require" + "github.com/valkey-io/valkey-go" +) + +func getClient() valkey.Client { + newClient, _ := valkey.NewClient(valkey.ClientOption{ + InitAddress: []string{"localhost:6379"}, + Password: "ValkeyPassw0rd", + SelectDB: 0, + }) + return newClient +} + +func getCreateBackend(client valkey.Client, additionalOptions ...BackendOption) func(options ...backend.BackendOption) test.TestBackend { + return func(options ...backend.BackendOption) test.TestBackend { + // Flush database + if err := client.Do(context.Background(), client.B().Flushdb().Build()).Error(); err != nil { + panic(err) + } + + r, err := client.Do(context.Background(), client.B().Keys().Pattern("*").Build()).AsStrSlice() + if err != nil { + panic(err) + } + + if len(r) > 0 { + panic("Keys should've been empty" + strings.Join(r, ", ")) + } + + redisOptions := []BackendOption{ + WithBlockTimeout(time.Millisecond * 10), + WithBackendOptions(options...), + } + + redisOptions = append(redisOptions, additionalOptions...) + + b, err := NewValkeyBackend(client, redisOptions...) + if err != nil { + panic(err) + } + + return b + } +} + +var _ test.TestBackend = (*valkeyBackend)(nil) + +// GetFutureEvents +func (vb *valkeyBackend) GetFutureEvents(ctx context.Context) ([]*history.Event, error) { + r, err := vb.client.Do(ctx, vb.client.B().Zrangebyscore().Key(vb.keys.futureEventsKey()).Min("-inf").Max("+inf").Build()).AsStrSlice() + if err != nil { + return nil, fmt.Errorf("getting future events: %w", err) + } + + events := make([]*history.Event, 0) + + for _, eventID := range r { + eventStr, err := vb.client.Do(ctx, vb.client.B().Hget().Key(eventID).Field("event").Build()).AsBytes() + if err != nil { + return nil, fmt.Errorf("getting event %v: %w", eventID, err) + } + + var event *history.Event + if err := json.Unmarshal(eventStr, &event); err != nil { + return nil, fmt.Errorf("unmarshaling event %v: %w", eventID, err) + } + + events = append(events, event) + } + + return events, nil +} + +func Test_Diag_GetWorkflowInstances(t *testing.T) { + if testing.Short() { + t.Skip() + } + + c := getClient() + t.Cleanup(func() { c.Close() }) + + vc := getCreateBackend(c)() + + bd := vc.(diag.Backend) + + ctx := context.Background() + instances, err := bd.GetWorkflowInstances(ctx, "", "", 5) + require.NoError(t, err) + require.Empty(t, instances) + + cl := client.New(bd) + + _, err = cl.CreateWorkflowInstance(ctx, client.WorkflowInstanceOptions{ + InstanceID: "ex1", + }, "some-workflow") + require.NoError(t, err) + + instances, err = bd.GetWorkflowInstances(ctx, "", "", 5) + require.NoError(t, err) + require.Len(t, instances, 1) + require.Equal(t, "ex1", instances[0].Instance.InstanceID) +} diff --git a/backend/valkey/events.go b/backend/valkey/events.go new file mode 100644 index 00000000..d34dae08 --- /dev/null +++ b/backend/valkey/events.go @@ -0,0 +1,30 @@ +package valkey + +import ( + "encoding/json" + + "github.com/cschleiden/go-workflows/backend/history" +) + +type eventWithoutAttributes struct { + *history.Event +} + +func (e *eventWithoutAttributes) MarshalJSON() ([]byte, error) { + return json.Marshal(&struct { + *history.Event + Attributes interface{} `json:"attr"` + }{ + Event: e.Event, + Attributes: nil, + }) +} + +func marshalEventWithoutAttributes(event *history.Event) (string, error) { + data, err := json.Marshal(&eventWithoutAttributes{event}) + if err != nil { + return "", err + } + + return string(data), nil +} diff --git a/backend/valkey/events_future.go b/backend/valkey/events_future.go new file mode 100644 index 00000000..d348f6a3 --- /dev/null +++ b/backend/valkey/events_future.go @@ -0,0 +1,20 @@ +package valkey + +import ( + "context" + "fmt" + "strconv" + "time" +) + +func scheduleFutureEvents(ctx context.Context, vb *valkeyBackend) error { + now := time.Now().UnixMilli() + nowStr := strconv.FormatInt(now, 10) + err := futureEventsScript.Exec(ctx, vb.client, []string{vb.keys.futureEventsKey()}, []string{nowStr, vb.keys.prefix}).Error() + + if err != nil { + return fmt.Errorf("checking future events: %w", err) + } + + return nil +} diff --git a/backend/valkey/expire.go b/backend/valkey/expire.go new file mode 100644 index 00000000..27f362ec --- /dev/null +++ b/backend/valkey/expire.go @@ -0,0 +1,34 @@ +package valkey + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/cschleiden/go-workflows/core" +) + +func (vb *valkeyBackend) setWorkflowInstanceExpiration(ctx context.Context, instance *core.WorkflowInstance, expiration time.Duration) error { + now := time.Now().UnixMilli() + nowStr := strconv.FormatInt(now, 10) + + exp := time.Now().Add(expiration).UnixMilli() + expStr := strconv.FormatInt(exp, 10) + + err := expireWorkflowInstanceScript.Exec(ctx, vb.client, []string{ + vb.keys.instancesByCreation(), + vb.keys.instancesExpiring(), + vb.keys.instanceKey(instance), + vb.keys.pendingEventsKey(instance), + vb.keys.historyKey(instance), + vb.keys.payloadKey(instance), + }, []string{ + nowStr, + fmt.Sprintf("%.0f", expiration.Seconds()), + expStr, + instanceSegment(instance), + }).Error() + + return err +} diff --git a/backend/valkey/instance.go b/backend/valkey/instance.go new file mode 100644 index 00000000..422809b0 --- /dev/null +++ b/backend/valkey/instance.go @@ -0,0 +1,231 @@ +package valkey + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/cschleiden/go-workflows/backend" + "github.com/cschleiden/go-workflows/backend/history" + "github.com/cschleiden/go-workflows/backend/metadata" + "github.com/cschleiden/go-workflows/core" + "github.com/cschleiden/go-workflows/workflow" + "github.com/valkey-io/valkey-go" +) + +func (vb *valkeyBackend) CreateWorkflowInstance(ctx context.Context, instance *workflow.Instance, event *history.Event) error { + a := event.Attributes.(*history.ExecutionStartedAttributes) + + instanceState, err := json.Marshal(&instanceState{ + Queue: string(a.Queue), + Instance: instance, + State: core.WorkflowInstanceStateActive, + Metadata: a.Metadata, + CreatedAt: time.Now(), + }) + if err != nil { + return fmt.Errorf("marshaling instance state: %w", err) + } + + activeInstance, err := json.Marshal(instance) + if err != nil { + return fmt.Errorf("marshaling instance: %w", err) + } + + eventData, payloadData, err := marshalEvent(event) + if err != nil { + return err + } + + keyInfo := vb.workflowQueue.Keys(a.Queue) + + // Execute Lua script for atomic creation + err = createWorkflowInstanceScript.Exec(ctx, vb.client, []string{ + vb.keys.instanceKey(instance), + vb.keys.activeInstanceExecutionKey(instance.InstanceID), + vb.keys.pendingEventsKey(instance), + vb.keys.payloadKey(instance), + vb.keys.instancesActive(), + vb.keys.instancesByCreation(), + keyInfo.SetKey, + keyInfo.StreamKey, + vb.workflowQueue.queueSetKey, + }, []string{ + instanceSegment(instance), + string(instanceState), + string(activeInstance), + event.ID, + eventData, + payloadData, + fmt.Sprintf("%d", time.Now().UTC().UnixNano()), + }).Error() + + if err != nil { + if err.Error() == "ERR InstanceAlreadyExists" { + return backend.ErrInstanceAlreadyExists + } + return fmt.Errorf("creating workflow instance: %w", err) + } + + return nil +} + +func (vb *valkeyBackend) GetWorkflowInstanceHistory(ctx context.Context, instance *core.WorkflowInstance, lastSequenceID *int64) ([]*history.Event, error) { + start := "-" + if lastSequenceID != nil { + start = strconv.FormatInt(*lastSequenceID, 10) + } + + msgs, err := vb.client.Do(ctx, vb.client.B().Xrange().Key(vb.keys.historyKey(instance)).Start(start).End("+").Build()).AsXRange() + if err != nil { + return nil, err + } + + payloadKeys := make([]string, 0, len(msgs)) + events := make([]*history.Event, 0, len(msgs)) + for _, msg := range msgs { + eventStr, ok := msg.FieldValues["event"] + if !ok || eventStr == "" { + continue + } + + var event *history.Event + if err := json.Unmarshal([]byte(eventStr), &event); err != nil { + return nil, fmt.Errorf("unmarshaling event: %w", err) + } + + payloadKeys = append(payloadKeys, event.ID) + events = append(events, event) + } + + if len(payloadKeys) > 0 { + cmd := vb.client.B().Hmget().Key(vb.keys.payloadKey(instance)).Field(payloadKeys...) + res, err := vb.client.Do(ctx, cmd.Build()).AsStrSlice() + if err != nil { + return nil, fmt.Errorf("reading payloads: %w", err) + } + + for i, event := range events { + event.Attributes, err = history.DeserializeAttributes(event.Type, []byte(res[i])) + if err != nil { + return nil, fmt.Errorf("deserializing attributes for event %v: %w", event.Type, err) + } + } + } + + return events, nil +} + +func (vb *valkeyBackend) GetWorkflowInstanceState(ctx context.Context, instance *core.WorkflowInstance) (core.WorkflowInstanceState, error) { + instanceState, err := readInstance(ctx, vb.client, vb.keys.instanceKey(instance)) + if err != nil { + return core.WorkflowInstanceStateActive, err + } + + return instanceState.State, nil +} + +func (vb *valkeyBackend) CancelWorkflowInstance(ctx context.Context, instance *core.WorkflowInstance, event *history.Event) error { + // Read the instance to check if it exists + instanceState, err := readInstance(ctx, vb.client, vb.keys.instanceKey(instance)) + if err != nil { + return err + } + + // Prepare event data + eventData, payloadData, err := marshalEvent(event) + if err != nil { + return err + } + + keyInfo := vb.workflowQueue.Keys(workflow.Queue(instanceState.Queue)) + + // Cancel instance + err = cancelWorkflowInstanceScript.Exec(ctx, vb.client, []string{ + vb.keys.payloadKey(instance), + vb.keys.pendingEventsKey(instance), + keyInfo.SetKey, + keyInfo.StreamKey, + }, []string{ + event.ID, + eventData, + payloadData, + instanceSegment(instance), + }).Error() + + if err != nil { + return fmt.Errorf("canceling workflow instance: %w", err) + } + + return nil +} + +func (vb *valkeyBackend) RemoveWorkflowInstance(ctx context.Context, instance *core.WorkflowInstance) error { + i, err := readInstance(ctx, vb.client, vb.keys.instanceKey(instance)) + if err != nil { + return err + } + + if i.State != core.WorkflowInstanceStateFinished && i.State != core.WorkflowInstanceStateContinuedAsNew { + return backend.ErrInstanceNotFinished + } + + return vb.deleteInstance(ctx, instance) +} + +func (vb *valkeyBackend) RemoveWorkflowInstances(_ context.Context, _ ...backend.RemovalOption) error { + return backend.ErrNotSupported{ + Message: "not supported, use auto-expiration", + } +} + +type instanceState struct { + Queue string `json:"queue"` + + Instance *core.WorkflowInstance `json:"instance,omitempty"` + State core.WorkflowInstanceState `json:"state,omitempty"` + + Metadata *metadata.WorkflowMetadata `json:"metadata,omitempty"` + + CreatedAt time.Time `json:"created_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + + LastSequenceID int64 `json:"last_sequence_id,omitempty"` +} + +func readInstance(ctx context.Context, client valkey.Client, instanceKey string) (*instanceState, error) { + val, err := client.Do(ctx, client.B().Get().Key(instanceKey).Build()).ToString() + if err != nil { + if valkey.IsValkeyNil(err) { + return nil, backend.ErrInstanceNotFound + } + return nil, fmt.Errorf("reading instance: %w", err) + } + + var state instanceState + if err := json.Unmarshal([]byte(val), &state); err != nil { + return nil, fmt.Errorf("unmarshaling instance state: %w", err) + } + + return &state, nil +} + +func (vb *valkeyBackend) readActiveInstanceExecution(ctx context.Context, instanceID string) (*core.WorkflowInstance, error) { + val, err := vb.client.Do(ctx, vb.client.B().Get().Key(vb.keys.activeInstanceExecutionKey(instanceID)).Build()).ToString() + if err != nil { + return nil, err + } + + if val == "" { + return nil, nil + } + + var instance *core.WorkflowInstance + if err := json.Unmarshal([]byte(val), &instance); err != nil { + return nil, fmt.Errorf("unmarshaling instance: %w", err) + } + + return instance, nil +} diff --git a/backend/valkey/keys.go b/backend/valkey/keys.go new file mode 100644 index 00000000..30abeb9f --- /dev/null +++ b/backend/valkey/keys.go @@ -0,0 +1,73 @@ +package valkey + +import ( + "fmt" + + "github.com/cschleiden/go-workflows/core" +) + +type keys struct { + // Ensure prefix ends with `:` + prefix string +} + +func newKeys(prefix string) *keys { + if prefix != "" && prefix[len(prefix)-1] != ':' { + prefix += ":" + } + + return &keys{prefix: prefix} +} + +// activeInstanceExecutionKey returns the key for the latest execution of the given instance +func (k *keys) activeInstanceExecutionKey(instanceID string) string { + return fmt.Sprintf("%sactive-instance-execution:%v", k.prefix, instanceID) +} + +func instanceSegment(instance *core.WorkflowInstance) string { + return fmt.Sprintf("%v:%v", instance.InstanceID, instance.ExecutionID) +} + +func (k *keys) instanceKey(instance *core.WorkflowInstance) string { + return k.instanceKeyFromSegment(instanceSegment(instance)) +} + +func (k *keys) instanceKeyFromSegment(segment string) string { + return fmt.Sprintf("%sinstance:%v", k.prefix, segment) +} + +// instancesByCreation returns the key for the ZSET that contains all instances sorted by creation date. The score is the +// creation time as a unix timestamp. Used for listing all workflow instances in the diagnostics UI. +func (k *keys) instancesByCreation() string { + return fmt.Sprintf("%sinstances-by-creation", k.prefix) +} + +// instancesActive returns the key for the SET that contains all active instances. Used for reporting active workflow +// instances in stats. +func (k *keys) instancesActive() string { + return fmt.Sprintf("%sinstances-active", k.prefix) +} + +func (k *keys) instancesExpiring() string { + return fmt.Sprintf("%sinstances-expiring", k.prefix) +} + +func (k *keys) pendingEventsKey(instance *core.WorkflowInstance) string { + return fmt.Sprintf("%spending-events:%v", k.prefix, instanceSegment(instance)) +} + +func (k *keys) historyKey(instance *core.WorkflowInstance) string { + return fmt.Sprintf("%shistory:%v", k.prefix, instanceSegment(instance)) +} + +func (k *keys) futureEventsKey() string { + return fmt.Sprintf("%sfuture-events", k.prefix) +} + +func (k *keys) futureEventKey(instance *core.WorkflowInstance, scheduleEventID int64) string { + return fmt.Sprintf("%sfuture-event:%v:%v", k.prefix, instanceSegment(instance), scheduleEventID) +} + +func (k *keys) payloadKey(instance *core.WorkflowInstance) string { + return fmt.Sprintf("%spayload:%v", k.prefix, instanceSegment(instance)) +} diff --git a/backend/valkey/keys_test.go b/backend/valkey/keys_test.go new file mode 100644 index 00000000..867c045d --- /dev/null +++ b/backend/valkey/keys_test.go @@ -0,0 +1,24 @@ +package valkey + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_newKeys(t *testing.T) { + t.Run("WithEmptyPrefix", func(t *testing.T) { + k := newKeys("") + require.Empty(t, k.prefix) + }) + + t.Run("WithNonEmptyPrefixWithoutColon", func(t *testing.T) { + k := newKeys("prefix") + require.Equal(t, "prefix:", k.prefix) + }) + + t.Run("WithNonEmptyPrefixWithColon", func(t *testing.T) { + k := newKeys("prefix:") + require.Equal(t, "prefix:", k.prefix) + }) +} diff --git a/backend/valkey/options.go b/backend/valkey/options.go new file mode 100644 index 00000000..37404983 --- /dev/null +++ b/backend/valkey/options.go @@ -0,0 +1,59 @@ +package valkey + +import ( + "time" + + "github.com/cschleiden/go-workflows/backend" +) + +type Options struct { + *backend.Options + + BlockTimeout time.Duration + + AutoExpiration time.Duration + AutoExpirationContinueAsNew time.Duration + + KeyPrefix string +} + +type BackendOption func(*Options) + +// WithKeyPrefix sets the prefix for all keys used in the Valkey backend. +func WithKeyPrefix(prefix string) BackendOption { + return func(o *Options) { + o.KeyPrefix = prefix + } +} + +// WithBlockTimeout sets the timeout for blocking operations like dequeuing a workflow or activity task +func WithBlockTimeout(timeout time.Duration) BackendOption { + return func(o *Options) { + o.BlockTimeout = timeout + } +} + +// WithAutoExpiration sets the duration after which finished runs will expire from the data store. +// If set to 0 (default), runs will never expire and need to be manually removed. +func WithAutoExpiration(expireFinishedRunsAfter time.Duration) BackendOption { + return func(o *Options) { + o.AutoExpiration = expireFinishedRunsAfter + } +} + +// WithAutoExpirationContinueAsNew sets the duration after which runs that were completed with `ContinueAsNew` +// automatically expire. +// If set to 0 (default), the overall expiration setting set with `WithAutoExpiration` will be used. +func WithAutoExpirationContinueAsNew(expireContinuedAsNewRunsAfter time.Duration) BackendOption { + return func(o *Options) { + o.AutoExpirationContinueAsNew = expireContinuedAsNewRunsAfter + } +} + +func WithBackendOptions(opts ...backend.BackendOption) BackendOption { + return func(o *Options) { + for _, opt := range opts { + opt(o.Options) + } + } +} diff --git a/backend/valkey/queue.go b/backend/valkey/queue.go new file mode 100644 index 00000000..9ec9e902 --- /dev/null +++ b/backend/valkey/queue.go @@ -0,0 +1,289 @@ +package valkey + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "github.com/cschleiden/go-workflows/workflow" + "github.com/google/uuid" + "github.com/valkey-io/valkey-go" +) + +var ( + prepareCmd *valkey.Lua + enqueueCmd *valkey.Lua + completeCmd *valkey.Lua + recoverCmd *valkey.Lua + sizeCmd *valkey.Lua +) + +type taskQueue[T any] struct { + keyPrefix string + tasktype string + groupName string + workerName string + queueSetKey string +} + +type TaskItem[T any] struct { + // TaskID is the generated ID of the task item + TaskID string + + // ID is the provided id + ID string + + // Optional data stored with a task, needs to be serializable + Data T +} + +type KeyInfo struct { + StreamKey string + SetKey string +} + +func newTaskQueue[T any](keyPrefix, tasktype, workerName string) (*taskQueue[T], error) { + // Ensure the key prefix ends with a colon + if keyPrefix != "" && keyPrefix[len(keyPrefix)-1] != ':' { + keyPrefix += ":" + } + + if workerName == "" { + workerName = uuid.NewString() + } + + tq := &taskQueue[T]{ + keyPrefix: keyPrefix, + tasktype: tasktype, + groupName: "task-workers", + workerName: workerName, + queueSetKey: fmt.Sprintf("%s%s:queues", keyPrefix, tasktype), + } + + // Load all Lua scripts + cmdMapping := map[string]**valkey.Lua{ + "queue/prepare.lua": &prepareCmd, + "queue/size.lua": &sizeCmd, + "queue/recover.lua": &recoverCmd, + "queue/enqueue.lua": &enqueueCmd, + "queue/complete.lua": &completeCmd, + } + + if err := loadScripts(cmdMapping); err != nil { + return nil, fmt.Errorf("loading Lua scripts: %w", err) + } + + return tq, nil +} + +func (q *taskQueue[T]) Prepare(ctx context.Context, client valkey.Client, queues []workflow.Queue) error { + var queueStreamKeys []string + for _, queue := range queues { + queueStreamKeys = append(queueStreamKeys, q.Keys(queue).StreamKey) + } + + err := prepareCmd.Exec(ctx, client, queueStreamKeys, []string{q.groupName}).Error() + if err != nil && !valkey.IsValkeyNil(err) { + return fmt.Errorf("preparing queues: %w", err) + } + + return nil +} + +func (q *taskQueue[T]) Keys(queue workflow.Queue) KeyInfo { + return KeyInfo{ + StreamKey: fmt.Sprintf("%stask-stream:%s:%s", q.keyPrefix, queue, q.tasktype), + SetKey: fmt.Sprintf("%stask-set:%s:%s", q.keyPrefix, queue, q.tasktype), + } +} + +func (q *taskQueue[T]) Size(ctx context.Context, client valkey.Client) (map[workflow.Queue]int64, error) { + sizeData, err := sizeCmd.Exec(ctx, client, []string{q.queueSetKey}, []string{}).ToArray() + if err != nil { + return nil, fmt.Errorf("getting queue size: %w", err) + } + + res := map[workflow.Queue]int64{} + for i := 0; i < len(sizeData); i += 2 { + queueName, err := sizeData[i].ToString() + if err != nil { + return nil, fmt.Errorf("parsing queue name: %w", err) + } + + queueName = strings.TrimPrefix(queueName, q.keyPrefix) + queueName = strings.Split(queueName, ":")[1] // queue name is the third part of the key (0-indexed) + + queue := workflow.Queue(queueName) + size, err := sizeData[i+1].AsInt64() + if err != nil { + return nil, fmt.Errorf("parsing queue size: %w", err) + } + + res[queue] = size + } + + return res, nil +} + +func (q *taskQueue[T]) Enqueue(ctx context.Context, client valkey.Client, queue workflow.Queue, id string, data *T) error { + ds, err := json.Marshal(data) + if err != nil { + return err + } + + queueStreamInfo := q.Keys(queue) + if err := enqueueCmd.Exec(ctx, client, []string{q.queueSetKey, queueStreamInfo.SetKey, queueStreamInfo.StreamKey}, []string{q.groupName, id, string(ds)}).Error(); err != nil { + return fmt.Errorf("enqueueing task: %w", err) + } + + return nil +} + +func (q *taskQueue[T]) Dequeue(ctx context.Context, client valkey.Client, queues []workflow.Queue, lockTimeout, timeout time.Duration) (*TaskItem[T], error) { + // Try to recover abandoned tasks + task, err := q.recover(ctx, client, queues, lockTimeout) + if err != nil { + return nil, fmt.Errorf("checking for abandoned tasks: %w", err) + } + + if task != nil { + return task, nil + } + + // Check for new tasks + streamKeys := make([]string, 0, len(queues)) + for _, queue := range queues { + keyInfo := q.Keys(queue) + streamKeys = append(streamKeys, keyInfo.StreamKey) + } + + ids := make([]string, len(streamKeys)) + for i := range ids { + ids[i] = ">" + } + + // Try to dequeue from all given queues + cmd := client.B().Xreadgroup().Group(q.groupName, q.workerName).Block(timeout.Milliseconds()).Streams().Key(streamKeys...).Id(ids...) + results, err := client.Do(ctx, cmd.Build()).AsXRead() + if err != nil && !valkey.IsValkeyNil(err) { + return nil, fmt.Errorf("dequeueing task: %w", err) + } + + var msgs []valkey.XRangeEntry + for _, streamResult := range results { + msgs = append(msgs, streamResult...) + } + + if len(results) == 0 || len(msgs) == 0 || valkey.IsValkeyNil(err) { + return nil, nil + } + + return msgToTaskItem[T](msgs[0]) +} + +func (q *taskQueue[T]) Extend(ctx context.Context, client valkey.Client, queue workflow.Queue, taskID string) error { + // Claiming a message resets the idle timer + err := client.Do(ctx, client.B().Xclaim().Key(q.Keys(queue).StreamKey).Group(q.groupName).Consumer(q.workerName).MinIdleTime("0").Id(taskID).Build()).Error() + if err != nil { + // Check if error is due to no data available (nil response) + if valkey.IsValkeyNil(err) { + return nil + } + return fmt.Errorf("extending lease: %w", err) + } + + return nil +} + +func (q *taskQueue[T]) Complete(ctx context.Context, client valkey.Client, queue workflow.Queue, taskID string) error { + err := completeCmd.Exec(ctx, client, []string{ + q.Keys(queue).SetKey, + q.Keys(queue).StreamKey, + }, []string{taskID, q.groupName}).Error() + if err != nil && !valkey.IsValkeyNil(err) { + return fmt.Errorf("completing task: %w", err) + } + + return nil +} + +func (q *taskQueue[T]) recover(ctx context.Context, client valkey.Client, queues []workflow.Queue, idleTimeout time.Duration) (*TaskItem[T], error) { + var keys []string + for _, queue := range queues { + keys = append(keys, q.Keys(queue).StreamKey) + } + + r, err := recoverCmd.Exec(ctx, client, keys, []string{q.groupName, q.workerName, strconv.FormatInt(idleTimeout.Milliseconds(), 10), "0"}).ToArray() + if err != nil { + if valkey.IsValkeyNil(err) { + return nil, nil + } + + return nil, fmt.Errorf("recovering abandoned task: %w", err) + } + + if len(r) > 1 { + msgs, err := r[1].ToArray() + if err != nil { + return nil, fmt.Errorf("recovering abandoned task: %w", err) + } + if len(msgs) > 0 && !msgs[0].IsNil() { + msgData, err := msgs[0].ToArray() + if err != nil { + return nil, fmt.Errorf("recovering abandoned task: %w", err) + } + id, err := msgData[0].ToString() + if err != nil { + return nil, fmt.Errorf("recovering abandoned task: %w", err) + } + rawValues, err := msgData[1].ToArray() + if err != nil { + return nil, fmt.Errorf("recovering abandoned task: %w", err) + } + values := make(map[string]string) + for i := 0; i < len(rawValues); i += 2 { + key, err := rawValues[i].ToString() + if err != nil { + return nil, fmt.Errorf("recovering abandoned task: %w", err) + } + value, err := rawValues[i+1].ToString() + if err != nil { + return nil, fmt.Errorf("recovering abandoned task: %w", err) + } + values[key] = value + } + + return msgToTaskItem[T](valkey.XRangeEntry{ + ID: id, + FieldValues: values, + }) + } + } + + return nil, nil +} + +func msgToTaskItem[T any](msg valkey.XRangeEntry) (*TaskItem[T], error) { + id, idOk := msg.FieldValues["id"] + data, dataOk := msg.FieldValues["data"] + + var t T + if dataOk && data != "" { + if err := json.Unmarshal([]byte(data), &t); err != nil { + return nil, err + } + } + + if !idOk { + return nil, fmt.Errorf("message missing id field") + } + + return &TaskItem[T]{ + TaskID: msg.ID, + ID: id, + Data: t, + }, nil +} diff --git a/backend/valkey/queue_test.go b/backend/valkey/queue_test.go new file mode 100644 index 00000000..5b396668 --- /dev/null +++ b/backend/valkey/queue_test.go @@ -0,0 +1,237 @@ +package valkey + +import ( + "context" + "testing" + "time" + + "github.com/cschleiden/go-workflows/core" + "github.com/cschleiden/go-workflows/workflow" + "github.com/stretchr/testify/assert" +) + +func Test_TaskQueue(t *testing.T) { + // These tests rely on a Valkey server on localhost:6379. + // Skip when running with -short. + if testing.Short() { + t.Skip() + } + + taskType := "taskType" + + client := getClient() + + lockTimeout := time.Millisecond * 10 + blockTimeout := time.Millisecond * 10 + + tests := []struct { + name string + f func(t *testing.T, q *taskQueue[any]) + }{ + { + name: "Simple enqueue/dequeue", + f: func(t *testing.T, q *taskQueue[any]) { + ctx := context.Background() + + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", nil)) + + task, err := q.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.NotNil(t, task) + assert.Equal(t, "t1", task.ID) + }, + }, + { + name: "Size", + f: func(t *testing.T, q *taskQueue[any]) { + ctx := context.Background() + + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", nil)) + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t2", nil)) + assert.NoError(t, q.Enqueue(ctx, client, "OtherQueue", "t3", nil)) + + s1, err := q.Size(ctx, client) + assert.NoError(t, err) + assert.Equal(t, map[workflow.Queue]int64{ + workflow.QueueDefault: 2, + "OtherQueue": 1, + }, s1) + }, + }, + { + name: "Guarantee uniqueness", + f: func(t *testing.T, q *taskQueue[any]) { + ctx := context.Background() + + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", nil)) + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", nil)) + + task, err := q.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.NotNil(t, task) + + assert.NoError(t, q.Complete(ctx, client, workflow.QueueDefault, task.TaskID)) + + // After completion, the same id can be enqueued again + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", nil)) + }, + }, + { + name: "Store custom data", + f: func(t *testing.T, _ *taskQueue[any]) { + type foo struct { + Count int `json:"count"` + Name string `json:"name"` + } + + ctx := context.Background() + + q, err := newTaskQueue[foo]("prefix", taskType, "") + assert.NoError(t, err) + + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", &foo{ + Count: 1, + Name: "bar", + })) + + task, err := q.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.NotNil(t, task) + assert.Equal(t, "t1", task.ID) + assert.Equal(t, 1, task.Data.Count) + assert.Equal(t, "bar", task.Data.Name) + }, + }, + { + name: "Simple enqueue/dequeue different worker", + f: func(t *testing.T, q *taskQueue[any]) { + ctx := context.Background() + + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", nil)) + + q2, err := newTaskQueue[any]("prefix", taskType, "") + assert.NoError(t, err) + + // Dequeue using second worker + task, err := q2.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.NotNil(t, task) + assert.Equal(t, "t1", task.ID) + }, + }, + { + name: "Complete removes task", + f: func(t *testing.T, q *taskQueue[any]) { + q2, err := newTaskQueue[any]("prefix", taskType, "") + assert.NoError(t, err) + + ctx := context.Background() + + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", nil)) + + task, err := q.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.NotNil(t, task) + + // Complete task using second worker + assert.NoError(t, q2.Complete(ctx, client, workflow.QueueDefault, task.TaskID)) + + time.Sleep(time.Millisecond * 10) + + // Try to recover using second worker; should not find anything + task2, err := q2.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.Nil(t, task2) + }, + }, + { + name: "Recover task", + f: func(t *testing.T, _ *taskQueue[any]) { + type taskData struct { + Count int `json:"count"` + } + q, err := newTaskQueue[taskData]("prefix", taskType, "") + assert.NoError(t, err) + + ctx := context.Background() + + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", &taskData{Count: 42})) + + q2, err := newTaskQueue[taskData]("prefix", taskType, "") + assert.NoError(t, err) + + task, err := q2.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.NotNil(t, task) + assert.Equal(t, "t1", task.ID) + + time.Sleep(time.Millisecond * 10) + + // Assume q2 crashed, recover from other worker + recoveredTask, err := q.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, time.Millisecond*1, blockTimeout) + assert.NoError(t, err) + assert.NotNil(t, recoveredTask) + assert.Equal(t, task, recoveredTask) + }, + }, + { + name: "Extending task prevents recovering", + f: func(t *testing.T, q *taskQueue[any]) { + ctx := context.Background() + + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", nil)) + + // Create second worker (with different name) + q2, err := newTaskQueue[any]("prefix", taskType, "") + assert.NoError(t, err) + + task, err := q2.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.NotNil(t, task) + assert.Equal(t, "t1", task.ID) + + time.Sleep(time.Millisecond * 5) + + assert.NoError(t, q2.Extend(ctx, client, workflow.QueueDefault, task.TaskID)) + + // Use large lock timeout; should not recover + recoveredTask, err := q.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, time.Second*2, blockTimeout) + assert.NoError(t, err) + assert.Nil(t, recoveredTask) + }, + }, + { + name: "Will only dequeue from given queue", + f: func(t *testing.T, q *taskQueue[any]) { + ctx := context.Background() + + assert.NoError(t, q.Enqueue(ctx, client, workflow.QueueDefault, "t1", nil)) + + assert.NoError(t, q.Prepare(ctx, client, []workflow.Queue{core.QueueSystem, workflow.QueueDefault})) + + task, err := q.Dequeue(ctx, client, []workflow.Queue{core.QueueSystem}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.Nil(t, task) + + task, err = q.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout) + assert.NoError(t, err) + assert.NotNil(t, task) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + // Best-effort cleanup between tests + client.Do(ctx, client.B().Flushdb().Build()) + + q, err := newTaskQueue[any]("prefix", taskType, "") + assert.NoError(t, err) + + assert.NoError(t, q.Prepare(ctx, client, []workflow.Queue{workflow.QueueDefault})) + + tt.f(t, q) + }) + } +} diff --git a/backend/valkey/scripts/cancel_workflow_instance.lua b/backend/valkey/scripts/cancel_workflow_instance.lua new file mode 100644 index 00000000..d3237c96 --- /dev/null +++ b/backend/valkey/scripts/cancel_workflow_instance.lua @@ -0,0 +1,23 @@ +local payloadHashKey = KEYS[1] +local pendingEventsKey = KEYS[2] +local workflowSetKey = KEYS[3] +local workflowStreamKey = KEYS[4] + +local eventId = ARGV[1] +local eventData = ARGV[2] +local payload = ARGV[3] +local instanceSegment = ARGV[4] + +-- Add event payload +redis.pcall("HSETNX", payloadHashKey, eventId, payload) + +-- Add event to pending events stream +server.call("XADD", pendingEventsKey, "*", "event", eventData) + +-- Queue workflow task +local added = server.call("SADD", workflowSetKey, instanceSegment) +if added == 1 then + server.call("XADD", workflowStreamKey, "*", "id", instanceSegment, "data", "") +end + +return true diff --git a/backend/valkey/scripts/complete_activity_task.lua b/backend/valkey/scripts/complete_activity_task.lua new file mode 100644 index 00000000..762085f0 --- /dev/null +++ b/backend/valkey/scripts/complete_activity_task.lua @@ -0,0 +1,43 @@ +-- Complete an activity task, add the result event to the workflow instance, and enqueue the workflow task +-- KEYS[1] = activity set key +-- KEYS[2] = activity stream key +-- KEYS[3] = pending events stream key +-- KEYS[4] = payload hash key +-- KEYS[5] = workflow queues set key +-- KEYS[6] = workflow set key (for specific queue) +-- KEYS[7] = workflow stream key (for specific queue) +-- ARGV[1] = task id (activity) +-- ARGV[2] = group name (activity group) +-- ARGV[3] = event id +-- ARGV[4] = event data (json, without attributes) +-- ARGV[5] = payload data (json, can be empty) +-- ARGV[6] = workflow queue group name +-- ARGV[7] = workflow instance segment id + +-- Complete the activity task (from queue/complete.lua) +local task = server.call("XRANGE", KEYS[2], ARGV[1], ARGV[1]) +if #task == 0 then + return nil +end + +local id = task[1][2][2] +server.call("SREM", KEYS[1], id) +server.call("XACK", KEYS[2], "NOMKSTREAM", ARGV[2], ARGV[1]) +server.call("XDEL", KEYS[2], ARGV[1]) + +-- Add event to pending events stream for workflow instance +server.call("XADD", KEYS[3], "*", "event", ARGV[4]) + +-- Store payload if provided (only if not empty) +if ARGV[5] ~= "" then + redis.pcall("HSETNX", KEYS[4], ARGV[3], ARGV[5]) +end + +-- Enqueue workflow task (from queue/enqueue.lua) +server.call("SADD", KEYS[5], KEYS[6]) +local added = server.call("SADD", KEYS[6], ARGV[7]) +if added == 1 then + server.call("XADD", KEYS[7], "*", "id", ARGV[7], "data", "") +end + +return true diff --git a/backend/valkey/scripts/complete_workflow_task.lua b/backend/valkey/scripts/complete_workflow_task.lua new file mode 100644 index 00000000..d112bb17 --- /dev/null +++ b/backend/valkey/scripts/complete_workflow_task.lua @@ -0,0 +1,234 @@ +local keyIdx = 1 +local argvIdx = 1 + +local getKey = function() + local key = KEYS[keyIdx] + keyIdx = keyIdx + 1 + return key +end + +local getArgv = function() + local argv = ARGV[argvIdx] + argvIdx = argvIdx + 1 + -- server.call("ECHO", argv) + return argv +end + +-- Shared keys +local instanceKey = getKey() +local historyStreamKey = getKey() +local pendingEventsKey = getKey() +local payloadHashKey = getKey() +local futureEventZSetKey = getKey() +local activeInstancesKey = getKey() +local instancesByCreation = getKey() + +local workflowSetKey = getKey() +local workflowStreamKey = getKey() +local workflowQueuesSetKey = getKey() + +local prefix = getArgv() +local instanceSegment = getArgv() + +local storePayload = function(eventId, payload) + redis.pcall("HSETNX", payloadHashKey, eventId, payload) +end + +-- Read instance +local instance = cjson.decode(server.call("GET", instanceKey)) + +-- Add executed events to history +local executedEvents = tonumber(getArgv()) +local lastSequenceId = 0 +for i = 1, executedEvents do + local eventId = getArgv() + local eventData = getArgv() + local payloadData = getArgv() + local sequenceId = getArgv() + + -- Add event to history + server.call("XADD", historyStreamKey, sequenceId, "event", eventData) + + storePayload(eventId, payloadData) + + lastSequenceId = tonumber(sequenceId) +end + +-- Remove executed pending events +local lastPendingEventMessageId = getArgv() +server.call("XTRIM", pendingEventsKey, "MINID", lastPendingEventMessageId) +server.call("XDEL", pendingEventsKey, lastPendingEventMessageId) + +-- Update instance state +local now = getArgv() +local nowUnix = tonumber(getArgv()) +local state = tonumber(getArgv()) + +-- State constants +local ContinuedAsNew = tonumber(getArgv()) +local Finished = tonumber(getArgv()) + +instance["state"] = state + +-- If workflow instance finished, remove active execution +local activeInstanceExecutionKey = getKey() +if state == ContinuedAsNew or state == Finished then + -- Remove active execution + server.call("DEL", activeInstanceExecutionKey) + + instance["completed_at"] = now + + server.call("SREM", activeInstancesKey, instanceSegment) +end + +if lastSequenceId > 0 then + instance["last_sequence_id"] = lastSequenceId +end + +server.call("SET", instanceKey, cjson.encode(instance)) + +-- Remove canceled timers +local timersToCancel = tonumber(getArgv()) +for i = 1, timersToCancel do + local futureEventKey = getKey() + + local eventRemoved = server.call("ZREM", futureEventZSetKey, futureEventKey) + -- Event might've become visible while this task was being processed, in that + -- case it would be already removed from futureEventZSetKey + if eventRemoved == 1 then + -- remove payload + local eventId = server.call("HGET", futureEventKey, "id") + server.call("HDEL", payloadHashKey, eventId) + -- remove event hash + server.call("DEL", futureEventKey) + end +end + +-- Schedule timers +local timersToSchedule = tonumber(getArgv()) +for i = 1, timersToSchedule do + local eventId = getArgv() + local timestamp = getArgv() + local eventData = getArgv() + local payloadData = getArgv() + + local futureEventKey = getKey() + + server.call("ZADD", futureEventZSetKey, timestamp, futureEventKey) + server.call("HSET", futureEventKey, "instance", instanceSegment, "id", eventId, "event", eventData, "queue", instance["queue"]) + storePayload(eventId, payloadData) +end + +-- Schedule activities +local activities = tonumber(getArgv()) + +for i = 1, activities do + local activityQueue = getArgv() + local activityId = getArgv() + local activityData = getArgv() + + local activitySetKey = prefix .. "task-set:" .. activityQueue .. ":activities" + local activityStreamKey = prefix .. "task-stream:" .. activityQueue .. ":activities" + server.call("SADD", prefix .. "activities:queues", activitySetKey) + + local added = server.call("SADD", activitySetKey, activityId) + if added == 1 then + server.call("XADD", activityStreamKey, "*", "id", activityId, "data", activityData) + end +end + +-- Send events to other workflow instances +local otherWorkflowInstances = tonumber(getArgv()) +for i = 1, otherWorkflowInstances do + local targetInstanceKey = getKey() + local targetActiveInstanceExecutionKey = getKey() + + local targetInstanceSegment = getArgv() + local targetInstanceId = getArgv() + local createNewInstance = tonumber(getArgv()) + local eventsToDeliver = tonumber(getArgv()) + local skipEvents = false + + -- Creating a new instance? + if createNewInstance == 1 then + local targetInstanceState = getArgv() + local targetActiveInstanceExecutionState = getArgv() + + local conflictEventId = getArgv() + local conflictEventData = getArgv() + local conflictEventPayloadData = getArgv() + + -- Does the instance exist already? + local instanceExists = server.call("EXISTS", targetActiveInstanceExecutionKey) + if instanceExists == 1 then + server.call("XADD", pendingEventsKey, "*", "event", conflictEventData) + storePayload(conflictEventId, conflictEventPayloadData) + server.call("ECHO", + "Conflict detected, event " .. + conflictEventId .. " was not delivered to instance " .. targetInstanceSegment .. ".") + + skipEvents = true + else + -- Create new instance + server.call("SETNX", targetInstanceKey, targetInstanceState) + + -- Set active execution + server.call("SET", targetActiveInstanceExecutionKey, targetActiveInstanceExecutionState) + + -- Track active instance + server.call("SADD", activeInstancesKey, targetInstanceSegment) + server.call("ZADD", instancesByCreation, nowUnix, targetInstanceSegment) + end + end + + local instanceQueueSetKey = getKey() + local instanceQueueStreamKey = getKey() + local instancePendingEventsKey = getKey() + local instancePayloadHashKey = getKey() + + for j = 1, eventsToDeliver do + local eventId = getArgv() + local eventData = getArgv() + local payloadData = getArgv() + + if not skipEvents then + -- Add event to pending events + server.call("XADD", instancePendingEventsKey, "*", "event", eventData) + + -- Store payload + redis.pcall("HSETNX", instancePayloadHashKey, eventId, payloadData) + end + end + + -- If events were delivered, try to queue a workflow task + if not skipEvents then + -- Enqueue workflow task + server.call("SADD", workflowQueuesSetKey, instanceQueueSetKey) + local added = server.call("SADD", instanceQueueSetKey, targetInstanceSegment) + if added == 1 then + server.call("XADD", instanceQueueStreamKey, "*", "id", targetInstanceSegment, "data", "") + end + end +end + +-- Complete workflow task and mark instance task as completed +local taskId = getArgv() +local groupName = getArgv() +local task = server.call("XRANGE", workflowStreamKey, taskId, taskId) +if #task ~= 0 then + local id = task[1][2][2] + server.call("SREM", workflowSetKey, id) + server.call("XACK", workflowStreamKey, groupName, taskId) + server.call("XDEL", workflowStreamKey, taskId) +end + +-- If there are pending events, queue the instance again +local pending_events = server.call("XLEN", pendingEventsKey) +if pending_events > 0 then + local added = server.call("SADD", workflowSetKey, instanceSegment) + if added == 1 then + server.call("XADD", workflowStreamKey, "*", "id", instanceSegment, "data", "") + end +end + +return true \ No newline at end of file diff --git a/backend/valkey/scripts/create_workflow_instance.lua b/backend/valkey/scripts/create_workflow_instance.lua new file mode 100644 index 00000000..92355341 --- /dev/null +++ b/backend/valkey/scripts/create_workflow_instance.lua @@ -0,0 +1,65 @@ +local keyIdx = 1 +local argvIdx = 1 + +local getKey = function() + local key = KEYS[keyIdx] + keyIdx = keyIdx + 1 + return key +end + +local getArgv = function() + local argv = ARGV[argvIdx] + argvIdx = argvIdx + 1 + return argv +end + +local instanceKey = getKey() +local activeInstanceExecutionKey = getKey() +local pendingEventsKey = getKey() +local payloadHashKey = getKey() + +local instancesActiveKey = getKey() +local instancesByCreation = getKey() + +local workflowSetKey = getKey() +local workflowStreamKey = getKey() +local workflowQueuesSet = getKey() + +local instanceSegment = getArgv() + +-- Is there an existing instance with active execution? +local instanceExists = server.call("EXISTS", activeInstanceExecutionKey) +if instanceExists == 1 then + return redis.error_reply("ERR InstanceAlreadyExists") +end + +-- Create new instance +local instanceState = getArgv() +server.call("SETNX", instanceKey, instanceState) + +-- Set active execution +local activeInstanceExecutionState = getArgv() +server.call("SET", activeInstanceExecutionKey, activeInstanceExecutionState) + +-- Track active instance +server.call("SADD", instancesActiveKey, instanceSegment) + +-- add initial event & payload +local eventId = getArgv() +local eventData = getArgv() +server.call("XADD", pendingEventsKey, "*", "event", eventData) + +local payload = getArgv() +redis.pcall("HSETNX", payloadHashKey, eventId, payload) + +local creationTimestamp = tonumber(getArgv()) +server.call("ZADD", instancesByCreation, creationTimestamp, instanceSegment) + +-- queue workflow task +server.call("SADD", workflowQueuesSet, workflowSetKey) -- track queue +local added = server.call("SADD", workflowSetKey, instanceSegment) +if added == 1 then + server.call("XADD", workflowStreamKey, "*", "id", instanceSegment, "data", "") +end + +return true \ No newline at end of file diff --git a/backend/valkey/scripts/delete_instance.lua b/backend/valkey/scripts/delete_instance.lua new file mode 100644 index 00000000..6538eaf5 --- /dev/null +++ b/backend/valkey/scripts/delete_instance.lua @@ -0,0 +1,14 @@ +local instanceKey = KEYS[1] +local pendingEventsKey = KEYS[2] +local historyKey = KEYS[3] +local payloadKey = KEYS[4] +local activeInstanceExecutionKey = KEYS[5] +local instancesByCreationKey = KEYS[6] + +local instanceSegment = ARGV[1] + +-- Delete all instance-related keys +server.call("DEL", instanceKey, pendingEventsKey, historyKey, payloadKey, activeInstanceExecutionKey) + +-- Remove instance from sorted set +return server.call("ZREM", instancesByCreationKey, instanceSegment) diff --git a/backend/valkey/scripts/expire_workflow_instance.lua b/backend/valkey/scripts/expire_workflow_instance.lua new file mode 100644 index 00000000..2335f5f6 --- /dev/null +++ b/backend/valkey/scripts/expire_workflow_instance.lua @@ -0,0 +1,29 @@ +-- Set the given expiration time on all keys passed in +-- KEYS[1] - instances-by-creation key +-- KEYS[2] - instances-expiring key +-- KEYS[3] - instance key +-- KEYS[4] - pending events key +-- KEYS[5] - history key +-- KEYS[6] - payload key +-- ARGV[1] - current timestamp +-- ARGV[2] - expiration time in seconds +-- ARGV[3] - expiration timestamp in unix milliseconds +-- ARGV[4] - instance segment + +-- Find instances which have already expired and remove from the index set +local expiredInstances = server.call("ZRANGE", KEYS[2], "-inf", ARGV[1], "BYSCORE") +for i = 1, #expiredInstances do + local instanceSegment = expiredInstances[i] + server.call("ZREM", KEYS[1], instanceSegment) -- index set + server.call("ZREM", KEYS[2], instanceSegment) -- expiration set +end + +-- Add expiration time for future cleanup +server.call("ZADD", KEYS[2], ARGV[3], ARGV[4]) + +-- Set expiration on all keys +for i = 3, #KEYS do + server.call("EXPIRE", KEYS[i], ARGV[2]) +end + +return 0 diff --git a/backend/valkey/scripts/queue/complete.lua b/backend/valkey/scripts/queue/complete.lua new file mode 100644 index 00000000..80e50187 --- /dev/null +++ b/backend/valkey/scripts/queue/complete.lua @@ -0,0 +1,21 @@ +-- We need TaskIDs for the stream and caller provided IDs for the set. So first look up +-- the ID in the stream using the TaskID, then remove from the set and the stream +-- KEYS[1] = set +-- KEYS[2] = stream +-- ARGV[1] = task id +-- ARGV[2] = group +-- We have to XACK _and_ XDEL here. See https://github.com/redis/redis/issues/5754 +local task = server.call("XRANGE", KEYS[2], ARGV[1], ARGV[1]) +if #task == 0 then + return nil +end + +local id = task[1][2][2] +server.call("SREM", KEYS[1], id) +server.call("XACK", KEYS[2], "NOMKSTREAM", ARGV[2], ARGV[1]) + +-- Delete the task here. Overall we'll keep the stream at a small size, so fragmentation +-- is not an issue for us. +server.call("XDEL", KEYS[2], ARGV[1]) + +return true \ No newline at end of file diff --git a/backend/valkey/scripts/queue/enqueue.lua b/backend/valkey/scripts/queue/enqueue.lua new file mode 100644 index 00000000..1cc83b14 --- /dev/null +++ b/backend/valkey/scripts/queue/enqueue.lua @@ -0,0 +1,13 @@ +-- KEYS[1] = queues set +-- KEYS[2] = set +-- KEYS[3] = stream +-- ARGV[1] = consumer group +-- ARGV[2] = caller provided id of the task +-- ARGV[3] = additional data to store with the task +server.call("SADD", KEYS[1], KEYS[2]) +local added = server.call("SADD", KEYS[2], ARGV[2]) +if added == 1 then + server.call("XADD", KEYS[3], "*", "id", ARGV[2], "data", ARGV[3]) +end + +return true \ No newline at end of file diff --git a/backend/valkey/scripts/queue/prepare.lua b/backend/valkey/scripts/queue/prepare.lua new file mode 100644 index 00000000..a2e05b7a --- /dev/null +++ b/backend/valkey/scripts/queue/prepare.lua @@ -0,0 +1,30 @@ +-- KEYS[1..n] - queue stream keys +-- ARGV[1] - group name + +for i = 1, #KEYS do + local streamKey = KEYS[i] + local groupName = ARGV[1] + local exists = false + local res = redis.pcall('XINFO', 'GROUPS', streamKey) + + if res and type(res) == 'table' then + for _, groupInfo in ipairs(res) do + if type(groupInfo) == 'table' then + for i = 1, #groupInfo, 2 do + if groupInfo[i] == 'name' and groupInfo[i + 1] == groupName then + exists = true + break + end + end + end + + if exists then + break + end + end + end + + if not exists then + server.call('XGROUP', 'CREATE', streamKey, groupName, '0', 'MKSTREAM') + end +end \ No newline at end of file diff --git a/backend/valkey/scripts/queue/recover.lua b/backend/valkey/scripts/queue/recover.lua new file mode 100644 index 00000000..4572896d --- /dev/null +++ b/backend/valkey/scripts/queue/recover.lua @@ -0,0 +1,16 @@ +-- KEYS[1..n] = queue stream keys +-- ARGV[1] = group name +-- ARGV[2] = consumer/worker name +-- ARGV[3] = min-idle time in ms +-- ARGV[4] = start + +-- Try to recover abandoned tasks +for i = 1, #KEYS do + local stream = KEYS[i] + local recovered = server.call("XAUTOCLAIM", stream, ARGV[1], ARGV[2], ARGV[3], ARGV[4], "COUNT", 1) + if #recovered > 0 then + if #recovered[1] > 0 then + return recovered + end + end +end \ No newline at end of file diff --git a/backend/valkey/scripts/queue/size.lua b/backend/valkey/scripts/queue/size.lua new file mode 100644 index 00000000..09fa3230 --- /dev/null +++ b/backend/valkey/scripts/queue/size.lua @@ -0,0 +1,13 @@ +-- Return a table with the queue name as key and the number of tasks in the queue as value +-- KEYS[1] = stream set key +local res = {} +local r = server.call("SMEMBERS", KEYS[1]) +local idx = 1 +for i = 1, #r, 1 do + local queue = r[i] + local length = server.call("SCARD", queue) + table.insert(res, queue) + table.insert(res, length) +end + +return res diff --git a/backend/valkey/scripts/schedule_future_events.lua b/backend/valkey/scripts/schedule_future_events.lua new file mode 100644 index 00000000..befdcf32 --- /dev/null +++ b/backend/valkey/scripts/schedule_future_events.lua @@ -0,0 +1,39 @@ +-- Find all due future events. For each event: +-- - Look up event data +-- - Add to pending event stream for workflow instance +-- - Try to queue workflow task for workflow instance +-- - Remove event from future event set and delete event data +-- +-- KEYS[1] - future event set key +-- ARGV[1] - current timestamp for zrange +-- ARGV[2] - redis key prefix +-- +-- Note: this does not work with Redis Cluster since not all keys are passed into the script. +-- Find events which should become visible now +local now = ARGV[1] +local events = server.call("ZRANGE", KEYS[1], "-inf", now, "BYSCORE") +local prefix = ARGV[2] +for i = 1, #events do + local instanceSegment = server.call("HGET", events[i], "instance") + local queue = server.call("HGET", events[i], "queue") + + local setKey = prefix .. "task-set:" .. queue .. ":workflows" + local streamKey = prefix .. "task-stream:" .. queue .. ":workflows" + + -- Try to queue workflow task. If a workflow task is already queued, ignore this event for now. + local added = server.call("SADD", setKey, instanceSegment) + if added == 1 then + server.call("XADD", streamKey, "*", "id", instanceSegment, "data", "") + + -- Add event to pending event stream + local eventData = server.call("HGET", events[i], "event") + local pending_events_key = prefix .. "pending-events:" .. instanceSegment + server.call("XADD", pending_events_key, "*", "event", eventData) + + -- Delete event hash data + server.call("DEL", events[i]) + server.call("ZREM", KEYS[1], events[i]) + end +end + +return #events diff --git a/backend/valkey/scripts/signal_workflow.lua b/backend/valkey/scripts/signal_workflow.lua new file mode 100644 index 00000000..ce05e822 --- /dev/null +++ b/backend/valkey/scripts/signal_workflow.lua @@ -0,0 +1,35 @@ +-- Signal a workflow instance by adding an event to its pending events stream and queuing it +-- +-- KEYS[1] - payload hash key +-- KEYS[2] - pending events stream key +-- KEYS[3] - workflow task set key +-- KEYS[4] - workflow task stream key +-- +-- ARGV[1] - event id +-- ARGV[2] - event data (JSON) +-- ARGV[3] - event payload (JSON) +-- ARGV[4] - instance segment + +local payloadHashKey = KEYS[1] +local pendingEventsKey = KEYS[2] +local workflowSetKey = KEYS[3] +local workflowStreamKey = KEYS[4] + +local eventId = ARGV[1] +local eventData = ARGV[2] +local payload = ARGV[3] +local instanceSegment = ARGV[4] + +-- Add event payload +redis.pcall("HSETNX", payloadHashKey, eventId, payload) + +-- Add event to pending events stream +server.call("XADD", pendingEventsKey, "*", "event", eventData) + +-- Queue workflow task +local added = server.call("SADD", workflowSetKey, instanceSegment) +if added == 1 then + server.call("XADD", workflowStreamKey, "*", "id", instanceSegment, "data", "") +end + +return true diff --git a/backend/valkey/signal.go b/backend/valkey/signal.go new file mode 100644 index 00000000..d81b5787 --- /dev/null +++ b/backend/valkey/signal.go @@ -0,0 +1,54 @@ +package valkey + +import ( + "context" + "fmt" + + "github.com/cschleiden/go-workflows/backend" + "github.com/cschleiden/go-workflows/backend/history" + "github.com/cschleiden/go-workflows/workflow" +) + +func (vb *valkeyBackend) SignalWorkflow(ctx context.Context, instanceID string, event *history.Event) error { + // Get current execution of the instance + instance, err := vb.readActiveInstanceExecution(ctx, instanceID) + if err != nil { + return fmt.Errorf("reading active instance execution: %w", err) + } + + if instance == nil { + return backend.ErrInstanceNotFound + } + + instanceState, err := readInstance(ctx, vb.client, vb.keys.instanceKey(instance)) + if err != nil { + return err + } + + eventData, payload, err := marshalEvent(event) + if err != nil { + return fmt.Errorf("marshaling event: %w", err) + } + + queue := workflow.Queue(instanceState.Queue) + queueKeys := vb.workflowQueue.Keys(queue) + + // Execute the Lua script + err = signalWorkflowScript.Exec(ctx, vb.client, []string{ + vb.keys.payloadKey(instanceState.Instance), + vb.keys.pendingEventsKey(instanceState.Instance), + queueKeys.SetKey, + queueKeys.StreamKey, + }, []string{ + event.ID, + eventData, + payload, + instanceSegment(instanceState.Instance), + }).Error() + + if err != nil { + return fmt.Errorf("signaling workflow: %w", err) + } + + return nil +} diff --git a/backend/valkey/stats.go b/backend/valkey/stats.go new file mode 100644 index 00000000..27a1db69 --- /dev/null +++ b/backend/valkey/stats.go @@ -0,0 +1,38 @@ +package valkey + +import ( + "context" + "fmt" + + "github.com/cschleiden/go-workflows/backend" +) + +func (vb *valkeyBackend) GetStats(ctx context.Context) (*backend.Stats, error) { + s := &backend.Stats{} + + // get workflow instances + activeInstances, err := vb.client.Do(ctx, vb.client.B().Scard().Key(vb.keys.instancesActive()).Build()).AsInt64() + if err != nil { + return nil, fmt.Errorf("getting active instances: %w", err) + } + + s.ActiveWorkflowInstances = activeInstances + + // get pending workflow tasks + pendingWorkflows, err := vb.workflowQueue.Size(ctx, vb.client) + if err != nil { + return nil, fmt.Errorf("getting active workflows: %w", err) + } + + s.PendingWorkflowTasks = pendingWorkflows + + // get pending activities + pendingActivities, err := vb.activityQueue.Size(ctx, vb.client) + if err != nil { + return nil, fmt.Errorf("getting active activities: %w", err) + } + + s.PendingActivityTasks = pendingActivities + + return s, nil +} diff --git a/backend/valkey/valkey.go b/backend/valkey/valkey.go new file mode 100644 index 00000000..e624c668 --- /dev/null +++ b/backend/valkey/valkey.go @@ -0,0 +1,135 @@ +package valkey + +import ( + "embed" + "fmt" + "io/fs" + "time" + + "github.com/cschleiden/go-workflows/backend" + "github.com/cschleiden/go-workflows/backend/history" + "github.com/cschleiden/go-workflows/backend/metrics" + "github.com/cschleiden/go-workflows/core" + "github.com/cschleiden/go-workflows/internal/metrickeys" + "github.com/valkey-io/valkey-go" + "go.opentelemetry.io/otel/trace" +) + +var _ backend.Backend = (*valkeyBackend)(nil) + +//go:embed scripts +var luaScripts embed.FS + +var ( + createWorkflowInstanceScript *valkey.Lua + completeWorkflowTaskScript *valkey.Lua + completeActivityTaskScript *valkey.Lua + deleteInstanceScript *valkey.Lua + futureEventsScript *valkey.Lua + expireWorkflowInstanceScript *valkey.Lua + cancelWorkflowInstanceScript *valkey.Lua + signalWorkflowScript *valkey.Lua +) + +func NewValkeyBackend(client valkey.Client, opts ...BackendOption) (*valkeyBackend, error) { + vopts := &Options{ + Options: backend.ApplyOptions(), + BlockTimeout: time.Second * 2, + } + + for _, opt := range opts { + opt(vopts) + } + + workflowQueue, err := newTaskQueue[workflowData](vopts.KeyPrefix, "workflows", vopts.WorkerName) + if err != nil { + return nil, fmt.Errorf("creating workflow task queue: %w", err) + } + + activityQueue, err := newTaskQueue[activityData](vopts.KeyPrefix, "activities", vopts.WorkerName) + if err != nil { + return nil, fmt.Errorf("creating activity task queue: %w", err) + } + + vb := &valkeyBackend{ + client: client, + options: vopts, + keys: newKeys(vopts.KeyPrefix), + workflowQueue: workflowQueue, + activityQueue: activityQueue, + } + + // Load all Lua scripts + scriptMapping := map[string]**valkey.Lua{ + "cancel_workflow_instance.lua": &cancelWorkflowInstanceScript, + "complete_activity_task.lua": &completeActivityTaskScript, + "complete_workflow_task.lua": &completeWorkflowTaskScript, + "create_workflow_instance.lua": &createWorkflowInstanceScript, + "delete_instance.lua": &deleteInstanceScript, + "expire_workflow_instance.lua": &expireWorkflowInstanceScript, + "schedule_future_events.lua": &futureEventsScript, + "signal_workflow.lua": &signalWorkflowScript, + } + + if err := loadScripts(scriptMapping); err != nil { + return nil, fmt.Errorf("loading Lua scripts: %w", err) + } + + return vb, nil +} + +func loadScripts(scriptMapping map[string]**valkey.Lua) error { + for scriptFile, scriptVar := range scriptMapping { + scriptContent, err := fs.ReadFile(luaScripts, "scripts/"+scriptFile) + if err != nil { + return fmt.Errorf("reading Lua script %s: %w", scriptFile, err) + } + + *scriptVar = valkey.NewLuaScript(string(scriptContent)) + } + + return nil +} + +type valkeyBackend struct { + client valkey.Client + options *Options + keys *keys + workflowQueue *taskQueue[workflowData] + activityQueue *taskQueue[activityData] +} + +type workflowData struct{} + +type activityData struct { + Instance *core.WorkflowInstance `json:"instance,omitempty"` + Queue string `json:"queue,omitempty"` + ID string `json:"id,omitempty"` + Event *history.Event `json:"event,omitempty"` +} + +func (vb *valkeyBackend) Metrics() metrics.Client { + return vb.options.Metrics.WithTags(metrics.Tags{metrickeys.Backend: "valkey"}) +} + +func (vb *valkeyBackend) Tracer() trace.Tracer { + return vb.options.TracerProvider.Tracer(backend.TracerName) +} + +func (vb *valkeyBackend) Options() *backend.Options { + return vb.options.Options +} + +func (vb *valkeyBackend) Close() error { + vb.client.Close() + return nil +} + +func (vb *valkeyBackend) FeatureSupported(feature backend.Feature) bool { + switch feature { + case backend.Feature_Expiration: + return false + } + + return true +} diff --git a/backend/valkey/workflow.go b/backend/valkey/workflow.go new file mode 100644 index 00000000..6008a84b --- /dev/null +++ b/backend/valkey/workflow.go @@ -0,0 +1,328 @@ +package valkey + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/cschleiden/go-workflows/backend" + "github.com/cschleiden/go-workflows/backend/history" + "github.com/cschleiden/go-workflows/core" + "github.com/cschleiden/go-workflows/internal/log" + "github.com/cschleiden/go-workflows/internal/propagators" + "github.com/cschleiden/go-workflows/internal/workflowerrors" + "github.com/cschleiden/go-workflows/workflow" +) + +func (vb *valkeyBackend) PrepareWorkflowQueues(ctx context.Context, queues []workflow.Queue) error { + return vb.workflowQueue.Prepare(ctx, vb.client, queues) +} + +func (vb *valkeyBackend) GetWorkflowTask(ctx context.Context, queues []workflow.Queue) (*backend.WorkflowTask, error) { + if err := scheduleFutureEvents(ctx, vb); err != nil { + return nil, fmt.Errorf("scheduling future events: %w", err) + } + + // Try to get a workflow task, this locks the instance when it dequeues one + instanceTask, err := vb.workflowQueue.Dequeue(ctx, vb.client, queues, vb.options.WorkflowLockTimeout, vb.options.BlockTimeout) + if err != nil { + return nil, err + } + + if instanceTask == nil { + return nil, nil + } + + state, err := readInstance(ctx, vb.client, vb.keys.instanceKeyFromSegment(instanceTask.ID)) + if err != nil { + return nil, fmt.Errorf("reading workflow instance with ID %s: %w", instanceTask.ID, err) + } + + // Read all pending events for this instance + msgs, err := vb.client.Do(ctx, vb.client.B().Xrange().Key(vb.keys.pendingEventsKey(state.Instance)).Start("-").End("+").Build()).AsXRange() + if err != nil { + return nil, fmt.Errorf("reading event stream: %w", err) + } + + payloadKeys := make([]string, 0, len(msgs)) + newEvents := make([]*history.Event, 0, len(msgs)) + lastMessageID := "" + for _, msg := range msgs { + eventStr, ok := msg.FieldValues["event"] + if !ok || eventStr == "" { + continue + } + + var event *history.Event + if err := json.Unmarshal([]byte(eventStr), &event); err != nil { + return nil, fmt.Errorf("unmarshaling event: %w", err) + } + + payloadKeys = append(payloadKeys, event.ID) + newEvents = append(newEvents, event) + lastMessageID = msg.ID + } + + // Fetch event payloads + if len(payloadKeys) > 0 { + cmd := vb.client.B().Hmget().Key(vb.keys.payloadKey(state.Instance)).Field(payloadKeys...) + res, err := vb.client.Do(ctx, cmd.Build()).AsStrSlice() + if err != nil { + return nil, fmt.Errorf("reading payloads: %w", err) + } + + for i, event := range newEvents { + event.Attributes, err = history.DeserializeAttributes(event.Type, []byte(res[i])) + if err != nil { + return nil, fmt.Errorf("deserializing attributes for event %v: %w", event.Type, err) + } + } + } + + return &backend.WorkflowTask{ + ID: instanceTask.TaskID, + Queue: core.Queue(state.Queue), + WorkflowInstance: state.Instance, + WorkflowInstanceState: state.State, + Metadata: state.Metadata, + LastSequenceID: state.LastSequenceID, + NewEvents: newEvents, + CustomData: lastMessageID, + }, nil +} + +func (vb *valkeyBackend) ExtendWorkflowTask(ctx context.Context, task *backend.WorkflowTask) error { + return vb.workflowQueue.Extend(ctx, vb.client, task.Queue, task.ID) +} + +func (vb *valkeyBackend) CompleteWorkflowTask( + ctx context.Context, + task *backend.WorkflowTask, + state core.WorkflowInstanceState, + executedEvents, activityEvents, timerEvents []*history.Event, + workflowEvents []*history.WorkflowEvent, +) error { + keys := make([]string, 0) + args := make([]string, 0) + + instance := task.WorkflowInstance + + queueKeys := vb.workflowQueue.Keys(task.Queue) + keys = append(keys, + vb.keys.instanceKey(instance), + vb.keys.historyKey(instance), + vb.keys.pendingEventsKey(instance), + vb.keys.payloadKey(instance), + vb.keys.futureEventsKey(), + vb.keys.instancesActive(), + vb.keys.instancesByCreation(), + queueKeys.SetKey, + queueKeys.StreamKey, + vb.workflowQueue.queueSetKey, + ) + args = append(args, vb.keys.prefix, instanceSegment(instance)) + + // Add executed events to the history + args = append(args, fmt.Sprintf("%d", len(executedEvents))) + + for _, event := range executedEvents { + eventData, payloadData, err := marshalEvent(event) + if err != nil { + return err + } + + args = append(args, event.ID, eventData, payloadData, fmt.Sprintf("%d", event.SequenceID)) + } + + // Remove executed pending events + lastPendingEventMessageID := task.CustomData.(string) + args = append(args, lastPendingEventMessageID) + + // Update instance state and update active execution + now := time.Now().UTC() + nowStr := now.Format(time.RFC3339) + nowUnix := now.Unix() + args = append( + args, + nowStr, + fmt.Sprintf("%d", nowUnix), + fmt.Sprintf("%d", int(state)), + fmt.Sprintf("%d", int(core.WorkflowInstanceStateContinuedAsNew)), + fmt.Sprintf("%d", int(core.WorkflowInstanceStateFinished)), + ) + keys = append(keys, vb.keys.activeInstanceExecutionKey(instance.InstanceID)) + + // Remove canceled timers + timersToCancel := make([]*history.Event, 0) + for _, event := range executedEvents { + if event.Type == history.EventType_TimerCanceled { + timersToCancel = append(timersToCancel, event) + } + } + + args = append(args, fmt.Sprintf("%d", len(timersToCancel))) + for _, event := range timersToCancel { + keys = append(keys, vb.keys.futureEventKey(instance, event.ScheduleEventID)) + } + + // Schedule timers + args = append(args, fmt.Sprintf("%d", len(timerEvents))) + for _, timerEvent := range timerEvents { + eventData, payloadEventData, err := marshalEvent(timerEvent) + if err != nil { + return err + } + + args = append(args, timerEvent.ID, strconv.FormatInt(timerEvent.VisibleAt.UnixMilli(), 10), eventData, payloadEventData) + keys = append(keys, vb.keys.futureEventKey(instance, timerEvent.ScheduleEventID)) + } + + // Schedule activities + args = append(args, fmt.Sprintf("%d", len(activityEvents))) + for _, activityEvent := range activityEvents { + a := activityEvent.Attributes.(*history.ActivityScheduledAttributes) + queue := a.Queue + if queue == "" { + // Default to workflow queue + queue = task.Queue + } + + activityData, err := json.Marshal(&activityData{ + Instance: instance, + ID: activityEvent.ID, + Event: activityEvent, + Queue: string(queue), + }) + if err != nil { + return fmt.Errorf("marshaling activity data: %w", err) + } + + activityQueue := string(queue) + args = append(args, activityQueue, activityEvent.ID, string(activityData)) + } + + // Send new workflow events to the respective streams + groupedEvents := history.EventsByWorkflowInstance(workflowEvents) + args = append(args, fmt.Sprintf("%d", len(groupedEvents))) + for targetInstance, events := range groupedEvents { + keys = append(keys, vb.keys.instanceKey(&targetInstance), vb.keys.activeInstanceExecutionKey(targetInstance.InstanceID)) + args = append(args, instanceSegment(&targetInstance), targetInstance.InstanceID) + + // Are we creating a new workflow instance? + m := events[0] + createNewInstance := m.HistoryEvent.Type == history.EventType_WorkflowExecutionStarted + args = append(args, fmt.Sprintf("%v", createNewInstance)) + args = append(args, fmt.Sprintf("%d", len(events))) + + if createNewInstance { + a := m.HistoryEvent.Attributes.(*history.ExecutionStartedAttributes) + + queue := a.Queue + if queue == "" { + queue = task.Queue + } + + isb, err := json.Marshal(&instanceState{ + Queue: string(queue), + Instance: &targetInstance, + State: core.WorkflowInstanceStateActive, + Metadata: a.Metadata, + CreatedAt: time.Now(), + }) + if err != nil { + return fmt.Errorf("marshaling new instance state: %w", err) + } + + ib, err := json.Marshal(targetInstance) + if err != nil { + return fmt.Errorf("marshaling instance: %w", err) + } + + args = append(args, string(isb), string(ib)) + + // Create pending event for conflicts + pfe := history.NewPendingEvent(time.Now(), history.EventType_SubWorkflowFailed, &history.SubWorkflowFailedAttributes{ + Error: workflowerrors.FromError(backend.ErrInstanceAlreadyExists), + }, history.ScheduleEventID(m.WorkflowInstance.ParentEventID)) + eventData, payloadEventData, err := marshalEvent(pfe) + if err != nil { + return fmt.Errorf("marshaling event: %w", err) + } + + args = append(args, pfe.ID, eventData, payloadEventData) + + queueKeys := vb.workflowQueue.Keys(queue) + keys = append(keys, queueKeys.SetKey, queueKeys.StreamKey) + } else { + targetInstanceState, err := readInstance(ctx, vb.client, vb.keys.instanceKey(&targetInstance)) + if err != nil { + return fmt.Errorf("reading target instance: %w", err) + } + + queueKeys := vb.workflowQueue.Keys(core.Queue(targetInstanceState.Queue)) + keys = append(keys, queueKeys.SetKey, queueKeys.StreamKey) + } + + keys = append(keys, vb.keys.pendingEventsKey(&targetInstance), vb.keys.payloadKey(&targetInstance)) + for _, m := range events { + eventData, payloadEventData, err := marshalEvent(m.HistoryEvent) + if err != nil { + return fmt.Errorf("marshaling event: %w", err) + } + + args = append(args, m.HistoryEvent.ID, eventData, payloadEventData) + } + } + + // Complete workflow task and unlock instance. + args = append(args, task.ID, vb.workflowQueue.groupName) + + // Run script + err := completeWorkflowTaskScript.Exec(ctx, vb.client, keys, args).Error() + if err != nil { + return fmt.Errorf("completing workflow task: %w", err) + } + + if state == core.WorkflowInstanceStateFinished || state == core.WorkflowInstanceStateContinuedAsNew { + // Trace workflow completion + ctx, err := (&propagators.TracingContextPropagator{}).Extract(ctx, task.Metadata) + if err != nil { + vb.options.Logger.Error("extracting tracing context", log.ErrorKey, err) + } + + // Auto expiration + expiration := vb.options.AutoExpiration + if state == core.WorkflowInstanceStateContinuedAsNew && vb.options.AutoExpirationContinueAsNew > 0 { + expiration = vb.options.AutoExpirationContinueAsNew + } + + if expiration > 0 { + if err := vb.setWorkflowInstanceExpiration(ctx, instance, expiration); err != nil { + return fmt.Errorf("setting workflow instance expiration: %w", err) + } + } + + if vb.options.RemoveContinuedAsNewInstances && state == core.WorkflowInstanceStateContinuedAsNew { + if err := vb.RemoveWorkflowInstance(ctx, instance); err != nil { + return fmt.Errorf("removing workflow instance: %w", err) + } + } + } + + return nil +} + +func marshalEvent(event *history.Event) (string, string, error) { + eventData, err := marshalEventWithoutAttributes(event) + if err != nil { + return "", "", fmt.Errorf("marshaling event payload: %w", err) + } + + payloadEventData, err := json.Marshal(event.Attributes) + if err != nil { + return "", "", fmt.Errorf("marshaling event payload: %w", err) + } + return eventData, string(payloadEventData), nil +} diff --git a/diag/diag.go b/diag/diag.go index e5f36bd8..cd296bd9 100644 --- a/diag/diag.go +++ b/diag/diag.go @@ -34,12 +34,14 @@ func NewServeMux(backend Backend) *http.ServeMux { stats, err := backend.GetStats(r.Context()) if err != nil { w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) return } w.Header().Add("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(stats); err != nil { w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) return } @@ -72,12 +74,14 @@ func NewServeMux(backend Backend) *http.ServeMux { instances, err := backend.GetWorkflowInstances(r.Context(), afterInstanceID, afterExecutionID, count) if err != nil { w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) return } w.Header().Add("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(instances); err != nil { w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) return } @@ -106,6 +110,7 @@ func NewServeMux(backend Backend) *http.ServeMux { history, err := backend.GetWorkflowInstanceHistory(r.Context(), instanceRef.Instance, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) return } @@ -130,6 +135,7 @@ func NewServeMux(backend Backend) *http.ServeMux { w.Header().Add("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(result); err != nil { w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) return } @@ -160,6 +166,7 @@ func NewServeMux(backend Backend) *http.ServeMux { w.Header().Add("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(tree); err != nil { w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) return } diff --git a/go.mod b/go.mod index 3ee7dc00..a7fc2a9b 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/jellydator/ttlcache/v3 v3.0.0 github.com/redis/go-redis/v9 v9.0.2 github.com/stretchr/testify v1.10.0 + github.com/valkey-io/valkey-go v1.0.68 go.opentelemetry.io/otel v1.31.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.31.0 diff --git a/go.sum b/go.sum index a0e1c03b..51fbfa4b 100644 --- a/go.sum +++ b/go.sum @@ -43,8 +43,8 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69 github.com/golang-migrate/migrate/v4 v4.16.2 h1:8coYbMKUyInrFk1lfGfRovTLAW7PhWp8qQDT2iKfuoA= github.com/golang-migrate/migrate/v4 v4.16.2/go.mod h1:pfcJX4nPHaVdc5nmdCikFBWtm+UBpiZjRNNsyBbp0/o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -84,6 +84,8 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= +github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.0.2 h1:9yCKha/T5XdGtO0q9Q9a6T5NUCsTn/DrBg0D7ufOcFM= @@ -107,6 +109,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/valkey-io/valkey-go v1.0.68 h1:bTbfonp49b41DqrF30q+y2JL3gcbjd2IiacFAtO4JBA= +github.com/valkey-io/valkey-go v1.0.68/go.mod h1:bHmwjIEOrGq/ubOJfh5uMRs7Xj6mV3mQ/ZXUbmqpjqY= go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 h1:K0XaT3DwHAcV4nKLzcQvwAgSyisUghWoY20I7huthMk=