diff --git a/go.mod b/go.mod index 2d19925..8a24544 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/getlantern/amp v0.0.0-20251211213807-4cbc22624b9f github.com/getlantern/dnstt v0.0.0-20250530230749-4d64f4edcf0f github.com/getlantern/fronted v0.0.0-20260105215156-9ae1d001d54f + github.com/stretchr/testify v1.11.1 ) require ( @@ -27,6 +28,7 @@ require ( github.com/alitto/pond/v2 v2.1.5 // indirect github.com/andybalholm/brotli v1.0.6 // indirect github.com/cloudflare/circl v1.5.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/flynn/noise v1.0.1-0.20220214164934-d803f5c4b0f4 // indirect github.com/getlantern/context v0.0.0-20220418194847-3d5e7a086201 // indirect github.com/getlantern/errors v1.0.3 // indirect @@ -48,6 +50,7 @@ require ( github.com/klauspost/reedsolomon v1.12.0 // indirect github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/refraction-networking/utls v1.7.1 // indirect github.com/shadowsocks/go-shadowsocks2 v0.1.5 // indirect github.com/templexxx/cpu v0.1.1 // indirect diff --git a/kindling.go b/kindling.go index 4632101..cdd588a 100644 --- a/kindling.go +++ b/kindling.go @@ -93,7 +93,7 @@ func (k *kindling) ReplaceTransport(name string, rt func(ctx context.Context, ad for i, tr := range k.transports { slog.Info("Checking transport", "name", tr.Name()) if tr.Name() == name { - k.transports[i] = newTransport(name, tr.MaxLength(), rt) + k.transports[i] = newTransport(name, tr.MaxLength(), tr.IsStreamable(), rt) return nil } } @@ -108,7 +108,7 @@ func WithDomainFronting(f fronted.Fronted) Option { log.Error("Fronted instance is nil") return &emptyOption{} } - return WithTransport(newTransport("fronted", 0, func(ctx context.Context, addr string) (http.RoundTripper, error) { + return WithTransport(newTransport("fronted", 0, true, func(ctx context.Context, addr string) (http.RoundTripper, error) { return f.NewConnectedRoundTripper(ctx, addr) })) } @@ -121,7 +121,7 @@ func WithDNSTunnel(d dnstt.DNSTT) Option { log.Error("DNSTT instance is nil") return &emptyOption{} } - return WithTransport(newTransport("dnstt", 0, func(ctx context.Context, addr string) (http.RoundTripper, error) { + return WithTransport(newTransport("dnstt", 0, true, func(ctx context.Context, addr string) (http.RoundTripper, error) { return d.NewRoundTripper(ctx, addr) })) } @@ -133,7 +133,7 @@ func WithAMPCache(c amp.Client) Option { log.Error("AMP client is nil") return &emptyOption{} } - return WithTransport(newTransport("amp", 6000, func(ctx context.Context, addr string) (http.RoundTripper, error) { + return WithTransport(newTransport("amp", 6000, false, func(ctx context.Context, addr string) (http.RoundTripper, error) { return c.RoundTripper() })) } @@ -148,7 +148,7 @@ func WithProxyless(domains ...string) Option { log.Error("Failed to create smart dialer", "error", err) return } - k.transports = append(k.transports, newTransport("smart", 0, smartDialer)) + k.transports = append(k.transports, newTransport("smart", 0, true, smartDialer)) }) } @@ -178,6 +178,9 @@ type Transport interface { // A value of 0 means there is no limit. MaxLength() int + // IsStreamable returns if the transport support streaming + IsStreamable() bool + // Name returns the name of the transport for logging and debugging purposes. Name() string } @@ -314,9 +317,10 @@ func (o *emptyOption) apply(k *kindling) {} func (o *emptyOption) priority() int { return 0 } type namedTransport struct { - name string - maxLength int - rtg roundTripperGenerator + name string + maxLength int + isStreamable bool + rtg roundTripperGenerator } func (t *namedTransport) NewRoundTripper(ctx context.Context, addr string) (http.RoundTripper, error) { @@ -327,14 +331,19 @@ func (t *namedTransport) MaxLength() int { return t.maxLength } +func (t *namedTransport) IsStreamable() bool { + return t.isStreamable +} + func (t *namedTransport) Name() string { return t.name } -func newTransport(name string, maxLength int, rtg roundTripperGenerator) Transport { +func newTransport(name string, maxLength int, isStreamable bool, rtg roundTripperGenerator) Transport { return &namedTransport{ - name: name, - maxLength: maxLength, - rtg: rtg, + name: name, + maxLength: maxLength, + isStreamable: isStreamable, + rtg: rtg, } } diff --git a/kindling_test.go b/kindling_test.go index 3546330..42d527d 100644 --- a/kindling_test.go +++ b/kindling_test.go @@ -147,7 +147,7 @@ func TestKindling_ReplaceTransport(t *testing.T) { originalRT := func(ctx context.Context, addr string) (http.RoundTripper, error) { return &dummyRoundTripper{}, nil } - transport := newTransport("test-transport", 100, originalRT) + transport := newTransport("test-transport", 100, true, originalRT) kindling := NewKindling("test-app", WithTransport(transport)) // Replace the transport @@ -167,7 +167,7 @@ func TestKindling_ReplaceTransport(t *testing.T) { originalRT := func(ctx context.Context, addr string) (http.RoundTripper, error) { return &dummyRoundTripper{}, nil } - transport := newTransport("test-transport", 100, originalRT) + transport := newTransport("test-transport", 100, true, originalRT) kindling := NewKindling("test-app", WithTransport(transport)) // Try to replace a non-existent transport @@ -194,8 +194,8 @@ func TestKindling_ReplaceTransport(t *testing.T) { rt2 := func(ctx context.Context, addr string) (http.RoundTripper, error) { return &dummyRoundTripper{}, nil } - transport1 := newTransport("transport-1", 100, rt1) - transport2 := newTransport("transport-2", 200, rt2) + transport1 := newTransport("transport-1", 100, true, rt1) + transport2 := newTransport("transport-2", 200, true, rt2) kindling := NewKindling("test-app", WithTransport(transport1), WithTransport(transport2)) // Replace the second transport diff --git a/race_transport.go b/race_transport.go index 6b51830..620dd04 100644 --- a/race_transport.go +++ b/race_transport.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/http" "sync" @@ -56,6 +57,7 @@ func (t *raceTransport) RoundTrip(originalRequest *http.Request) (*http.Response // Store a raw copy of the request body for request copies sent to the various // transports. bodyBytes := requestBodyBytes(originalRequest) + hasStreamingHeader := originalRequest.Header.Get("accept") == "text/event-stream" log.Debug(fmt.Sprintf("Dialing with %v dialers and body length %v", len(t.transports), len(bodyBytes))) for _, tr := range t.transports { hasLimit := tr.MaxLength() > 0 @@ -63,6 +65,10 @@ func (t *raceTransport) RoundTrip(originalRequest *http.Request) (*http.Response log.Debug("Skipping transport due to size limit", "name", tr.Name(), "size", len(bodyBytes), "maxLength", tr.MaxLength()) continue } + if hasStreamingHeader && !tr.IsStreamable() { + log.Debug("Skipping transport because it doesn't support streaming", slog.String("name", tr.Name())) + continue + } go func(tr Transport) { // Recover from panics in the dialer. defer func() { diff --git a/race_transport_test.go b/race_transport_test.go index a40730b..cc1e44e 100644 --- a/race_transport_test.go +++ b/race_transport_test.go @@ -2,12 +2,149 @@ package kindling import ( "bytes" + "context" + "errors" "io" "net/http" + "net/http/httptest" "strings" + "sync/atomic" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// mockTransport is a test implementation of Transport that records whether it was called. +type mockTransport struct { + name string + isStreamable bool + maxLength int + newRoundTripper func(ctx context.Context, addr string) (http.RoundTripper, error) +} + +func (m *mockTransport) Name() string { return m.name } +func (m *mockTransport) IsStreamable() bool { return m.isStreamable } +func (m *mockTransport) MaxLength() int { return m.maxLength } +func (m *mockTransport) NewRoundTripper(ctx context.Context, addr string) (http.RoundTripper, error) { + return m.newRoundTripper(ctx, addr) +} + +func TestRaceTransport_StreamingHeaderFilter(t *testing.T) { + t.Parallel() + + // Verifies that when the Accept header is "text/event-stream", only streamable + // transports are used and non-streamable ones are skipped entirely. + t.Run("StreamingRequest_SkipsNonStreamableTransport", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + var streamableCalled, nonStreamableCalled atomic.Bool + rt := newRaceTransport("test", func(string) {}, + &mockTransport{ + name: "streamable", + isStreamable: true, + newRoundTripper: func(ctx context.Context, addr string) (http.RoundTripper, error) { + streamableCalled.Store(true) + return server.Client().Transport, nil + }, + }, + &mockTransport{ + name: "non-streamable", + isStreamable: false, + newRoundTripper: func(ctx context.Context, addr string) (http.RoundTripper, error) { + nonStreamableCalled.Store(true) + return server.Client().Transport, nil + }, + }, + ) + + req, err := http.NewRequest("GET", server.URL, nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + resp.Body.Close() + + assert.True(t, streamableCalled.Load(), "streamable transport should be used for streaming request") + assert.False(t, nonStreamableCalled.Load(), "non-streamable transport should be skipped for streaming request") + }) + + // Verifies that a streamable transport successfully handles a streaming + // request end-to-end and returns a successful response. + t.Run("StreamingRequest_UsesStreamableTransport", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + var called atomic.Bool + rt := newRaceTransport("test", func(string) {}, + &mockTransport{ + name: "streamable", + isStreamable: true, + newRoundTripper: func(ctx context.Context, addr string) (http.RoundTripper, error) { + called.Store(true) + return server.Client().Transport, nil + }, + }, + ) + + req, err := http.NewRequest("GET", server.URL, nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + resp.Body.Close() + + assert.True(t, called.Load(), "streamable transport should be called for streaming request") + }) + + // Verifies that the streaming filter is not applied for ordinary (non-streaming) + // requests, so both streamable and non-streamable transports are attempted. + t.Run("NonStreamingRequest_AllowsNonStreamableTransport", func(t *testing.T) { + t.Parallel() + + var streamableCalled, nonStreamableCalled atomic.Bool + rt := newRaceTransport("test", func(string) {}, + &mockTransport{ + name: "streamable", + isStreamable: true, + newRoundTripper: func(ctx context.Context, addr string) (http.RoundTripper, error) { + streamableCalled.Store(true) + return nil, errors.New("intentional error") + }, + }, + &mockTransport{ + name: "non-streamable", + isStreamable: false, + newRoundTripper: func(ctx context.Context, addr string) (http.RoundTripper, error) { + nonStreamableCalled.Store(true) + return nil, errors.New("intentional error") + }, + }, + ) + + req, err := http.NewRequest("GET", "http://example.com", nil) + require.NoError(t, err) + // No Accept: text/event-stream header — streaming filter must not apply. + + _, err = rt.RoundTrip(req) + require.Error(t, err, "expected error when all transports fail") + + assert.True(t, streamableCalled.Load(), "streamable transport should be attempted for non-streaming request") + assert.True(t, nonStreamableCalled.Load(), "non-streamable transport should be attempted for non-streaming request") + }) +} + func TestCloneRequest_NilBody(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil {