diff --git a/compress_test.go b/compress_test.go index 1964c84f..2aee1bd1 100644 --- a/compress_test.go +++ b/compress_test.go @@ -3,11 +3,15 @@ package websocket import ( + "bufio" "bytes" "compress/flate" + "context" "io" + "net" "strings" "testing" + "time" "github.com/coder/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/xrand" @@ -59,3 +63,142 @@ func BenchmarkFlateReader(b *testing.B) { io.ReadAll(r) } } + +// TestWriteSingleFrameCompressed verifies that Conn.Write sends compressed +// messages in a single frame instead of multiple frames, and that messages +// below the flateThreshold are sent uncompressed. +// This is a regression test for https://github.com/coder/websocket/issues/435 +func TestWriteSingleFrameCompressed(t *testing.T) { + t.Parallel() + + var ( + flateThreshold = 64 + + largeMsg = []byte(strings.Repeat("hello world ", 100)) // ~1200 bytes, above threshold + smallMsg = []byte("small message") // 13 bytes, below threshold + ) + + testCases := []struct { + name string + mode CompressionMode + msg []byte + wantRsv1 bool // true = compressed, false = uncompressed + }{ + {"ContextTakeover/AboveThreshold", CompressionContextTakeover, largeMsg, true}, + {"NoContextTakeover/AboveThreshold", CompressionNoContextTakeover, largeMsg, true}, + {"ContextTakeover/BelowThreshold", CompressionContextTakeover, smallMsg, false}, + {"NoContextTakeover/BelowThreshold", CompressionNoContextTakeover, smallMsg, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + c := newConn(connConfig{ + rwc: clientConn, + client: true, + copts: tc.mode.opts(), + flateThreshold: flateThreshold, + br: bufio.NewReader(clientConn), + bw: bufio.NewWriterSize(clientConn, 4096), + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + + writeDone := make(chan error, 1) + go func() { + writeDone <- c.Write(ctx, MessageText, tc.msg) + }() + + reader := bufio.NewReader(serverConn) + readBuf := make([]byte, 8) + + h, err := readFrameHeader(reader, readBuf) + assert.Success(t, err) + + _, err = io.CopyN(io.Discard, reader, h.payloadLength) + assert.Success(t, err) + + assert.Equal(t, "opcode", opText, h.opcode) + assert.Equal(t, "rsv1 (compressed)", tc.wantRsv1, h.rsv1) + assert.Equal(t, "fin", true, h.fin) + + err = <-writeDone + assert.Success(t, err) + }) + } +} + +// TestWriteThenWriterContextTakeover verifies that using Conn.Write followed by +// Conn.Writer works correctly with context takeover enabled. This tests that +// the flateWriter destination is properly restored after Conn.Write redirects +// it to a temporary buffer. +func TestWriteThenWriterContextTakeover(t *testing.T) { + t.Parallel() + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + client := newConn(connConfig{ + rwc: clientConn, + client: true, + copts: CompressionContextTakeover.opts(), + flateThreshold: 64, + br: bufio.NewReader(clientConn), + bw: bufio.NewWriterSize(clientConn, 4096), + }) + + server := newConn(connConfig{ + rwc: serverConn, + client: false, + copts: CompressionContextTakeover.opts(), + flateThreshold: 64, + br: bufio.NewReader(serverConn), + bw: bufio.NewWriterSize(serverConn, 4096), + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500) + defer cancel() + + msg1 := []byte(strings.Repeat("first message ", 100)) + msg2 := []byte(strings.Repeat("second message ", 100)) + + type readResult struct { + typ MessageType + p []byte + err error + } + readCh := make(chan readResult, 2) + go func() { + for range 2 { + typ, p, err := server.Read(ctx) + readCh <- readResult{typ, p, err} + } + }() + + // First message: Write() redirects flateWriter to temp buffer + assert.Success(t, client.Write(ctx, MessageText, msg1)) + + r := <-readCh + assert.Success(t, r.err) + assert.Equal(t, "msg1 type", MessageText, r.typ) + assert.Equal(t, "msg1 content", string(msg1), string(r.p)) + + // Second message: Writer() streaming API + w, err := client.Writer(ctx, MessageBinary) + assert.Success(t, err) + _, err = w.Write(msg2) + assert.Success(t, err) + assert.Success(t, w.Close()) + + r = <-readCh + assert.Success(t, r.err) + assert.Equal(t, "msg2 type", MessageBinary, r.typ) + assert.Equal(t, "msg2 content", string(msg2), string(r.p)) +} diff --git a/write.go b/write.go index d7172a7b..cd234c93 100644 --- a/write.go +++ b/write.go @@ -14,6 +14,7 @@ import ( "net" "time" + "github.com/coder/websocket/internal/bpool" "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" ) @@ -100,7 +101,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { - mw, err := c.writer(ctx, typ) + _, err := c.writer(ctx, typ) if err != nil { return 0, err } @@ -110,13 +111,49 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } - n, err := mw.Write(p) + // Below threshold: write uncompressed in single frame. + if len(p) < c.flateThreshold { + defer c.msgWriter.mu.unlock() + return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) + } + + // Compress into buffer, then write as single frame. + defer c.msgWriter.mu.unlock() + + buf := bpool.Get() + defer bpool.Put(buf) + + c.msgWriter.ensureFlate() + fw := c.msgWriter.flateWriter + fw.Reset(buf) + + _, err = fw.Write(p) if err != nil { - return n, err + return 0, fmt.Errorf("failed to compress: %w", err) } - err = mw.Close() - return n, err + err = fw.Flush() + if err != nil { + return 0, fmt.Errorf("failed to flush compression: %w", err) + } + + if !c.msgWriter.flateContextTakeover() { + c.msgWriter.putFlateWriter() + } else { + // Restore flateWriter destination for subsequent Writer() API calls. + fw.Reset(c.msgWriter.trimWriter) + } + + // Remove deflate tail bytes (last 4 bytes: \x00\x00\xff\xff). + // See RFC 7692 section 7.2.1. + compressed := buf.Bytes() + compressed = compressed[:len(compressed)-4] + + _, err = c.writeFrame(ctx, true, true, c.msgWriter.opcode, compressed) + if err != nil { + return 0, err + } + return len(p), nil } func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {