Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,6 @@ The dashboard will automatically open in your browser at `http://127.0.0.1:5173`

**Options:**
- `--port <PORT>`: Run on a custom port (default: 5173)
- `--no-browser`: Don't automatically open the browser

Example:
```bash
alphatrion dashboard --port 8080 --no-browser
```

**Documentation:**
- [Dashboard Setup Guide](./docs/dashboard/setup.md) - Complete setup and troubleshooting guide
Expand Down
10 changes: 6 additions & 4 deletions alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ def _cancel(self):
self._context.cancel()

def _stop(self):
# cancel the runs first, then stop the experiment.
for run in self._runs.values():
# When experiment is stopped, we consider the unfinished runs as cancelled.
run.cancel()
self._runs.clear()

exp = self._runtime._metadb.get_experiment(experiment_id=self.id)
if exp is not None and exp.status not in FINISHED_STATUS:
duration = (
Expand All @@ -365,10 +371,6 @@ def _stop(self):
)

self._runtime.current_proj.unregister_experiment(self.id)
for run in self._runs.values():
# When experiment is stopped, we consider the unfinished runs as cancelled.
run.cancel()
self._runs.clear()

def _get_obj(self):
return self._runtime._metadb.get_experiment(experiment_id=self.id)
Expand Down
81 changes: 80 additions & 1 deletion alphatrion/server/cmd/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# ruff: noqa: E501
# ruff: noqa: B904

import logging
from importlib.metadata import version

from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from strawberry.fastapi import GraphQLRouter

from alphatrion.server.graphql.schema import schema

# Configure logging
logger = logging.getLogger(__name__)

app = FastAPI()

# Add CORS middleware - allows frontend to access the API
Expand All @@ -20,6 +24,81 @@
allow_headers=["*"],
)


# Helper function to extract operation name from query
def extract_operation_name(query: str) -> str:
"""Extract operation name from GraphQL query."""
import re

# Try to find operation name in format: query OperationName or mutation OperationName
match = re.search(r"(query|mutation)\s+(\w+)", query)
if match:
return match.group(2)

# Try to find first field selection (e.g., { getExperiment { ... })
match = re.search(r"\{\s*(\w+)", query)
if match:
return match.group(1)

return "Anonymous"


# Add middleware to log GraphQL requests
@app.middleware("http")
async def log_graphql_requests(request: Request, call_next):
"""Middleware to log GraphQL requests and responses."""
operation_name = "Unknown"
operation_type = "query"

if request.url.path == "/graphql" and request.method == "POST":
try:
# Read and cache the body
body = await request.body()
import json

data = json.loads(body)
query = data.get("query", "")
variables = data.get("variables", {})

# Get operation name from request or extract from query
operation_name = data.get("operationName")
if not operation_name:
operation_name = extract_operation_name(query)

# Extract operation type (query or mutation)
if query.strip().startswith("mutation"):
operation_type = "mutation"

# Log the GraphQL operation request
variable_keys = list(variables.keys()) if variables else []
logger.info(
f"GraphQL {operation_type}: {operation_name} | Variables: {variable_keys if variable_keys else 'None'}"
)
logger.debug(f"GraphQL {operation_type} full query:\n{query}")

# Create a new request with the cached body
async def receive():
return {"type": "http.request", "body": body}

request._receive = receive

except Exception as e:
logger.error(f"Failed to log GraphQL request: {e}")

response = await call_next(request)

# Log response status for GraphQL operations
if request.url.path == "/graphql" and request.method == "POST":
try:
logger.info(
f"GraphQL {operation_type} {operation_name} completed | Status: {response.status_code}"
)
except Exception as e:
logger.error(f"Failed to log GraphQL response: {e}")

return response


# Create GraphQL router
graphql_app = GraphQLRouter(schema)

Expand Down
5 changes: 5 additions & 0 deletions alphatrion/server/cmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from alphatrion import envs
from alphatrion.storage import runtime
from alphatrion.utils import log

load_dotenv()
console = Console()
Expand Down Expand Up @@ -198,6 +199,10 @@ def run_server(args):
style="bold green",
)
console.print(msg)

# Configure logging before starting the server
log.configure_logging()

runtime.init()
uvicorn.run("alphatrion.server.cmd.app:app", host=args.host, port=args.port)

Expand Down
75 changes: 72 additions & 3 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from alphatrion.artifact import artifact
from alphatrion.storage import runtime
from alphatrion.storage.sql_models import Status
from alphatrion.storage.sql_models import FINISHED_STATUS, Status

