diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index 30ec7ffc61..efddfb81ed 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -512,6 +512,7 @@ func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey, return nil, nil, err } edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: chanID.ToUint64(), LastUpdate: time.Now(), @@ -528,6 +529,7 @@ func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey, return nil, nil, err } edgePolicy = &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: chanID.ToUint64(), LastUpdate: time.Now(), diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 3ea80bae91..c13f52fb6c 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -3415,19 +3415,12 @@ func (d *AuthenticatedGossiper) handleChanUpdate(ctx context.Context, // different alias. This might mean that SigBytes is incorrect as it // signs a different SCID than the database SCID, but since there will // only be a difference if AuthProof == nil, this is fine. - update := &models.ChannelEdgePolicy{ - SigBytes: upd.Signature.ToSignatureBytes(), - ChannelID: chanInfo.ChannelID, - LastUpdate: timestamp, - MessageFlags: upd.MessageFlags, - ChannelFlags: upd.ChannelFlags, - TimeLockDelta: upd.TimeLockDelta, - MinHTLC: upd.HtlcMinimumMsat, - MaxHTLC: upd.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate), - InboundFee: upd.InboundFee.ValOpt(), - ExtraOpaqueData: upd.ExtraOpaqueData, + update, err := models.ChanEdgePolicyFromWire( + chanInfo.ChannelID, upd, + ) + if err != nil { + nMsg.err <- err + return nil, false } if err := d.cfg.Graph.UpdateEdge(ctx, update, ops...); err != nil { diff --git a/docs/release-notes/release-notes-0.21.0.md b/docs/release-notes/release-notes-0.21.0.md index d5ebede7ba..47c92e88cb 100644 --- a/docs/release-notes/release-notes-0.21.0.md +++ b/docs/release-notes/release-notes-0.21.0.md @@ -154,7 +154,8 @@ * Prepare the graph DB for handling gossip V2 nodes and channels [1](https://github.com/lightningnetwork/lnd/pull/10339) [2](https://github.com/lightningnetwork/lnd/pull/10379) - [3](https://github.com/lightningnetwork/lnd/pull/10380). + [3](https://github.com/lightningnetwork/lnd/pull/10380) + [4](https://github.com/lightningnetwork/lnd/pull/10542). ## Code Health diff --git a/graph/builder.go b/graph/builder.go index 1f3309c048..9d3baebfe2 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -952,19 +952,12 @@ func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { return false } - update := &models.ChannelEdgePolicy{ - SigBytes: msg.Signature.ToSignatureBytes(), - ChannelID: msg.ShortChannelID.ToUint64(), - LastUpdate: time.Unix(int64(msg.Timestamp), 0), - MessageFlags: msg.MessageFlags, - ChannelFlags: msg.ChannelFlags, - TimeLockDelta: msg.TimeLockDelta, - MinHTLC: msg.HtlcMinimumMsat, - MaxHTLC: msg.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate), - InboundFee: msg.InboundFee.ValOpt(), - ExtraOpaqueData: msg.ExtraOpaqueData, + update, err := models.ChanEdgePolicyFromWire( + msg.ShortChannelID.ToUint64(), msg, + ) + if err != nil { + log.Errorf("Unable to parse channel update: %v", err) + return false } err = b.UpdateEdge(ctx, update) @@ -1050,8 +1043,8 @@ func (b *Builder) addEdge(ctx context.Context, edge *models.ChannelEdgeInfo, // Prior to processing the announcement we first check if we // already know of this channel, if so, then we can exit early. - _, _, exists, isZombie, err := b.cfg.Graph.HasChannelEdge( - edge.ChannelID, + exists, isZombie, err := b.cfg.Graph.HasChannelEdge( + edge.Version, edge.ChannelID, ) if err != nil && !errors.Is(err, graphdb.ErrGraphNoEdgesFound) { return fmt.Errorf("unable to check for edge existence: %w", @@ -1152,7 +1145,7 @@ func (b *Builder) updateEdge(ctx context.Context, defer b.channelEdgeMtx.Unlock(policy.ChannelID) edge1Timestamp, edge2Timestamp, exists, isZombie, err := - b.cfg.Graph.HasChannelEdge(policy.ChannelID) + b.cfg.Graph.HasV1ChannelEdge(policy.ChannelID) if err != nil && !errors.Is(err, graphdb.ErrGraphNoEdgesFound) { return fmt.Errorf("unable to check for edge existence: %w", err) } @@ -1283,7 +1276,7 @@ func (b *Builder) ForAllOutgoingChannels(ctx context.Context, reset func()) error { return b.cfg.Graph.ForEachNodeChannel( - ctx, b.cfg.SelfNode, + ctx, lnwire.GossipVersion1, b.cfg.SelfNode, func(c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -1338,8 +1331,8 @@ func (b *Builder) IsPublicNode(node route.Vertex) (bool, error) { // // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) IsKnownEdge(chanID lnwire.ShortChannelID) bool { - _, _, exists, isZombie, _ := b.cfg.Graph.HasChannelEdge( - chanID.ToUint64(), + exists, isZombie, _ := b.cfg.Graph.HasChannelEdge( + lnwire.GossipVersion1, chanID.ToUint64(), ) return exists || isZombie @@ -1350,7 +1343,9 @@ func (b *Builder) IsKnownEdge(chanID lnwire.ShortChannelID) bool { // // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) IsZombieEdge(chanID lnwire.ShortChannelID) (bool, error) { - _, _, _, isZombie, err := b.cfg.Graph.HasChannelEdge(chanID.ToUint64()) + _, isZombie, err := b.cfg.Graph.HasChannelEdge( + lnwire.GossipVersion1, chanID.ToUint64(), + ) return isZombie, err } @@ -1363,7 +1358,7 @@ func (b *Builder) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool { edge1Timestamp, edge2Timestamp, exists, isZombie, err := - b.cfg.Graph.HasChannelEdge(chanID.ToUint64()) + b.cfg.Graph.HasV1ChannelEdge(chanID.ToUint64()) if err != nil { log.Debugf("Check stale edge policy got error: %v", err) return false diff --git a/graph/builder_test.go b/graph/builder_test.go index 25b2d28b1f..acab91ccfb 100644 --- a/graph/builder_test.go +++ b/graph/builder_test.go @@ -166,6 +166,7 @@ func TestIgnoreChannelEdgePolicyForUnknownChannel(t *testing.T) { require.NoError(t, err) edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -318,7 +319,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { } // Check that the fundingTxs are in the graph db. - _, _, has, isZombie, err := ctx.graph.HasChannelEdge(chanID1) + has, isZombie, err := ctx.graph.HasChannelEdge(chanID1) if err != nil { t.Fatalf("error looking for edge: %v", chanID1) } @@ -329,7 +330,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { t.Fatal("edge was marked as zombie") } - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) + has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) if err != nil { t.Fatalf("error looking for edge: %v", chanID2) } @@ -386,7 +387,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { // The channel with chanID2 should not be in the database anymore, // since it is not confirmed on the longest chain. chanID1 should // still be. - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID1) + has, isZombie, err = ctx.graph.HasChannelEdge(chanID1) require.NoError(t, err) if !has { @@ -396,7 +397,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { t.Fatal("edge was marked as zombie") } - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) + has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) if err != nil { t.Fatalf("error looking for edge: %v", chanID2) } @@ -526,7 +527,7 @@ func TestDisconnectedBlocks(t *testing.T) { } // Check that the fundingTxs are in the graph db. - _, _, has, isZombie, err := ctx.graph.HasChannelEdge(chanID1) + has, isZombie, err := ctx.graph.HasChannelEdge(chanID1) if err != nil { t.Fatalf("error looking for edge: %v", chanID1) } @@ -537,7 +538,7 @@ func TestDisconnectedBlocks(t *testing.T) { t.Fatal("edge was marked as zombie") } - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) + has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) if err != nil { t.Fatalf("error looking for edge: %v", chanID2) } @@ -579,7 +580,7 @@ func TestDisconnectedBlocks(t *testing.T) { // chanID2 should not be in the database anymore, since it is not // confirmed on the longest chain. chanID1 should still be. - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID1) + has, isZombie, err = ctx.graph.HasChannelEdge(chanID1) if err != nil { t.Fatalf("error looking for edge: %v", chanID1) } @@ -590,7 +591,7 @@ func TestDisconnectedBlocks(t *testing.T) { t.Fatal("edge was marked as zombie") } - _, _, has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) + has, isZombie, err = ctx.graph.HasChannelEdge(chanID2) if err != nil { t.Fatalf("error looking for edge: %v", chanID2) } @@ -664,7 +665,7 @@ func TestChansClosedOfflinePruneGraph(t *testing.T) { } // The router should now be aware of the channel we created above. - _, _, hasChan, isZombie, err := ctx.graph.HasChannelEdge( + hasChan, isZombie, err := ctx.graph.HasChannelEdge( chanID1.ToUint64(), ) if err != nil { @@ -746,7 +747,7 @@ func TestChansClosedOfflinePruneGraph(t *testing.T) { // At this point, the channel that was pruned should no longer be known // by the router. - _, _, hasChan, isZombie, err = ctx.graph.HasChannelEdge( + hasChan, isZombie, err = ctx.graph.HasChannelEdge( chanID1.ToUint64(), ) if err != nil { @@ -1219,6 +1220,7 @@ func TestIsStaleEdgePolicy(t *testing.T) { // We'll also add two edge policies, one for each direction. edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: updateTimeStamp, @@ -1233,6 +1235,7 @@ func TestIsStaleEdgePolicy(t *testing.T) { } edgePolicy = &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: updateTimeStamp, @@ -1557,6 +1560,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( } edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), MessageFlags: lnwire.ChanUpdateMsgFlags( edge.MessageFlags, @@ -1715,7 +1719,7 @@ func assertChannelsPruned(t *testing.T, graph *graphdb.VersionedGraph, for _, channel := range channels { _, shouldPrune := pruned[channel.ChannelID] - _, _, exists, isZombie, err := graph.HasChannelEdge( + exists, isZombie, err := graph.HasChannelEdge( channel.ChannelID, ) if err != nil { @@ -1939,7 +1943,9 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, channelFlags |= lnwire.ChanUpdateDisabled } + //nolint:ll edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, @@ -1970,7 +1976,9 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, } channelFlags |= lnwire.ChanUpdateDirection + //nolint:ll edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, diff --git a/graph/db/benchmark_test.go b/graph/db/benchmark_test.go index d7cc3fbc96..00cfeaf337 100644 --- a/graph/db/benchmark_test.go +++ b/graph/db/benchmark_test.go @@ -370,8 +370,8 @@ func TestPopulateDBs(t *testing.T) { numPolicies = 0 ) err := graph.ForEachChannel( - ctx, func(info *models.ChannelEdgeInfo, - policy, + ctx, lnwire.GossipVersion1, + func(info *models.ChannelEdgeInfo, policy, policy2 *models.ChannelEdgePolicy) error { numChans++ @@ -497,48 +497,49 @@ func syncGraph(t *testing.T, src, dest *ChannelGraph) { } var wgChans sync.WaitGroup - err = src.ForEachChannel(ctx, func(info *models.ChannelEdgeInfo, - policy1, policy2 *models.ChannelEdgePolicy) error { - - // Add each channel & policy. We do this in a goroutine to - // take advantage of batch processing. - wgChans.Add(1) - go func() { - defer wgChans.Done() - - err := dest.AddChannelEdge( - ctx, info, batch.LazyAdd(), - ) - if !errors.Is(err, ErrEdgeAlreadyExist) { - require.NoError(t, err) - } - - if policy1 != nil { - err = dest.UpdateEdgePolicy( - ctx, policy1, batch.LazyAdd(), + err = src.ForEachChannel(ctx, lnwire.GossipVersion1, + func(info *models.ChannelEdgeInfo, + policy1, policy2 *models.ChannelEdgePolicy) error { + + // Add each channel & policy. We do this in a goroutine + // to take advantage of batch processing. + wgChans.Add(1) + go func() { + defer wgChans.Done() + + err := dest.AddChannelEdge( + ctx, info, batch.LazyAdd(), ) - require.NoError(t, err) - } + if !errors.Is(err, ErrEdgeAlreadyExist) { + require.NoError(t, err) + } - if policy2 != nil { - err = dest.UpdateEdgePolicy( - ctx, policy2, batch.LazyAdd(), - ) - require.NoError(t, err) - } + if policy1 != nil { + err = dest.UpdateEdgePolicy( + ctx, policy1, batch.LazyAdd(), + ) + require.NoError(t, err) + } - mu.Lock() - total++ - chunk++ - s.Do(func() { - reportChanStats() - chunk = 0 - }) - mu.Unlock() - }() + if policy2 != nil { + err = dest.UpdateEdgePolicy( + ctx, policy2, batch.LazyAdd(), + ) + require.NoError(t, err) + } - return nil - }, func() {}) + mu.Lock() + total++ + chunk++ + s.Do(func() { + reportChanStats() + chunk = 0 + }) + mu.Unlock() + }() + + return nil + }, func() {}) require.NoError(t, err) wgChans.Wait() @@ -638,7 +639,8 @@ func BenchmarkGraphReadMethods(b *testing.B) { fn: func(b testing.TB, store Store) { //nolint:ll err := store.ForEachChannel( - ctx, func(_ *models.ChannelEdgeInfo, + ctx, lnwire.GossipVersion1, + func(_ *models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -821,7 +823,7 @@ func BenchmarkFindOptimalSQLQueryConfig(b *testing.B) { //nolint:ll err = store.ForEachChannel( - ctx, + ctx, lnwire.GossipVersion1, func(_ *models.ChannelEdgeInfo, _, _ *models.ChannelEdgePolicy) error { diff --git a/graph/db/graph.go b/graph/db/graph.go index 41396aceb3..baee3eccb8 100644 --- a/graph/db/graph.go +++ b/graph/db/graph.go @@ -597,11 +597,12 @@ func (c *ChannelGraph) ForEachSourceNodeChannel(ctx context.Context, // ForEachNodeChannel iterates through all channels of the given node. func (c *ChannelGraph) ForEachNodeChannel(ctx context.Context, - nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, + v lnwire.GossipVersion, nodePub route.Vertex, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, reset func()) error { - return c.db.ForEachNodeChannel(ctx, nodePub, cb, reset) + return c.db.ForEachNodeChannel(ctx, v, nodePub, cb, reset) } // ForEachNode iterates through all stored vertices/nodes in the graph. @@ -642,10 +643,11 @@ func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) { // ForEachChannel iterates through all channel edges stored within the graph. func (c *ChannelGraph) ForEachChannel(ctx context.Context, - cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error, reset func()) error { + v lnwire.GossipVersion, cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, + reset func()) error { - return c.db.ForEachChannel(ctx, cb, reset) + return c.db.ForEachChannel(ctx, v, cb, reset) } // ForEachChannelCacheable iterates through all channel edges for the cache. @@ -661,11 +663,18 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { return c.db.DisabledChannelIDs() } +// HasV1ChannelEdge returns true if the database knows of a channel edge. +func (c *ChannelGraph) HasV1ChannelEdge(chanID uint64) (time.Time, + time.Time, bool, bool, error) { + + return c.db.HasV1ChannelEdge(chanID) +} + // HasChannelEdge returns true if the database knows of a channel edge. -func (c *ChannelGraph) HasChannelEdge(chanID uint64) (time.Time, time.Time, - bool, bool, error) { +func (c *ChannelGraph) HasChannelEdge(v lnwire.GossipVersion, + chanID uint64) (bool, bool, error) { - return c.db.HasChannelEdge(chanID) + return c.db.HasChannelEdge(v, chanID) } // AddEdgeProof sets the proof of an existing edge in the graph database. @@ -889,6 +898,31 @@ func (c *VersionedGraph) DeleteChannelEdges(strictZombiePruning, return err } +// HasChannelEdge returns true if the database knows of a channel edge with the +// passed channel ID and this graph's gossip version, and false otherwise. If it +// is not found, then the zombie index is checked and its result is returned as +// the second boolean. +func (c *VersionedGraph) HasChannelEdge(chanID uint64) (bool, bool, error) { + return c.db.HasChannelEdge(c.v, chanID) +} + +// ForEachNodeChannel iterates through all channels of the given node. +func (c *VersionedGraph) ForEachNodeChannel(ctx context.Context, + nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error, reset func()) error { + + return c.db.ForEachNodeChannel(ctx, c.v, nodePub, cb, reset) +} + +// ForEachChannel iterates through all channel edges stored within the graph. +func (c *VersionedGraph) ForEachChannel(ctx context.Context, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error, reset func()) error { + + return c.db.ForEachChannel(ctx, c.v, cb, reset) +} + // IsPublicNode determines whether the node is seen as public in the graph. func (c *VersionedGraph) IsPublicNode(pubKey [33]byte) (bool, error) { return c.db.IsPublicNode(c.v, pubKey) diff --git a/graph/db/graph_cache.go b/graph/db/graph_cache.go index 4a3a3b0f9b..7cf101face 100644 --- a/graph/db/graph_cache.go +++ b/graph/db/graph_cache.go @@ -142,8 +142,8 @@ func (c *GraphCache) AddChannel(info *models.CachedEdgeInfo, // Skip adding policies if both are disabled, as the channel is // currently unusable for routing. However, we still add the channel // structure above so that policy updates can later enable it. - if policy1 != nil && policy1.IsDisabled() && - policy2 != nil && policy2.IsDisabled() { + if policy1 != nil && policy1.IsDisabled && + policy2 != nil && policy2.IsDisabled { log.Debugf("Skipping policies for channel %v: both "+ "policies are disabled (channel structure still "+ @@ -156,14 +156,14 @@ func (c *GraphCache) AddChannel(info *models.CachedEdgeInfo, // of node 2 then we have the policy 1 as seen from node 1. if policy1 != nil { fromNode, toNode := info.NodeKey1Bytes, info.NodeKey2Bytes - if !policy1.IsNode1() { + if !policy1.IsNode1 { fromNode, toNode = toNode, fromNode } c.UpdatePolicy(policy1, fromNode, toNode) } if policy2 != nil { fromNode, toNode := info.NodeKey2Bytes, info.NodeKey1Bytes - if policy2.IsNode1() { + if policy2.IsNode1 { fromNode, toNode = toNode, fromNode } c.UpdatePolicy(policy2, fromNode, toNode) @@ -210,7 +210,7 @@ func (c *GraphCache) UpdatePolicy(policy *models.CachedEdgePolicy, fromNode, switch { // This is node 1, and it is edge 1, so this is the outgoing // policy for node 1. - case channel.IsNode1 && policy.IsNode1(): + case channel.IsNode1 && policy.IsNode1: channel.OutPolicySet = true policy.InboundFee.WhenSome(func(fee lnwire.Fee) { channel.InboundFee = fee @@ -218,7 +218,7 @@ func (c *GraphCache) UpdatePolicy(policy *models.CachedEdgePolicy, fromNode, // This is node 2, and it is edge 2, so this is the outgoing // policy for node 2. - case !channel.IsNode1 && !policy.IsNode1(): + case !channel.IsNode1 && !policy.IsNode1: channel.OutPolicySet = true policy.InboundFee.WhenSome(func(fee lnwire.Fee) { channel.InboundFee = fee diff --git a/graph/db/graph_cache_test.go b/graph/db/graph_cache_test.go index 89e3a7e87d..3d5fba85db 100644 --- a/graph/db/graph_cache_test.go +++ b/graph/db/graph_cache_test.go @@ -33,9 +33,9 @@ func TestGraphCacheAddNode(t *testing.T) { runTest := func(nodeA, nodeB route.Vertex) { t.Helper() - channelFlagA, channelFlagB := 0, 1 + isNode1A, isNode1B := true, false if nodeA == pubKey2 { - channelFlagA, channelFlagB = 1, 0 + isNode1A, isNode1B = false, true } inboundFee := lnwire.Fee{ @@ -44,8 +44,9 @@ func TestGraphCacheAddNode(t *testing.T) { } outPolicy1 := &models.CachedEdgePolicy{ - ChannelID: 1000, - ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA), + ChannelID: 1000, + IsNode1: isNode1A, + IsDisabled: false, ToNodePubKey: func() route.Vertex { return nodeB }, @@ -53,8 +54,9 @@ func TestGraphCacheAddNode(t *testing.T) { InboundFee: fn.Some(inboundFee), } inPolicy1 := &models.CachedEdgePolicy{ - ChannelID: 1000, - ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB), + ChannelID: 1000, + IsNode1: isNode1B, + IsDisabled: false, ToNodePubKey: func() route.Vertex { return nodeA }, @@ -125,8 +127,9 @@ func assertCachedPolicyEqual(t *testing.T, original, cached *models.CachedEdgePolicy) { require.Equal(t, original.ChannelID, cached.ChannelID) - require.Equal(t, original.MessageFlags, cached.MessageFlags) - require.Equal(t, original.ChannelFlags, cached.ChannelFlags) + require.Equal(t, original.HasMaxHTLC, cached.HasMaxHTLC) + require.Equal(t, original.IsNode1, cached.IsNode1) + require.Equal(t, original.IsDisabled, cached.IsDisabled) require.Equal(t, original.TimeLockDelta, cached.TimeLockDelta) require.Equal(t, original.MinHTLC, cached.MinHTLC) require.Equal(t, original.MaxHTLC, cached.MaxHTLC) @@ -171,13 +174,14 @@ func TestGraphCacheDisabledPoliciesRegression(t *testing.T) { // Create two disabled policies. disabledPolicy1 := &models.CachedEdgePolicy{ - ChannelID: chanID, - ChannelFlags: lnwire.ChanUpdateDisabled, + ChannelID: chanID, + IsNode1: true, + IsDisabled: true, } disabledPolicy2 := &models.CachedEdgePolicy{ - ChannelID: chanID, - ChannelFlags: lnwire.ChanUpdateDisabled | - lnwire.ChanUpdateDirection, + ChannelID: chanID, + IsNode1: false, + IsDisabled: true, } // Add the channel with both policies disabled (simulating @@ -207,7 +211,8 @@ func TestGraphCacheDisabledPoliciesRegression(t *testing.T) { // Now simulate receiving a fresh update enabling one direction. enabledPolicy1 := &models.CachedEdgePolicy{ ChannelID: chanID, - ChannelFlags: 0, // NOT disabled anymore + IsNode1: true, + IsDisabled: false, TimeLockDelta: 40, MinHTLC: lnwire.MilliSatoshi(1000), } diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 5148f63a8a..1a36262168 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -146,6 +146,14 @@ var versionedTests = []versionedTest{ name: "edge insertion deletion", test: testEdgeInsertionDeletion, }, + { + name: "edge policy crud", + test: testEdgePolicyCRUD, + }, + { + name: "incomplete channel policies", + test: testIncompleteChannelPolicies, + }, { name: "partial node", test: testPartialNode, @@ -886,7 +894,9 @@ func TestDisconnectBlockAtHeight(t *testing.T) { } // The two first edges should be removed from the db. - _, _, has, isZombie, err := graph.HasChannelEdge(edgeInfo.ChannelID) + has, isZombie, err := graph.HasChannelEdge( + lnwire.GossipVersion1, edgeInfo.ChannelID, + ) require.NoError(t, err, "unable to query for edge") if has { t.Fatalf("edge1 was not pruned from the graph") @@ -894,7 +904,9 @@ func TestDisconnectBlockAtHeight(t *testing.T) { if isZombie { t.Fatal("reorged edge1 should not be marked as zombie") } - _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo2.ChannelID) + has, isZombie, err = graph.HasChannelEdge( + lnwire.GossipVersion1, edgeInfo2.ChannelID, + ) require.NoError(t, err, "unable to query for edge") if has { t.Fatalf("edge2 was not pruned from the graph") @@ -904,7 +916,9 @@ func TestDisconnectBlockAtHeight(t *testing.T) { } // Edge 3 should not be removed. - _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo3.ChannelID) + has, isZombie, err = graph.HasChannelEdge( + lnwire.GossipVersion1, edgeInfo3.ChannelID, + ) require.NoError(t, err, "unable to query for edge") if !has { t.Fatalf("edge3 was pruned from the graph") @@ -1044,6 +1058,7 @@ func createChannelEdge(node1, node2 *models.Node) (*models.ChannelEdgeInfo, ) edge1 := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: chanID, LastUpdate: nextUpdateTime(), @@ -1058,6 +1073,7 @@ func createChannelEdge(node1, node2 *models.Node) (*models.ChannelEdgeInfo, ExtraOpaqueData: []byte{1, 0}, } edge2 := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: chanID, LastUpdate: nextUpdateTime(), @@ -1125,7 +1141,9 @@ func TestEdgeInfoUpdates(t *testing.T) { // Check for existence of the edge within the database, it should be // found. - _, _, found, isZombie, err := graph.HasChannelEdge(chanID) + found, isZombie, err := graph.HasChannelEdge( + lnwire.GossipVersion1, chanID, + ) require.NoError(t, err, "unable to query for edge") if !found { t.Fatalf("graph should have of inserted edge") @@ -1170,18 +1188,28 @@ func TestEdgeInfoUpdates(t *testing.T) { assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) } -// TestEdgePolicyCRUD tests basic CRUD operations for edge policies. -func TestEdgePolicyCRUD(t *testing.T) { +// testEdgePolicyCRUD tests basic CRUD operations for edge policies. +func testEdgePolicyCRUD(t *testing.T, v lnwire.GossipVersion) { t.Parallel() ctx := t.Context() - graph := MakeTestGraph(t) + graph := NewVersionedGraph(MakeTestGraph(t), v) - node1 := createTestVertex(t, lnwire.GossipVersion1) - node2 := createTestVertex(t, lnwire.GossipVersion1) + node1 := createTestVertex(t, v) + node2 := createTestVertex(t, v) // Create an edge. Don't add it to the DB yet. - edgeInfo, edge1, edge2 := createChannelEdge(node1, node2) + edgeInfo, shortChanID := createEdge( + v, 100, 1, 0, 0, node1, node2, + ) + chanID := shortChanID.ToUint64() + + edge1 := newEdgePolicy(v, chanID, nextUpdateTime().Unix(), true) + edge2 := newEdgePolicy(v, chanID, nextUpdateTime().Unix(), false) + edge1.ToNode = edgeInfo.NodeKey2Bytes + edge2.ToNode = edgeInfo.NodeKey1Bytes + edge1.SigBytes = testSig.Serialize() + edge2.SigBytes = testSig.Serialize() updateAndAssertPolicies := func() { // Make copies of the policies before calling UpdateEdgePolicy @@ -1190,8 +1218,14 @@ func TestEdgePolicyCRUD(t *testing.T) { edge1 := copyEdgePolicy(edge1) edge2 := copyEdgePolicy(edge2) - edge1.LastUpdate = nextUpdateTime() - edge2.LastUpdate = nextUpdateTime() + switch v { + case lnwire.GossipVersion1: + edge1.LastUpdate = nextUpdateTime() + edge2.LastUpdate = nextUpdateTime() + case lnwire.GossipVersion2: + edge1.LastBlockHeight = nextBlockHeight() + edge2.LastBlockHeight = nextBlockHeight() + } require.NoError(t, graph.UpdateEdgePolicy(ctx, edge1)) require.NoError(t, graph.UpdateEdgePolicy(ctx, edge2)) @@ -1207,7 +1241,8 @@ func TestEdgePolicyCRUD(t *testing.T) { // assert that the deserialized policies match the original // ones. err := graph.ForEachChannel( - ctx, func(info *models.ChannelEdgeInfo, + ctx, + func(info *models.ChannelEdgeInfo, policy1 *models.ChannelEdgePolicy, policy2 *models.ChannelEdgePolicy) error { @@ -1238,13 +1273,26 @@ func TestEdgePolicyCRUD(t *testing.T) { updateAndAssertPolicies() - // Update one of the edges to have ChannelFlags include a bit unknown - // to us. - edge1.ChannelFlags |= 1 << 6 + switch v { + case lnwire.GossipVersion1: + // Update one of the edges to have ChannelFlags include a bit + // unknown to us. + edge1.ChannelFlags |= 1 << 6 - // Update the other edge to have MessageFlags include a bit unknown to - // us. - edge2.MessageFlags |= 1 << 4 + // Update the other edge to have MessageFlags include a bit + // unknown to us. + edge2.MessageFlags |= 1 << 4 + + case lnwire.GossipVersion2: + // Update one of the edges to have DisableFlags include a bit + // unknown to us. + edge1.DisableFlags |= 1 << 6 + + // Update the other edge to have a modified extra signed field. + edge2.ExtraSignedFields = map[uint64][]byte{ + 200: {0x4, 0x5}, + } + } updateAndAssertPolicies() } @@ -1437,16 +1485,20 @@ func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, func randEdgePolicy(chanID uint64) *models.ChannelEdgePolicy { update := prand.Int63() - return newEdgePolicy(chanID, update) + return newEdgePolicy(lnwire.GossipVersion1, chanID, update, true) } func copyEdgePolicy(p *models.ChannelEdgePolicy) *models.ChannelEdgePolicy { return &models.ChannelEdgePolicy{ + Version: p.Version, SigBytes: p.SigBytes, ChannelID: p.ChannelID, LastUpdate: p.LastUpdate, + LastBlockHeight: p.LastBlockHeight, + SecondPeer: p.SecondPeer, MessageFlags: p.MessageFlags, ChannelFlags: p.ChannelFlags, + DisableFlags: p.DisableFlags, TimeLockDelta: p.TimeLockDelta, MinHTLC: p.MinHTLC, MaxHTLC: p.MaxHTLC, @@ -1454,21 +1506,42 @@ func copyEdgePolicy(p *models.ChannelEdgePolicy) *models.ChannelEdgePolicy { FeeProportionalMillionths: p.FeeProportionalMillionths, ToNode: p.ToNode, ExtraOpaqueData: p.ExtraOpaqueData, + ExtraSignedFields: p.ExtraSignedFields, } } -func newEdgePolicy(chanID uint64, updateTime int64) *models.ChannelEdgePolicy { - return &models.ChannelEdgePolicy{ +func newEdgePolicy(v lnwire.GossipVersion, chanID uint64, + updateTime int64, isNode1 bool) *models.ChannelEdgePolicy { + + policy := &models.ChannelEdgePolicy{ + Version: v, ChannelID: chanID, - LastUpdate: time.Unix(updateTime, 0), - MessageFlags: 1, - ChannelFlags: 0, TimeLockDelta: uint16(prand.Int63()), MinHTLC: lnwire.MilliSatoshi(prand.Int63()), MaxHTLC: lnwire.MilliSatoshi(prand.Int63()), FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int63()), } + + if v == lnwire.GossipVersion2 { + policy.LastBlockHeight = nextBlockHeight() + policy.SecondPeer = !isNode1 + policy.DisableFlags = 0 + policy.ExtraSignedFields = map[uint64][]byte{ + 100: {0x1, 0x2, 0x3}, + } + + return policy + } + + policy.LastUpdate = time.Unix(updateTime, 0) + policy.MessageFlags = 1 + if !isNode1 { + policy.ChannelFlags = lnwire.ChanUpdateDirection + } + policy.ExtraOpaqueData = []byte{1, 0} + + return policy } // testAddEdgeProof tests the ability to add an edge proof to an existing edge. @@ -1668,13 +1741,14 @@ func TestGraphTraversal(t *testing.T) { // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have // properly been reached. - err = graph.ForEachChannel(ctx, func(ei *models.ChannelEdgeInfo, - _ *models.ChannelEdgePolicy, - _ *models.ChannelEdgePolicy) error { + err = graph.ForEachChannel(ctx, lnwire.GossipVersion1, + func(ei *models.ChannelEdgeInfo, + _ *models.ChannelEdgePolicy, + _ *models.ChannelEdgePolicy) error { - delete(chanIndex, ei.ChannelID) - return nil - }, func() {}) + delete(chanIndex, ei.ChannelID) + return nil + }, func() {}) require.NoError(t, err) require.Len(t, chanIndex, 0) @@ -1683,7 +1757,7 @@ func TestGraphTraversal(t *testing.T) { numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] err = graph.ForEachNodeChannel( - ctx, firstNode.PubKeyBytes, + ctx, lnwire.GossipVersion1, firstNode.PubKeyBytes, func(_ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { @@ -1970,7 +2044,8 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 err := graph.ForEachChannel( - t.Context(), func(*models.ChannelEdgeInfo, + t.Context(), lnwire.GossipVersion1, + func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error { @@ -2354,7 +2429,8 @@ func TestChanUpdatesInHorizon(t *testing.T) { endTime = endTime.Add(time.Second * 10) edge1 := newEdgePolicy( - chanID.ToUint64(), edge1UpdateTime.Unix(), + lnwire.GossipVersion1, chanID.ToUint64(), + edge1UpdateTime.Unix(), true, ) edge1.ChannelFlags = 0 edge1.ToNode = node2.PubKeyBytes @@ -2364,7 +2440,8 @@ func TestChanUpdatesInHorizon(t *testing.T) { } edge2 := newEdgePolicy( - chanID.ToUint64(), edge2UpdateTime.Unix(), + lnwire.GossipVersion1, chanID.ToUint64(), + edge2UpdateTime.Unix(), false, ) edge2.ChannelFlags = 1 edge2.ToNode = node1.PubKeyBytes @@ -2816,8 +2893,10 @@ func TestChanUpdatesInHorizonBoundaryConditions(t *testing.T) { t, graph.AddChannelEdge(ctx, channel), ) + //nolint:ll edge1 := newEdgePolicy( - chanID.ToUint64(), updateTime.Unix(), + lnwire.GossipVersion1, chanID.ToUint64(), + updateTime.Unix(), true, ) edge1.ChannelFlags = 0 edge1.ToNode = node2.PubKeyBytes @@ -2826,8 +2905,10 @@ func TestChanUpdatesInHorizonBoundaryConditions(t *testing.T) { t, graph.UpdateEdgePolicy(ctx, edge1), ) + //nolint:ll edge2 := newEdgePolicy( - chanID.ToUint64(), updateTime.Unix(), + lnwire.GossipVersion1, chanID.ToUint64(), + updateTime.Unix(), false, ) edge2.ChannelFlags = 1 edge2.ToNode = node1.PubKeyBytes @@ -3284,7 +3365,8 @@ func TestStressTestChannelGraphAPI(t *testing.T) { return nil } - _, _, _, _, err := graph.HasChannelEdge( + _, _, err := graph.HasChannelEdge( + lnwire.GossipVersion1, channel.id.ToUint64(), ) @@ -3454,6 +3536,7 @@ func TestFilterChannelRange(t *testing.T) { updateTime = time.Unix(updateTimeSeed, 0) err = graph.UpdateEdgePolicy( ctx, &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, ToNode: node.PubKeyBytes, ChannelFlags: chanFlags, ChannelID: chanID, @@ -3663,7 +3746,10 @@ func TestFetchChanInfos(t *testing.T) { updateTime := endTime endTime = updateTime.Add(time.Second * 10) - edge1 := newEdgePolicy(chanID.ToUint64(), updateTime.Unix()) + edge1 := newEdgePolicy( + lnwire.GossipVersion1, chanID.ToUint64(), + updateTime.Unix(), true, + ) edge1.ChannelFlags = 0 edge1.ToNode = node2.PubKeyBytes edge1.SigBytes = testSig.Serialize() @@ -3671,7 +3757,10 @@ func TestFetchChanInfos(t *testing.T) { t.Fatalf("unable to update edge: %v", err) } - edge2 := newEdgePolicy(chanID.ToUint64(), updateTime.Unix()) + edge2 := newEdgePolicy( + lnwire.GossipVersion1, chanID.ToUint64(), + updateTime.Unix(), false, + ) edge2.ChannelFlags = 1 edge2.ToNode = node1.PubKeyBytes edge2.SigBytes = testSig.Serialize() @@ -3727,32 +3816,26 @@ func TestFetchChanInfos(t *testing.T) { } } -// TestIncompleteChannelPolicies tests that a channel that only has a policy +// testIncompleteChannelPolicies tests that a channel that only has a policy // specified on one end is properly returned in ForEachChannel calls from // both sides. -func TestIncompleteChannelPolicies(t *testing.T) { +func testIncompleteChannelPolicies(t *testing.T, v lnwire.GossipVersion) { t.Parallel() ctx := t.Context() - graph := MakeTestGraph(t) + graph := NewVersionedGraph(MakeTestGraph(t), v) // Create two nodes. - node1 := createTestVertex(t, lnwire.GossipVersion1) - if err := graph.AddNode(ctx, node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2 := createTestVertex(t, lnwire.GossipVersion1) - if err := graph.AddNode(ctx, node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } + node1 := createTestVertex(t, v) + require.NoError(t, graph.AddNode(ctx, node1)) + node2 := createTestVertex(t, v) + require.NoError(t, graph.AddNode(ctx, node2)) channel, chanID := createEdge( - lnwire.GossipVersion1, uint32(0), 0, 0, 0, node1, node2, + v, uint32(0), 0, 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(ctx, channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } + require.NoError(t, graph.AddChannelEdge(ctx, channel)) // Ensure that channel is reported with unknown policies. checkPolicies := func(node *models.Node, expectedIn, @@ -3764,21 +3847,8 @@ func TestIncompleteChannelPolicies(t *testing.T) { func(_ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { - if !expectedOut && outEdge != nil { - t.Fatalf("Expected no outgoing policy") - } - - if expectedOut && outEdge == nil { - t.Fatalf("Expected an outgoing policy") - } - - if !expectedIn && inEdge != nil { - t.Fatalf("Expected no incoming policy") - } - - if expectedIn && inEdge == nil { - t.Fatalf("Expected an incoming policy") - } + require.Equal(t, expectedOut, outEdge != nil) + require.Equal(t, expectedIn, inEdge != nil) calls++ @@ -3791,30 +3861,38 @@ func TestIncompleteChannelPolicies(t *testing.T) { checkPolicies(node2, false, false) - // Only create an edge policy for node1 and leave the policy for node2 - // unknown. - updateTime := time.Unix(1234, 0) + newTestEdgePolicy := func(isNode1 bool, + toNode route.Vertex) *models.ChannelEdgePolicy { - edgePolicy := newEdgePolicy(chanID.ToUint64(), updateTime.Unix()) - edgePolicy.ChannelFlags = 0 - edgePolicy.ToNode = node2.PubKeyBytes - edgePolicy.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(ctx, edgePolicy); err != nil { - t.Fatalf("unable to update edge: %v", err) + policy := newEdgePolicy( + v, chanID.ToUint64(), nextUpdateTime().Unix(), isNode1, + ) + policy.ToNode = toNode + policy.SigBytes = testSig.Serialize() + + if v == lnwire.GossipVersion1 { + if isNode1 { + policy.ChannelFlags = 0 + } else { + policy.ChannelFlags = lnwire.ChanUpdateDirection + } + } + + return policy } + // Only create an edge policy for node1 and leave the policy for node2 + // unknown. + edgePolicy := newTestEdgePolicy(true, node2.PubKeyBytes) + require.NoError(t, graph.UpdateEdgePolicy(ctx, edgePolicy)) + checkPolicies(node1, false, true) checkPolicies(node2, true, false) // Create second policy and assert that both policies are reported // as present. - edgePolicy = newEdgePolicy(chanID.ToUint64(), updateTime.Unix()) - edgePolicy.ChannelFlags = 1 - edgePolicy.ToNode = node1.PubKeyBytes - edgePolicy.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(ctx, edgePolicy); err != nil { - t.Fatalf("unable to update edge: %v", err) - } + edgePolicy = newTestEdgePolicy(false, node1.PubKeyBytes) + require.NoError(t, graph.UpdateEdgePolicy(ctx, edgePolicy)) checkPolicies(node1, true, true) checkPolicies(node2, true, true) @@ -4659,22 +4737,14 @@ func compareNodes(t *testing.T, a, b *models.Node) { // compareEdgePolicies is used to compare two ChannelEdgePolices using // compareNodes, so as to exclude comparisons of the Nodes' Features struct. func compareEdgePolicies(a, b *models.ChannelEdgePolicy) error { + if a.Version != b.Version { + return fmt.Errorf("Version doesn't match: expected %v, got %v", + a.Version, b.Version) + } if a.ChannelID != b.ChannelID { return fmt.Errorf("ChannelID doesn't match: expected %v, "+ "got %v", a.ChannelID, b.ChannelID) } - if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { - return fmt.Errorf("edge LastUpdate doesn't match: "+ - "expected %#v, got %#v", a.LastUpdate, b.LastUpdate) - } - if a.MessageFlags != b.MessageFlags { - return fmt.Errorf("MessageFlags doesn't match: expected %v, "+ - "got %v", a.MessageFlags, b.MessageFlags) - } - if a.ChannelFlags != b.ChannelFlags { - return fmt.Errorf("ChannelFlags doesn't match: expected %v, "+ - "got %v", a.ChannelFlags, b.ChannelFlags) - } if a.TimeLockDelta != b.TimeLockDelta { return fmt.Errorf("TimeLockDelta doesn't match: expected %v, "+ "got %v", a.TimeLockDelta, b.TimeLockDelta) @@ -4696,18 +4766,66 @@ func compareEdgePolicies(a, b *models.ChannelEdgePolicy) error { "expected %v, got %v", a.FeeProportionalMillionths, b.FeeProportionalMillionths) } - if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { - return fmt.Errorf("extra data doesn't match: %v vs %v", - a.ExtraOpaqueData, b.ExtraOpaqueData) - } if !bytes.Equal(a.ToNode[:], b.ToNode[:]) { return fmt.Errorf("ToNode doesn't match: expected %x, got %x", a.ToNode, b.ToNode) } + if a.Version == lnwire.GossipVersion2 { + if a.LastBlockHeight != b.LastBlockHeight { + return fmt.Errorf("LastBlockHeight doesn't match: "+ + "expected %v, got %v", a.LastBlockHeight, + b.LastBlockHeight) + } + if a.SecondPeer != b.SecondPeer { + return fmt.Errorf("SecondPeer doesn't match: "+ + "expected %v, got %v", a.SecondPeer, + b.SecondPeer) + } + if a.DisableFlags != b.DisableFlags { + return fmt.Errorf("DisableFlags doesn't match: "+ + "expected %v, got %v", a.DisableFlags, + b.DisableFlags) + } + if !equalExtraSignedFields( + a.ExtraSignedFields, b.ExtraSignedFields, + ) { + + return fmt.Errorf("ExtraSignedFields doesn't match: "+ + "expected %#v, got %#v", a.ExtraSignedFields, + b.ExtraSignedFields) + } + + return nil + } + + if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { + return fmt.Errorf("edge LastUpdate doesn't match: "+ + "expected %#v, got %#v", a.LastUpdate, b.LastUpdate) + } + if a.MessageFlags != b.MessageFlags { + return fmt.Errorf("MessageFlags doesn't match: expected %v, "+ + "got %v", a.MessageFlags, b.MessageFlags) + } + if a.ChannelFlags != b.ChannelFlags { + return fmt.Errorf("ChannelFlags doesn't match: expected %v, "+ + "got %v", a.ChannelFlags, b.ChannelFlags) + } + if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { + return fmt.Errorf("extra data doesn't match: %v vs %v", + a.ExtraOpaqueData, b.ExtraOpaqueData) + } return nil } +func equalExtraSignedFields(a, b map[uint64][]byte) bool { + if len(a) == 0 && len(b) == 0 { + return true + } + + return reflect.DeepEqual(a, b) +} + // TestLightningNodeSigVerification checks that we can use the Node's // pubkey to verify signatures. func TestLightningNodeSigVerification(t *testing.T) { @@ -4747,6 +4865,7 @@ func TestLightningNodeSigVerification(t *testing.T) { func TestComputeFee(t *testing.T) { var ( policy = models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, FeeBaseMSat: 10000, FeeProportionalMillionths: 30000, } @@ -4942,7 +5061,9 @@ func BenchmarkForEachChannel(b *testing.B) { return nil } - err := graph.ForEachNodeChannel(ctx, n, cb, func() {}) + err := graph.ForEachNodeChannel( + ctx, lnwire.GossipVersion1, n, cb, func() {}, + ) require.NoError(b, err) } } diff --git a/graph/db/interfaces.go b/graph/db/interfaces.go index 24a48e4c73..0936386be3 100644 --- a/graph/db/interfaces.go +++ b/graph/db/interfaces.go @@ -73,8 +73,9 @@ type Store interface { //nolint:interfacebloat // to the caller. // // Unknown policies are passed into the callback as nil values. - ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, - cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + ForEachNodeChannel(ctx context.Context, v lnwire.GossipVersion, + nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, reset func()) error // ForEachNodeCached is similar to forEachNode, but it returns @@ -162,9 +163,12 @@ type Store interface { //nolint:interfacebloat // NOTE: If an edge can't be found, or wasn't advertised, then a nil // pointer for that particular channel edge routing policy will be // passed into the callback. - ForEachChannel(ctx context.Context, cb func(*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, - reset func()) error + // + // TODO(elle): add a cross-version iteration API and make this iterate + // over all versions. + ForEachChannel(ctx context.Context, v lnwire.GossipVersion, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error, reset func()) error // ForEachChannelCacheable iterates through all the channel edges stored // within the graph and invokes the passed callback for each edge. The @@ -197,13 +201,20 @@ type Store interface { //nolint:interfacebloat AddChannelEdge(ctx context.Context, edge *models.ChannelEdgeInfo, op ...batch.SchedulerOption) error - // HasChannelEdge returns true if the database knows of a channel edge + // HasV1ChannelEdge returns true if the database knows of a channel edge // with the passed channel ID, and false otherwise. If an edge with that // ID is found within the graph, then two time stamps representing the // last time the edge was updated for both directed edges are returned // along with the boolean. If it is not found, then the zombie index is // checked and its result is returned as the second boolean. - HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, bool, + HasV1ChannelEdge(chanID uint64) (time.Time, time.Time, bool, bool, + error) + + // HasChannelEdge returns true if the database knows of a channel edge + // with the passed channel ID and gossip version, and false otherwise. + // If it is not found, then the zombie index is checked and its result + // is returned as the second boolean. + HasChannelEdge(v lnwire.GossipVersion, chanID uint64) (bool, bool, error) // DeleteChannelEdges removes edges with the given channel IDs from the diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 26c3e8ee39..c9e8aa9033 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -411,10 +411,14 @@ func (c *KVStore) AddrsForNode(ctx context.Context, v lnwire.GossipVersion, // NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer // for that particular channel edge routing policy will be passed into the // callback. -func (c *KVStore) ForEachChannel(_ context.Context, +func (c *KVStore) ForEachChannel(_ context.Context, v lnwire.GossipVersion, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, reset func()) error { + if v != lnwire.GossipVersion1 { + return ErrVersionNotSupportedForKVDB + } + return forEachChannel(c.db, cb, reset) } @@ -1179,8 +1183,11 @@ func (c *KVStore) AddChannelEdge(ctx context.Context, case alreadyExists: return ErrEdgeAlreadyExist default: - c.rejectCache.remove(edge.ChannelID) + c.rejectCache.remove( + lnwire.GossipVersion1, edge.ChannelID, + ) c.chanCache.remove(edge.ChannelID) + return nil } }, @@ -1283,13 +1290,13 @@ func (c *KVStore) addChannelEdge(tx kvdb.RwTx, return chanIndex.Put(b.Bytes(), chanKey[:]) } -// HasChannelEdge returns true if the database knows of a channel edge with the -// passed channel ID, and false otherwise. If an edge with that ID is found -// within the graph, then two time stamps representing the last time the edge -// was updated for both directed edges are returned along with the boolean. If -// it is not found, then the zombie index is checked and its result is returned -// as the second boolean. -func (c *KVStore) HasChannelEdge( +// HasV1ChannelEdge returns true if the database knows of a channel edge +// with the passed channel ID, and false otherwise. If an edge with that ID +// is found within the graph, then two time stamps representing the last time +// the edge was updated for both directed edges are returned along with the +// boolean. If it is not found, then the zombie index is checked and its +// result is returned as the second boolean. +func (c *KVStore) HasV1ChannelEdge( chanID uint64) (time.Time, time.Time, bool, bool, error) { var ( @@ -1302,7 +1309,7 @@ func (c *KVStore) HasChannelEdge( // We'll query the cache with the shared lock held to allow multiple // readers to access values in the cache concurrently if they exist. c.cacheMu.RLock() - if entry, ok := c.rejectCache.get(chanID); ok { + if entry, ok := c.rejectCache.get(lnwire.GossipVersion1, chanID); ok { c.cacheMu.RUnlock() upd1Time = time.Unix(entry.upd1Time, 0) upd2Time = time.Unix(entry.upd2Time, 0) @@ -1318,7 +1325,7 @@ func (c *KVStore) HasChannelEdge( // The item was not found with the shared lock, so we'll acquire the // exclusive lock and check the cache again in case another method added // the entry to the cache while no lock was held. - if entry, ok := c.rejectCache.get(chanID); ok { + if entry, ok := c.rejectCache.get(lnwire.GossipVersion1, chanID); ok { upd1Time = time.Unix(entry.upd1Time, 0) upd2Time = time.Unix(entry.upd2Time, 0) exists, isZombie = entry.flags.unpack() @@ -1385,7 +1392,7 @@ func (c *KVStore) HasChannelEdge( return time.Time{}, time.Time{}, exists, isZombie, err } - c.rejectCache.insert(chanID, rejectCacheEntry{ + c.rejectCache.insert(lnwire.GossipVersion1, chanID, rejectCacheEntry{ upd1Time: upd1Time.Unix(), upd2Time: upd2Time.Unix(), flags: packRejectFlags(exists, isZombie), @@ -1394,6 +1401,22 @@ func (c *KVStore) HasChannelEdge( return upd1Time, upd2Time, exists, isZombie, nil } +// HasChannelEdge returns true if the database knows of a channel edge with the +// passed channel ID and gossip version, and false otherwise. If it is not +// found, then the zombie index is checked and its result is returned as the +// second boolean. +func (c *KVStore) HasChannelEdge(v lnwire.GossipVersion, + chanID uint64) (bool, bool, error) { + + if v != lnwire.GossipVersion1 { + return false, false, ErrVersionNotSupportedForKVDB + } + + _, _, exists, isZombie, err := c.HasV1ChannelEdge(chanID) + + return exists, isZombie, err +} + // AddEdgeProof sets the proof of an existing edge in the graph database. func (c *KVStore) AddEdgeProof(chanID lnwire.ShortChannelID, proof *models.ChannelAuthProof) error { @@ -1564,7 +1587,7 @@ func (c *KVStore) PruneGraph(spentOutputs []*wire.OutPoint, } for _, channel := range chansClosed { - c.rejectCache.remove(channel.ChannelID) + c.rejectCache.remove(lnwire.GossipVersion1, channel.ChannelID) c.chanCache.remove(channel.ChannelID) } @@ -1831,7 +1854,7 @@ func (c *KVStore) DisconnectBlockAtHeight(height uint32) ( } for _, channel := range removedChans { - c.rejectCache.remove(channel.ChannelID) + c.rejectCache.remove(lnwire.GossipVersion1, channel.ChannelID) c.chanCache.remove(channel.ChannelID) } @@ -1950,7 +1973,7 @@ func (c *KVStore) DeleteChannelEdges(v lnwire.GossipVersion, } for _, chanID := range chanIDs { - c.rejectCache.remove(chanID) + c.rejectCache.remove(lnwire.GossipVersion1, chanID) c.chanCache.remove(chanID) } @@ -3265,13 +3288,14 @@ func (c *KVStore) updateEdgeCache(e *models.ChannelEdgePolicy, // the entry with the updated timestamp for the direction that was just // written. If the edge doesn't exist, we'll load the cache entry lazily // during the next query for this edge. - if entry, ok := c.rejectCache.get(e.ChannelID); ok { + entry, ok := c.rejectCache.get(lnwire.GossipVersion1, e.ChannelID) + if ok { if isUpdate1 { entry.upd1Time = e.LastUpdate.Unix() } else { entry.upd2Time = e.LastUpdate.Unix() } - c.rejectCache.insert(e.ChannelID, entry) + c.rejectCache.insert(lnwire.GossipVersion1, e.ChannelID, entry) } // If an entry for this channel is found in channel cache, we'll modify @@ -3296,6 +3320,9 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy) ( route.Vertex, route.Vertex, bool, error) { var noVertex route.Vertex + if edge.Version != lnwire.GossipVersion1 { + return noVertex, noVertex, false, ErrVersionNotSupportedForKVDB + } edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { @@ -3663,10 +3690,15 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. -func (c *KVStore) ForEachNodeChannel(_ context.Context, nodePub route.Vertex, +func (c *KVStore) ForEachNodeChannel(_ context.Context, + v lnwire.GossipVersion, nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, reset func()) error { + if v != lnwire.GossipVersion1 { + return ErrVersionNotSupportedForKVDB + } + return nodeTraversal( nil, nodePub[:], c.db, func(_ kvdb.RTx, info *models.ChannelEdgeInfo, policy, @@ -4182,7 +4214,7 @@ func (c *KVStore) MarkEdgeZombie(chanID uint64, return err } - c.rejectCache.remove(chanID) + c.rejectCache.remove(lnwire.GossipVersion1, chanID) c.chanCache.remove(chanID) return nil @@ -4251,7 +4283,7 @@ func (c *KVStore) markEdgeLiveUnsafe(tx kvdb.RwTx, chanID uint64) error { return err } - c.rejectCache.remove(chanID) + c.rejectCache.remove(lnwire.GossipVersion1, chanID) c.chanCache.remove(chanID) return nil @@ -5157,6 +5189,10 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy, to []byte) error { + if edge.Version != lnwire.GossipVersion1 { + return ErrVersionNotSupportedForKVDB + } + err := wire.WriteVarBytes(w, 0, edge.SigBytes) if err != nil { return err @@ -5245,7 +5281,9 @@ func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy, error) { func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy, error) { - edge := &models.ChannelEdgePolicy{} + edge := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, + } var err error edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") diff --git a/graph/db/models/cached_edge_policy.go b/graph/db/models/cached_edge_policy.go index 40b0d9212c..90c8d56c39 100644 --- a/graph/db/models/cached_edge_policy.go +++ b/graph/db/models/cached_edge_policy.go @@ -20,13 +20,15 @@ type CachedEdgePolicy struct { // and the last 2 bytes are the output index for the channel. ChannelID uint64 - // MessageFlags is a bitfield which indicates the presence of optional - // fields (like max_htlc) in the policy. - MessageFlags lnwire.ChanUpdateMsgFlags + // HasMaxHTLC indicates whether the policy has a max HTLC value. + HasMaxHTLC bool - // ChannelFlags is a bitfield which signals the capabilities of the - // channel as well as the directed edge this update applies to. - ChannelFlags lnwire.ChanUpdateChanFlags + // IsNode1 indicates whether this policy was announced by the channel's + // node_1. + IsNode1 bool + + // IsDisabled indicates whether the policy disables forwarding. + IsDisabled bool // TimeLockDelta is the number of blocks this node will subtract from // the expiry of an incoming HTLC. This value expresses the time buffer @@ -75,24 +77,31 @@ func (c *CachedEdgePolicy) ComputeFee( return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts } -// IsDisabled returns true if the channel is disabled in the direction from the -// advertising node. -func (c *CachedEdgePolicy) IsDisabled() bool { - return c.ChannelFlags&lnwire.ChanUpdateDisabled != 0 -} - -// IsNode1 returns true if this policy was announced by the channel's node_1 -// node. -func (c *CachedEdgePolicy) IsNode1() bool { - return c.ChannelFlags&lnwire.ChanUpdateDirection == 0 -} - // NewCachedPolicy turns a full policy into a minimal one that can be cached. func NewCachedPolicy(policy *ChannelEdgePolicy) *CachedEdgePolicy { + if policy.Version != lnwire.GossipVersion2 { + return &CachedEdgePolicy{ + ChannelID: policy.ChannelID, + HasMaxHTLC: policy.MessageFlags.HasMaxHtlc(), + IsDisabled: policy.ChannelFlags& + lnwire.ChanUpdateDisabled != 0, + IsNode1: policy.ChannelFlags& + lnwire.ChanUpdateDirection == 0, + TimeLockDelta: policy.TimeLockDelta, + MinHTLC: policy.MinHTLC, + MaxHTLC: policy.MaxHTLC, + FeeBaseMSat: policy.FeeBaseMSat, + FeeProportionalMillionths: policy. + FeeProportionalMillionths, + InboundFee: policy.InboundFee, + } + } + return &CachedEdgePolicy{ ChannelID: policy.ChannelID, - MessageFlags: policy.MessageFlags, - ChannelFlags: policy.ChannelFlags, + HasMaxHTLC: true, + IsNode1: !policy.SecondPeer, + IsDisabled: !policy.DisableFlags.IsEnabled(), TimeLockDelta: policy.TimeLockDelta, MinHTLC: policy.MinHTLC, MaxHTLC: policy.MaxHTLC, diff --git a/graph/db/models/channel_edge_policy.go b/graph/db/models/channel_edge_policy.go index a2661ef6f8..7e862c1a0c 100644 --- a/graph/db/models/channel_edge_policy.go +++ b/graph/db/models/channel_edge_policy.go @@ -14,6 +14,10 @@ import ( // information concerning fees, and minimum time-lock information which is // utilized during path finding. type ChannelEdgePolicy struct { + // Version is the gossip version of the channel update that produced + // this policy. + Version lnwire.GossipVersion + // SigBytes is the raw bytes of the signature of the channel edge // policy. We'll only parse these if the caller needs to access the // signature for validation purposes. @@ -28,6 +32,14 @@ type ChannelEdgePolicy struct { // was received. LastUpdate time.Time + // LastBlockHeight is the block height that timestamps the last update + // for v2 channel updates. + LastBlockHeight uint32 + + // SecondPeer indicates whether this policy was announced by the second + // peer in the channel for v2 channel updates. + SecondPeer bool + // MessageFlags is a bitfield which indicates the presence of optional // fields (like max_htlc) in the policy. MessageFlags lnwire.ChanUpdateMsgFlags @@ -36,6 +48,10 @@ type ChannelEdgePolicy struct { // channel as well as the directed edge this update applies to. ChannelFlags lnwire.ChanUpdateChanFlags + // DisableFlags is a v2-specific bitfield which signals whether the + // channel is disabled for incoming or outgoing traffic. + DisableFlags lnwire.ChanUpdateDisableFlags + // TimeLockDelta is the number of blocks this node will subtract from // the expiry of an incoming HTLC. This value expresses the time buffer // the node would like to HTLC exchanges. @@ -77,11 +93,78 @@ type ChannelEdgePolicy struct { // and ensure we're able to make upgrades to the network in a forwards // compatible manner. ExtraOpaqueData lnwire.ExtraOpaqueData + + // ExtraSignedFields are the extra signed fields found in v2 channel + // updates. + ExtraSignedFields map[uint64][]byte +} + +// ChanEdgePolicyFromWire constructs a ChannelEdgePolicy from a channel update +// message. +func ChanEdgePolicyFromWire(scid uint64, + update lnwire.ChannelUpdate) (*ChannelEdgePolicy, error) { + + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + //nolint:ll + return &ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, + SigBytes: upd.Signature.ToSignatureBytes(), + ChannelID: scid, + LastUpdate: time.Unix(int64(upd.Timestamp), 0), + MessageFlags: upd.MessageFlags, + ChannelFlags: upd.ChannelFlags, + TimeLockDelta: upd.TimeLockDelta, + MinHTLC: upd.HtlcMinimumMsat, + MaxHTLC: upd.HtlcMaximumMsat, + FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate), + InboundFee: upd.InboundFee.ValOpt(), + ExtraOpaqueData: upd.ExtraOpaqueData, + }, nil + + case *lnwire.ChannelUpdate2: + return &ChannelEdgePolicy{ + Version: lnwire.GossipVersion2, + SigBytes: upd.Signature.Val.ToSignatureBytes(), + ChannelID: upd.ShortChannelID.Val.ToUint64(), + LastBlockHeight: upd.BlockHeight.Val, + SecondPeer: upd.SecondPeer.IsSome(), + DisableFlags: upd.DisabledFlags.Val, + TimeLockDelta: upd.CLTVExpiryDelta.Val, + MinHTLC: upd.HTLCMinimumMsat.Val, + MaxHTLC: upd.HTLCMaximumMsat.Val, + FeeBaseMSat: lnwire.MilliSatoshi( + upd.FeeBaseMsat.Val, + ), + FeeProportionalMillionths: lnwire.MilliSatoshi( + upd.FeeProportionalMillionths.Val, + ), + InboundFee: upd.InboundFee.ValOpt(), + ExtraSignedFields: upd.ExtraSignedFields, + }, nil + } + + return nil, fmt.Errorf("unknown channel update version: %v", + update.MsgType()) +} + +// IsNode1 returns true if this policy was announced by the channel's node_1. +func (c *ChannelEdgePolicy) IsNode1() bool { + if c.Version == lnwire.GossipVersion1 { + return c.ChannelFlags&lnwire.ChanUpdateDirection == 0 + } + + return !c.SecondPeer } // IsDisabled determines whether the edge has the disabled bit set. func (c *ChannelEdgePolicy) IsDisabled() bool { - return c.ChannelFlags.IsDisabled() + if c.Version == lnwire.GossipVersion1 { + return c.ChannelFlags.IsDisabled() + } + + return !c.DisableFlags.IsEnabled() } // ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over @@ -95,7 +178,13 @@ func (c *ChannelEdgePolicy) ComputeFee( // String returns a human-readable version of the channel edge policy. func (c *ChannelEdgePolicy) String() string { - return fmt.Sprintf("ChannelID=%v, MessageFlags=%v, ChannelFlags=%v, "+ - "LastUpdate=%v", c.ChannelID, c.MessageFlags, c.ChannelFlags, - c.LastUpdate) + if c.Version == lnwire.GossipVersion1 { + return fmt.Sprintf("ChannelID=%v, MessageFlags=%v, "+ + "ChannelFlags=%v, LastUpdate=%v", c.ChannelID, + c.MessageFlags, c.ChannelFlags, c.LastUpdate) + } + + return fmt.Sprintf("ChannelID=%v, Node1=%v, DisableFlags=%v, "+ + "BlockHeight=%v", c.ChannelID, !c.SecondPeer, + c.DisableFlags, c.LastBlockHeight) } diff --git a/graph/db/reject_cache.go b/graph/db/reject_cache.go index 2a2721928b..b5b4386729 100644 --- a/graph/db/reject_cache.go +++ b/graph/db/reject_cache.go @@ -1,5 +1,11 @@ package graphdb +import ( + "time" + + "github.com/lightningnetwork/lnd/lnwire" +) + // rejectFlags is a compact representation of various metadata stored by the // reject cache about a particular channel. type rejectFlags uint8 @@ -41,9 +47,64 @@ func (f rejectFlags) unpack() (bool, bool) { // including the timestamps of its latest edge policies and whether or not the // channel exists in the graph. type rejectCacheEntry struct { + // upd{1,2}Time are Unix timestamps for v1 policies. upd1Time int64 upd2Time int64 - flags rejectFlags + + // upd{1,2}BlockHeight are the last known block heights for v2 + // policies. + upd1BlockHeight int64 + upd2BlockHeight int64 + + flags rejectFlags +} + +func newRejectCacheEntryV1(upd1, upd2 time.Time, exists, + isZombie bool) rejectCacheEntry { + + return rejectCacheEntry{ + upd1Time: upd1.Unix(), + upd2Time: upd2.Unix(), + flags: packRejectFlags(exists, isZombie), + } +} + +func newRejectCacheEntryV2(upd1, upd2 uint32, exists, + isZombie bool) rejectCacheEntry { + + return rejectCacheEntry{ + upd1BlockHeight: int64(upd1), + upd2BlockHeight: int64(upd2), + flags: packRejectFlags(exists, isZombie), + } +} + +func updateRejectCacheEntryV1(entry *rejectCacheEntry, isUpdate1 bool, + lastUpdate time.Time) { + + if isUpdate1 { + entry.upd1Time = lastUpdate.Unix() + } else { + entry.upd2Time = lastUpdate.Unix() + } +} + +func updateRejectCacheEntryV2(entry *rejectCacheEntry, isUpdate1 bool, + blockHeight uint32) { + + blockHeight64 := int64(blockHeight) + if isUpdate1 { + entry.upd1BlockHeight = blockHeight64 + } else { + entry.upd2BlockHeight = blockHeight64 + } +} + +// rejectCacheKey uniquely identifies a channel entry in the reject cache by +// gossip version and channel ID. +type rejectCacheKey struct { + version lnwire.GossipVersion + chanID uint64 } // rejectCache is an in-memory cache used to improve the performance of @@ -51,20 +112,25 @@ type rejectCacheEntry struct { // well as the most recent timestamps for each policy (if they exists). type rejectCache struct { n int - edges map[uint64]rejectCacheEntry + edges map[rejectCacheKey]rejectCacheEntry } // newRejectCache creates a new rejectCache with maximum capacity of n entries. func newRejectCache(n int) *rejectCache { return &rejectCache{ n: n, - edges: make(map[uint64]rejectCacheEntry, n), + edges: make(map[rejectCacheKey]rejectCacheEntry, n), } } // get returns the entry from the cache for chanid, if it exists. -func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) { - entry, ok := c.edges[chanid] +func (c *rejectCache) get(version lnwire.GossipVersion, chanid uint64) ( + rejectCacheEntry, bool) { + + entry, ok := c.edges[rejectCacheKey{ + version: version, + chanID: chanid, + }] return entry, ok } @@ -72,10 +138,17 @@ func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) { // exists, it will be replaced with the new entry. If the entry doesn't exists, // it will be inserted to the cache, performing a random eviction if the cache // is at capacity. -func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) { +func (c *rejectCache) insert(version lnwire.GossipVersion, chanid uint64, + entry rejectCacheEntry) { + + key := rejectCacheKey{ + version: version, + chanID: chanid, + } + // If entry exists, replace it. - if _, ok := c.edges[chanid]; ok { - c.edges[chanid] = entry + if _, ok := c.edges[key]; ok { + c.edges[key] = entry return } @@ -86,10 +159,13 @@ func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) { break } } - c.edges[chanid] = entry + c.edges[key] = entry } // remove deletes an entry for chanid from the cache, if it exists. -func (c *rejectCache) remove(chanid uint64) { - delete(c.edges, chanid) +func (c *rejectCache) remove(version lnwire.GossipVersion, chanid uint64) { + delete(c.edges, rejectCacheKey{ + version: version, + chanID: chanid, + }) } diff --git a/graph/db/reject_cache_test.go b/graph/db/reject_cache_test.go index f64c39c33d..c526f487b7 100644 --- a/graph/db/reject_cache_test.go +++ b/graph/db/reject_cache_test.go @@ -3,6 +3,8 @@ package graphdb import ( "reflect" "testing" + + "github.com/lightningnetwork/lnd/lnwire" ) // TestRejectCache checks the behavior of the rejectCache with respect to insertion, @@ -15,14 +17,14 @@ func TestRejectCache(t *testing.T) { // As a sanity check, assert that querying the empty cache does not // return an entry. - _, ok := c.get(0) + _, ok := c.get(lnwire.GossipVersion1, 0) if ok { t.Fatalf("reject cache should be empty") } // Now, fill up the cache entirely. for i := uint64(0); i < cacheSize; i++ { - c.insert(i, entryForInt(i)) + c.insert(lnwire.GossipVersion1, i, entryForInt(i)) } // Assert that the cache has all of the entries just inserted, since no @@ -30,7 +32,10 @@ func TestRejectCache(t *testing.T) { assertHasEntries(t, c, 0, cacheSize) // Now, insert a new element that causes the cache to evict an element. - c.insert(cacheSize, entryForInt(cacheSize)) + c.insert( + lnwire.GossipVersion1, cacheSize, + entryForInt(cacheSize), + ) // Assert that the cache has this last entry, as the cache should evict // some prior element and not the newly inserted one. @@ -40,7 +45,7 @@ func TestRejectCache(t *testing.T) { // elements. evicted := make(map[uint64]struct{}) for i := uint64(0); i < cacheSize+1; i++ { - _, ok := c.get(i) + _, ok := c.get(lnwire.GossipVersion1, i) if !ok { evicted[i] = struct{}{} } @@ -54,9 +59,9 @@ func TestRejectCache(t *testing.T) { // Remove the highest item which initially caused the eviction and // reinsert the element that was evicted prior. - c.remove(cacheSize) + c.remove(lnwire.GossipVersion1, cacheSize) for i := range evicted { - c.insert(i, entryForInt(i)) + c.insert(lnwire.GossipVersion1, i, entryForInt(i)) } // Since the removal created an extra slot, the last insertion should @@ -69,7 +74,7 @@ func TestRejectCache(t *testing.T) { // happening on inserts for existing cache items, we expect this to fail // with high probability. for i := uint64(0); i < cacheSize; i++ { - c.insert(i, entryForInt(i)) + c.insert(lnwire.GossipVersion1, i, entryForInt(i)) } assertHasEntries(t, c, 0, cacheSize) @@ -82,7 +87,7 @@ func assertHasEntries(t *testing.T, c *rejectCache, start, end uint64) { t.Helper() for i := start; i < end; i++ { - entry, ok := c.get(i) + entry, ok := c.get(lnwire.GossipVersion1, i) if !ok { t.Fatalf("reject cache should contain chan %d", i) } diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index c55110338a..dce78be671 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -747,8 +747,11 @@ func (s *SQLStore) AddChannelEdge(ctx context.Context, case alreadyExists: return ErrEdgeAlreadyExist default: - s.rejectCache.remove(edge.ChannelID) + s.rejectCache.remove( + edge.Version, edge.ChannelID, + ) s.chanCache.remove(edge.ChannelID) + return nil } }, @@ -862,13 +865,18 @@ func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy, // the entry with the updated timestamp for the direction that was just // written. If the edge doesn't exist, we'll load the cache entry lazily // during the next query for this edge. - if entry, ok := s.rejectCache.get(e.ChannelID); ok { - if isUpdate1 { - entry.upd1Time = e.LastUpdate.Unix() - } else { - entry.upd2Time = e.LastUpdate.Unix() + if entry, ok := s.rejectCache.get(e.Version, e.ChannelID); ok { + switch e.Version { + case lnwire.GossipVersion1: + updateRejectCacheEntryV1( + &entry, isUpdate1, e.LastUpdate, + ) + case lnwire.GossipVersion2: + updateRejectCacheEntryV2( + &entry, isUpdate1, e.LastBlockHeight, + ) } - s.rejectCache.insert(e.ChannelID, entry) + s.rejectCache.insert(e.Version, e.ChannelID, entry) } // If an entry for this channel is found in channel cache, we'll modify @@ -905,7 +913,7 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context, } return forEachNodeChannel( - ctx, db, s.cfg, nodeID, + ctx, db, s.cfg, lnwire.GossipVersion1, nodeID, func(info *models.ChannelEdgeInfo, outPolicy *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -1019,14 +1027,15 @@ func (s *SQLStore) ForEachNodeCacheable(ctx context.Context, // Unknown policies are passed into the callback as nil values. // // NOTE: part of the Store interface. -func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, +func (s *SQLStore) ForEachNodeChannel(ctx context.Context, + v lnwire.GossipVersion, nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, reset func()) error { return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { dbNode, err := db.GetNodeByPubKey( ctx, sqlc.GetNodeByPubKeyParams{ - Version: int16(lnwire.GossipVersion1), + Version: int16(v), PubKey: nodePub[:], }, ) @@ -1036,7 +1045,7 @@ func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, return fmt.Errorf("unable to fetch node: %w", err) } - return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb) + return forEachNodeChannel(ctx, db, s.cfg, v, dbNode.ID, cb) }, reset) } @@ -1602,11 +1611,16 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo, // // NOTE: part of the Store interface. func (s *SQLStore) ForEachChannel(ctx context.Context, - cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error, reset func()) error { + v lnwire.GossipVersion, cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, + reset func()) error { + + if !isKnownGossipVersion(v) { + return fmt.Errorf("unsupported gossip version: %d", v) + } return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - return forEachChannelWithPolicies(ctx, db, s.cfg, cb) + return forEachChannelWithPolicies(ctx, db, s.cfg, v, cb) }, reset) } @@ -1764,7 +1778,7 @@ func (s *SQLStore) MarkEdgeZombie(chanID uint64, "(channel_id=%d): %w", chanID, err) } - s.rejectCache.remove(chanID) + s.rejectCache.remove(lnwire.GossipVersion1, chanID) s.chanCache.remove(chanID) return nil @@ -1813,7 +1827,7 @@ func (s *SQLStore) MarkEdgeLive(chanID uint64) error { "(channel_id=%d): %w", chanID, err) } - s.rejectCache.remove(chanID) + s.rejectCache.remove(lnwire.GossipVersion1, chanID) s.chanCache.remove(chanID) return err @@ -1995,7 +2009,7 @@ func (s *SQLStore) DeleteChannelEdges(v lnwire.GossipVersion, } for _, chanID := range chanIDs { - s.rejectCache.remove(chanID) + s.rejectCache.remove(v, chanID) s.chanCache.remove(chanID) } @@ -2199,15 +2213,15 @@ func (s *SQLStore) FetchChannelEdgesByOutpoint(v lnwire.GossipVersion, return edge, policy1, policy2, nil } -// HasChannelEdge returns true if the database knows of a channel edge with the -// passed channel ID, and false otherwise. If an edge with that ID is found -// within the graph, then two time stamps representing the last time the edge -// was updated for both directed edges are returned along with the boolean. If -// it is not found, then the zombie index is checked and its result is returned -// as the second boolean. +// HasV1ChannelEdge returns true if the database knows of a channel edge +// with the passed channel ID, and false otherwise. If an edge with that ID +// is found within the graph, then two time stamps representing the last time +// the edge was updated for both directed edges are returned along with the +// boolean. If it is not found, then the zombie index is checked and its +// result is returned as the second boolean. // // NOTE: part of the Store interface. -func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, +func (s *SQLStore) HasV1ChannelEdge(chanID uint64) (time.Time, time.Time, bool, bool, error) { ctx := context.TODO() @@ -2222,7 +2236,7 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, // We'll query the cache with the shared lock held to allow multiple // readers to access values in the cache concurrently if they exist. s.cacheMu.RLock() - if entry, ok := s.rejectCache.get(chanID); ok { + if entry, ok := s.rejectCache.get(lnwire.GossipVersion1, chanID); ok { s.cacheMu.RUnlock() node1LastUpdate = time.Unix(entry.upd1Time, 0) node2LastUpdate = time.Unix(entry.upd2Time, 0) @@ -2238,7 +2252,7 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, // The item was not found with the shared lock, so we'll acquire the // exclusive lock and check the cache again in case another method added // the entry to the cache while no lock was held. - if entry, ok := s.rejectCache.get(chanID); ok { + if entry, ok := s.rejectCache.get(lnwire.GossipVersion1, chanID); ok { node1LastUpdate = time.Unix(entry.upd1Time, 0) node2LastUpdate = time.Unix(entry.upd2Time, 0) exists, isZombie = entry.flags.unpack() @@ -2309,15 +2323,169 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, fmt.Errorf("unable to fetch channel: %w", err) } - s.rejectCache.insert(chanID, rejectCacheEntry{ - upd1Time: node1LastUpdate.Unix(), - upd2Time: node2LastUpdate.Unix(), - flags: packRejectFlags(exists, isZombie), - }) + s.rejectCache.insert( + lnwire.GossipVersion1, chanID, + newRejectCacheEntryV1( + node1LastUpdate, node2LastUpdate, exists, + isZombie, + ), + ) return node1LastUpdate, node2LastUpdate, exists, isZombie, nil } +// HasChannelEdge returns true if the database knows of a channel edge with the +// passed channel ID and gossip version, and false otherwise. If an edge with +// that ID is found within the graph, then the zombie index is checked and its +// result is returned as the second boolean. +// +// NOTE: part of the Store interface. +func (s *SQLStore) HasChannelEdge(v lnwire.GossipVersion, + chanID uint64) (bool, bool, error) { + + if !isKnownGossipVersion(v) { + return false, false, fmt.Errorf( + "unsupported gossip version: %d", v, + ) + } + + ctx := context.TODO() + + var ( + exists bool + isZombie bool + node1LastUpdate time.Time + node2LastUpdate time.Time + node1Block uint32 + node2Block uint32 + ) + + // We'll query the cache with the shared lock held to allow multiple + // readers to access values in the cache concurrently if they exist. + s.cacheMu.RLock() + if entry, ok := s.rejectCache.get(v, chanID); ok { + s.cacheMu.RUnlock() + exists, isZombie = entry.flags.unpack() + return exists, isZombie, nil + } + s.cacheMu.RUnlock() + + s.cacheMu.Lock() + defer s.cacheMu.Unlock() + + // The item was not found with the shared lock, so we'll acquire the + // exclusive lock and check the cache again in case another method added + // the entry to the cache while no lock was held. + if entry, ok := s.rejectCache.get(v, chanID); ok { + exists, isZombie = entry.flags.unpack() + return exists, isZombie, nil + } + + chanIDB := channelIDToBytes(chanID) + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + channel, err := db.GetChannelBySCID( + ctx, sqlc.GetChannelBySCIDParams{ + Scid: chanIDB, + Version: int16(v), + }, + ) + if errors.Is(err, sql.ErrNoRows) { + // Check if it is a zombie channel. + isZombie, err = db.IsZombieChannel( + ctx, sqlc.IsZombieChannelParams{ + Scid: chanIDB, + Version: int16(v), + }, + ) + if err != nil { + return fmt.Errorf("could not check if channel "+ + "is zombie: %w", err) + } + + return nil + } else if err != nil { + return fmt.Errorf("unable to fetch channel: %w", err) + } + + exists = true + + policy1, err := db.GetChannelPolicyByChannelAndNode( + ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{ + Version: int16(v), + ChannelID: channel.ID, + NodeID: channel.NodeID1, + }, + ) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("unable to fetch channel policy: %w", + err) + } else if err == nil { + switch v { + case lnwire.GossipVersion1: + if policy1.LastUpdate.Valid { + node1LastUpdate = time.Unix( + policy1.LastUpdate.Int64, 0, + ) + } + case lnwire.GossipVersion2: + if policy1.BlockHeight.Valid { + node1Block = uint32( + policy1.BlockHeight.Int64, + ) + } + } + } + + policy2, err := db.GetChannelPolicyByChannelAndNode( + ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{ + Version: int16(v), + ChannelID: channel.ID, + NodeID: channel.NodeID2, + }, + ) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("unable to fetch channel policy: %w", + err) + } else if err == nil { + switch v { + case lnwire.GossipVersion1: + if policy2.LastUpdate.Valid { + node2LastUpdate = time.Unix( + policy2.LastUpdate.Int64, 0, + ) + } + case lnwire.GossipVersion2: + if policy2.BlockHeight.Valid { + node2Block = uint32( + policy2.BlockHeight.Int64, + ) + } + } + } + + return nil + }, sqldb.NoOpReset) + if err != nil { + return false, false, + fmt.Errorf("unable to fetch channel: %w", err) + } + + var entry rejectCacheEntry + switch v { + case lnwire.GossipVersion1: + entry = newRejectCacheEntryV1( + node1LastUpdate, node2LastUpdate, exists, isZombie, + ) + case lnwire.GossipVersion2: + entry = newRejectCacheEntryV2( + node1Block, node2Block, exists, isZombie, + ) + } + s.rejectCache.insert(v, chanID, entry) + + return exists, isZombie, nil +} + // ChannelID attempt to lookup the 8-byte compact channel ID which maps to the // passed channel point (outpoint). If the passed channel doesn't exist within // the database, then ErrEdgeNotFound is returned. @@ -2732,7 +2900,7 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint, } for _, channel := range closedChans { - s.rejectCache.remove(channel.ChannelID) + s.rejectCache.remove(channel.Version, channel.ChannelID) s.chanCache.remove(channel.ChannelID) } @@ -3001,7 +3169,7 @@ func (s *SQLStore) DisconnectBlockAtHeight(height uint32) ( s.cacheMu.Lock() for _, channel := range removedChans { - s.rejectCache.remove(channel.ChannelID) + s.rejectCache.remove(channel.Version, channel.ChannelID) s.chanCache.remove(channel.ChannelID) } s.cacheMu.Unlock() @@ -3357,14 +3525,15 @@ func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig, // edge information, the outgoing policy and the incoming policy for the // channel and node combo. func forEachNodeChannel(ctx context.Context, db SQLQueries, - cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo, + cfg *SQLStoreConfig, v lnwire.GossipVersion, id int64, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { - // Get all the V1 channels for this node. + // Get all the channels for this node. rows, err := db.ListChannelsByNodeID( ctx, sqlc.ListChannelsByNodeIDParams{ - Version: int16(lnwire.GossipVersion1), + Version: int16(v), NodeID1: id, }, ) @@ -3455,10 +3624,16 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries, var ( node1Pub, node2Pub route.Vertex - isNode1 bool chanIDB = channelIDToBytes(edge.ChannelID) + version = edge.Version ) + if !isKnownGossipVersion(version) { + return node1Pub, node2Pub, false, fmt.Errorf( + "unsupported gossip version: %d", version, + ) + } + // Check that this edge policy refers to a channel that we already // know of. We do this explicitly so that we can return the appropriate // ErrEdgeNotFound error if the channel doesn't exist, rather than @@ -3466,7 +3641,7 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries, dbChan, err := tx.GetChannelAndNodesBySCID( ctx, sqlc.GetChannelAndNodesBySCIDParams{ Scid: chanIDB, - Version: int16(lnwire.GossipVersion1), + Version: int16(version), }, ) if errors.Is(err, sql.ErrNoRows) { @@ -3480,7 +3655,7 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries, copy(node2Pub[:], dbChan.Node2PubKey) // Figure out which node this edge is from. - isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 + isNode1 := edge.IsNode1() nodeID := dbChan.NodeID1 if !isNode1 { nodeID = dbChan.NodeID2 @@ -3495,29 +3670,41 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries, inboundBase = sqldb.SQLInt64(fee.BaseFee) }) - id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{ - Version: int16(lnwire.GossipVersion1), - ChannelID: dbChan.ID, - NodeID: nodeID, - Timelock: int32(edge.TimeLockDelta), - FeePpm: int64(edge.FeeProportionalMillionths), - BaseFeeMsat: int64(edge.FeeBaseMSat), - MinHtlcMsat: int64(edge.MinHTLC), - LastUpdate: sqldb.SQLInt64(edge.LastUpdate.Unix()), - Disabled: sql.NullBool{ - Valid: true, - Bool: edge.IsDisabled(), - }, - MaxHtlcMsat: sql.NullInt64{ - Valid: edge.MessageFlags.HasMaxHtlc(), - Int64: int64(edge.MaxHTLC), - }, + params := sqlc.UpsertEdgePolicyParams{ + Version: int16(version), + ChannelID: dbChan.ID, + NodeID: nodeID, + Timelock: int32(edge.TimeLockDelta), + FeePpm: int64(edge.FeeProportionalMillionths), + BaseFeeMsat: int64(edge.FeeBaseMSat), + MinHtlcMsat: int64(edge.MinHTLC), MessageFlags: sqldb.SQLInt16(edge.MessageFlags), ChannelFlags: sqldb.SQLInt16(edge.ChannelFlags), InboundBaseFeeMsat: inboundBase, InboundFeeRateMilliMsat: inboundRate, Signature: edge.SigBytes, - }) + } + + switch version { + case lnwire.GossipVersion1: + params.LastUpdate = sqldb.SQLInt64(edge.LastUpdate.Unix()) + params.Disabled = sql.NullBool{ + Valid: true, + Bool: edge.IsDisabled(), + } + params.MaxHtlcMsat = sql.NullInt64{ + Valid: edge.MessageFlags.HasMaxHtlc(), + Int64: int64(edge.MaxHTLC), + } + case lnwire.GossipVersion2: + params.BlockHeight = sqldb.SQLInt64( + int64(edge.LastBlockHeight), + ) + params.DisableFlags = sqldb.SQLInt16(edge.DisableFlags) + params.MaxHtlcMsat = sqldb.SQLInt64(int64(edge.MaxHTLC)) + } + + id, err := tx.UpsertEdgePolicy(ctx, params) if err != nil { return node1Pub, node2Pub, isNode1, fmt.Errorf("unable to upsert edge policy: %w", err) @@ -3525,10 +3712,14 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries, // Convert the flat extra opaque data into a map of TLV types to // values. - extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData) - if err != nil { - return node1Pub, node2Pub, false, fmt.Errorf("unable to "+ - "marshal extra opaque data: %w", err) + extra := edge.ExtraSignedFields + if version == lnwire.GossipVersion1 { + extra, err = marshalExtraOpaqueData(edge.ExtraOpaqueData) + if err != nil { + return node1Pub, node2Pub, false, fmt.Errorf( + "unable to marshal extra opaque data: %w", err, + ) + } } // Update the channel policy's extra signed fields. @@ -4711,14 +4902,14 @@ func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig, } pol1, err := buildChanPolicyWithBatchData( - dbPol1, channelID, node2, batchData, + true, dbPol1, channelID, node2, batchData, ) if err != nil { return nil, nil, fmt.Errorf("unable to build policy1: %w", err) } pol2, err := buildChanPolicyWithBatchData( - dbPol2, channelID, node1, batchData, + false, dbPol2, channelID, node1, batchData, ) if err != nil { return nil, nil, fmt.Errorf("unable to build policy2: %w", err) @@ -4736,7 +4927,9 @@ func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy, var p1, p2 *models.CachedEdgePolicy if dbPol1 != nil { - policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2) + policy1, err := buildChanPolicy( + true, *dbPol1, channelID, nil, node2, + ) if err != nil { return nil, nil, err } @@ -4744,7 +4937,9 @@ func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy, p1 = models.NewCachedPolicy(policy1) } if dbPol2 != nil { - policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1) + policy2, err := buildChanPolicy( + false, *dbPol2, channelID, nil, node1, + ) if err != nil { return nil, nil, err } @@ -4757,16 +4952,10 @@ func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy, // buildChanPolicy builds a models.ChannelEdgePolicy instance from the // provided sqlc.GraphChannelPolicy and other required information. -func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64, - extras map[uint64][]byte, +func buildChanPolicy(isNode1 bool, dbPolicy sqlc.GraphChannelPolicy, + channelID uint64, extras map[uint64][]byte, toNode route.Vertex) (*models.ChannelEdgePolicy, error) { - recs, err := lnwire.CustomRecords(extras).Serialize() - if err != nil { - return nil, fmt.Errorf("unable to serialize extra signed "+ - "fields: %w", err) - } - var inboundFee fn.Option[lnwire.Fee] if dbPolicy.InboundFeeRateMilliMsat.Valid || dbPolicy.InboundBaseFeeMsat.Valid { @@ -4777,18 +4966,11 @@ func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64, }) } - return &models.ChannelEdgePolicy{ - SigBytes: dbPolicy.Signature, - ChannelID: channelID, - LastUpdate: time.Unix( - dbPolicy.LastUpdate.Int64, 0, - ), - MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags]( - dbPolicy.MessageFlags, - ), - ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags]( - dbPolicy.ChannelFlags, - ), + p := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion(dbPolicy.Version), + SigBytes: dbPolicy.Signature, + ChannelID: channelID, + SecondPeer: !isNode1, TimeLockDelta: uint16(dbPolicy.Timelock), MinHTLC: lnwire.MilliSatoshi( dbPolicy.MinHtlcMsat, @@ -4802,8 +4984,40 @@ func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64, FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm), ToNode: toNode, InboundFee: inboundFee, - ExtraOpaqueData: recs, - }, nil + } + + if p.Version == lnwire.GossipVersion1 { + recs, err := lnwire.CustomRecords(extras).Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize extra "+ + "signed fields: %w", err) + } + + p.ExtraOpaqueData = recs + p.LastUpdate = time.Unix(dbPolicy.LastUpdate.Int64, 0) + //nolint:ll + p.MessageFlags = sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags]( + dbPolicy.MessageFlags, + ) + //nolint:ll + p.ChannelFlags = sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags]( + dbPolicy.ChannelFlags, + ) + } else { + if dbPolicy.BlockHeight.Valid { + p.LastBlockHeight = uint32( + dbPolicy.BlockHeight.Int64, + ) + } + + //nolint:ll + p.DisableFlags = sqldb.ExtractSqlInt16[lnwire.ChanUpdateDisableFlags]( + dbPolicy.DisableFlags, + ) + p.ExtraSignedFields = extras + } + + return p, nil } // extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give @@ -4820,6 +5034,7 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, case sqlc.ListChannelsWithPoliciesForCachePaginatedRow: if r.Policy1Timelock.Valid { policy1 = &sqlc.GraphChannelPolicy{ + Version: int16(lnwire.GossipVersion1), Timelock: r.Policy1Timelock.Int32, FeePpm: r.Policy1FeePpm.Int64, BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, @@ -4830,10 +5045,13 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, Disabled: r.Policy1Disabled, MessageFlags: r.Policy1MessageFlags, ChannelFlags: r.Policy1ChannelFlags, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, } } if r.Policy2Timelock.Valid { policy2 = &sqlc.GraphChannelPolicy{ + Version: int16(lnwire.GossipVersion1), Timelock: r.Policy2Timelock.Int32, FeePpm: r.Policy2FeePpm.Int64, BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, @@ -4844,6 +5062,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, Disabled: r.Policy2Disabled, MessageFlags: r.Policy2MessageFlags, ChannelFlags: r.Policy2ChannelFlags, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, } } @@ -4868,6 +5088,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy1MessageFlags, ChannelFlags: r.Policy1ChannelFlags, Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, } } if r.Policy2ID.Valid { @@ -4888,6 +5110,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy2MessageFlags, ChannelFlags: r.Policy2ChannelFlags, Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, } } @@ -4912,6 +5136,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy1MessageFlags, ChannelFlags: r.Policy1ChannelFlags, Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, } } if r.Policy2ID.Valid { @@ -4932,6 +5158,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy2MessageFlags, ChannelFlags: r.Policy2ChannelFlags, Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, } } @@ -4956,6 +5184,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy1MessageFlags, ChannelFlags: r.Policy1ChannelFlags, Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, } } if r.Policy2ID.Valid { @@ -4976,6 +5206,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy2MessageFlags, ChannelFlags: r.Policy2ChannelFlags, Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, } } @@ -5000,6 +5232,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy1MessageFlags, ChannelFlags: r.Policy1ChannelFlags, Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, } } if r.Policy2ID.Valid { @@ -5020,6 +5254,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy2MessageFlags, ChannelFlags: r.Policy2ChannelFlags, Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, } } @@ -5044,6 +5280,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy1MessageFlags, ChannelFlags: r.Policy1ChannelFlags, Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, } } if r.Policy2ID.Valid { @@ -5064,6 +5302,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy2MessageFlags, ChannelFlags: r.Policy2ChannelFlags, Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, } } @@ -5088,6 +5328,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy1MessageFlags, ChannelFlags: r.Policy1ChannelFlags, Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, } } if r.Policy2ID.Valid { @@ -5108,6 +5350,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy2MessageFlags, ChannelFlags: r.Policy2ChannelFlags, Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, } } @@ -5132,6 +5376,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy1MessageFlags, ChannelFlags: r.Policy1ChannelFlags, Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, } } if r.Policy2ID.Valid { @@ -5152,6 +5398,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy2MessageFlags, ChannelFlags: r.Policy2ChannelFlags, Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, } } @@ -5176,6 +5424,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy1MessageFlags, ChannelFlags: r.Policy1ChannelFlags, Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, } } if r.Policy2ID.Valid { @@ -5196,6 +5446,8 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, MessageFlags: r.Policy2MessageFlags, ChannelFlags: r.Policy2ChannelFlags, Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, } } @@ -5477,14 +5729,14 @@ func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy, *models.ChannelEdgePolicy, error) { pol1, err := buildChanPolicyWithBatchData( - dbPol1, channelID, node2, batchData, + true, dbPol1, channelID, node2, batchData, ) if err != nil { return nil, nil, fmt.Errorf("unable to build policy1: %w", err) } pol2, err := buildChanPolicyWithBatchData( - dbPol2, channelID, node1, batchData, + false, dbPol2, channelID, node1, batchData, ) if err != nil { return nil, nil, fmt.Errorf("unable to build policy2: %w", err) @@ -5495,9 +5747,10 @@ func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy, // buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from // the provided sqlc.GraphChannelPolicy and the provided batchChannelData. -func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy, - channelID uint64, toNode route.Vertex, - batchData *batchChannelData) (*models.ChannelEdgePolicy, error) { +func buildChanPolicyWithBatchData(isNode1 bool, + dbPol *sqlc.GraphChannelPolicy, channelID uint64, + toNode route.Vertex, batchData *batchChannelData) ( + *models.ChannelEdgePolicy, error) { if dbPol == nil { return nil, nil @@ -5510,7 +5763,7 @@ func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy, dbPol1Extras = make(map[uint64][]byte) } - return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode) + return buildChanPolicy(isNode1, *dbPol, channelID, dbPol1Extras, toNode) } // batchChannelData holds all the related data for a batch of channels. @@ -5728,8 +5981,8 @@ func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig, // forEachChannelWithPolicies executes a paginated query to process each channel // with policies in the graph. func forEachChannelWithPolicies(ctx context.Context, db SQLQueries, - cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, + cfg *SQLStoreConfig, v lnwire.GossipVersion, + processChannel func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { type channelBatchIDs struct { @@ -5743,7 +5996,7 @@ func forEachChannelWithPolicies(ctx context.Context, db SQLQueries, return db.ListChannelsWithPoliciesPaginated( ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{ - Version: int16(lnwire.GossipVersion1), + Version: int16(v), ID: lastID, Limit: limit, }, diff --git a/graph/notifications_test.go b/graph/notifications_test.go index b086fc5620..20e45026ae 100644 --- a/graph/notifications_test.go +++ b/graph/notifications_test.go @@ -113,6 +113,7 @@ func randEdgePolicy(chanID *lnwire.ShortChannelID, } return &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: chanID.ToUint64(), LastUpdate: time.Unix(int64(prand.Int31()), 0), diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index ca83f8babc..104b2873dd 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -317,7 +317,9 @@ var shouldIncludeChannelTestCases = []struct { return edge }(), + //nolint:ll &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, FeeBaseMSat: 1000, FeeProportionalMillionths: 20, TimeLockDelta: 13, @@ -364,7 +366,9 @@ var shouldIncludeChannelTestCases = []struct { ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, + //nolint:ll &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, FeeBaseMSat: 1000, FeeProportionalMillionths: 20, TimeLockDelta: 13, @@ -409,7 +413,9 @@ var shouldIncludeChannelTestCases = []struct { ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, + //nolint:ll &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, FeeBaseMSat: 1000, FeeProportionalMillionths: 20, TimeLockDelta: 13, diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index 5265d317f0..59ce324c69 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -116,12 +116,14 @@ func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, return edgeInfo, &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, ChannelID: channel.ShortChanID().ToUint64(), ChannelFlags: dir1, LastUpdate: time.Now(), SigBytes: testSigBytes, }, &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, ChannelID: channel.ShortChanID().ToUint64(), ChannelFlags: dir2, LastUpdate: time.Now(), @@ -222,6 +224,7 @@ func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate1, timestamp := time.Unix(int64(update.Timestamp), 0) policy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, ChannelID: update.ShortChannelID.ToUint64(), ChannelFlags: update.ChannelFlags, LastUpdate: timestamp, diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index cc86d82ccc..6dd30b0ae3 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -364,6 +364,7 @@ func (r *Manager) createEdge(channel *channeldb.OpenChannel, // be updated with the new values in the call to processChan below. timeLockDelta := uint16(r.DefaultRoutingPolicy.TimeLockDelta) edge := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, ChannelID: shortChanID.ToUint64(), LastUpdate: timestamp, TimeLockDelta: timeLockDelta, diff --git a/routing/localchans/manager_test.go b/routing/localchans/manager_test.go index 6196330bd1..c48e616416 100644 --- a/routing/localchans/manager_test.go +++ b/routing/localchans/manager_test.go @@ -64,6 +64,7 @@ func TestManager(t *testing.T) { } currentPolicy := models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, MinHTLC: minHTLC, MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, } @@ -451,6 +452,7 @@ func TestCreateEdgeLower(t *testing.T) { require.NoError(t, err) expectedEdge := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, ChannelID: 8, LastUpdate: timestamp, TimeLockDelta: 7, @@ -542,6 +544,7 @@ func TestCreateEdgeHigher(t *testing.T) { require.NoError(t, err) expectedEdge := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, ChannelID: 8, LastUpdate: timestamp, TimeLockDelta: 7, diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index ed291a6aae..01f5d9849e 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -389,6 +389,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( } edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), ChannelFlags: channelFlags, @@ -740,7 +741,9 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, channelFlags |= lnwire.ChanUpdateDisabled } + //nolint:ll edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, @@ -772,7 +775,9 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, } channelFlags |= lnwire.ChanUpdateDirection + //nolint:ll edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, diff --git a/routing/router_test.go b/routing/router_test.go index 6f9d2c3f31..115c02c2c9 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -2751,6 +2751,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { // We must add the edge policy to be able to use the edge for route // finding. edgePolicy := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -2766,6 +2767,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { // Create edge in the other direction as well. edgePolicy = &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -2832,6 +2834,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { require.NoError(t, ctx.graph.AddChannelEdge(ctxb, edge)) edgePolicy = &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -2846,6 +2849,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { require.NoError(t, ctx.graph.UpdateEdgePolicy(ctxb, edgePolicy)) edgePolicy = &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion1, SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -2964,19 +2968,13 @@ func (m *mockGraphBuilder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { return false } - err := m.updateEdge(&models.ChannelEdgePolicy{ - SigBytes: msg.Signature.ToSignatureBytes(), - ChannelID: msg.ShortChannelID.ToUint64(), - LastUpdate: time.Unix(int64(msg.Timestamp), 0), - MessageFlags: msg.MessageFlags, - ChannelFlags: msg.ChannelFlags, - TimeLockDelta: msg.TimeLockDelta, - MinHTLC: msg.HtlcMinimumMsat, - MaxHTLC: msg.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate), - ExtraOpaqueData: msg.ExtraOpaqueData, - }) + update, err := models.ChanEdgePolicyFromWire( + msg.ShortChannelID.ToUint64(), msg, + ) + if err != nil { + return false + } + err = m.updateEdge(update) return err == nil } diff --git a/routing/unified_edges.go b/routing/unified_edges.go index 9b8f6c5c03..fda06aabfd 100644 --- a/routing/unified_edges.go +++ b/routing/unified_edges.go @@ -188,7 +188,7 @@ func (u *unifiedEdge) amtInRange(amt lnwire.MilliSatoshi) bool { } // Skip channels for which this htlc is too large. - if u.policy.MessageFlags.HasMaxHtlc() && + if u.policy.HasMaxHTLC && amt > u.policy.MaxHTLC { log.Tracef("Exceeds policy's MaxHTLC: amt=%v, MaxHTLC=%v", @@ -376,7 +376,7 @@ func (u *edgeUnifier) getEdgeNetwork(netAmtReceived lnwire.MilliSatoshi, } // For network channels, skip the disabled ones. - if edge.policy.IsDisabled() { + if edge.policy.IsDisabled { log.Debugf("Skipped edge %v due to it being disabled", edge.policy.ChannelID) continue @@ -385,7 +385,7 @@ func (u *edgeUnifier) getEdgeNetwork(netAmtReceived lnwire.MilliSatoshi, // Track the maximal capacity for usable channels. If we don't // know the capacity, we fall back to MaxHTLC. capMsat := lnwire.NewMSatFromSatoshis(edge.capacity) - if capMsat == 0 && edge.policy.MessageFlags.HasMaxHtlc() { + if capMsat == 0 && edge.policy.HasMaxHTLC { log.Tracef("No capacity available for channel %v, "+ "using MaxHtlcMsat (%v) as a fallback.", edge.policy.ChannelID, edge.policy.MaxHTLC) diff --git a/routing/unified_edges_test.go b/routing/unified_edges_test.go index 8fc79031ac..25c8e9220b 100644 --- a/routing/unified_edges_test.go +++ b/routing/unified_edges_test.go @@ -30,7 +30,7 @@ func TestNodeEdgeUnifier(t *testing.T) { FeeProportionalMillionths: 100000, FeeBaseMSat: 30, TimeLockDelta: 60, - MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, + HasMaxHTLC: true, MaxHTLC: 5000, MinHTLC: 100, } @@ -39,7 +39,7 @@ func TestNodeEdgeUnifier(t *testing.T) { FeeProportionalMillionths: 190000, FeeBaseMSat: 10, TimeLockDelta: 40, - MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, + HasMaxHTLC: true, MaxHTLC: 4000, MinHTLC: 100, } diff --git a/rpcserver.go b/rpcserver.go index 166eb9be84..d6be226f05 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6852,10 +6852,11 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, } } - // Obtain the pointer to the global singleton channel graph, this will - // provide a consistent view of the graph due to bolt db's - // transactional model. - graph := r.server.graphDB + // Obtain the pointer to the V1 channel graph. This will provide a + // consistent view of the graph due to bolt db's transactional model. + // + // TODO(elle): switch to a cross-version graph view when available. + graph := r.server.v1Graph // First iterate through all the known nodes (connected or unconnected // within the graph), collating their current state into the RPC diff --git a/server.go b/server.go index 7ba74176a7..ad97261edb 100644 --- a/server.go +++ b/server.go @@ -1131,7 +1131,8 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, *models.ChannelEdgePolicy) error, reset func()) error { - return s.graphDB.ForEachNodeChannel(ctx, selfVertex, + return s.v1Graph.ForEachNodeChannel( + ctx, selfVertex, func(c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index b8bb884081..aa2a04671a 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -397,6 +397,8 @@ SELECT cp1.message_flags AS policy_1_message_flags, cp1.channel_flags AS policy_1_channel_flags, cp1.signature AS policy_1_signature, + cp1.block_height AS policy_1_block_height, + cp1.disable_flags AS policy_1_disable_flags, -- Node 2 policy cp2.id AS policy_2_id, @@ -413,7 +415,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy_2_message_flags, cp2.channel_flags AS policy_2_channel_flags, - cp2.signature AS policy_2_signature + cp2.signature AS policy_2_signature, + cp2.block_height AS policy_2_block_height, + cp2.disable_flags AS policy_2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id JOIN graph_nodes n2 ON c.node_id_2 = n2.id @@ -448,6 +452,8 @@ type GetChannelByOutpointWithPoliciesRow struct { Policy1MessageFlags sql.NullInt16 Policy1ChannelFlags sql.NullInt16 Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 Policy2ID sql.NullInt64 Policy2NodeID sql.NullInt64 Policy2Version sql.NullInt16 @@ -463,6 +469,8 @@ type GetChannelByOutpointWithPoliciesRow struct { Policy2MessageFlags sql.NullInt16 Policy2ChannelFlags sql.NullInt16 Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 } func (q *Queries) GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error) { @@ -502,6 +510,8 @@ func (q *Queries) GetChannelByOutpointWithPolicies(ctx context.Context, arg GetC &i.Policy1MessageFlags, &i.Policy1ChannelFlags, &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, &i.Policy2ID, &i.Policy2NodeID, &i.Policy2Version, @@ -517,6 +527,8 @@ func (q *Queries) GetChannelByOutpointWithPolicies(ctx context.Context, arg GetC &i.Policy2MessageFlags, &i.Policy2ChannelFlags, &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, ) return i, err } @@ -577,6 +589,8 @@ SELECT cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -593,7 +607,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy_2_message_flags, cp2.channel_flags AS policy_2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -630,6 +646,8 @@ type GetChannelBySCIDWithPoliciesRow struct { Policy1MessageFlags sql.NullInt16 Policy1ChannelFlags sql.NullInt16 Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 Policy2ID sql.NullInt64 Policy2NodeID sql.NullInt64 Policy2Version sql.NullInt16 @@ -645,6 +663,8 @@ type GetChannelBySCIDWithPoliciesRow struct { Policy2MessageFlags sql.NullInt16 Policy2ChannelFlags sql.NullInt16 Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 } func (q *Queries) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error) { @@ -698,6 +718,8 @@ func (q *Queries) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChann &i.Policy1MessageFlags, &i.Policy1ChannelFlags, &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, &i.Policy2ID, &i.Policy2NodeID, &i.Policy2Version, @@ -713,6 +735,8 @@ func (q *Queries) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChann &i.Policy2MessageFlags, &i.Policy2ChannelFlags, &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, ) return i, err } @@ -917,6 +941,8 @@ SELECT cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -933,7 +959,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -966,6 +994,8 @@ type GetChannelsByIDsRow struct { Policy1MessageFlags sql.NullInt16 Policy1ChannelFlags sql.NullInt16 Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 Policy2ID sql.NullInt64 Policy2NodeID sql.NullInt64 Policy2Version sql.NullInt16 @@ -981,6 +1011,8 @@ type GetChannelsByIDsRow struct { Policy2MessageFlags sql.NullInt16 Policy2ChannelFlags sql.NullInt16 Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 } func (q *Queries) GetChannelsByIDs(ctx context.Context, ids []int64) ([]GetChannelsByIDsRow, error) { @@ -1038,6 +1070,8 @@ func (q *Queries) GetChannelsByIDs(ctx context.Context, ids []int64) ([]GetChann &i.Policy1MessageFlags, &i.Policy1ChannelFlags, &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, &i.Policy2ID, &i.Policy2NodeID, &i.Policy2Version, @@ -1053,6 +1087,8 @@ func (q *Queries) GetChannelsByIDs(ctx context.Context, ids []int64) ([]GetChann &i.Policy2MessageFlags, &i.Policy2ChannelFlags, &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, ); err != nil { return nil, err } @@ -1159,6 +1195,8 @@ SELECT cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 (node_id_2) cp2.id AS policy2_id, @@ -1175,7 +1213,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -1244,6 +1284,8 @@ type GetChannelsByPolicyLastUpdateRangeRow struct { Policy1MessageFlags sql.NullInt16 Policy1ChannelFlags sql.NullInt16 Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 Policy2ID sql.NullInt64 Policy2NodeID sql.NullInt64 Policy2Version sql.NullInt16 @@ -1259,6 +1301,8 @@ type GetChannelsByPolicyLastUpdateRangeRow struct { Policy2MessageFlags sql.NullInt16 Policy2ChannelFlags sql.NullInt16 Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 } func (q *Queries) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) { @@ -1325,6 +1369,8 @@ func (q *Queries) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg Ge &i.Policy1MessageFlags, &i.Policy1ChannelFlags, &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, &i.Policy2ID, &i.Policy2NodeID, &i.Policy2Version, @@ -1340,6 +1386,8 @@ func (q *Queries) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg Ge &i.Policy2MessageFlags, &i.Policy2ChannelFlags, &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, ); err != nil { return nil, err } @@ -1440,6 +1488,8 @@ SELECT cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -1456,7 +1506,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy_2_message_flags, cp2.channel_flags AS policy_2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -1494,6 +1546,8 @@ type GetChannelsBySCIDWithPoliciesRow struct { Policy1MessageFlags sql.NullInt16 Policy1ChannelFlags sql.NullInt16 Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 Policy2ID sql.NullInt64 Policy2NodeID sql.NullInt64 Policy2Version sql.NullInt16 @@ -1509,6 +1563,8 @@ type GetChannelsBySCIDWithPoliciesRow struct { Policy2MessageFlags sql.NullInt16 Policy2ChannelFlags sql.NullInt16 Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 } func (q *Queries) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) { @@ -1579,6 +1635,8 @@ func (q *Queries) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChan &i.Policy1MessageFlags, &i.Policy1ChannelFlags, &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, &i.Policy2ID, &i.Policy2NodeID, &i.Policy2Version, @@ -1594,6 +1652,8 @@ func (q *Queries) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChan &i.Policy2MessageFlags, &i.Policy2ChannelFlags, &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, ); err != nil { return nil, err } @@ -2826,6 +2886,8 @@ SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -2842,7 +2904,9 @@ SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -2879,6 +2943,8 @@ type ListChannelsByNodeIDRow struct { Policy1MessageFlags sql.NullInt16 Policy1ChannelFlags sql.NullInt16 Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 Policy2ID sql.NullInt64 Policy2NodeID sql.NullInt64 Policy2Version sql.NullInt16 @@ -2894,6 +2960,8 @@ type ListChannelsByNodeIDRow struct { Policy2MessageFlags sql.NullInt16 Policy2ChannelFlags sql.NullInt16 Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 } func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error) { @@ -2939,6 +3007,8 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo &i.Policy1MessageFlags, &i.Policy1ChannelFlags, &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, &i.Policy2ID, &i.Policy2NodeID, &i.Policy2Version, @@ -2954,6 +3024,8 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo &i.Policy2MessageFlags, &i.Policy2ChannelFlags, &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, ); err != nil { return nil, err } @@ -2992,6 +3064,8 @@ SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -3008,7 +3082,9 @@ SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -3047,6 +3123,8 @@ type ListChannelsForNodeIDsRow struct { Policy1MessageFlags sql.NullInt16 Policy1ChannelFlags sql.NullInt16 Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 Policy2ID sql.NullInt64 Policy2NodeID sql.NullInt64 Policy2Version sql.NullInt16 @@ -3062,6 +3140,8 @@ type ListChannelsForNodeIDsRow struct { Policy2MessageFlags sql.NullInt16 Policy2ChannelFlags sql.NullInt16 Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 } func (q *Queries) ListChannelsForNodeIDs(ctx context.Context, arg ListChannelsForNodeIDsParams) ([]ListChannelsForNodeIDsRow, error) { @@ -3126,6 +3206,8 @@ func (q *Queries) ListChannelsForNodeIDs(ctx context.Context, arg ListChannelsFo &i.Policy1MessageFlags, &i.Policy1ChannelFlags, &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, &i.Policy2ID, &i.Policy2NodeID, &i.Policy2Version, @@ -3141,6 +3223,8 @@ func (q *Queries) ListChannelsForNodeIDs(ctx context.Context, arg ListChannelsFo &i.Policy2MessageFlags, &i.Policy2ChannelFlags, &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, ); err != nil { return nil, err } @@ -3225,6 +3309,8 @@ SELECT cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Node 2 policy cp2.timelock AS policy_2_timelock, @@ -3236,7 +3322,9 @@ SELECT cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, - cp2.channel_flags AS policy2_channel_flags + cp2.channel_flags AS policy2_channel_flags, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -3272,6 +3360,8 @@ type ListChannelsWithPoliciesForCachePaginatedRow struct { Policy1InboundFeeRateMilliMsat sql.NullInt64 Policy1MessageFlags sql.NullInt16 Policy1ChannelFlags sql.NullInt16 + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 Policy2Timelock sql.NullInt32 Policy2FeePpm sql.NullInt64 Policy2BaseFeeMsat sql.NullInt64 @@ -3282,6 +3372,8 @@ type ListChannelsWithPoliciesForCachePaginatedRow struct { Policy2InboundFeeRateMilliMsat sql.NullInt64 Policy2MessageFlags sql.NullInt16 Policy2ChannelFlags sql.NullInt16 + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 } func (q *Queries) ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg ListChannelsWithPoliciesForCachePaginatedParams) ([]ListChannelsWithPoliciesForCachePaginatedRow, error) { @@ -3309,6 +3401,8 @@ func (q *Queries) ListChannelsWithPoliciesForCachePaginated(ctx context.Context, &i.Policy1InboundFeeRateMilliMsat, &i.Policy1MessageFlags, &i.Policy1ChannelFlags, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, &i.Policy2Timelock, &i.Policy2FeePpm, &i.Policy2BaseFeeMsat, @@ -3319,6 +3413,8 @@ func (q *Queries) ListChannelsWithPoliciesForCachePaginated(ctx context.Context, &i.Policy2InboundFeeRateMilliMsat, &i.Policy2MessageFlags, &i.Policy2ChannelFlags, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, ); err != nil { return nil, err } @@ -3356,6 +3452,8 @@ SELECT cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, cp1.signature AS policy_1_signature, -- Node 2 policy @@ -3373,7 +3471,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy_2_signature + cp2.signature AS policy_2_signature, + cp2.block_height AS policy_2_block_height, + cp2.disable_flags AS policy_2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -3411,6 +3511,8 @@ type ListChannelsWithPoliciesPaginatedRow struct { Policy1InboundFeeRateMilliMsat sql.NullInt64 Policy1MessageFlags sql.NullInt16 Policy1ChannelFlags sql.NullInt16 + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 Policy1Signature []byte Policy2ID sql.NullInt64 Policy2NodeID sql.NullInt64 @@ -3427,6 +3529,8 @@ type ListChannelsWithPoliciesPaginatedRow struct { Policy2MessageFlags sql.NullInt16 Policy2ChannelFlags sql.NullInt16 Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 } func (q *Queries) ListChannelsWithPoliciesPaginated(ctx context.Context, arg ListChannelsWithPoliciesPaginatedParams) ([]ListChannelsWithPoliciesPaginatedRow, error) { @@ -3471,6 +3575,8 @@ func (q *Queries) ListChannelsWithPoliciesPaginated(ctx context.Context, arg Lis &i.Policy1InboundFeeRateMilliMsat, &i.Policy1MessageFlags, &i.Policy1ChannelFlags, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, &i.Policy1Signature, &i.Policy2ID, &i.Policy2NodeID, @@ -3487,6 +3593,8 @@ func (q *Queries) ListChannelsWithPoliciesPaginated(ctx context.Context, arg Lis &i.Policy2MessageFlags, &i.Policy2ChannelFlags, &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, ); err != nil { return nil, err } @@ -3674,9 +3782,9 @@ INSERT INTO graph_channel_policies ( base_fee_msat, min_htlc_msat, last_update, disabled, max_htlc_msat, inbound_base_fee_msat, inbound_fee_rate_milli_msat, message_flags, channel_flags, - signature + signature, block_height, disable_flags ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 ) ON CONFLICT (channel_id, node_id, version) -- Update the following fields if a conflict occurs on channel_id, @@ -3693,8 +3801,21 @@ ON CONFLICT (channel_id, node_id, version) inbound_fee_rate_milli_msat = EXCLUDED.inbound_fee_rate_milli_msat, message_flags = EXCLUDED.message_flags, channel_flags = EXCLUDED.channel_flags, - signature = EXCLUDED.signature -WHERE EXCLUDED.last_update > graph_channel_policies.last_update + signature = EXCLUDED.signature, + block_height = EXCLUDED.block_height, + disable_flags = EXCLUDED.disable_flags +WHERE ( + EXCLUDED.version = 1 AND ( + graph_channel_policies.last_update IS NULL + OR EXCLUDED.last_update > graph_channel_policies.last_update + ) +) +OR ( + EXCLUDED.version = 2 AND ( + graph_channel_policies.block_height IS NULL + OR EXCLUDED.block_height >= graph_channel_policies.block_height + ) +) RETURNING id ` @@ -3714,6 +3835,8 @@ type UpsertEdgePolicyParams struct { MessageFlags sql.NullInt16 ChannelFlags sql.NullInt16 Signature []byte + BlockHeight sql.NullInt64 + DisableFlags sql.NullInt16 } func (q *Queries) UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyParams) (int64, error) { @@ -3733,6 +3856,8 @@ func (q *Queries) UpsertEdgePolicy(ctx context.Context, arg UpsertEdgePolicyPara arg.MessageFlags, arg.ChannelFlags, arg.Signature, + arg.BlockHeight, + arg.DisableFlags, ) var id int64 err := row.Scan(&id) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 0ad71f8783..a4a42a1e3a 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -407,6 +407,8 @@ SELECT cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -423,7 +425,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy_2_message_flags, cp2.channel_flags AS policy_2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -462,6 +466,8 @@ SELECT cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -478,7 +484,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -511,6 +519,8 @@ SELECT cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 (node_id_2) cp2.id AS policy2_id, @@ -527,7 +537,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -590,6 +602,8 @@ SELECT cp1.message_flags AS policy_1_message_flags, cp1.channel_flags AS policy_1_channel_flags, cp1.signature AS policy_1_signature, + cp1.block_height AS policy_1_block_height, + cp1.disable_flags AS policy_1_disable_flags, -- Node 2 policy cp2.id AS policy_2_id, @@ -606,7 +620,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy_2_message_flags, cp2.channel_flags AS policy_2_channel_flags, - cp2.signature AS policy_2_signature + cp2.signature AS policy_2_signature, + cp2.block_height AS policy_2_block_height, + cp2.disable_flags AS policy_2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id JOIN graph_nodes n2 ON c.node_id_2 = n2.id @@ -647,6 +663,8 @@ SELECT sqlc.embed(c), cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -663,7 +681,9 @@ SELECT sqlc.embed(c), cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -700,6 +720,8 @@ SELECT sqlc.embed(c), cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -716,7 +738,9 @@ SELECT sqlc.embed(c), cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -765,6 +789,8 @@ SELECT cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, cp1.signature AS policy_1_signature, -- Node 2 policy @@ -782,7 +808,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, cp2.channel_flags AS policy2_channel_flags, - cp2.signature AS policy_2_signature + cp2.signature AS policy_2_signature, + cp2.block_height AS policy_2_block_height, + cp2.disable_flags AS policy_2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -816,6 +844,8 @@ SELECT cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Node 2 policy cp2.timelock AS policy_2_timelock, @@ -827,7 +857,9 @@ SELECT cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy2_message_flags, - cp2.channel_flags AS policy2_channel_flags + cp2.channel_flags AS policy2_channel_flags, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id @@ -900,9 +932,9 @@ INSERT INTO graph_channel_policies ( base_fee_msat, min_htlc_msat, last_update, disabled, max_htlc_msat, inbound_base_fee_msat, inbound_fee_rate_milli_msat, message_flags, channel_flags, - signature + signature, block_height, disable_flags ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 ) ON CONFLICT (channel_id, node_id, version) -- Update the following fields if a conflict occurs on channel_id, @@ -919,8 +951,21 @@ ON CONFLICT (channel_id, node_id, version) inbound_fee_rate_milli_msat = EXCLUDED.inbound_fee_rate_milli_msat, message_flags = EXCLUDED.message_flags, channel_flags = EXCLUDED.channel_flags, - signature = EXCLUDED.signature -WHERE EXCLUDED.last_update > graph_channel_policies.last_update + signature = EXCLUDED.signature, + block_height = EXCLUDED.block_height, + disable_flags = EXCLUDED.disable_flags +WHERE ( + EXCLUDED.version = 1 AND ( + graph_channel_policies.last_update IS NULL + OR EXCLUDED.last_update > graph_channel_policies.last_update + ) +) +OR ( + EXCLUDED.version = 2 AND ( + graph_channel_policies.block_height IS NULL + OR EXCLUDED.block_height >= graph_channel_policies.block_height + ) +) RETURNING id; -- name: GetChannelPolicyByChannelAndNode :one @@ -952,6 +997,8 @@ SELECT cp1.message_flags AS policy1_message_flags, cp1.channel_flags AS policy1_channel_flags, cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, -- Policy 2 cp2.id AS policy2_id, @@ -968,7 +1015,9 @@ SELECT cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, cp2.message_flags AS policy_2_message_flags, cp2.channel_flags AS policy_2_channel_flags, - cp2.signature AS policy2_signature + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags FROM graph_channels c JOIN graph_nodes n1 ON c.node_id_1 = n1.id