Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request
// message during hydration, and we want to return the messages as they were before.
var messages []generated.Message
var conversation generated.Conversation
resuming := lastSeqID >= 0
if lastSeqID < 0 {
err := s.db.Queries(ctx, func(q *generated.Queries) error {
var err error
Expand All @@ -1050,9 +1051,16 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request
lastSeqID = messages[len(messages)-1].SequenceID
}
} else {
// Resuming - just get conversation metadata
// Resuming - fetch any messages we missed while disconnected
err := s.db.Queries(ctx, func(q *generated.Queries) error {
var err error
messages, err = q.ListMessagesSince(ctx, generated.ListMessagesSinceParams{
ConversationID: conversationID,
SequenceID: lastSeqID,
})
if err != nil {
return err
}
conversation, err = q.GetConversation(ctx, conversationID)
return err
})
Expand All @@ -1061,6 +1069,10 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// Update lastSeqID so the subscription starts after these messages
if len(messages) > 0 {
lastSeqID = messages[len(messages)-1].SequenceID
}
}

// Get or create conversation manager to access working state
Expand All @@ -1071,10 +1083,16 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request
return
}

// Send initial response
// Send initial response (all messages for fresh connections, missed messages for resumes)
if len(messages) > 0 {
// Fresh connection - send all messages
apiMessages := toAPIMessages(messages)
// Only send context_window_size for fresh connections where we have all messages.
// On resume we only have the missed messages, so the calculation would be wrong.
// The client keeps its previous value and gets updates from subsequent stream events.
var ctxSize uint64
if !resuming {
ctxSize = calculateContextWindowSize(apiMessages)
}
streamData := StreamResponse{
Messages: apiMessages,
Conversation: conversation,
Expand All @@ -1083,7 +1101,7 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request
Working: manager.IsAgentWorking(),
Model: manager.GetModel(),
},
ContextWindowSize: calculateContextWindowSize(apiMessages),
ContextWindowSize: ctxSize,
}
data, _ := json.Marshal(streamData)
fmt.Fprintf(w, "data: %s\n\n", data)
Expand Down
113 changes: 88 additions & 25 deletions server/stream_heartbeat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -16,7 +17,7 @@ import (
)

