From 260488e3ed5edcf5a6c1f95806d64e73fc4fc107 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sun, 22 Feb 2026 05:54:46 -0700 Subject: [PATCH 1/2] Add WithDialer option for custom TCP dialer injection Allow callers to inject a custom dialer function via WithDialer() that flows through to the tlsdialer used by each front. This enables kindling to set a single dialer that automatically applies to all fronted connections. Co-Authored-By: Claude Opus 4.6 --- front.go | 19 +++++++++++++++++-- fronted.go | 18 +++++++++++++++--- fronted_test.go | 2 +- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/front.go b/front.go index 8c4492d..75c2c4f 100644 --- a/front.go +++ b/front.go @@ -1,6 +1,7 @@ package fronted import ( + "context" "crypto/sha256" "crypto/x509" "encoding/json" @@ -87,14 +88,16 @@ type front struct { ProviderID string mx sync.RWMutex cacheDirty chan interface{} + dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) } -func newFront(m *Masquerade, providerID string, cacheDirty chan interface{}) Front { +func newFront(m *Masquerade, providerID string, cacheDirty chan interface{}, dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)) Front { return &front{ Masquerade: *m, ProviderID: providerID, LastSucceeded: time.Time{}, cacheDirty: cacheDirty, + dialFunc: dialFunc, } } func (fr *front) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) { @@ -117,8 +120,20 @@ func (fr *front) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) ( return verifyPeerCertificate(rawCerts, rootCAs, verifyHostname) } } + + var doDial func(network, addr string, timeout time.Duration) (net.Conn, error) + if fr.dialFunc != nil { + doDial = func(network, addr string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return fr.dialFunc(ctx, network, addr) + } + } else { + doDial = dialWithTimeout + } + dialer := &tlsdialer.Dialer{ - DoDial: dialWithTimeout, + DoDial: doDial, Timeout: dialTimeout, SendServerName: sendServerNameExtension, Config: tlsConfig, diff --git a/fronted.go b/fronted.go index 0b9fc77..36fc4d6 100644 --- a/fronted.go +++ b/fronted.go @@ -48,6 +48,7 @@ var ( // fronted identifies working IP address/domain pairings for domain fronting and is // an implementation of http.RoundTripper for the convenience of callers. type fronted struct { + dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) certPool atomic.Value fronts *threadSafeFronts maxAllowedCachedAge time.Duration @@ -123,6 +124,9 @@ func NewFronted(options ...Option) Fronted { for _, opt := range options { opt(f) } + if f.dialFunc == nil { + f.dialFunc = (&net.Dialer{}).DialContext + } if f.cacheFile == "" { f.cacheFile = defaultCacheFilePath() } @@ -178,6 +182,14 @@ func WithPanicListener(panicListener func(string)) Option { } } +// WithDialer sets a custom dialer function for the fronted instance. This allows callers to +// inject their own dialer for making TCP connections to fronting domains. +func WithDialer(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option { + return func(f *fronted) { + f.dialFunc = dial + } +} + // SetLogger sets the logger to use by fronted. func SetLogger(logger *slog.Logger) { log = logger @@ -296,7 +308,7 @@ func (f *fronted) onNewFronts(pool *x509.CertPool, providers map[string]*Provide f.addProviders(providersCopy) log.Debug("Loading candidates for providers", "numProviders", len(providersCopy)) - fronts := loadFronts(providersCopy, f.cacheDirty) + fronts := loadFronts(providersCopy, f.cacheDirty, f.dialFunc) log.Debug("Finished loading candidates") log.Debug("Existing fronts", slog.Int("size", f.fronts.frontSize())) @@ -596,7 +608,7 @@ func copyProviders(providers map[string]*Provider, countryCode string) map[strin return providersCopy } -func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) []Front { +func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}, dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)) []Front { // Preallocate the slice to avoid reallocation size := 0 for _, p := range providers { @@ -622,7 +634,7 @@ func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) []F } for _, c := range sh { - fronts[index] = newFront(c, providerID, cacheDirty) + fronts[index] = newFront(c, providerID, cacheDirty, dialFunc) index++ } } diff --git a/fronted_test.go b/fronted_test.go index e6c28e6..ee6909b 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -799,7 +799,7 @@ func TestLoadFronts(t *testing.T) { // Create the cache dirty channel cacheDirty := make(chan interface{}, 10) - masquerades := loadFronts(providers, cacheDirty) + masquerades := loadFronts(providers, cacheDirty, nil) assert.Equal(t, 4, len(masquerades), "Unexpected number of masquerades loaded") From 665091b389789f195c6c9ffb75992b3c9c6e3657 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sun, 22 Feb 2026 06:56:54 -0700 Subject: [PATCH 2/2] Address PR review: introduce DialFunc type, improve docs, add tests - Introduce named DialFunc type for the dialer function signature - Improve WithDialer docstring with more context about how the dialer is used - Add tests: TestWithDialer, TestWithDialerDefault, TestWithDialerFlowsToFronts Co-Authored-By: Claude Opus 4.6 --- front.go | 4 ++-- fronted.go | 13 +++++++---- fronted_test.go | 59 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 6 deletions(-) diff --git a/front.go b/front.go index 75c2c4f..53ae1af 100644 --- a/front.go +++ b/front.go @@ -88,10 +88,10 @@ type front struct { ProviderID string mx sync.RWMutex cacheDirty chan interface{} - dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + dialFunc DialFunc } -func newFront(m *Masquerade, providerID string, cacheDirty chan interface{}, dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)) Front { +func newFront(m *Masquerade, providerID string, cacheDirty chan interface{}, dialFunc DialFunc) Front { return &front{ Masquerade: *m, ProviderID: providerID, diff --git a/fronted.go b/fronted.go index 36fc4d6..00ec6a9 100644 --- a/fronted.go +++ b/fronted.go @@ -45,10 +45,13 @@ var ( defaultFrontedProviderID = "cloudfront" ) +// DialFunc is the function type used for dialing network connections. +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + // fronted identifies working IP address/domain pairings for domain fronting and is // an implementation of http.RoundTripper for the convenience of callers. type fronted struct { - dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + dialFunc DialFunc certPool atomic.Value fronts *threadSafeFronts maxAllowedCachedAge time.Duration @@ -183,8 +186,10 @@ func WithPanicListener(panicListener func(string)) Option { } // WithDialer sets a custom dialer function for the fronted instance. This allows callers to -// inject their own dialer for making TCP connections to fronting domains. -func WithDialer(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option { +// inject their own dialer for making the underlying TCP connections. The dialer will typically +// be invoked with an IP:port destination (derived from the configured fronting infrastructure), +// while the fronting domain name (SNI/ServerName) is configured separately via the TLS settings. +func WithDialer(dial DialFunc) Option { return func(f *fronted) { f.dialFunc = dial } @@ -608,7 +613,7 @@ func copyProviders(providers map[string]*Provider, countryCode string) map[strin return providersCopy } -func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}, dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)) []Front { +func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}, dialFunc DialFunc) []Front { // Preallocate the slice to avoid reallocation size := 0 for _, p := range providers { diff --git a/fronted_test.go b/fronted_test.go index ee6909b..bbe0c38 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -926,3 +926,62 @@ func (m *mockFront) markWithResult(good bool) bool { // Make sure that the mockMasquerade implements the MasqueradeInterface var _ Front = (*mockFront)(nil) + +func TestWithDialer(t *testing.T) { + called := false + customDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + called = true + return (&net.Dialer{}).DialContext(ctx, network, addr) + } + + f := NewFronted( + WithDialer(customDialer), + WithEmbeddedConfigName("noconfig.yaml"), + ) + defer f.Close() + + d := f.(*fronted) + assert.NotNil(t, d.dialFunc, "dialFunc should be set") + + // Verify the custom dialer is stored (we can't compare funcs directly, but we can + // verify it's not the default by calling it and checking our flag). + _, _ = d.dialFunc(context.Background(), "tcp", "localhost:0") + assert.True(t, called, "custom dialer should have been called") +} + +func TestWithDialerDefault(t *testing.T) { + f := NewFronted(WithEmbeddedConfigName("noconfig.yaml")) + defer f.Close() + + d := f.(*fronted) + assert.NotNil(t, d.dialFunc, "default dialFunc should be set when WithDialer is not used") +} + +func TestWithDialerFlowsToFronts(t *testing.T) { + called := false + customDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + called = true + return nil, errors.New("custom dialer called") + } + + f := NewFronted( + WithDialer(customDialer), + WithEmbeddedConfigName("noconfig.yaml"), + ) + defer f.Close() + + d := f.(*fronted) + + // Create a provider and fronts using the fronted's dialer + masquerades := []*Masquerade{{Domain: "example.com", IpAddress: "127.0.0.1"}} + providers := map[string]*Provider{ + "test": NewProvider(nil, "", masquerades, nil, nil, nil, ""), + } + fronts := loadFronts(providers, d.cacheDirty, d.dialFunc) + assert.Equal(t, 1, len(fronts)) + + // Dialing through the front should use the custom dialer + _, err := fronts[0].dial(nil, tls.HelloChrome_131) + assert.Error(t, err) + assert.True(t, called, "custom dialer should flow through to fronts") +}