diff --git a/backend/chainlit/data/chainlit_data_layer.py b/backend/chainlit/data/chainlit_data_layer.py index d2fb2d9f62..c2efd27193 100644 --- a/backend/chainlit/data/chainlit_data_layer.py +++ b/backend/chainlit/data/chainlit_data_layer.py @@ -591,6 +591,7 @@ async def update_thread( ) # Merge incoming metadata with existing metadata, deleting incoming keys with None values + merged_metadata = None if metadata is not None: existing = await self.execute_query( 'SELECT "metadata" FROM "Thread" WHERE id = $1', @@ -609,14 +610,16 @@ async def update_thread( to_delete = {k for k, v in metadata.items() if v is None} incoming = {k: v for k, v in metadata.items() if v is not None} base = {k: v for k, v in base.items() if k not in to_delete} - metadata = {**base, **incoming} + merged_metadata = {**base, **incoming} data = { "id": thread_id, "name": thread_name, "userId": user_id, "tags": tags, - "metadata": json.dumps(metadata or {}), + "metadata": json.dumps(merged_metadata) + if merged_metadata is not None + else None, "updatedAt": datetime.now(), } diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 6accf1ccc4..f9393e5e3b 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -1012,6 +1012,7 @@ async def get_shared_thread( if not isinstance(metadata, dict): metadata = {} + user_can_view = False if getattr(config.code, "on_shared_thread_view", None): try: user_can_view = await config.code.on_shared_thread_view( diff --git a/backend/tests/data/test_chainlit_data_layer.py b/backend/tests/data/test_chainlit_data_layer.py new file mode 100644 index 0000000000..6ec61dc223 --- /dev/null +++ b/backend/tests/data/test_chainlit_data_layer.py @@ -0,0 +1,141 @@ +import json +from unittest.mock import AsyncMock + +import pytest + +from chainlit.data.chainlit_data_layer import ChainlitDataLayer + + +@pytest.mark.asyncio +async def test_update_thread_preserves_metadata_when_none(): + """Test that update_thread does not overwrite existing metadata when metadata=None.""" + # Create a mock data layer + data_layer = ChainlitDataLayer( + database_url="postgresql://test", storage_client=None, show_logger=False + ) + + # Mock the execute_query method + data_layer.execute_query = AsyncMock() + + # Simulate calling update_thread with only a name, metadata=None (default) + await data_layer.update_thread(thread_id="test-thread-123", name="Updated Name") + + # Verify execute_query was called + assert data_layer.execute_query.called + + # Get the query and params from the call + call_args = data_layer.execute_query.call_args + query = call_args[0][0] + params = call_args[0][1] + + # The query should NOT include metadata in the update + # because metadata was None and should be excluded from the data dict + assert "metadata" not in query.lower() + assert "metadata" not in str(params.values()) + + +@pytest.mark.asyncio +async def test_update_thread_merges_metadata_when_provided(): + """Test that update_thread merges metadata correctly when provided.""" + # Create a mock data layer + data_layer = ChainlitDataLayer( + database_url="postgresql://test", storage_client=None, show_logger=False + ) + + # Mock the execute_query method to return existing metadata + existing_metadata = {"is_shared": True, "custom_field": "original"} + + async def mock_execute_query(query, params): + if "SELECT" in query and "metadata" in query: + # Return existing thread metadata + return [{"metadata": json.dumps(existing_metadata)}] + # For the UPDATE/INSERT, return None + return None + + data_layer.execute_query = AsyncMock(side_effect=mock_execute_query) + + # Call update_thread with partial metadata update + new_metadata = {"custom_field": "updated", "new_field": "added"} + await data_layer.update_thread( + thread_id="test-thread-123", name="Updated Name", metadata=new_metadata + ) + + # Verify execute_query was called twice (once for SELECT, once for UPDATE) + assert data_layer.execute_query.call_count == 2 + + # Get the UPDATE call + update_call = data_layer.execute_query.call_args_list[1] + update_params = update_call[0][1] + + # The metadata should be merged + # Expected: {"is_shared": True, "custom_field": "updated", "new_field": "added"} + # Find the JSON metadata in the params + metadata_json = None + for value in update_params.values(): + if isinstance(value, str) and value.startswith("{"): + try: + metadata_json = json.loads(value) + break + except json.JSONDecodeError: + pass + + assert metadata_json is not None + assert metadata_json.get("is_shared") is True + assert metadata_json.get("custom_field") == "updated" + assert metadata_json.get("new_field") == "added" + + +@pytest.mark.asyncio +async def test_update_thread_deletes_keys_with_none_values(): + """Test that update_thread deletes keys when value is None.""" + # Create a mock data layer + data_layer = ChainlitDataLayer( + database_url="postgresql://test", storage_client=None, show_logger=False + ) + + # Mock the execute_query method to return existing metadata + existing_metadata = { + "is_shared": True, + "to_delete": "will be removed", + "keep": "stays", + } + + async def mock_execute_query(query, params): + if "SELECT" in query and "metadata" in query: + # Return existing thread metadata + return [{"metadata": json.dumps(existing_metadata)}] + # For the UPDATE/INSERT, return None + return None + + data_layer.execute_query = AsyncMock(side_effect=mock_execute_query) + + # Call update_thread with None value to delete a key + new_metadata = {"to_delete": None, "new_field": "added"} + await data_layer.update_thread(thread_id="test-thread-123", metadata=new_metadata) + + # Verify execute_query was called twice + assert data_layer.execute_query.call_count == 2 + + # Get the UPDATE call + update_call = data_layer.execute_query.call_args_list[1] + update_params = update_call[0][1] + + # The metadata should have deleted "to_delete" key and added "new_field" + # Expected: {"is_shared": True, "keep": "stays", "new_field": "added"} + metadata_json = None + for value in update_params.values(): + if isinstance(value, str) and value.startswith("{"): + try: + metadata_json = json.loads(value) + break + except json.JSONDecodeError: + pass + + if metadata_json: + # Verify "to_delete" is not in the merged metadata + assert "to_delete" not in metadata_json + # Verify "new_field" was added + assert metadata_json.get("new_field") == "added" + # Verify "is_shared" and "keep" are preserved + assert metadata_json.get("is_shared") is True + assert metadata_json.get("keep") == "stays"