Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions front.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fronted

import (
"context"
"crypto/sha256"
"crypto/x509"
"encoding/json"
Expand Down Expand Up @@ -87,14 +88,16 @@ type front struct {
ProviderID string
mx sync.RWMutex
cacheDirty chan interface{}
dialFunc DialFunc
}

func newFront(m *Masquerade, providerID string, cacheDirty chan interface{}) Front {
func newFront(m *Masquerade, providerID string, cacheDirty chan interface{}, dialFunc DialFunc) Front {
return &front{
Masquerade: *m,
ProviderID: providerID,
LastSucceeded: time.Time{},
cacheDirty: cacheDirty,
dialFunc: dialFunc,
}
}
func (fr *front) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) {
Expand All @@ -117,8 +120,20 @@ func (fr *front) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (
return verifyPeerCertificate(rawCerts, rootCAs, verifyHostname)
}
}

var doDial func(network, addr string, timeout time.Duration) (net.Conn, error)
if fr.dialFunc != nil {
doDial = func(network, addr string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return fr.dialFunc(ctx, network, addr)
}
} else {
doDial = dialWithTimeout
}

dialer := &tlsdialer.Dialer{
DoDial: dialWithTimeout,
DoDial: doDial,
Timeout: dialTimeout,
SendServerName: sendServerNameExtension,
Config: tlsConfig,
Expand Down
23 changes: 20 additions & 3 deletions fronted.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@ var (
defaultFrontedProviderID = "cloudfront"
)

// DialFunc is the function type used for dialing network connections.
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)

// fronted identifies working IP address/domain pairings for domain fronting and is
// an implementation of http.RoundTripper for the convenience of callers.
type fronted struct {
dialFunc DialFunc
certPool atomic.Value
fronts *threadSafeFronts
maxAllowedCachedAge time.Duration
Expand Down Expand Up @@ -123,6 +127,9 @@ func NewFronted(options ...Option) Fronted {
for _, opt := range options {
opt(f)
}
if f.dialFunc == nil {
f.dialFunc = (&net.Dialer{}).DialContext
}
if f.cacheFile == "" {
f.cacheFile = defaultCacheFilePath()
}
Expand Down Expand Up @@ -178,6 +185,16 @@ func WithPanicListener(panicListener func(string)) Option {
}
}

// WithDialer sets a custom dialer function for the fronted instance. This allows callers to
// inject their own dialer for making the underlying TCP connections. The dialer will typically
// be invoked with an IP:port destination (derived from the configured fronting infrastructure),
// while the fronting domain name (SNI/ServerName) is configured separately via the TLS settings.
func WithDialer(dial DialFunc) Option {
return func(f *fronted) {
f.dialFunc = dial
}
}
Comment on lines 188 to 196
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New public option WithDialer adds behavior that isn’t covered by tests yet. Please add a unit test that injects a dialer via WithDialer, triggers a dial (e.g., by calling fr.dial(...) or f.doDial(...) with a dialer that records invocation and returns a sentinel error), and asserts the injected dialer is called and its error is propagated.

Copilot uses AI. Check for mistakes.

// SetLogger sets the logger to use by fronted.
func SetLogger(logger *slog.Logger) {
log = logger
Expand Down Expand Up @@ -296,7 +313,7 @@ func (f *fronted) onNewFronts(pool *x509.CertPool, providers map[string]*Provide
f.addProviders(providersCopy)

log.Debug("Loading candidates for providers", "numProviders", len(providersCopy))
fronts := loadFronts(providersCopy, f.cacheDirty)
fronts := loadFronts(providersCopy, f.cacheDirty, f.dialFunc)
log.Debug("Finished loading candidates")

log.Debug("Existing fronts", slog.Int("size", f.fronts.frontSize()))
Expand Down Expand Up @@ -596,7 +613,7 @@ func copyProviders(providers map[string]*Provider, countryCode string) map[strin
return providersCopy
}

func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) []Front {
func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}, dialFunc DialFunc) []Front {
// Preallocate the slice to avoid reallocation
size := 0
for _, p := range providers {
Expand All @@ -622,7 +639,7 @@ func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) []F
}

for _, c := range sh {
fronts[index] = newFront(c, providerID, cacheDirty)
fronts[index] = newFront(c, providerID, cacheDirty, dialFunc)
index++
}
}
Expand Down
61 changes: 60 additions & 1 deletion fronted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ func TestLoadFronts(t *testing.T) {

// Create the cache dirty channel
cacheDirty := make(chan interface{}, 10)
masquerades := loadFronts(providers, cacheDirty)
masquerades := loadFronts(providers, cacheDirty, nil)

assert.Equal(t, 4, len(masquerades), "Unexpected number of masquerades loaded")

Expand Down Expand Up @@ -926,3 +926,62 @@ func (m *mockFront) markWithResult(good bool) bool {

// Make sure that the mockMasquerade implements the MasqueradeInterface
var _ Front = (*mockFront)(nil)

func TestWithDialer(t *testing.T) {
called := false
customDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
called = true
return (&net.Dialer{}).DialContext(ctx, network, addr)
}

f := NewFronted(
WithDialer(customDialer),
WithEmbeddedConfigName("noconfig.yaml"),
)
defer f.Close()

d := f.(*fronted)
assert.NotNil(t, d.dialFunc, "dialFunc should be set")

// Verify the custom dialer is stored (we can't compare funcs directly, but we can
// verify it's not the default by calling it and checking our flag).
_, _ = d.dialFunc(context.Background(), "tcp", "localhost:0")
assert.True(t, called, "custom dialer should have been called")
}

func TestWithDialerDefault(t *testing.T) {
f := NewFronted(WithEmbeddedConfigName("noconfig.yaml"))
defer f.Close()

d := f.(*fronted)
assert.NotNil(t, d.dialFunc, "default dialFunc should be set when WithDialer is not used")
}

func TestWithDialerFlowsToFronts(t *testing.T) {
called := false
customDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
called = true
return nil, errors.New("custom dialer called")
}

f := NewFronted(
WithDialer(customDialer),
WithEmbeddedConfigName("noconfig.yaml"),
)
defer f.Close()

d := f.(*fronted)

// Create a provider and fronts using the fronted's dialer
masquerades := []*Masquerade{{Domain: "example.com", IpAddress: "127.0.0.1"}}
providers := map[string]*Provider{
"test": NewProvider(nil, "", masquerades, nil, nil, nil, ""),
}
fronts := loadFronts(providers, d.cacheDirty, d.dialFunc)
assert.Equal(t, 1, len(fronts))

// Dialing through the front should use the custom dialer
_, err := fronts[0].dial(nil, tls.HelloChrome_131)
assert.Error(t, err)
assert.True(t, called, "custom dialer should flow through to fronts")
}
Loading