diff --git a/intra/dnsx/alg.go b/intra/dnsx/alg.go index ea8dbfa3..be7e2036 100644 --- a/intra/dnsx/alg.go +++ b/intra/dnsx/alg.go @@ -1927,6 +1927,19 @@ func (t *dnsgateway) ptrLocked(maybeAlg netip.Addr, uid, tid string, useptr bool return copyUniq(domains) } +func minNonZeroDuration(curr, next time.Duration) time.Duration { + if curr == 0 { + return next + } + if next == 0 { + return curr + } + if next < curr { + return next + } + return curr +} + // resolvLocked returns IPs and related targets for domain // depending on typ. // If typ is typalg, returns all algips for domain. @@ -1950,7 +1963,7 @@ func (t *dnsgateway) resolvLocked(domain string, typ iptype, tid, uid string) (i if life, fresh := ans.fresh(); fresh { // not stale ip4s = append(ip4s, ans.algip) targets = append(targets, domainsFor(ans.baseans, tid, uid, ans.algip, xalive)...) - until = min(until, life) + until = minNonZeroDuration(until, life) } else { staleips = append(staleips, ans.algip) } @@ -1964,7 +1977,7 @@ func (t *dnsgateway) resolvLocked(domain string, typ iptype, tid, uid string) (i if life, fresh := ans.fresh(); fresh { // not stale ip6s = append(ip6s, ans.algip) targets = append(targets, domainsFor(ans.baseans, tid, uid, ans.algip, xalive)...) - until = min(until, life) + until = minNonZeroDuration(until, life) } else { staleips = append(staleips, ans.algip) } @@ -1984,7 +1997,7 @@ func (t *dnsgateway) resolvLocked(domain string, typ iptype, tid, uid string) (i all4s := v4only(ans.ips.realipsFor(tid, uid, xalive)) ip4s = append(ip4s, all4s...) targets = append(targets, domainsFor(ans.baseans, tid, uid, core.FirstOf(all4s), xalive)...) - until = min(until, life) + until = minNonZeroDuration(until, life) } else { staleips = append(staleips, ans.ips.realipsFor(tid, uid, xall)...) } @@ -1999,7 +2012,7 @@ func (t *dnsgateway) resolvLocked(domain string, typ iptype, tid, uid string) (i all6s := v6only(ans.ips.realipsFor(tid, uid, xalive)) ip6s = append(ip6s, all6s...) targets = append(targets, domainsFor(ans.baseans, tid, uid, core.FirstOf(all6s), xalive)...) - until = min(until, life) + until = minNonZeroDuration(until, life) } else { staleips = append(staleips, ans.ips.realipsFor(tid, uid, xall)...) } @@ -2019,7 +2032,7 @@ func (t *dnsgateway) resolvLocked(domain string, typ iptype, tid, uid string) (i all4s := v4only(ans.ips.secipsFor(tid, uid)) ip4s = append(ip4s, all4s...) targets = append(targets, domainsFor(ans.baseans, tid, uid, core.FirstOf(all4s), xalive)...) - until = min(until, life) + until = minNonZeroDuration(until, life) } else { staleips = append(staleips, ans.ips.secips(xall)...) } @@ -2034,7 +2047,7 @@ func (t *dnsgateway) resolvLocked(domain string, typ iptype, tid, uid string) (i all6s := v6only(ans.ips.secipsFor(tid, uid)) ip6s = append(ip6s, all6s...) targets = append(targets, domainsFor(ans.baseans, tid, uid, core.FirstOf(all6s), xalive)...) - until = min(until, life) + until = minNonZeroDuration(until, life) } else { // TODO: stale targets? staleips = append(staleips, ans.ips.secips(xall)...) diff --git a/intra/dnsx/alg_cache_test.go b/intra/dnsx/alg_cache_test.go new file mode 100644 index 00000000..0a2a21b8 --- /dev/null +++ b/intra/dnsx/alg_cache_test.go @@ -0,0 +1,62 @@ +package dnsx + +import ( + "net/netip" + "testing" + "time" + + "github.com/miekg/dns" +) + +func TestResolvLockedUntilUsesRemainingTTL(t *testing.T) { + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + domain := qname(q) + + tid := "tid" + uid := "uid" + ttl := time.Now().Add(30 * time.Second) + + realip := netip.MustParseAddr("1.1.1.1") + algip := netip.MustParseAddr("100.64.0.1") + + xips := NewXips(tid, uid, []netip.Addr{realip}, nil, ttl) + if xips == nil { + t.Fatalf("expected xips") + } + xdomains := NewXdomains(tid, uid, []string{domain}, ttl) + if xdomains == nil { + t.Fatalf("expected xdomains") + } + + gw := &dnsgateway{ + alg: map[string]*algans{ + domain + key4 + "0": { + algip: algip, + baseans: &baseans{ + ips: xips, + domains: xdomains, + ttl: ttl, + }, + }, + }, + nat: make(map[netip.Addr]*baseans), + ptr: make(map[netip.Addr]*baseans), + } + + _, _, _, until, _ := gw.resolvLocked(domain, typreal, tid, uid) + if until <= 0 { + t.Fatalf("expected positive ttl, got %s", until) + } + + ans, err := gw.fromInternalCache(tid, uid, q, typreal) + if err != nil { + t.Fatalf("expected cache hit, got error %v", err) + } + if len(ans.Answer) == 0 { + t.Fatalf("expected at least one answer") + } + if got := ans.Answer[0].Header().Ttl; got < 20 { + t.Fatalf("expected ttl >= 20s, got %d", got) + } +}