Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions backend/chainlit/data/chainlit_data_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -35,6 +36,24 @@
from chainlit.step import StepDict

ISO_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
ISO_FORMAT_NO_Z = "%Y-%m-%dT%H:%M:%S.%f"


def _parse_iso_datetime(date_string: str) -> "datetime":
"""Parse an ISO datetime string, tolerating both with and without trailing Z."""
if date_string.endswith("Z"):
return datetime.strptime(date_string, ISO_FORMAT)
return datetime.strptime(date_string, ISO_FORMAT_NO_Z)


def _datetime_to_utc_iso(dt: "datetime") -> str:
"""Convert a datetime to a UTC ISO string with trailing Z for consistency."""
s = dt.isoformat()
if not s.endswith("Z"):
# Strip any timezone offset info (+00:00) before appending Z
s = re.sub(r"[+-]\d{2}:\d{2}$", "", s)
s += "Z"
return s


class ChainlitDataLayer(BaseDataLayer):
Expand All @@ -54,7 +73,7 @@ async def connect(self):
self.pool = await asyncpg.create_pool(self.database_url)

async def get_current_timestamp(self) -> datetime:
return datetime.now()
return datetime.utcnow()

async def execute_query(
self, query: str, params: Union[Dict, None] = None
Expand Down Expand Up @@ -95,7 +114,7 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]:
return PersistedUser(
id=str(row.get("id")),
identifier=str(row.get("identifier")),
createdAt=row.get("createdAt").isoformat(), # type: ignore
createdAt=_datetime_to_utc_iso(row.get("createdAt")), # type: ignore
metadata=json.loads(row.get("metadata", "{}")),
)

Expand All @@ -121,7 +140,7 @@ async def create_user(self, user: User) -> Optional[PersistedUser]:
return PersistedUser(
id=str(row.get("id")),
identifier=str(row.get("identifier")),
createdAt=row.get("createdAt").isoformat(), # type: ignore
createdAt=_datetime_to_utc_iso(row.get("createdAt")), # type: ignore
metadata=json.loads(row.get("metadata", "{}")),
)

