diff --git a/server/handlers.go b/server/handlers.go index 527926b..ba1b718 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1030,6 +1030,7 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request // message during hydration, and we want to return the messages as they were before. var messages []generated.Message var conversation generated.Conversation + resuming := lastSeqID >= 0 if lastSeqID < 0 { err := s.db.Queries(ctx, func(q *generated.Queries) error { var err error @@ -1050,9 +1051,16 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request lastSeqID = messages[len(messages)-1].SequenceID } } else { - // Resuming - just get conversation metadata + // Resuming - fetch any messages we missed while disconnected err := s.db.Queries(ctx, func(q *generated.Queries) error { var err error + messages, err = q.ListMessagesSince(ctx, generated.ListMessagesSinceParams{ + ConversationID: conversationID, + SequenceID: lastSeqID, + }) + if err != nil { + return err + } conversation, err = q.GetConversation(ctx, conversationID) return err }) @@ -1061,6 +1069,10 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request http.Error(w, "Internal server error", http.StatusInternalServerError) return } + // Update lastSeqID so the subscription starts after these messages + if len(messages) > 0 { + lastSeqID = messages[len(messages)-1].SequenceID + } } // Get or create conversation manager to access working state @@ -1071,10 +1083,16 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request return } - // Send initial response + // Send initial response (all messages for fresh connections, missed messages for resumes) if len(messages) > 0 { - // Fresh connection - send all messages apiMessages := toAPIMessages(messages) + // Only send context_window_size for fresh connections where we have all messages. + // On resume we only have the missed messages, so the calculation would be wrong. + // The client keeps its previous value and gets updates from subsequent stream events. + var ctxSize uint64 + if !resuming { + ctxSize = calculateContextWindowSize(apiMessages) + } streamData := StreamResponse{ Messages: apiMessages, Conversation: conversation, @@ -1083,7 +1101,7 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request Working: manager.IsAgentWorking(), Model: manager.GetModel(), }, - ContextWindowSize: calculateContextWindowSize(apiMessages), + ContextWindowSize: ctxSize, } data, _ := json.Marshal(streamData) fmt.Fprintf(w, "data: %s\n\n", data) diff --git a/server/stream_heartbeat_test.go b/server/stream_heartbeat_test.go index 1e8b656..3afd709 100644 --- a/server/stream_heartbeat_test.go +++ b/server/stream_heartbeat_test.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -16,7 +17,7 @@ import ( ) // TestStreamResumeWithLastSequenceID verifies that using last_sequence_id -// parameter skips sending messages and sends a heartbeat instead. +// parameter only sends messages newer than the given sequence ID. func TestStreamResumeWithLastSequenceID(t *testing.T) { database, cleanup := setupTestDB(t) defer cleanup() @@ -34,7 +35,7 @@ func TestStreamResumeWithLastSequenceID(t *testing.T) { Role: llm.MessageRoleUser, Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello"}}, } - msg1, err := database.CreateMessage(ctx, db.CreateMessageParams{ + _, err = database.CreateMessage(ctx, db.CreateMessageParams{ ConversationID: conv.ConversationID, Type: db.MessageTypeUser, LLMData: userMsg, @@ -49,7 +50,7 @@ func TestStreamResumeWithLastSequenceID(t *testing.T) { Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hi there!"}}, EndOfTurn: true, } - msg2, err := database.CreateMessage(ctx, db.CreateMessageParams{ + _, err = database.CreateMessage(ctx, db.CreateMessageParams{ ConversationID: conv.ConversationID, Type: db.MessageTypeAgent, LLMData: agentMsg, @@ -107,50 +108,112 @@ func TestStreamResumeWithLastSequenceID(t *testing.T) { } }) - // Test 2: Resume with last_sequence_id - should get heartbeat with no messages - t.Run("resume_connection", func(t *testing.T) { + // Find the actual last sequence ID (system prompt may have been added) + var lastSeqID int64 + t.Run("find_last_seq_id", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream", nil).WithContext(ctx) + req.Header.Set("Accept", "text/event-stream") + w := newResponseRecorderWithClose() + done := make(chan struct{}) + go func() { defer close(done); mux.ServeHTTP(w, req) }() + time.Sleep(300 * time.Millisecond) + w.Close() + cancel() + <-done + jsonData := strings.TrimPrefix(strings.Split(w.Body.String(), "\n")[0], "data: ") + var resp StreamResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("Failed to parse: %v", err) + } + for _, m := range resp.Messages { + if m.SequenceID > lastSeqID { + lastSeqID = m.SequenceID + } + } + }) + + // Test 2: Resume with no new messages - should get heartbeat + t.Run("resume_no_new_messages", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - // Use the sequence ID of the last message - req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream?last_sequence_id="+string(rune('0'+msg2.SequenceID)), nil).WithContext(ctx) + url := fmt.Sprintf("/api/conversation/%s/stream?last_sequence_id=%d", conv.ConversationID, lastSeqID) + req := httptest.NewRequest("GET", url, nil).WithContext(ctx) req.Header.Set("Accept", "text/event-stream") w := newResponseRecorderWithClose() - done := make(chan struct{}) - go func() { - defer close(done) - mux.ServeHTTP(w, req) - }() - + go func() { defer close(done); mux.ServeHTTP(w, req) }() time.Sleep(300 * time.Millisecond) w.Close() cancel() <-done - body := w.Body.String() - if !strings.HasPrefix(body, "data: ") { - t.Fatalf("Expected SSE data, got: %s", body) - } - - jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ") + jsonData := strings.TrimPrefix(strings.Split(w.Body.String(), "\n")[0], "data: ") var response StreamResponse if err := json.Unmarshal([]byte(jsonData), &response); err != nil { t.Fatalf("Failed to parse response: %v", err) } - if len(response.Messages) != 0 { - t.Errorf("Expected 0 messages when resuming, got %d", len(response.Messages)) + t.Errorf("Expected 0 messages, got %d", len(response.Messages)) } if !response.Heartbeat { - t.Error("Resume connection should be a heartbeat") + t.Error("Resume with no new messages should be a heartbeat") + } + }) + + // Test 3: Resume with missed messages - should get the missed messages + t.Run("resume_with_missed_messages", func(t *testing.T) { + // Add a new message with usage data (simulating what happens while client is disconnected) + newMsg := llm.Message{ + Role: llm.MessageRoleAssistant, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "You missed this!"}}, + } + usage := llm.Usage{InputTokens: 5000, OutputTokens: 200} + _, err := database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: conv.ConversationID, + Type: db.MessageTypeAgent, + LLMData: newMsg, + UsageData: &usage, + }) + if err != nil { + t.Fatalf("Failed to create message: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + url := fmt.Sprintf("/api/conversation/%s/stream?last_sequence_id=%d", conv.ConversationID, lastSeqID) + req := httptest.NewRequest("GET", url, nil).WithContext(ctx) + req.Header.Set("Accept", "text/event-stream") + + w := newResponseRecorderWithClose() + done := make(chan struct{}) + go func() { defer close(done); mux.ServeHTTP(w, req) }() + time.Sleep(300 * time.Millisecond) + w.Close() + cancel() + <-done + + jsonData := strings.TrimPrefix(strings.Split(w.Body.String(), "\n")[0], "data: ") + var response StreamResponse + if err := json.Unmarshal([]byte(jsonData), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + if len(response.Messages) != 1 { + t.Errorf("Expected 1 missed message, got %d", len(response.Messages)) + } + if response.Heartbeat { + t.Error("Should not be a heartbeat when there are missed messages") } if response.ConversationState == nil { - t.Error("Expected ConversationState in heartbeat") + t.Error("Expected ConversationState") + } + if response.ContextWindowSize != 0 { + t.Errorf("Resume should not send context_window_size (got %d)", response.ContextWindowSize) } }) - // Suppress unused variable warnings - _ = msg1 } diff --git a/ui/src/components/ChatInterface.tsx b/ui/src/components/ChatInterface.tsx index 899329e..1b4f541 100644 --- a/ui/src/components/ChatInterface.tsx +++ b/ui/src/components/ChatInterface.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useRef, useCallback } from "react"; +import React, { useState, useEffect, useLayoutEffect, useRef, useCallback, useMemo } from "react"; import { Message, Conversation, @@ -653,7 +653,6 @@ function ChatInterface({ const [showScrollToBottom, setShowScrollToBottom] = useState(false); const [terminalInjectedText, setTerminalInjectedText] = useState(null); const [terminalAutoFocusId, setTerminalAutoFocusId] = useState(null); - const messagesEndRef = useRef(null); const messagesContainerRef = useRef(null); const eventSourceRef = useRef(null); const overflowMenuRef = useRef(null); @@ -663,6 +662,9 @@ function ChatInterface({ const lastSequenceIdRef = useRef(-1); const hasConnectedRef = useRef(false); const userScrolledRef = useRef(false); + const loadingRef = useRef(false); + // Pending scroll target from loadMessages: undefined = none, null = bottom, number = saved position + const pendingScrollRef = useRef(undefined); // Load messages and set up streaming useEffect(() => { @@ -674,6 +676,7 @@ function ChatInterface({ // No conversation yet, show empty state setMessages([]); setContextWindowSize(0); + loadingRef.current = false; setLoading(false); } @@ -703,7 +706,40 @@ function ChatInterface({ } }, [agentWorking]); - // Check scroll position and handle scroll-to-bottom button + const scrollStore = useMemo(() => { + const key = conversationId ? `shelley_scroll_${conversationId}` : null; + return { + save(scrollTop: number) { + if (key) localStorage.setItem(key, String(scrollTop)); + }, + load(): number | null { + if (!key) return null; + const v = localStorage.getItem(key); + return v != null ? Number(v) : null; + }, + }; + }, [conversationId]); + + // Save scroll position to localStorage on page hide/unload + useEffect(() => { + const save = () => { + const container = messagesContainerRef.current; + if (!container || !conversationId) return; + scrollStore.save(container.scrollTop); + }; + const onVisChange = () => { + if (document.visibilityState === "hidden") save(); + }; + document.addEventListener("visibilitychange", onVisChange); + window.addEventListener("beforeunload", save); + return () => { + document.removeEventListener("visibilitychange", onVisChange); + window.removeEventListener("beforeunload", save); + }; + }, [conversationId]); + + // Check scroll position, handle scroll-to-bottom button, and re-scroll on content resize + const scrollSaveTimerRef = useRef(null); useEffect(() => { const container = messagesContainerRef.current; if (!container) return; @@ -713,18 +749,75 @@ function ChatInterface({ const isNearBottom = scrollHeight - scrollTop - clientHeight < 100; setShowScrollToBottom(!isNearBottom); userScrolledRef.current = !isNearBottom; + // Debounced save — 100ms after scroll settles + if (scrollSaveTimerRef.current) clearTimeout(scrollSaveTimerRef.current); + scrollSaveTimerRef.current = window.setTimeout(() => { + if (!loadingRef.current) scrollStore.save(container.scrollTop); + }, 100); }; container.addEventListener("scroll", handleScroll); - return () => container.removeEventListener("scroll", handleScroll); - }, []); - // Auto-scroll to bottom when new messages arrive (only if user is already at bottom) - useEffect(() => { - if (!userScrolledRef.current) { + // Re-scroll to bottom when content expands (images loading, tool outputs rendering) + // but only if the user hasn't scrolled away. + let lastScrollHeight = container.scrollHeight; + const ro = new ResizeObserver(() => { + if (container.scrollHeight === lastScrollHeight) return; + lastScrollHeight = container.scrollHeight; + if (!userScrolledRef.current && !catchingUpRef.current) { + container.scrollTop = container.scrollHeight; + } + }); + // .messages-list may not exist yet (loading spinner). Use MutationObserver + // to attach ResizeObserver when it appears. + const attachRO = () => { + const list = container.querySelector(".messages-list"); + if (list) { + ro.observe(list); + return true; + } + return false; + }; + let mo: MutationObserver | null = null; + if (!attachRO()) { + mo = new MutationObserver((_, self) => { + if (attachRO()) { self.disconnect(); mo = null; } + }); + mo.observe(container, { childList: true, subtree: true }); + } + + return () => { + container.removeEventListener("scroll", handleScroll); + if (scrollSaveTimerRef.current) clearTimeout(scrollSaveTimerRef.current); + mo?.disconnect(); + ro.disconnect(); + }; + }, [scrollStore]); + + // Scroll after React commits the DOM, before the browser paints. + // Handles both initial load (pending scroll from loadMessages) and streaming updates. + useLayoutEffect(() => { + if (loading) return; + const pending = pendingScrollRef.current; + if (pending !== undefined) { + pendingScrollRef.current = undefined; + if (pending != null) { + const container = messagesContainerRef.current; + if (container) { + container.scrollTop = pending; + const isNearBottom = container.scrollHeight - pending - container.clientHeight < 100; + userScrolledRef.current = !isNearBottom; + setShowScrollToBottom(!isNearBottom); + } + } else { + scrollToBottom(); + } + return; + } + if (!userScrolledRef.current && !catchingUpRef.current) { scrollToBottom(); } - }, [messages]); + }, [messages, loading]); // Close overflow menu when clicking outside useEffect(() => { @@ -742,10 +835,6 @@ function ChatInterface({ } }, [showOverflowMenu]); - // Reconnect when page becomes visible, focused, or online - // Store reconnect function in a ref so event listeners always have the latest version - const reconnectRef = useRef<() => void>(() => {}); - // Check connection health - returns true if connection needs to be re-established const checkConnectionHealth = useCallback(() => { if (!conversationId) return false; @@ -760,53 +849,24 @@ function ChatInterface({ return false; }, [conversationId]); - useEffect(() => { - const handleVisibilityChange = () => { - if (document.visibilityState === "visible") { - // When tab becomes visible, always check connection health - if (checkConnectionHealth()) { - console.log("Tab visible: connection unhealthy, reconnecting"); - reconnectRef.current(); - } else { - console.log("Tab visible: connection healthy"); - } - } - }; + // Track when the page was last hidden (for detecting stale connections on iOS Safari) + const hiddenAtRef = useRef(null); - const handleFocus = () => { - // On focus, check connection health - if (checkConnectionHealth()) { - console.log("Window focus: connection unhealthy, reconnecting"); - reconnectRef.current(); - } - }; - - const handleOnline = () => { - // Coming back online - definitely try to reconnect if needed - if (checkConnectionHealth()) { - console.log("Online: connection unhealthy, reconnecting"); - reconnectRef.current(); - } - }; - - document.addEventListener("visibilitychange", handleVisibilityChange); - window.addEventListener("focus", handleFocus); - window.addEventListener("online", handleOnline); - - return () => { - document.removeEventListener("visibilitychange", handleVisibilityChange); - window.removeEventListener("focus", handleFocus); - window.removeEventListener("online", handleOnline); - }; - }, [checkConnectionHealth]); + // Suppress auto-scroll during catch-up after returning from a backgrounded tab + const catchingUpRef = useRef(false); const loadMessages = async () => { if (!conversationId) return; try { + loadingRef.current = true; setLoading(true); setError(null); const response = await api.getConversation(conversationId); + // Set pending scroll target before state updates so useLayoutEffect can handle it. + pendingScrollRef.current = scrollStore.load(); setMessages(response.messages ?? []); + loadingRef.current = false; + setLoading(false); // ConversationState is sent via the streaming endpoint, not on initial load // We don't update agentWorking here - the stream will provide the current state // Always update context window size when loading a conversation. @@ -818,29 +878,27 @@ function ChatInterface({ } catch (err) { console.error("Failed to load messages:", err); setError("Failed to load messages"); - } finally { - // Always set loading to false, even if other operations fail + loadingRef.current = false; setLoading(false); } }; - // Reset heartbeat timeout - called on every message received - const resetHeartbeatTimeout = () => { - if (heartbeatTimeoutRef.current) { - clearTimeout(heartbeatTimeoutRef.current); - } - // If we don't receive any message (including heartbeat) within 60 seconds, reconnect - heartbeatTimeoutRef.current = window.setTimeout(() => { - console.warn("No heartbeat received in 60 seconds, reconnecting..."); - if (eventSourceRef.current) { - eventSourceRef.current.close(); - eventSourceRef.current = null; + const setupMessageStream = useCallback(() => { + const resetHeartbeatTimeout = () => { + if (heartbeatTimeoutRef.current) { + clearTimeout(heartbeatTimeoutRef.current); } - setupMessageStream(); - }, 60000); - }; + // If we don't receive any message (including heartbeat) within 60 seconds, reconnect + heartbeatTimeoutRef.current = window.setTimeout(() => { + console.warn("No heartbeat received in 60 seconds, reconnecting..."); + if (eventSourceRef.current) { + eventSourceRef.current.close(); + eventSourceRef.current = null; + } + setupMessageStream(); + }, 60000); + }; - const setupMessageStream = () => { if (!conversationId) return; if (eventSourceRef.current) { @@ -864,6 +922,10 @@ function ChatInterface({ // Reset heartbeat timeout on every message resetHeartbeatTimeout(); + // Clear catch-up flag after the first message event (the catch-up batch) + // so that subsequent messages auto-scroll normally again. + catchingUpRef.current = false; + try { const streamResponse: StreamResponse = JSON.parse(event.data); const incomingMessages = Array.isArray(streamResponse.messages) @@ -1002,7 +1064,86 @@ function ChatInterface({ // Start heartbeat timeout monitoring resetHeartbeatTimeout(); }; - }; + }, [conversationId, onConversationUpdate, onConversationListUpdate, onConversationStateUpdate]); + + // Force-reconnect: close existing connection and reconnect to get missed messages + const forceReconnect = useCallback(() => { + if (!conversationId) return; + if (eventSourceRef.current) { + eventSourceRef.current.close(); + eventSourceRef.current = null; + } + if (reconnectTimeoutRef.current) { + clearTimeout(reconnectTimeoutRef.current); + reconnectTimeoutRef.current = null; + } + if (periodicRetryRef.current) { + clearInterval(periodicRetryRef.current); + periodicRetryRef.current = null; + } + setIsDisconnected(false); + setIsReconnecting(false); + setReconnectAttempts(0); + setupMessageStream(); + }, [conversationId, setupMessageStream]); + + // Reconnect only if connection is dead + const reconnect = useCallback(() => { + if (!eventSourceRef.current || eventSourceRef.current.readyState === 2) { + forceReconnect(); + } + }, [forceReconnect]); + + // Reconnect when page becomes visible, focused, or online + useEffect(() => { + const handleVisibilityChange = () => { + if (document.visibilityState === "hidden") { + hiddenAtRef.current = Date.now(); + return; + } + // Page became visible + const hiddenFor = hiddenAtRef.current ? Date.now() - hiddenAtRef.current : 0; + hiddenAtRef.current = null; + + if (checkConnectionHealth()) { + // Connection is already known-dead + console.log("Tab visible: connection unhealthy, reconnecting"); + catchingUpRef.current = true; + reconnect(); + } else if (hiddenFor > 5000) { + // On iOS Safari, backgrounded tabs have their TCP connections killed + // but EventSource.readyState may still show OPEN. Force reconnect + // to pick up any missed messages from the server. + console.log(`Tab visible after ${Math.round(hiddenFor / 1000)}s hidden, force reconnecting`); + catchingUpRef.current = true; + forceReconnect(); + } + }; + + const handleFocus = () => { + if (checkConnectionHealth()) { + console.log("Window focus: connection unhealthy, reconnecting"); + reconnect(); + } + }; + + const handleOnline = () => { + if (checkConnectionHealth()) { + console.log("Online: connection unhealthy, reconnecting"); + reconnect(); + } + }; + + document.addEventListener("visibilitychange", handleVisibilityChange); + window.addEventListener("focus", handleFocus); + window.addEventListener("online", handleOnline); + + return () => { + document.removeEventListener("visibilitychange", handleVisibilityChange); + window.removeEventListener("focus", handleFocus); + window.removeEventListener("online", handleOnline); + }; + }, [checkConnectionHealth, reconnect, forceReconnect]); const sendMessage = async (message: string) => { if (!message.trim() || sending) return; @@ -1067,7 +1208,10 @@ function ChatInterface({ }; const scrollToBottom = () => { - messagesEndRef.current?.scrollIntoView({ behavior: "instant" }); + const container = messagesContainerRef.current; + if (container) { + container.scrollTop = container.scrollHeight; + } userScrolledRef.current = false; setShowScrollToBottom(false); }; @@ -1077,47 +1221,6 @@ function ChatInterface({ setTerminalInjectedText(text); }, []); - const handleManualReconnect = () => { - if (!conversationId || eventSourceRef.current) return; - setIsDisconnected(false); - setIsReconnecting(false); - setReconnectAttempts(0); - if (reconnectTimeoutRef.current) { - clearTimeout(reconnectTimeoutRef.current); - reconnectTimeoutRef.current = null; - } - if (periodicRetryRef.current) { - clearInterval(periodicRetryRef.current); - periodicRetryRef.current = null; - } - setupMessageStream(); - }; - - // Update the reconnect ref - always attempt reconnect if connection is unhealthy - useEffect(() => { - reconnectRef.current = () => { - if (!conversationId) return; - // Always try to reconnect if there's no active connection - if (!eventSourceRef.current || eventSourceRef.current.readyState === 2) { - console.log("Reconnect triggered: no active connection"); - // Clear any pending reconnect attempts - if (reconnectTimeoutRef.current) { - clearTimeout(reconnectTimeoutRef.current); - reconnectTimeoutRef.current = null; - } - if (periodicRetryRef.current) { - clearInterval(periodicRetryRef.current); - periodicRetryRef.current = null; - } - // Reset state and reconnect - setIsDisconnected(false); - setIsReconnecting(false); - setReconnectAttempts(0); - setupMessageStream(); - } - }; - }, [conversationId]); - // Handle external trigger to open diff viewer useEffect(() => { if (openDiffViewerTrigger && openDiffViewerTrigger > 0) { @@ -1766,7 +1869,6 @@ function ChatInterface({
{renderMessages()} -
)}
@@ -1818,7 +1920,7 @@ function ChatInterface({ <> Disconnected