from .types import (
AddUserToTeamInput,
Expand Down Expand Up @@ -168,14 +168,30 @@ def get_experiment(id: strawberry.ID) -> Experiment | None:
metadb = runtime.storage_runtime().metadb
exp = metadb.get_experiment(experiment_id=uuid.UUID(id))
if exp:
meta = exp.meta or {}

# Aggregate and cache tokens for finished experiments
# Only calculate if experiment is in a finished state and tokens
# not already cached.
exp_status = Status(exp.status)
is_finished = exp_status in FINISHED_STATUS

if is_finished and "total_tokens" not in meta:
token_data = GraphQLResolvers.aggregate_experiment_tokens(
experiment_id=id
)
if token_data["total_tokens"] > 0:
meta.update(token_data)
metadb.update_experiment(experiment_id=uuid.UUID(id), meta=meta)

return Experiment(
id=exp.uuid,
team_id=exp.team_id,
user_id=exp.user_id,
project_id=exp.project_id,
name=exp.name,
description=exp.description,
meta=exp.meta,
meta=meta,
params=exp.params,
duration=exp.duration,
status=GraphQLStatusEnum[Status(exp.status).name],
Expand Down Expand Up @@ -405,7 +421,13 @@ def aggregate_run_tokens(run_id: strawberry.ID) -> dict[str, int]:
"""Aggregate token usage from all traces for a run."""
from alphatrion import envs

# Check if tracing is enabled
# One potential issue here is if tracing is disabled after the run
# has completed, we won't be able to aggregate tokens anymore since
# we rely on fetching spans from the trace store.
# For now we assume tracing is enabled if users want to see token usage.
# In the future, we could consider caching token data in the metadb when
# runs/experiments are completed to avoid relying on trace store for
# historical data.
if os.getenv(envs.ENABLE_TRACING, "false").lower() != "true":
return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}

Expand Down Expand Up @@ -440,6 +462,53 @@ def aggregate_run_tokens(run_id: strawberry.ID) -> dict[str, int]:
logging.error(f"Failed to aggregate tokens for run {run_id}: {e}")
return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}

@staticmethod
def aggregate_experiment_tokens(experiment_id: strawberry.ID) -> dict[str, int]:
"""Aggregate token usage from all runs in an experiment."""
try:
metadb = runtime.storage_runtime().metadb

# Get all runs for this experiment (unpaginated)
runs = metadb.list_runs_by_exp_id(
exp_id=uuid.UUID(experiment_id),
page=0,
page_size=10000, # Large page size to get all runs
)

total_tokens = 0
input_tokens = 0
output_tokens = 0

for run in runs:
current_run = run

if current_run.meta:
# Trigger the aggregation of tokens for the run if not already done
# When experiment is finished, its runs should also be finished, so
# token aggregation should be safe without worrying.
if "total_tokens" not in current_run.meta:
GraphQLResolvers.aggregate_run_tokens(run_id=current_run.uuid)
# Refresh run data to get updated tokens
current_run = metadb.get_run(run_id=current_run.uuid)

# Sum up tokens from each run's meta
total_tokens += int(current_run.meta.get("total_tokens", 0))
input_tokens += int(current_run.meta.get("input_tokens", 0))
output_tokens += int(current_run.meta.get("output_tokens", 0))

return {
"total_tokens": total_tokens,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
}
except Exception as e:
import logging

logging.error(
f"Failed to aggregate tokens for experiment {experiment_id}: {e}"
)
return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}

@staticmethod
def list_spans(run_id: strawberry.ID) -> list[Span]:
"""List all spans for a specific run."""
Expand Down
29 changes: 29 additions & 0 deletions alphatrion/utils/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging
import os
import sys


def configure_logging():
"""Configure logging for the server with GraphQL debugging support."""
log_level = os.getenv("ALPHATRION_LOG_LEVEL", "INFO").upper()

# Configure logging format
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"

# Configure root logger
logging.basicConfig(
level=log_level,
format=log_format,
datefmt=date_format,
handlers=[logging.StreamHandler(sys.stdout)],
)

# Set uvicorn logger to INFO to avoid too much noise
logging.getLogger("uvicorn").setLevel(logging.INFO)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)

# Log startup info
logger = logging.getLogger(__name__)
logger.info(f"Logging configured with level: {log_level}")
logger.info("Set ALPHATRION_LOG_LEVEL=DEBUG to see detailed GraphQL queries")
Loading
Loading