Skip to content

Commit 3198932

Browse files
committed
cre-1601: improved comments; improved naming; simplified logic; observations validation; ring performance improvements
1 parent 693eb34 commit 3198932

File tree

6 files changed

+121
-70
lines changed

6 files changed

+121
-70
lines changed

pkg/workflows/ring/factory.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ring
22

33
import (
44
"context"
5+
"errors"
56

67
"github.com/smartcontractkit/chainlink-common/pkg/logger"
78
"github.com/smartcontractkit/chainlink-common/pkg/services"
@@ -27,8 +28,10 @@ type Factory struct {
2728
services.StateMachine
2829
}
2930

30-
// NewFactory creates a factory for the shard orchestrator consensus plugin
3131
func NewFactory(s *Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logger, cfg *ConsensusConfig) (*Factory, error) {
32+
if arbiterScaler == nil {
33+
return nil, errors.New("arbiterScaler is required")
34+
}
3235
if cfg == nil {
3336
cfg = &ConsensusConfig{
3437
BatchSize: defaultBatchSize,
@@ -38,14 +41,14 @@ func NewFactory(s *Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logg
3841
store: s,
3942
arbiterScaler: arbiterScaler,
4043
config: cfg,
41-
lggr: logger.Named(lggr, "ShardOrchestratorFactory"),
44+
lggr: logger.Named(lggr, "RingPluginFactory"),
4245
}, nil
4346
}
4447

4548
func (o *Factory) NewReportingPlugin(_ context.Context, config ocr3types.ReportingPluginConfig) (ocr3types.ReportingPlugin[[]byte], ocr3types.ReportingPluginInfo, error) {
4649
plugin, err := NewPlugin(o.store, o.arbiterScaler, config, o.lggr, o.config)
4750
pluginInfo := ocr3types.ReportingPluginInfo{
48-
Name: "Shard Orchestrator Consensus Plugin",
51+
Name: "RingPlugin",
4952
Limits: ocr3types.ReportingPluginLimits{
5053
MaxQueryLength: defaultMaxPhaseOutputBytes,
5154
MaxObservationLength: defaultMaxPhaseOutputBytes,

pkg/workflows/ring/factory_test.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,47 @@ import (
1212
func TestFactory_NewFactory(t *testing.T) {
1313
lggr := logger.Test(t)
1414
store := NewStore()
15+
arbiter := &mockArbiter{}
1516

1617
t.Run("with_nil_config", func(t *testing.T) {
17-
f, err := NewFactory(store, nil, lggr, nil)
18+
f, err := NewFactory(store, arbiter, lggr, nil)
1819
require.NoError(t, err)
1920
require.NotNil(t, f)
2021
})
2122

2223
t.Run("with_custom_config", func(t *testing.T) {
2324
cfg := &ConsensusConfig{BatchSize: 50}
24-
f, err := NewFactory(store, nil, lggr, cfg)
25+
f, err := NewFactory(store, arbiter, lggr, cfg)
2526
require.NoError(t, err)
2627
require.NotNil(t, f)
2728
})
29+
30+
t.Run("nil_arbiter_returns_error", func(t *testing.T) {
31+
_, err := NewFactory(store, nil, lggr, nil)
32+
require.Error(t, err)
33+
require.Contains(t, err.Error(), "arbiterScaler is required")
34+
})
2835
}
2936

3037
func TestFactory_NewReportingPlugin(t *testing.T) {
3138
lggr := logger.Test(t)
3239
store := NewStore()
33-
f, err := NewFactory(store, nil, lggr, nil)
40+
f, err := NewFactory(store, &mockArbiter{}, lggr, nil)
3441
require.NoError(t, err)
3542

3643
config := ocr3types.ReportingPluginConfig{N: 4, F: 1}
3744
plugin, info, err := f.NewReportingPlugin(context.Background(), config)
3845
require.NoError(t, err)
3946
require.NotNil(t, plugin)
4047
require.NotEmpty(t, info.Name)
41-
require.Equal(t, "Shard Orchestrator Consensus Plugin", info.Name)
48+
require.Equal(t, "RingPlugin", info.Name)
4249
require.Equal(t, defaultMaxReportCount, info.Limits.MaxReportCount)
4350
}
4451

4552
func TestFactory_Lifecycle(t *testing.T) {
4653
lggr := logger.Test(t)
4754
store := NewStore()
48-
f, err := NewFactory(store, nil, lggr, nil)
55+
f, err := NewFactory(store, &mockArbiter{}, lggr, nil)
4956
require.NoError(t, err)
5057

5158
err = f.Start(context.Background())

pkg/workflows/ring/plugin.go

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ const (
4444
DefaultTimeToSync = 5 * time.Minute
4545
)
4646

47-
// NewPlugin creates a consensus reporting plugin for shard orchestration
4847
func NewPlugin(store *Store, arbiterScaler pb.ArbiterScalerClient, config ocr3types.ReportingPluginConfig, lggr logger.Logger, cfg *ConsensusConfig) (*Plugin, error) {
48+
if arbiterScaler == nil {
49+
return nil, errors.New("RingOCR arbiterScaler is required")
50+
}
4951
if cfg == nil {
5052
cfg = &ConsensusConfig{
5153
BatchSize: DefaultBatchSize,
@@ -86,20 +88,14 @@ func (p *Plugin) Observation(ctx context.Context, _ ocr3types.OutcomeContext, _
8688
var wantShards uint32
8789
shardStatus := make(map[uint32]*pb.ShardStatus)
8890

89-
if p.arbiterScaler != nil {
90-
status, err := p.arbiterScaler.Status(ctx, &emptypb.Empty{})
91-
if err != nil {
92-
p.lggr.Warnw("failed to get arbiter scaler status", "error", err)
93-
} else {
94-
wantShards = status.WantShards
95-
shardStatus = status.Status
96-
}
91+
status, err := p.arbiterScaler.Status(ctx, &emptypb.Empty{})
92+
if err != nil {
93+
p.lggr.Warnw("RingOCR failed to get arbiter scaler status", "error", err)
94+
wantShards = 0
95+
shardStatus = make(map[uint32]*pb.ShardStatus)
9796
} else {
98-
// Fallback to store if no arbiter scaler configured
99-
shardHealth := p.store.GetShardHealth()
100-
for shardID, healthy := range shardHealth {
101-
shardStatus[shardID] = &pb.ShardStatus{IsHealthy: healthy}
102-
}
97+
wantShards = status.WantShards
98+
shardStatus = status.Status
10399
}
104100

105101
allWorkflowIDs := make([]string, 0)
@@ -108,9 +104,11 @@ func (p *Plugin) Observation(ctx context.Context, _ ocr3types.OutcomeContext, _
108104
}
109105

110106
pendingAllocs := p.store.GetPendingAllocations()
111-
allWorkflowIDs = append(allWorkflowIDs, pendingAllocs...)
107+
p.lggr.Infow("RingOCR Observation pending allocations", "pendingAllocs", pendingAllocs)
112108

109+
allWorkflowIDs = append(allWorkflowIDs, pendingAllocs...)
113110
allWorkflowIDs = uniqueSorted(allWorkflowIDs)
111+
p.lggr.Infow("RingOCR Observation all workflow IDs unique", "allWorkflowIDs", allWorkflowIDs, "wantShards", wantShards)
114112

115113
observation := &pb.Observation{
116114
ShardStatus: shardStatus,
@@ -122,23 +120,30 @@ func (p *Plugin) Observation(ctx context.Context, _ ocr3types.OutcomeContext, _
122120
return proto.MarshalOptions{Deterministic: true}.Marshal(observation)
123121
}
124122

125-
//coverage:ignore
126-
func (p *Plugin) ValidateObservation(_ context.Context, _ ocr3types.OutcomeContext, _ types.Query, _ types.AttributedObservation) error {
123+
func (p *Plugin) ValidateObservation(_ context.Context, _ ocr3types.OutcomeContext, _ types.Query, ao types.AttributedObservation) error {
124+
observation := &pb.Observation{}
125+
if err := proto.Unmarshal(ao.Observation, observation); err != nil {
126+
return err
127+
}
128+
if observation.Now == nil {
129+
return errors.New("observation missing timestamp")
130+
}
131+
if observation.WantShards == 0 {
132+
return errors.New("observation missing WantShards")
133+
}
127134
return nil
128135
}
129136

130137
func (p *Plugin) ObservationQuorum(_ context.Context, _ ocr3types.OutcomeContext, _ types.Query, aos []types.AttributedObservation) (quorumReached bool, err error) {
131138
return quorumhelper.ObservationCountReachesObservationQuorum(quorumhelper.QuorumTwoFPlusOne, p.config.N, p.config.F, aos), nil
132139
}
133140

134-
func (p *Plugin) collectShardInfo(aos []types.AttributedObservation) (shardHealth map[uint32]int, workflows []string, timestamps []time.Time, wantShardVotes []uint32) {
141+
func (p *Plugin) collectShardInfo(aos []types.AttributedObservation) (shardHealth map[uint32]int, workflows []string, timestamps []time.Time, wantShardVotes map[commontypes.OracleID]uint32) {
135142
shardHealth = make(map[uint32]int)
143+
wantShardVotes = make(map[commontypes.OracleID]uint32)
136144
for _, ao := range aos {
137145
observation := &pb.Observation{}
138-
if err := proto.Unmarshal(ao.Observation, observation); err != nil {
139-
p.lggr.Warnf("failed to unmarshal observation: %v", err)
140-
continue
141-
}
146+
_ = proto.Unmarshal(ao.Observation, observation) // validated in ValidateObservation
142147

143148
for shardID, status := range observation.ShardStatus {
144149
if status != nil && status.IsHealthy {
@@ -147,14 +152,9 @@ func (p *Plugin) collectShardInfo(aos []types.AttributedObservation) (shardHealt
147152
}
148153

149154
workflows = append(workflows, observation.WorkflowIds...)
155+
timestamps = append(timestamps, observation.Now.AsTime())
150156

151-
if observation.Now != nil {
152-
timestamps = append(timestamps, observation.Now.AsTime())
153-
}
154-
155-
if observation.WantShards > 0 {
156-
wantShardVotes = append(wantShardVotes, observation.WantShards)
157-
}
157+
wantShardVotes[ao.Observer] = observation.WantShards
158158
}
159159
return shardHealth, workflows, timestamps, wantShardVotes
160160
}
@@ -183,6 +183,7 @@ func (p *Plugin) Outcome(_ context.Context, outctx ocr3types.OutcomeContext, _ t
183183
}
184184

185185
currentShardHealth, allWorkflows, nows, wantShardVotes := p.collectShardInfo(aos)
186+
p.lggr.Infow("RingOCR Outcome collect shard info", "currentShardHealth", currentShardHealth, "wantShardVotes", wantShardVotes)
186187

187188
// Need at least F+1 timestamps; fewer means >F faulty nodes and we can't trust this round
188189
if len(nows) < p.config.F+1 {
@@ -193,18 +194,13 @@ func (p *Plugin) Outcome(_ context.Context, outctx ocr3types.OutcomeContext, _ t
193194
// Use the median timestamp to determine the current time
194195
now := nows[len(nows)/2]
195196

196-
// Use median for wantShards consensus; fall back to current state if insufficient votes
197-
var wantShards uint32
198-
if len(wantShardVotes) >= p.config.F+1 {
199-
slices.Sort(wantShardVotes)
200-
wantShards = wantShardVotes[len(wantShardVotes)/2]
201-
} else if rs := prior.State.GetRoutableShards(); rs > 0 {
202-
wantShards = rs
203-
} else if tr := prior.State.GetTransition(); tr != nil {
204-
wantShards = tr.WantShards
205-
} else {
206-
wantShards = 1 // ultimate fallback
197+
// Use median for wantShards consensus (all validated observations have WantShards > 0)
198+
votes := make([]uint32, 0, len(wantShardVotes))
199+
for _, v := range wantShardVotes {
200+
votes = append(votes, v)
207201
}
202+
slices.Sort(votes)
203+
wantShards := votes[len(votes)/2]
208204

209205
allWorkflows = uniqueSorted(allWorkflows)
210206

@@ -217,20 +213,23 @@ func (p *Plugin) Outcome(_ context.Context, outctx ocr3types.OutcomeContext, _ t
217213

218214
// Deterministic hashing ensures all nodes agree on workflow-to-shard assignments
219215
// without coordination, preventing protocol failures from inconsistent routing
216+
ring := newShardRing(healthyShards)
220217
routes := make(map[string]*pb.WorkflowRoute)
221218
for _, wfID := range allWorkflows {
222-
assignedShard := getShardForWorkflow(wfID, healthyShards)
223-
routes[wfID] = &pb.WorkflowRoute{
224-
Shard: assignedShard,
219+
shard, err := locateShard(ring, wfID)
220+
if err != nil {
221+
p.lggr.Warnw("RingOCR failed to locate shard for workflow", "workflowID", wfID, "error", err)
222+
shard = 0 // fallback to shard 0 when no healthy shards
225223
}
224+
routes[wfID] = &pb.WorkflowRoute{Shard: shard}
226225
}
227226

228227
outcome := &pb.Outcome{
229228
State: nextState,
230229
Routes: routes,
231230
}
232231

233-
p.lggr.Infow("Consensus Outcome", "healthyShards", len(healthyShards), "totalObservations", len(aos), "workflowCount", len(routes))
232+
p.lggr.Infow("RingOCR Outcome", "healthyShards", len(healthyShards), "totalObservations", len(aos), "workflowCount", len(routes))
234233

235234
return proto.MarshalOptions{Deterministic: true}.Marshal(outcome)
236235
}

pkg/workflows/ring/plugin_test.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import (
77

88
"github.com/smartcontractkit/libocr/commontypes"
99
"github.com/stretchr/testify/require"
10+
"google.golang.org/grpc"
1011
"google.golang.org/protobuf/proto"
12+
"google.golang.org/protobuf/types/known/emptypb"
1113
"google.golang.org/protobuf/types/known/timestamppb"
1214

1315
"github.com/smartcontractkit/chainlink-common/pkg/logger"
@@ -16,6 +18,21 @@ import (
1618
"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"
1719
)
1820

21+
type mockArbiter struct {
22+
status *pb.ReplicaStatus
23+
}
24+
25+
func (m *mockArbiter) Status(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*pb.ReplicaStatus, error) {
26+
if m.status != nil {
27+
return m.status, nil
28+
}
29+
return &pb.ReplicaStatus{}, nil
30+
}
31+
32+
func (m *mockArbiter) ConsensusWantShards(ctx context.Context, req *pb.ConsensusWantShardsRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
33+
return &emptypb.Empty{}, nil
34+
}
35+
1936
var twoHealthyShards = []map[uint32]*pb.ShardStatus{
2037
{0: {IsHealthy: true}, 1: {IsHealthy: true}},
2138
{0: {IsHealthy: true}, 1: {IsHealthy: true}},
@@ -44,7 +61,7 @@ func TestPlugin_Outcome(t *testing.T) {
4461
MaxDurationShouldTransmitAcceptedReport: 0,
4562
}
4663

47-
plugin, err := NewPlugin(store, nil, config, lggr, nil)
64+
plugin, err := NewPlugin(store, &mockArbiter{}, config, lggr, nil)
4865
require.NoError(t, err)
4966

5067
ctx := t.Context()
@@ -154,7 +171,7 @@ func TestPlugin_StateTransitions(t *testing.T) {
154171
}
155172

156173
// Use short time to sync for testing
157-
plugin, err := NewPlugin(store, nil, config, lggr, &ConsensusConfig{
174+
plugin, err := NewPlugin(store, &mockArbiter{}, config, lggr, &ConsensusConfig{
158175
BatchSize: 100,
159176
TimeToSync: 1 * time.Second,
160177
})
@@ -371,6 +388,16 @@ func makeObservationsWithWantShards(t *testing.T, shardStatuses []map[uint32]*pb
371388
return aos
372389
}
373390

391+
func TestPlugin_NewPlugin_NilArbiter(t *testing.T) {
392+
lggr := logger.Test(t)
393+
store := NewStore()
394+
config := ocr3types.ReportingPluginConfig{N: 4, F: 1}
395+
396+
_, err := NewPlugin(store, nil, config, lggr, nil)
397+
require.Error(t, err)
398+
require.Contains(t, err.Error(), "RingOCR arbiterScaler is required")
399+
}
400+
374401
func TestPlugin_getHealthyShards(t *testing.T) {
375402
tests := []struct {
376403
name string
@@ -407,13 +434,14 @@ func TestPlugin_NoHealthyShardsFallbackToShardZero(t *testing.T) {
407434
N: 4, F: 1,
408435
}
409436

410-
plugin, err := NewPlugin(store, nil, config, lggr, &ConsensusConfig{
437+
arbiter := &mockArbiter{}
438+
plugin, err := NewPlugin(store, arbiter, config, lggr, &ConsensusConfig{
411439
BatchSize: 100,
412440
TimeToSync: 1 * time.Second,
413441
})
414442
require.NoError(t, err)
415443

416-
transmitter := NewTransmitter(lggr, store, nil, "test-account")
444+
transmitter := NewTransmitter(lggr, store, arbiter, "test-account")
417445

418446
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
419447
defer cancel()
@@ -507,7 +535,7 @@ func TestPlugin_ObservationQuorum(t *testing.T) {
507535
lggr := logger.Test(t)
508536
store := NewStore()
509537
config := ocr3types.ReportingPluginConfig{N: 4, F: 1}
510-
plugin, err := NewPlugin(store, nil, config, lggr, nil)
538+
plugin, err := NewPlugin(store, &mockArbiter{}, config, lggr, nil)
511539
require.NoError(t, err)
512540

513541
ctx := context.Background()

0 commit comments

Comments
 (0)