// TestStreamResumeWithLastSequenceID verifies that using last_sequence_id
// parameter skips sending messages and sends a heartbeat instead.
// parameter only sends messages newer than the given sequence ID.
func TestStreamResumeWithLastSequenceID(t *testing.T) {
database, cleanup := setupTestDB(t)
defer cleanup()
Expand All @@ -34,7 +35,7 @@ func TestStreamResumeWithLastSequenceID(t *testing.T) {
Role: llm.MessageRoleUser,
Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello"}},
}
msg1, err := database.CreateMessage(ctx, db.CreateMessageParams{
_, err = database.CreateMessage(ctx, db.CreateMessageParams{
ConversationID: conv.ConversationID,
Type: db.MessageTypeUser,
LLMData: userMsg,
Expand All @@ -49,7 +50,7 @@ func TestStreamResumeWithLastSequenceID(t *testing.T) {
Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hi there!"}},
EndOfTurn: true,
}
msg2, err := database.CreateMessage(ctx, db.CreateMessageParams{
_, err = database.CreateMessage(ctx, db.CreateMessageParams{
ConversationID: conv.ConversationID,
Type: db.MessageTypeAgent,
LLMData: agentMsg,
Expand Down Expand Up @@ -107,50 +108,112 @@ func TestStreamResumeWithLastSequenceID(t *testing.T) {
}
})

// Test 2: Resume with last_sequence_id - should get heartbeat with no messages
t.Run("resume_connection", func(t *testing.T) {
// Find the actual last sequence ID (system prompt may have been added)
var lastSeqID int64
t.Run("find_last_seq_id", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream", nil).WithContext(ctx)
req.Header.Set("Accept", "text/event-stream")
w := newResponseRecorderWithClose()
done := make(chan struct{})
go func() { defer close(done); mux.ServeHTTP(w, req) }()
time.Sleep(300 * time.Millisecond)
w.Close()
cancel()
<-done
jsonData := strings.TrimPrefix(strings.Split(w.Body.String(), "\n")[0], "data: ")
var resp StreamResponse
if err := json.Unmarshal([]byte(jsonData), &resp); err != nil {
t.Fatalf("Failed to parse: %v", err)
}
for _, m := range resp.Messages {
if m.SequenceID > lastSeqID {
lastSeqID = m.SequenceID
}
}
})

// Test 2: Resume with no new messages - should get heartbeat
t.Run("resume_no_new_messages", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

// Use the sequence ID of the last message
req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream?last_sequence_id="+string(rune('0'+msg2.SequenceID)), nil).WithContext(ctx)
url := fmt.Sprintf("/api/conversation/%s/stream?last_sequence_id=%d", conv.ConversationID, lastSeqID)
req := httptest.NewRequest("GET", url, nil).WithContext(ctx)
req.Header.Set("Accept", "text/event-stream")

w := newResponseRecorderWithClose()

done := make(chan struct{})
go func() {
defer close(done)
mux.ServeHTTP(w, req)
}()

go func() { defer close(done); mux.ServeHTTP(w, req) }()
time.Sleep(300 * time.Millisecond)
w.Close()
cancel()
<-done

body := w.Body.String()
if !strings.HasPrefix(body, "data: ") {
t.Fatalf("Expected SSE data, got: %s", body)
}

jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ")
jsonData := strings.TrimPrefix(strings.Split(w.Body.String(), "\n")[0], "data: ")
var response StreamResponse
if err := json.Unmarshal([]byte(jsonData), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}

if len(response.Messages) != 0 {
t.Errorf("Expected 0 messages when resuming, got %d", len(response.Messages))
t.Errorf("Expected 0 messages, got %d", len(response.Messages))
}
if !response.Heartbeat {
t.Error("Resume connection should be a heartbeat")
t.Error("Resume with no new messages should be a heartbeat")
}
})

// Test 3: Resume with missed messages - should get the missed messages
t.Run("resume_with_missed_messages", func(t *testing.T) {
// Add a new message with usage data (simulating what happens while client is disconnected)
newMsg := llm.Message{
Role: llm.MessageRoleAssistant,
Content: []llm.Content{{Type: llm.ContentTypeText, Text: "You missed this!"}},
}
usage := llm.Usage{InputTokens: 5000, OutputTokens: 200}
_, err := database.CreateMessage(ctx, db.CreateMessageParams{
ConversationID: conv.ConversationID,
Type: db.MessageTypeAgent,
LLMData: newMsg,
UsageData: &usage,
})
if err != nil {
t.Fatalf("Failed to create message: %v", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

url := fmt.Sprintf("/api/conversation/%s/stream?last_sequence_id=%d", conv.ConversationID, lastSeqID)
req := httptest.NewRequest("GET", url, nil).WithContext(ctx)
req.Header.Set("Accept", "text/event-stream")

w := newResponseRecorderWithClose()
done := make(chan struct{})
go func() { defer close(done); mux.ServeHTTP(w, req) }()
time.Sleep(300 * time.Millisecond)
w.Close()
cancel()
<-done

jsonData := strings.TrimPrefix(strings.Split(w.Body.String(), "\n")[0], "data: ")
var response StreamResponse
if err := json.Unmarshal([]byte(jsonData), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if len(response.Messages) != 1 {
t.Errorf("Expected 1 missed message, got %d", len(response.Messages))
}
if response.Heartbeat {
t.Error("Should not be a heartbeat when there are missed messages")
}
if response.ConversationState == nil {
t.Error("Expected ConversationState in heartbeat")
t.Error("Expected ConversationState")
}
if response.ContextWindowSize != 0 {
t.Errorf("Resume should not send context_window_size (got %d)", response.ContextWindowSize)
}
})

// Suppress unused variable warnings
_ = msg1
}
Loading