diff --git a/.env.example b/.env.example index 9523d1be..226e3800 100644 --- a/.env.example +++ b/.env.example @@ -15,3 +15,4 @@ ALPHATRION_ARTIFACT_INSECURE=false # Tracing configurations ALPHATRION_ENABLE_TRACING=true ALPHATRION_CLICKHOUSE_INIT_TABLES=true +ALPHATRION_CLICKHOUSE_ENABLE_BATCH=true diff --git a/.env.integration-test b/.env.integration-test index f7c39e84..af9cde7d 100644 --- a/.env.integration-test +++ b/.env.integration-test @@ -5,4 +5,5 @@ ALPHATRION_ARTIFACT_INSECURE=true ALPHATRION_LOG_LEVEL=INFO ALPHATRION_AUTO_CLEANUP=true ALPHATRION_ENABLE_TRACING=true -ALPHATRION_CLICKHOUSE_INIT_TABLES=true \ No newline at end of file +ALPHATRION_CLICKHOUSE_INIT_TABLES=true +ALPHATRION_CLICKHOUSE_ENABLE_BATCH=true \ No newline at end of file diff --git a/alphatrion/envs.py b/alphatrion/envs.py index b4962d60..ab24d640 100644 --- a/alphatrion/envs.py +++ b/alphatrion/envs.py @@ -14,7 +14,7 @@ CLICKHOUSE_DATABASE = "ALPHATRION_CLICKHOUSE_DATABASE" CLICKHOUSE_USERNAME = "ALPHATRION_CLICKHOUSE_USERNAME" CLICKHOUSE_PASSWORD = "ALPHATRION_CLICKHOUSE_PASSWORD" -INIT_CLICKHOUSE_TABLES = "ALPHATRION_INIT_CLICKHOUSE_TABLES" +CLICKHOUSE_ENABLE_BATCH = "ALPHATRION_CLICKHOUSE_ENABLE_BATCH" # Dashboard only related envs DASHBOARD_USER_ID = "ALPHATRION_DASHBOARD_USER_ID" diff --git a/alphatrion/server/graphql/resolvers.py b/alphatrion/server/graphql/resolvers.py index ea7173fa..f3cb7047 100644 --- a/alphatrion/server/graphql/resolvers.py +++ b/alphatrion/server/graphql/resolvers.py @@ -220,13 +220,23 @@ def get_run(id: strawberry.ID) -> Run | None: metadb = runtime.storage_runtime().metadb run = metadb.get_run(run_id=uuid.UUID(id)) if run: + meta = run.meta or {} + + # Aggregate and cache tokens for completed runs. + # It could be slow for the first time. + if Status(run.status) == Status.COMPLETED and "total_tokens" not in meta: + token_data = GraphQLResolvers.aggregate_run_tokens(run_id=id) + if token_data["total_tokens"] > 0: + meta.update(token_data) + metadb.update_run(run_id=uuid.UUID(id), meta=meta) + return Run( id=run.uuid, team_id=run.team_id, user_id=run.user_id, project_id=run.project_id, experiment_id=run.experiment_id, - meta=run.meta, + meta=meta, status=GraphQLStatusEnum[Status(run.status).name], created_at=run.created_at, ) @@ -250,6 +260,24 @@ def list_exp_metrics(experiment_id: strawberry.ID) -> list[Metric]: for m in metrics ] + @staticmethod + def list_run_metrics(run_id: strawberry.ID) -> list[Metric]: + metadb = runtime.storage_runtime().metadb + metrics = metadb.list_metrics_by_run_id(run_id=run_id) + return [ + Metric( + id=m.uuid, + key=m.key, + value=m.value, + team_id=m.team_id, + project_id=m.project_id, + experiment_id=m.experiment_id, + run_id=m.run_id, + created_at=m.created_at, + ) + for m in metrics + ] + @staticmethod def total_projects(team_id: strawberry.ID) -> int: metadb = runtime.storage_runtime().metadb @@ -373,8 +401,48 @@ async def get_artifact_content( raise RuntimeError(f"Failed to get artifact content: {e}") from e @staticmethod - def list_traces(run_id: strawberry.ID) -> list[Span]: - """List all traces/spans for a specific run.""" + def aggregate_run_tokens(run_id: strawberry.ID) -> dict[str, int]: + """Aggregate token usage from all traces for a run.""" + from alphatrion import envs + + # Check if tracing is enabled + if os.getenv(envs.ENABLE_TRACING, "false").lower() != "true": + return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0} + + try: + trace_store = runtime.storage_runtime().tracestore + spans = trace_store.get_spans_by_run_id(uuid.UUID(run_id)) + trace_store.close() + + total_tokens = 0 + input_tokens = 0 + output_tokens = 0 + + for span in spans: + span_attrs = span.get("SpanAttributes", {}) + + # Aggregate tokens from LLM spans + if "llm.usage.total_tokens" in span_attrs: + total_tokens += int(span_attrs["llm.usage.total_tokens"]) + if "gen_ai.usage.input_tokens" in span_attrs: + input_tokens += int(span_attrs["gen_ai.usage.input_tokens"]) + if "gen_ai.usage.output_tokens" in span_attrs: + output_tokens += int(span_attrs["gen_ai.usage.output_tokens"]) + + return { + "total_tokens": total_tokens, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + } + except Exception as e: + import logging + + logging.error(f"Failed to aggregate tokens for run {run_id}: {e}") + return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0} + + @staticmethod + def list_spans(run_id: strawberry.ID) -> list[Span]: + """List all spans for a specific run.""" from alphatrion import envs # Check if tracing is enabled @@ -385,12 +453,12 @@ def list_traces(run_id: strawberry.ID) -> list[Span]: trace_store = runtime.storage_runtime().tracestore # Get traces from ClickHouse - traces = trace_store.get_traces_by_run_id(uuid.UUID(run_id)) + raw_spans = trace_store.get_spans_by_run_id(uuid.UUID(run_id)) trace_store.close() # Convert to GraphQL Span objects spans = [] - for t in traces: + for t in raw_spans: # Convert events events = [] if t.get("Events"): diff --git a/alphatrion/server/graphql/types.py b/alphatrion/server/graphql/types.py index 4aee1005..48a15f0f 100644 --- a/alphatrion/server/graphql/types.py +++ b/alphatrion/server/graphql/types.py @@ -129,6 +129,20 @@ class Run: status: GraphQLStatusEnum created_at: datetime + @strawberry.field + def metrics(self) -> list["Metric"]: + """Get metrics for this run.""" + from alphatrion.server.graphql.resolvers import GraphQLResolvers + + return GraphQLResolvers.list_run_metrics(run_id=self.id) + + @strawberry.field + def spans(self) -> list["Span"]: + """Get spans for this run.""" + from alphatrion.server.graphql.resolvers import GraphQLResolvers + + return GraphQLResolvers.list_spans(run_id=str(self.id)) + @strawberry.type class Metric: diff --git a/alphatrion/storage/runtime.py b/alphatrion/storage/runtime.py index 3a2c4899..9607b34a 100644 --- a/alphatrion/storage/runtime.py +++ b/alphatrion/storage/runtime.py @@ -2,6 +2,7 @@ import os from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider from traceloop.sdk import Traceloop from alphatrion import envs @@ -38,10 +39,13 @@ def __init__(self): == "true", ) + enable_batch = ( + os.getenv(envs.CLICKHOUSE_ENABLE_BATCH, "true").lower() == "true" + ) Traceloop.init( app_name="alphatrion", exporter=ClickHouseSpanExporter(self.tracestore), - disable_batch=False, # Enable batching + disable_batch=not enable_batch, telemetry_enabled=False, ) @@ -60,6 +64,12 @@ def metadb(self): def tracestore(self): return self._tracestore + def flush(self): + if self._tracestore: + tracer_provider = trace.get_tracer_provider() + if isinstance(tracer_provider, TracerProvider): + tracer_provider.force_flush(timeout_millis=5000) + def init(): """ diff --git a/alphatrion/storage/sqlstore.py b/alphatrion/storage/sqlstore.py index 18d2c3ba..91272801 100644 --- a/alphatrion/storage/sqlstore.py +++ b/alphatrion/storage/sqlstore.py @@ -695,3 +695,14 @@ def list_metrics_by_experiment_id(self, experiment_id: uuid.UUID) -> list[Metric ) session.close() return metrics + + def list_metrics_by_run_id(self, run_id: uuid.UUID) -> list[Metric]: + session = self._session() + metrics = ( + session.query(Metric) + .filter(Metric.run_id == run_id) + .order_by(Metric.created_at.asc()) + .all() + ) + session.close() + return metrics diff --git a/alphatrion/storage/tracestore.py b/alphatrion/storage/tracestore.py index 502e8f73..df5278f9 100644 --- a/alphatrion/storage/tracestore.py +++ b/alphatrion/storage/tracestore.py @@ -1,6 +1,7 @@ # ruff: noqa: E501 import logging +import threading import uuid from typing import Any @@ -30,6 +31,7 @@ def __init__( init_tables: If True, create tables on initialization """ self.database = database + self._lock = threading.Lock() # Protect concurrent access to ClickHouse client # Parse host and port, stripping protocol if present # Handle URLs like "http://localhost:8123" or "localhost:8123" @@ -68,9 +70,9 @@ def _create_database(self) -> None: raise def _create_tables(self) -> None: - """Create the otel_traces table if it doesn't exist.""" + """Create the otel_spans table if it doesn't exist.""" create_table_sql = f""" - CREATE TABLE IF NOT EXISTS {self.database}.otel_traces ( + CREATE TABLE IF NOT EXISTS {self.database}.otel_spans ( Timestamp DateTime64(9) CODEC(Delta, ZSTD(1)), TraceId String CODEC(ZSTD(1)), SpanId String CODEC(ZSTD(1)), @@ -111,7 +113,7 @@ def _create_tables(self) -> None: try: self.client.command(create_table_sql) - logger.info(f"Table {self.database}.otel_traces ready") + logger.info(f"Table {self.database}.otel_spans ready") except Exception as e: logger.error(f"Failed to create table: {e}") raise @@ -125,224 +127,149 @@ def insert_spans(self, spans: list[dict[str, Any]]) -> None: if not spans: return - try: - # Prepare data for insertion - data = [] - for span in spans: - data.append( - ( - span.get("Timestamp"), - span.get("TraceId", ""), - span.get("SpanId", ""), - span.get("ParentSpanId", ""), - span.get("SpanName", ""), - span.get("SpanKind", ""), - span.get("ServiceName", ""), - span.get("Duration", 0), - span.get("StatusCode", ""), - span.get("StatusMessage", ""), - span.get("TeamId", ""), - span.get("ProjectId", ""), - span.get("RunId", ""), - span.get("ExperimentId", ""), - span.get("SpanAttributes", {}), - span.get("ResourceAttributes", {}), - span.get("Events.Timestamp", []), - span.get("Events.Name", []), - span.get("Events.Attributes", []), - span.get("Links.TraceId", []), - span.get("Links.SpanId", []), - span.get("Links.Attributes", []), + with self._lock: # Protect concurrent access to ClickHouse client + try: + # Prepare data for insertion + data = [] + for span in spans: + data.append( + ( + span.get("Timestamp"), + span.get("TraceId", ""), + span.get("SpanId", ""), + span.get("ParentSpanId", ""), + span.get("SpanName", ""), + span.get("SpanKind", ""), + span.get("ServiceName", ""), + span.get("Duration", 0), + span.get("StatusCode", ""), + span.get("StatusMessage", ""), + span.get("TeamId", ""), + span.get("ProjectId", ""), + span.get("RunId", ""), + span.get("ExperimentId", ""), + span.get("SpanAttributes", {}), + span.get("ResourceAttributes", {}), + span.get("Events.Timestamp", []), + span.get("Events.Name", []), + span.get("Events.Attributes", []), + span.get("Links.TraceId", []), + span.get("Links.SpanId", []), + span.get("Links.Attributes", []), + ) ) - ) - - # Insert into ClickHouse - self.client.insert( - f"{self.database}.otel_traces", - data, - column_names=[ - "Timestamp", - "TraceId", - "SpanId", - "ParentSpanId", - "SpanName", - "SpanKind", - "ServiceName", - "Duration", - "StatusCode", - "StatusMessage", - "TeamId", - "ProjectId", - "RunId", - "ExperimentId", - "SpanAttributes", - "ResourceAttributes", - "Events.Timestamp", - "Events.Name", - "Events.Attributes", - "Links.TraceId", - "Links.SpanId", - "Links.Attributes", - ], - ) - logger.debug(f"Inserted {len(spans)} spans into ClickHouse") - except Exception as e: - logger.error(f"Failed to insert spans: {e}") - # Don't raise - we don't want to crash the application if tracing fails - - def get_traces_by_run_id(self, run_id: uuid.UUID) -> list[dict[str, Any]]: - """Get all traces/spans for a specific run_id. - Args: - run_id: The run ID to filter by - - Returns: - List of span dictionaries - """ - try: - query = f""" - SELECT - Timestamp, - TraceId, - SpanId, - ParentSpanId, - SpanName, - SpanKind, - ServiceName, - Duration, - StatusCode, - StatusMessage, - TeamId, - ProjectId, - RunId, - ExperimentId, - SpanAttributes, - ResourceAttributes, - Events.Timestamp as EventTimestamps, - Events.Name as EventNames, - Events.Attributes as EventAttributes, - Links.TraceId as LinkTraceIds, - Links.SpanId as LinkSpanIds, - Links.Attributes as LinkAttributes - FROM {self.database}.otel_traces - WHERE RunId = '{run_id}' - ORDER BY Timestamp ASC - """ - - result = self.client.query(query) - spans = [] - for row in result.result_rows: - spans.append( - { - "Timestamp": row[0], - "TraceId": row[1], - "SpanId": row[2], - "ParentSpanId": row[3], - "SpanName": row[4], - "SpanKind": row[5], - "ServiceName": row[6], - "Duration": row[7], - "StatusCode": row[8], - "StatusMessage": row[9], - "TeamId": row[10], - "ProjectId": row[11], - "RunId": row[12], - "ExperimentId": row[13], - "SpanAttributes": row[14], - "ResourceAttributes": row[15], - "Events": { - "Timestamp": row[16], - "Name": row[17], - "Attributes": row[18], - }, - "Links": { - "TraceId": row[19], - "SpanId": row[20], - "Attributes": row[21], - }, - } + # Insert into ClickHouse + self.client.insert( + f"{self.database}.otel_spans", + data, + column_names=[ + "Timestamp", + "TraceId", + "SpanId", + "ParentSpanId", + "SpanName", + "SpanKind", + "ServiceName", + "Duration", + "StatusCode", + "StatusMessage", + "TeamId", + "ProjectId", + "RunId", + "ExperimentId", + "SpanAttributes", + "ResourceAttributes", + "Events.Timestamp", + "Events.Name", + "Events.Attributes", + "Links.TraceId", + "Links.SpanId", + "Links.Attributes", + ], ) - return spans - except Exception as e: - logger.error(f"Failed to get traces by run_id: {e}") - return [] + logger.debug(f"Inserted {len(spans)} spans into ClickHouse") + except Exception as e: + logger.error(f"Failed to insert spans: {e}") + # Don't raise - we don't want to crash the application if tracing fails - def get_spans_by_trace_id(self, trace_id: str) -> list[dict[str, Any]]: - """Get all spans for a specific trace_id. + def get_spans_by_run_id(self, run_id: uuid.UUID) -> list[dict[str, Any]]: + """Get all spans for a specific run_id. Args: - trace_id: The trace ID to filter by + run_id: The run ID to filter by Returns: List of span dictionaries """ - try: - query = f""" - SELECT - Timestamp, - TraceId, - SpanId, - ParentSpanId, - SpanName, - SpanKind, - ServiceName, - Duration, - StatusCode, - StatusMessage, - TeamId, - ProjectId, - RunId, - ExperimentId, - SpanAttributes, - ResourceAttributes, - Events.Timestamp as EventTimestamps, - Events.Name as EventNames, - Events.Attributes as EventAttributes, - Links.TraceId as LinkTraceIds, - Links.SpanId as LinkSpanIds, - Links.Attributes as LinkAttributes - FROM {self.database}.otel_traces - WHERE TraceId = '{trace_id}' - ORDER BY Timestamp ASC - """ + with self._lock: # Protect concurrent access to ClickHouse client + try: + query = f""" + SELECT + Timestamp, + TraceId, + SpanId, + ParentSpanId, + SpanName, + SpanKind, + ServiceName, + Duration, + StatusCode, + StatusMessage, + TeamId, + ProjectId, + RunId, + ExperimentId, + SpanAttributes, + ResourceAttributes, + Events.Timestamp as EventTimestamps, + Events.Name as EventNames, + Events.Attributes as EventAttributes, + Links.TraceId as LinkTraceIds, + Links.SpanId as LinkSpanIds, + Links.Attributes as LinkAttributes + FROM {self.database}.otel_spans + WHERE RunId = '{run_id}' + ORDER BY Timestamp ASC + """ - result = self.client.query(query) - spans = [] - for row in result.result_rows: - spans.append( - { - "Timestamp": row[0], - "TraceId": row[1], - "SpanId": row[2], - "ParentSpanId": row[3], - "SpanName": row[4], - "SpanKind": row[5], - "ServiceName": row[6], - "Duration": row[7], - "StatusCode": row[8], - "StatusMessage": row[9], - "TeamId": row[10], - "ProjectId": row[11], - "RunId": row[12], - "ExperimentId": row[13], - "SpanAttributes": row[14], - "ResourceAttributes": row[15], - "Events": { - "Timestamp": row[16], - "Name": row[17], - "Attributes": row[18], - }, - "Links": { - "TraceId": row[19], - "SpanId": row[20], - "Attributes": row[21], - }, - } - ) - return spans - except Exception as e: - logger.error(f"Failed to get spans by trace_id: {e}") - return [] + result = self.client.query(query) + spans = [] + for row in result.result_rows: + spans.append( + { + "Timestamp": row[0], + "TraceId": row[1], + "SpanId": row[2], + "ParentSpanId": row[3], + "SpanName": row[4], + "SpanKind": row[5], + "ServiceName": row[6], + "Duration": row[7], + "StatusCode": row[8], + "StatusMessage": row[9], + "TeamId": row[10], + "ProjectId": row[11], + "RunId": row[12], + "ExperimentId": row[13], + "SpanAttributes": row[14], + "ResourceAttributes": row[15], + "Events": { + "Timestamp": row[16], + "Name": row[17], + "Attributes": row[18], + }, + "Links": { + "TraceId": row[19], + "SpanId": row[20], + "Attributes": row[21], + }, + } + ) + return spans + except Exception as e: + logger.error(f"Failed to get traces by run_id: {e}") + return [] def close(self) -> None: """Close the ClickHouse connection.""" diff --git a/dashboard/src/App.tsx b/dashboard/src/App.tsx index 410659fe..d4a2aaa0 100644 --- a/dashboard/src/App.tsx +++ b/dashboard/src/App.tsx @@ -129,9 +129,10 @@ function App() { } return ( - - - }> +
+ + + }> } /> } /> @@ -150,6 +151,7 @@ function App() { +
); } diff --git a/dashboard/src/components/layout/header.tsx b/dashboard/src/components/layout/header.tsx index efb80e2d..af77e16a 100644 --- a/dashboard/src/components/layout/header.tsx +++ b/dashboard/src/components/layout/header.tsx @@ -104,7 +104,7 @@ export function Header() { const breadcrumbs = generateBreadcrumbs(); return ( -
+
{/* Breadcrumbs */}