Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ $(LOCALBIN):
mkdir -p $(LOCALBIN)

test: generate fmt vet manifests setup-envtest
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./... -coverprofile cover.out
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test -race ./... -coverprofile cover.out

clean:
rm -rf bin/*
Expand Down
10 changes: 9 additions & 1 deletion pkg/dns/dns_proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"strings"
"sync"
"time"

"github.com/go-logr/logr"
Expand All @@ -22,6 +23,7 @@ const (
)

type DNSProxyHandler struct {
sync.RWMutex
log logr.Logger
udpClient *dnsgo.Client
tcpClient *dnsgo.Client
Expand Down Expand Up @@ -111,7 +113,9 @@ func (h *DNSProxyHandler) UpdateDNSServerAddr(addr string) error {
return fmt.Errorf("new DNS server address not valid: %w", err)
}

h.Lock()
h.dnsServerAddr = addr
h.Unlock()
return nil
}

Expand All @@ -128,7 +132,11 @@ func (h *DNSProxyHandler) getDataFromDNS(addr net.Addr, request *dnsgo.Msg) (*dn
return nil, fmt.Errorf("failed to determine transport protocol: %s", protocol)
}

response, _, err := client.Exchange(request, h.dnsServerAddr)
h.RLock()
dnsAddr := h.dnsServerAddr
h.RUnlock()

response, _, err := client.Exchange(request, dnsAddr)
if err != nil {
return nil, fmt.Errorf("failed to call target DNS: %w", err)
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/dns/dnscache.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ func (c *DNSCache) getSetsForRendering(fqdns []firewallv1.FQDNSelector) (result
}

func (c *DNSCache) updateDNSServerAddr(addr string) {
c.Lock()
c.dnsServerAddr = addr
c.Unlock()
}

// getSetNameForFQDN returns FQDN set data
Expand Down Expand Up @@ -339,13 +341,16 @@ func (c *DNSCache) loadDataFromDNSServer(fqdns []string) error {
return fmt.Errorf("too many hops, fqdn chain: %s", strings.Join(fqdns, ","))
}
qname := fqdns[len(fqdns)-1]
c.RLock()
dnsAddr := c.dnsServerAddr
c.RUnlock()
cl := new(dnsgo.Client)
for _, t := range []uint16{dnsgo.TypeA, dnsgo.TypeAAAA} {
m := new(dnsgo.Msg)
m.Id = dnsgo.Id()
m.SetQuestion(qname, t)
c.log.V(4).Info("DEBUG dnscache loadDataFromDNSServer function querying DNS", "message", m)
in, _, err := cl.Exchange(m, c.dnsServerAddr)
in, _, err := cl.Exchange(m, dnsAddr)
if err != nil {
return fmt.Errorf("failed to get DNS data about fqdn %s: %w", fqdns[0], err)
}
Expand Down
249 changes: 249 additions & 0 deletions pkg/dns/dnscache_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package dns

import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"

"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp"
"github.com/google/nftables"
firewallv1 "github.com/metal-stack/firewall-controller/v2/api/v1"
dnsgo "github.com/miekg/dns"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

func Test_GetSetsForFQDN(t *testing.T) {
Expand Down Expand Up @@ -226,3 +233,245 @@ func Test_createIPSetFromIPEntry(t *testing.T) {
})
}
}

const (
raceNumGoroutines = 10
raceNumIterations = 100
)

func newTestDNSCache(entries map[string]cacheEntry) *DNSCache {
return &DNSCache{
log: logr.Discard(),
fqdnToEntry: entries,
setNames: make(map[string]struct{}),
dnsServerAddr: "127.0.0.1:53",
ctx: context.Background(),
shootClient: fake.NewClientBuilder().Build(),
ipv4Enabled: true,
ipv6Enabled: true,
}
}

func makeTestRRs(fqdn string, ip string) []dnsgo.RR {
return []dnsgo.RR{
&dnsgo.A{
Hdr: dnsgo.RR_Header{Name: fqdn, Rrtype: dnsgo.TypeA, Ttl: 300},
A: net.ParseIP(ip),
},
}
}

func seedEntries(n int) map[string]cacheEntry {
entries := make(map[string]cacheEntry, n)
for i := range n {
fqdn := fmt.Sprintf("host%d.example.com.", i)
entries[fqdn] = cacheEntry{
IPv4: &iPEntry{
SetName: fmt.Sprintf("set%d", i),
IPs: map[string]time.Time{fmt.Sprintf("10.0.0.%d", i%256): time.Now().Add(5 * time.Minute)},
},
}
}
return entries
}

func TestRace_UpdateAndGetSetsForRendering(t *testing.T) {
cache := newTestDNSCache(seedEntries(5))
fqdns := []firewallv1.FQDNSelector{{MatchPattern: "*.example.com"}}

var wg sync.WaitGroup
start := make(chan struct{})

for i := range raceNumGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
<-start
fqdn := fmt.Sprintf("writer%d.example.com.", id)
for j := range raceNumIterations {
_ = cache.updateIPEntry(fqdn, makeTestRRs(fqdn, fmt.Sprintf("10.1.%d.%d", id, j%256)), time.Now(), nftables.TypeIPAddr)
}
}(i)
}

for range raceNumGoroutines {
wg.Go(func() {
<-start
for range raceNumIterations {
cache.getSetsForRendering(fqdns)
}
})
}

close(start)
wg.Wait()
}

func TestRace_UpdateAndGetSetNameForRegex(t *testing.T) {
cache := newTestDNSCache(seedEntries(5))

var wg sync.WaitGroup
start := make(chan struct{})

for i := range raceNumGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
<-start
fqdn := fmt.Sprintf("writer%d.example.com.", id)
for j := range raceNumIterations {
_ = cache.updateIPEntry(fqdn, makeTestRRs(fqdn, fmt.Sprintf("10.2.%d.%d", id, j%256)), time.Now(), nftables.TypeIPAddr)
}
}(i)
}

for range raceNumGoroutines {
wg.Go(func() {
<-start
for range raceNumIterations {
cache.getSetNameForRegex(`.*\.example\.com\.`)
}
})
}

close(start)
wg.Wait()
}

func TestRace_UpdateAndGetSetNameForFQDN(t *testing.T) {
cache := newTestDNSCache(seedEntries(5))

var wg sync.WaitGroup
start := make(chan struct{})

for i := range raceNumGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
<-start
fqdn := fmt.Sprintf("host%d.example.com.", id%5)
for j := range raceNumIterations {
_ = cache.updateIPEntry(fqdn, makeTestRRs(fqdn, fmt.Sprintf("10.3.%d.%d", id, j%256)), time.Now(), nftables.TypeIPAddr)
}
}(i)
}

for i := range raceNumGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
<-start
fqdn := fmt.Sprintf("host%d.example.com.", id%5)
for range raceNumIterations {
cache.getSetNameForFQDN(fqdn)
}
}(i)
}

close(start)
wg.Wait()
}

func TestRace_UpdateAndWriteStateToConfigmap(t *testing.T) {
cache := newTestDNSCache(seedEntries(5))

var wg sync.WaitGroup
start := make(chan struct{})

for i := range raceNumGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
<-start
fqdn := fmt.Sprintf("writer%d.example.com.", id)
for j := range raceNumIterations {
_ = cache.updateIPEntry(fqdn, makeTestRRs(fqdn, fmt.Sprintf("10.4.%d.%d", id, j%256)), time.Now(), nftables.TypeIPAddr)
}
}(i)
}

for range raceNumGoroutines {
wg.Go(func() {
<-start
for range raceNumIterations {
_ = cache.writeStateToConfigmap()
}
})
}

close(start)
wg.Wait()
}

func TestRace_UpdateDNSServerAddr(t *testing.T) {
cache := newTestDNSCache(seedEntries(1))

var wg sync.WaitGroup
start := make(chan struct{})

for i := range raceNumGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
<-start
for j := range raceNumIterations {
cache.updateDNSServerAddr(fmt.Sprintf("10.0.%d.%d:53", id, j%256))
}
}(i)
}

for range raceNumGoroutines {
wg.Go(func() {
<-start
for range raceNumIterations {
cache.RLock()
_ = cache.dnsServerAddr
cache.RUnlock()
}
})
}

close(start)
wg.Wait()
}

func TestRace_ConcurrentMultipleReaders(t *testing.T) {
cache := newTestDNSCache(seedEntries(10))
fqdns := []firewallv1.FQDNSelector{{MatchPattern: "*.example.com"}}

var wg sync.WaitGroup
start := make(chan struct{})

for range raceNumGoroutines {
wg.Go(func() {
<-start
for range raceNumIterations {
cache.getSetsForRendering(fqdns)
}
})
}

for range raceNumGoroutines {
wg.Go(func() {
<-start
for range raceNumIterations {
cache.getSetNameForRegex(`.*\.example\.com\.`)
}
})
}

for i := range raceNumGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
<-start
fqdn := fmt.Sprintf("host%d.example.com.", id%10)
for range raceNumIterations {
cache.getSetNameForFQDN(fqdn)
}
}(i)
}

close(start)
wg.Wait()
}
Loading