Expand Down Expand Up @@ -277,10 +296,10 @@ async def get_element(
id=str(row["id"]),
threadId=str(row["threadId"]),
type=metadata.get("type", "file"),
url=str(row["url"]),
url=row.get("url"),
name=str(row["name"]),
mime=str(row["mime"]),
objectKey=str(row["objectKey"]),
mime=str(row["mime"]) if row.get("mime") else None,
objectKey=row.get("objectKey"),
forId=str(row["stepId"]),
chainlitKey=row.get("chainlitKey"),
display=row["display"],
Expand Down Expand Up @@ -372,7 +391,7 @@ async def create_step(self, step_dict: StepDict):
timestamp = await self.get_current_timestamp()
created_at = step_dict.get("createdAt")
if created_at:
timestamp = datetime.strptime(created_at, ISO_FORMAT)
timestamp = _parse_iso_datetime(created_at)

params = {
"id": step_dict["id"],
Expand Down Expand Up @@ -497,7 +516,7 @@ async def list_threads(
for thread in threads:
thread_dict = ThreadDict(
id=str(thread["id"]),
createdAt=thread["updatedAt"].isoformat(),
createdAt=_datetime_to_utc_iso(thread["updatedAt"]),
name=thread["name"],
userId=str(thread["userId"]) if thread["userId"] else None,
userIdentifier=thread["user_identifier"],
Expand Down Expand Up @@ -555,13 +574,20 @@ async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
if self.storage_client is not None:
for elem in elements_results:
if not elem["url"] and elem["objectKey"]:
elem["url"] = await self.storage_client.get_read_url(
object_key=elem["objectKey"],
)
try:
elem["url"] = await self.storage_client.get_read_url(
object_key=elem["objectKey"],
)
except Exception as e:
logger.warning(
"Failed to get read URL for element '%s': %s",
elem.get("id", "unknown"),
e,
)

return ThreadDict(
id=str(thread["id"]),
createdAt=thread["createdAt"].isoformat(),
createdAt=_datetime_to_utc_iso(thread["createdAt"]),
name=thread["name"],
userId=str(thread["userId"]) if thread["userId"] else None,
userIdentifier=thread["user_identifier"],
Expand Down Expand Up @@ -617,7 +643,7 @@ async def update_thread(
"userId": user_id,
"tags": tags,
"metadata": json.dumps(metadata or {}),
"updatedAt": datetime.now(),
"updatedAt": datetime.utcnow(),
}

# Remove None values
Expand Down Expand Up @@ -678,11 +704,11 @@ def _convert_step_row_to_dict(self, row: Dict) -> StepDict:
input=row.get("input", {}),
output=row.get("output", {}),
metadata=json.loads(row.get("metadata", "{}")),
createdAt=row["createdAt"].isoformat() if row.get("createdAt") else None,
start=row["startTime"].isoformat() if row.get("startTime") else None,
createdAt=_datetime_to_utc_iso(row["createdAt"]) if row.get("createdAt") else None,
start=_datetime_to_utc_iso(row["startTime"]) if row.get("startTime") else None,
showInput=row.get("showInput"),
isError=row.get("isError"),
end=row["endTime"].isoformat() if row.get("endTime") else None,
end=_datetime_to_utc_iso(row["endTime"]) if row.get("endTime") else None,
feedback=self._extract_feedback_dict_from_step_row(row),
)

Expand Down
137 changes: 137 additions & 0 deletions backend/tests/data/test_chainlit_data_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from datetime import datetime

import pytest

from chainlit.data.chainlit_data_layer import (
ChainlitDataLayer,
_datetime_to_utc_iso,
_parse_iso_datetime,
)


class TestParseIsoDatetime:
"""Test suite for _parse_iso_datetime helper."""

def test_parse_with_z_suffix(self):
"""Test parsing ISO datetime string with trailing Z."""
result = _parse_iso_datetime("2025-09-04T02:00:42.164000Z")
assert result == datetime(2025, 9, 4, 2, 0, 42, 164000)

def test_parse_without_z_suffix(self):
"""Test parsing ISO datetime string without trailing Z (the bug case)."""
result = _parse_iso_datetime("2025-09-04T02:00:42.164000")
assert result == datetime(2025, 9, 4, 2, 0, 42, 164000)

def test_parse_without_z_raises_on_bad_format(self):
"""Test that invalid format still raises ValueError."""
with pytest.raises(ValueError):

Check failure on line 27 in backend/tests/data/test_chainlit_data_layer.py

View workflow job for this annotation

GitHub Actions / lint-backend / lint-backend

ruff (PT011)

backend/tests/data/test_chainlit_data_layer.py:27:28: PT011 `pytest.raises(ValueError)` is too broad, set the `match` parameter or use a more specific exception
_parse_iso_datetime("2025-09-04 02:00:42")

def test_roundtrip_with_z(self):
"""Test that parsing a Z-suffixed string and formatting round-trips."""
original = "2025-09-04T02:00:42.164000Z"
dt = _parse_iso_datetime(original)
formatted = _datetime_to_utc_iso(dt)
assert formatted == original

def test_roundtrip_without_z(self):
"""Test that parsing a non-Z string and formatting produces Z-suffixed output."""
original = "2025-09-04T02:00:42.164000"
dt = _parse_iso_datetime(original)
formatted = _datetime_to_utc_iso(dt)
assert formatted == original + "Z"


class TestDatetimeToUtcIso:
"""Test suite for _datetime_to_utc_iso helper."""

def test_adds_z_suffix(self):
"""Test that Z is always appended."""
dt = datetime(2025, 9, 4, 2, 0, 42, 164000)
result = _datetime_to_utc_iso(dt)
assert result == "2025-09-04T02:00:42.164000Z"

def test_no_double_z(self):
"""Test that Z is not duplicated."""
dt = datetime(2025, 1, 1, 0, 0, 0, 0)
result = _datetime_to_utc_iso(dt)
assert not result.endswith("ZZ")
assert result.endswith("Z")

def test_zero_microseconds(self):
"""Test formatting with zero microseconds."""
dt = datetime(2025, 1, 1, 12, 30, 45)
result = _datetime_to_utc_iso(dt)
assert result == "2025-01-01T12:30:45Z"
assert result.endswith("Z")


class TestConvertStepRowTimestamps:
"""Test that _convert_step_row_to_dict produces timestamps with trailing Z."""

def _make_layer(self):
return ChainlitDataLayer(database_url="postgresql://fake", storage_client=None)

def _make_step_row(self, **overrides):
row = {
"id": "step-1",
"threadId": "thread-1",
"parentId": None,
"name": "test_step",
"type": "run",
"input": "{}",
"output": "{}",
"metadata": "{}",
"createdAt": datetime(2025, 9, 4, 2, 0, 42, 164000),
"startTime": datetime(2025, 9, 4, 2, 0, 42, 164000),
"endTime": datetime(2025, 9, 4, 2, 0, 43, 0),
"showInput": "json",
"isError": False,
"feedback_id": None,
}
row.update(overrides)
return row

def test_step_timestamps_have_z_suffix(self):
"""Test that step createdAt, start, end all end with Z."""
layer = self._make_layer()
row = self._make_step_row()

result = layer._convert_step_row_to_dict(row)

assert result["createdAt"].endswith("Z"), (
f"createdAt should end with Z, got: {result['createdAt']}"
)
assert result["start"].endswith("Z"), (
f"start should end with Z, got: {result['start']}"
)
assert result["end"].endswith("Z"), (
f"end should end with Z, got: {result['end']}"
)

def test_step_timestamps_can_be_reparsed(self):
"""Test that timestamps from _convert_step_row_to_dict can be parsed back.

This is the exact scenario from bug #2491: after reading a step from DB,
the createdAt string should be parseable when passed back to
create_step/update_step.
"""
layer = self._make_layer()
row = self._make_step_row()

result = layer._convert_step_row_to_dict(row)

# Simulate what create_step does when update_step feeds back the step dict
parsed = _parse_iso_datetime(result["createdAt"])
assert parsed == datetime(2025, 9, 4, 2, 0, 42, 164000)

def test_step_none_timestamps_preserved(self):
"""Test that None timestamps are preserved as None."""
layer = self._make_layer()
row = self._make_step_row(createdAt=None, startTime=None, endTime=None)

result = layer._convert_step_row_to_dict(row)

assert result["createdAt"] is None
assert result["start"] is None
assert result["end"] is None
Loading