From 72023ab06cc9faaf8f0a4405a5989079f6eb0cad Mon Sep 17 00:00:00 2001 From: Florian Mallmann Date: Fri, 19 Dec 2025 10:42:31 +0100 Subject: [PATCH 1/8] feat: Add multi-turn conversation support for run.py --- CLAUDE.md | 22 +- Dockerfile | 2 - Tiltfile | 2 +- deploy/base/templates/evaluate-template.yaml | 2 +- deploy/base/templates/setup-template.yaml | 3 + deploy/local/data-server/configmap.yaml | 26 +- deploy/local/kustomization.yaml | 2 +- deploy/local/multi-turn-workflow.yaml | 31 +++ pyproject.toml | 4 +- scripts/data/datasets/ragas_dataset.jsonl | 1 + .../expected_ragas_experiment.jsonl | 1 + scripts/run.py | 253 +++++++++++++++++- scripts/setup.py | 8 +- tests/test_run.py | 195 +++++++++++++- uv.lock | 46 +--- 15 files changed, 505 insertions(+), 93 deletions(-) create mode 100644 deploy/local/multi-turn-workflow.yaml create mode 100644 scripts/data/datasets/ragas_dataset.jsonl create mode 100644 scripts/data/experiments/expected_ragas_experiment.jsonl diff --git a/CLAUDE.md b/CLAUDE.md index 0558890..aed0ec3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -109,8 +109,10 @@ make run **Phase 2: Run** (`scripts/run.py`) - **Input**: `data/datasets/ragas_dataset.jsonl` + Agent URL -- **Output**: `data/experiments/ragas_experiment.jsonl` (adds `response` field) -- **Purpose**: Sends each `user_input` to agent via A2A protocol using `a2a-sdk`, records agent responses +- **Output**: `data/experiments/ragas_experiment.jsonl` (adds `response` field for single-turn, full conversation for multi-turn) +- **Purpose**: Sends queries to agent via A2A protocol using `a2a-sdk`, records agent responses +- **Auto-Detection**: Detects single-turn vs multi-turn format and routes to appropriate experiment function +- **Multi-Turn Support**: For conversational datasets, sequentially queries agent for each user message while maintaining context_id **Phase 3: Evaluate** (`scripts/evaluate.py`) - **Input**: `data/experiments/ragas_experiment.jsonl` + LLM model + metrics list @@ -173,6 +175,22 @@ Observability Backend (Grafana) - **Client Library**: `a2a-sdk` Python package - **Usage in Testbench**: `run.py` uses `A2AClient` to send `user_input` prompts to agent's A2A endpoint - **Response Handling**: Agent responses stored in `response` field of experiment JSONL +- **Context Management**: A2A `context_id` field maintains conversation state across multiple turns + +### Multi-Turn Conversation Support +- **Purpose**: Evaluate agents in conversational scenarios with multiple back-and-forth exchanges +- **Detection**: `run.py` automatically detects dataset type by inspecting `user_input` field type (string = single-turn, list = multi-turn) +- **Experiment Functions**: + - `single_turn_experiment()`: Handles traditional question-answer format + - `multi_turn_experiment()`: Handles conversational interactions +- **Sequential Query Strategy**: For each human message in the conversation: + 1. Send message to agent via A2A protocol + 2. Capture agent's response and extract `context_id` + 3. Use `context_id` in subsequent messages to maintain conversation context + 4. After final turn, extract full conversation history from `task.history` +- **Data Format**: Multi-turn datasets use list of message dicts: `[{"content": "...", "type": "human"}, {"content": "...", "type": "ai"}, ...]` +- **Tool Calls**: Extracts tool call information from A2A `message.metadata` if available +- **Observability**: Creates parent span for conversation with child spans for each turn ### OpenTelemetry (OTLP) - **Purpose**: Standard protocol for publishing observability data diff --git a/Dockerfile b/Dockerfile index 56aa9eb..4f91d8e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,5 @@ FROM python:3.13-slim -WORKDIR /app - # Install runtime and build dependencies (git is needed for Gitpython, which is a dependency of Ragas) RUN apt-get update && apt-get install -y --no-install-recommends \ git \ diff --git a/Tiltfile b/Tiltfile index 7829ff8..72e3ba4 100644 --- a/Tiltfile +++ b/Tiltfile @@ -52,4 +52,4 @@ k8s_resource('ragas-evaluate-template', resource_deps=['testkube']) k8s_resource('ragas-publish-template', resource_deps=['testkube']) k8s_resource('ragas-run-template', resource_deps=['testkube']) k8s_resource('ragas-setup-template', resource_deps=['testkube']) -k8s_resource('ragas-evaluation-workflow', resource_deps=['testkube']) +k8s_resource('multi-turn-workflow', resource_deps=['testkube']) diff --git a/deploy/base/templates/evaluate-template.yaml b/deploy/base/templates/evaluate-template.yaml index 67e38d2..156491a 100644 --- a/deploy/base/templates/evaluate-template.yaml +++ b/deploy/base/templates/evaluate-template.yaml @@ -18,7 +18,7 @@ spec: image: type: string description: "Docker image to use for the evaluate step" - default: "ghcr.io/agentic-layer/testbench/testworkflows:latest" + default: "ghcr.io/agentic-layer/testbench/testworkflows:0.1.1" # Steps to execute steps: diff --git a/deploy/base/templates/setup-template.yaml b/deploy/base/templates/setup-template.yaml index 5b12b30..4e2c33b 100644 --- a/deploy/base/templates/setup-template.yaml +++ b/deploy/base/templates/setup-template.yaml @@ -21,6 +21,9 @@ spec: # Steps to execute steps: - name: setup-dataset + artifacts: + paths: + - "data/datasets/ragas_dataset.jsonl" run: command: - sh diff --git a/deploy/local/data-server/configmap.yaml b/deploy/local/data-server/configmap.yaml index 3a4e328..2aa66bd 100644 --- a/deploy/local/data-server/configmap.yaml +++ b/deploy/local/data-server/configmap.yaml @@ -5,31 +5,7 @@ metadata: data: dataset.json: | [ - { - "user_input": "What is the weather like in New York right now?", - "retrieved_contexts": ["The answer must state the current temperature in New York, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain)."], - "reference": "The answer must state the current temperature in New York, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain)." - }, - { - "user_input": "What is the current time in New York?", - "retrieved_contexts": ["The answer must state the current time in New York in HH:MM format and include the correct timezone abbreviation (e.g., CST)."], - "reference": "The answer must state the current time in New York in HH:MM format and include the correct timezone abbreviation (e.g., CST)." - }, - { - "user_input": "What is the weather like in Cairo?", - "retrieved_contexts": ["The answer must state the current temperature in Cairo, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain)."], - "reference": "The answer must state the current temperature in Cairo, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain)." - }, - { - "user_input": "How is the weather in Sydney?", - "retrieved_contexts": ["The answer must state the current temperature in Sydney, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain)."], - "reference": "The answer must state the current temperature in Sydney, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain)." - }, - { - "user_input": "Time in Garching?", - "retrieved_contexts": ["The answer must state the current time in Garching, Germany in HH:MM format and include the correct timezone abbreviation (CEST)."], - "reference": "The answer must state the current time in Garching, Germany in HH:MM format and include the correct timezone abbreviation (CEST)." - } + {"user_input": [{"content": "I need to increase my credit limit and check why my last transaction at Walmart was declined.", "type": "human"}, {"content": "That's not possible, I had enough money in my account.", "type": "human"}, {"content": "Oh, I forgot about the hotel booking.", "type": "human"}, {"content": "What about increasing my credit limit?", "type": "human"}]} ] dataset.csv: | user_input,retrieved_contexts,reference diff --git a/deploy/local/kustomization.yaml b/deploy/local/kustomization.yaml index 4b95374..9df4e84 100644 --- a/deploy/local/kustomization.yaml +++ b/deploy/local/kustomization.yaml @@ -5,4 +5,4 @@ resources: - weather-agent.yaml - data-server/ - ../base - - ragas-evaluation-workflow.yaml + - multi-turn-workflow.yaml diff --git a/deploy/local/multi-turn-workflow.yaml b/deploy/local/multi-turn-workflow.yaml new file mode 100644 index 0000000..9134000 --- /dev/null +++ b/deploy/local/multi-turn-workflow.yaml @@ -0,0 +1,31 @@ +apiVersion: testworkflows.testkube.io/v1 +kind: TestWorkflow +metadata: + name: multi-turn-workflow + namespace: testkube + labels: + testkube.io/test-category: ragas-evaluation + app: testworkflows + +spec: + container: + image: ghcr.io/agentic-layer/testbench/testworkflows:latest + env: + - name: OPENAI_API_BASE + value: "http://ai-gateway-litellm.ai-gateway:4000" + + # Global configuration that applies to all steps + config: + # Dataset configuration + datasetUrl: + type: string + description: "URL to the dataset file" + + # Steps using the templates + steps: + # Step 1: Setup - Download and convert dataset + - name: setup + use: + - name: ragas-setup-template + config: + datasetUrl: "{{ config.datasetUrl }}" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cbcab2a..69884a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,10 +8,12 @@ dependencies = [ "a2a-sdk>=0.3.10", "httpx>=0.28.1", "langchain-openai>=1.0.2", + "nest-asyncio>=1.6.0", "pandas>=2.3.3", "pandas-stubs>=2.3.0", "pyarrow>=21.0.0", - "ragas>=0.3.5", + "python-dotenv>=1.0.0", + "ragas[ag-ui]>=0.4.1", "requests>=2.32.5", "types-requests>=2.32.0", "opentelemetry-api>=1.20.0", diff --git a/scripts/data/datasets/ragas_dataset.jsonl b/scripts/data/datasets/ragas_dataset.jsonl new file mode 100644 index 0000000..c9dfd51 --- /dev/null +++ b/scripts/data/datasets/ragas_dataset.jsonl @@ -0,0 +1 @@ +{"user_input":[{"content":"What is the weather like in New York right now?","type":"human"},{"content":"What time is it in New York?","type":"human"}]} \ No newline at end of file diff --git a/scripts/data/experiments/expected_ragas_experiment.jsonl b/scripts/data/experiments/expected_ragas_experiment.jsonl new file mode 100644 index 0000000..2ee28c9 --- /dev/null +++ b/scripts/data/experiments/expected_ragas_experiment.jsonl @@ -0,0 +1 @@ +{"user_input":[{"content":"What is the weather like in New York right now?","type":"human"},{"content":"The weather is 25 degrees.","type":"agent"},{"content":"What time is it in New York?","type":"human"},{"content":"It is 11:49.","type":"agent"}],"reference":null,"trace_id":"5ee59682c1477b74b568078d477ad62d"} diff --git a/scripts/run.py b/scripts/run.py index d1860f4..4563715 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -19,12 +19,95 @@ from otel_setup import setup_otel from pydantic import BaseModel from ragas import Dataset, experiment +from ragas.messages import AIMessage, HumanMessage, ToolCall # Set up module-level logger logging.basicConfig(level=logging.INFO) logger: Logger = logging.getLogger(__name__) +def a2a_message_to_ragas(message: Message) -> HumanMessage | AIMessage: + """ + Convert A2A Message to RAGAS message format. + + Handles: + - Text extraction from multiple parts + - Role mapping (user → human, agent → ai) + - Tool call extraction from metadata + - Metadata preservation + + Args: + message: A2A Message object + + Returns: + HumanMessage or AIMessage + + Raises: + ValueError: If role is not user or agent + """ + # Extract text from all TextPart objects + text_parts = [] + for part in message.parts: + # Part is a wrapper - access the actual part inside + actual_part = part.root if hasattr(part, "root") else part + if hasattr(actual_part, "text"): + text_parts.append(actual_part.text) + + content = " ".join(text_parts) if text_parts else "" + + # Map role + if message.role == Role.user: + return HumanMessage(content=content, metadata=message.metadata) + elif message.role == Role.agent: + # Extract tool calls from metadata if present + tool_calls = None + if message.metadata and "tool_calls" in message.metadata: + # Parse tool calls from metadata + tool_calls_data = message.metadata["tool_calls"] + tool_calls = [ToolCall(name=tc["name"], args=tc["args"]) for tc in tool_calls_data] + + return AIMessage(content=content, metadata=message.metadata, tool_calls=tool_calls) + else: + raise ValueError(f"Unsupported message role: {message.role}") + + +def validate_multi_turn_input(user_input: list) -> list[dict]: + """ + Validate and normalize multi-turn user_input. + + Expected format: [{"content": "...", "type": "human"}, {"content": "...", "type": "ai"}, ...] + + Args: + user_input: List of message dictionaries + + Returns: + Validated list of message dicts + + Raises: + ValueError: If format is invalid + """ + if not isinstance(user_input, list): + raise ValueError(f"Multi-turn user_input must be list, got {type(user_input)}") + + if not user_input: + raise ValueError("Multi-turn user_input cannot be empty") + + for i, msg in enumerate(user_input): + if not isinstance(msg, dict): + raise ValueError(f"Message {i} must be dict, got {type(msg)}") + + if "content" not in msg: + raise ValueError(f"Message {i} missing 'content' field") + + if "type" not in msg: + raise ValueError(f"Message {i} missing 'type' field") + + if msg["type"] not in ("human", "ai", "tool"): + raise ValueError(f"Message {i} has invalid type: {msg['type']}") + + return user_input + + async def initialize_client(agent_url: str) -> Client: """Initialize the A2A client with a minimal agent card.""" logger.info(f"Initializing A2A client for: {agent_url}") @@ -42,12 +125,14 @@ async def initialize_client(agent_url: str) -> Client: @experiment() -async def run_agent_experiment(row, agent_url: str, workflow_name: str) -> dict[str, str | list]: +async def single_turn_experiment(row, agent_url: str, workflow_name: str) -> dict[str, str | list]: """ - Experiment function that processes each row from the dataset. + Single-turn experiment function that processes each row from the dataset. + + Sends a single user message to the agent and captures the response. Args: - row: A dictionary containing 'user_input', 'retrieved_contexts', and 'reference' fields + row: A dictionary containing 'user_input' (str), 'retrieved_contexts', and 'reference' fields agent_url: The URL of the agent to query workflow_name: Name of the test workflow for span labeling @@ -125,20 +210,170 @@ async def run_agent_experiment(row, agent_url: str, workflow_name: str) -> dict[ return result +@experiment() +async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict[str, list | str]: + """ + Multi-turn experiment function for conversational interactions. + + Processes a conversation by: + 1. Extracting human messages from input + 2. Sequentially querying agent for each turn + 3. Maintaining context_id across turns + 4. Extracting full conversation history from final task + 5. Converting to RAGAS MultiTurnSample format + + Args: + row: Dictionary with 'user_input' (list of message dicts) and 'reference' + agent_url: URL of the agent to query + workflow_name: Name of the test workflow for span labeling + + Returns: + Dictionary with 'user_input' (list of RAGAS messages), 'reference', 'trace_id' + """ + # Get tracer for creating spans + tracer = trace.get_tracer("testbench.run") + + # Create parent span for entire conversation + user_input_preview = str(row.get("user_input", []))[:100] + span_name = f"query_agent_multi_turn: {user_input_preview}" + + with tracer.start_as_current_span(span_name) as span: + # Extract trace ID + span_context = span.get_span_context() + trace_id = format(span_context.trace_id, "032x") + + # Add span attributes + span.set_attribute("test.turn_count", len(row.get("user_input", []))) + span.set_attribute("test.reference", row.get("reference", "")) + span.set_attribute("agent.url", agent_url) + span.set_attribute("workflow.name", workflow_name) + span.set_attribute("test.conversation_type", "multi_turn") + + try: + # Validate input format + user_input = validate_multi_turn_input(row.get("user_input")) + + async with httpx.AsyncClient(): + client = await initialize_client(agent_url) + + # Extract only human messages (agent messages are from dataset, not sent) + human_messages = [msg for msg in user_input if msg.get("type") == "human"] + + if not human_messages: + raise ValueError("No human messages found in user_input") + + context_id = None + conversation_messages = [] + + # Sequentially query agent for each human turn + for turn_idx, human_msg in enumerate(human_messages): + # Create child span for this turn + turn_span_name = f"turn_{turn_idx + 1}: {human_msg['content'][:50]}" + with tracer.start_as_current_span(turn_span_name) as turn_span: + turn_span.set_attribute("turn.index", turn_idx + 1) + turn_span.set_attribute("turn.content", human_msg["content"]) + + + # Create A2A message + message = Message( + role=Role.user, + parts=[TextPart(text=human_msg["content"])], + message_id=uuid4().hex, + context_id=context_id, # None for first turn, preserved after + ) + conversation_messages.append({"content": human_msg["content"], "type": "human"}) + + logger.info(f"Turn {turn_idx + 1}/{len(human_messages)}: {human_msg['content']}") + + # Send message and get response + agent_response_text = "" + async for response in client.send_message(message): + if isinstance(response, tuple): + task, _ = response + if task: + # Capture context_id from first response + if not context_id: + context_id = task.context_id + logger.info(f"Captured context_id: {context_id}") + span.set_attribute("conversation.context_id", context_id) + + # Extract agent response from artifacts (same approach as single_turn_experiment) + artifacts: list = task.model_dump(mode="json", include={"artifacts"}).get( + "artifacts", [] + ) + if artifacts and artifacts[0].get("parts"): + agent_response_text = artifacts[0]["parts"][0].get("text", "") + + # Add agent response to conversation + if agent_response_text: + conversation_messages.append({"content": agent_response_text, "type": "agent"}) + logger.info(f"Agent response: {agent_response_text[:100]}...") + else: + logger.warning(f"Empty agent response for turn {turn_idx + 1}") + + # Validate we got responses + if len(conversation_messages) < 2: + raise ValueError(f"Incomplete conversation: only {len(conversation_messages)} messages") + + # Use the manually built conversation + user_input_serialized = conversation_messages + + # Mark span as successful + span.set_status(Status(StatusCode.OK)) + span.set_attribute("conversation.message_count", len(conversation_messages)) + + except Exception as e: + logger.error(f"Error processing multi-turn conversation: {str(e)}") + + # Record exception in span + span.record_exception(e) + span.set_status(Status(StatusCode.ERROR, description=str(e))) + + # Return minimal result + return { + "user_input": row.get("user_input"), + "trace_id": trace_id, + } + + # Return result in MultiTurnSample format + result = { + "user_input": user_input_serialized, + "trace_id": trace_id, + } + + return result + + async def main(agent_url: str, workflow_name: str) -> None: - """Main function to load Ragas Dataset and run Experiment.""" + """Main function to load Dataset and run appropriate Experiment.""" # Initialize OpenTelemetry tracing setup_otel() - # Load existing Ragas dataset - logger.info("Loading Ragas dataset from data/datasets/ragas_dataset.jsonl") + # Load existing dataset + logger.info("Loading dataset from data/datasets/ragas_dataset.jsonl") dataset: Dataset[BaseModel] = Dataset.load(name="ragas_dataset", backend="local/jsonl", root_dir="./data") logger.info(f"Dataset loaded with {len(dataset)} samples") - # Run the experiment - logger.info("Starting experiment...") - await run_agent_experiment.arun(dataset, name="ragas_experiment", agent_url=agent_url, workflow_name=workflow_name) + # Detect dataset type by inspecting first row + if len(dataset) == 0: + raise ValueError("Dataset is empty") + + first_row = dataset[0] + is_multi_turn = isinstance(first_row.get("user_input"), list) + + if is_multi_turn: + logger.info("Detected multi-turn dataset") + logger.info("Starting multi-turn experiment...") + await multi_turn_experiment.arun( + dataset, name="ragas_experiment", agent_url=agent_url, workflow_name=workflow_name + ) + else: + logger.info("Detected single-turn dataset") + logger.info("Starting single-turn experiment...") + await single_turn_experiment.arun( + dataset, name="ragas_experiment", agent_url=agent_url, workflow_name=workflow_name + ) logger.info("Experiment completed successfully") logger.info("Results saved to data/experiments/ragas_experiment.jsonl") diff --git a/scripts/setup.py b/scripts/setup.py index 351f0bb..6bb4825 100644 --- a/scripts/setup.py +++ b/scripts/setup.py @@ -20,17 +20,13 @@ def dataframe_to_ragas_dataset(dataframe: DataFrame) -> None: - reference: The reference/ground truth answer """ - # Set output directory (and create it if it doesn't exist already) output_dir = Path("data") output_dir.mkdir(exist_ok=True) - # Convert DataFrame to list of dictionaries - dataset_samples = cast(list[dict[str, Any]], dataframe.to_dict(orient="records")) - # Create Ragas Dataset - dataset: Dataset[BaseModel] = Dataset( + dataset = Dataset.from_pandas( name="ragas_dataset", - data=dataset_samples, + dataframe=dataframe, backend="local/jsonl", root_dir="./data", ) diff --git a/tests/test_run.py b/tests/test_run.py index 8cacc30..a1f19ed 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -14,7 +14,14 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) -from run import initialize_client, main, run_agent_experiment +from run import ( + a2a_message_to_ragas, + initialize_client, + main, + multi_turn_experiment, + single_turn_experiment, + validate_multi_turn_input, +) # Fixtures @@ -64,9 +71,9 @@ def mock_factory_init(config=None): assert result == mock_client -# TestRunAgentExperiment tests +# TestSingleTurnExperiment tests @pytest.mark.asyncio -async def test_run_agent_experiment_success(monkeypatch): +async def test_single_turn_experiment_success(monkeypatch): """Test successful agent query execution""" # Mock the client @@ -110,7 +117,7 @@ def mock_httpx_client(): } # Call the function - result = await run_agent_experiment.func( + result = await single_turn_experiment.func( test_row, agent_url="http://test-agent:8000", workflow_name="test-workflow" ) @@ -124,7 +131,7 @@ def mock_httpx_client(): @pytest.mark.asyncio -async def test_run_agent_experiment_error(monkeypatch): +async def test_single_turn_experiment_error(monkeypatch): """Test agent query with error handling""" # Mock client that raises an error @@ -153,7 +160,7 @@ def mock_httpx_client(): } # Call the function - result = await run_agent_experiment.func( + result = await single_turn_experiment.func( test_row, agent_url="http://test-agent:8000", workflow_name="test-workflow" ) @@ -177,6 +184,10 @@ class MockDataset: def __len__(self): return 2 + def __getitem__(self, index): + # Return single-turn format for detection + return {"user_input": "Test question", "retrieved_contexts": [], "reference": "Answer"} + mock_dataset = MockDataset() def mock_dataset_load(path, backend): @@ -203,7 +214,7 @@ async def mock_arun_tracked(*args, **kwargs): return mock_experiment monkeypatch.setattr("run.Dataset.load", mock_dataset_load_tracked) - monkeypatch.setattr("run.run_agent_experiment.arun", mock_arun_tracked) + monkeypatch.setattr("run.single_turn_experiment.arun", mock_arun_tracked) # Run main await main("http://test-agent:8000", "test-workflow") @@ -218,3 +229,173 @@ async def mock_arun_tracked(*args, **kwargs): assert calls_to_arun[0]["kwargs"]["workflow_name"] == "test-workflow" finally: os.chdir(original_cwd) + + +# Test helper functions +def test_a2a_message_to_ragas_human(): + """Test conversion of A2A user message to RAGAS HumanMessage""" + from a2a.types import Message, Part, Role, TextPart + + # Create A2A user message + a2a_msg = Message( + role=Role.user, + parts=[Part(TextPart(text="Hello, how are you?"))], + message_id="test123", + ) + + # Convert to RAGAS + ragas_msg = a2a_message_to_ragas(a2a_msg) + + # Verify + from ragas.messages import HumanMessage + + assert isinstance(ragas_msg, HumanMessage) + assert ragas_msg.content == "Hello, how are you?" + + +def test_a2a_message_to_ragas_ai(): + """Test conversion of A2A agent message to RAGAS AIMessage""" + from a2a.types import Message, Part, Role, TextPart + + # Create A2A agent message + a2a_msg = Message( + role=Role.agent, + parts=[Part(TextPart(text="I'm doing well, thank you!"))], + message_id="test456", + ) + + # Convert to RAGAS + ragas_msg = a2a_message_to_ragas(a2a_msg) + + # Verify + from ragas.messages import AIMessage + + assert isinstance(ragas_msg, AIMessage) + assert ragas_msg.content == "I'm doing well, thank you!" + assert ragas_msg.tool_calls is None + + +def test_a2a_message_to_ragas_with_tool_calls(): + """Test tool call extraction from metadata""" + from a2a.types import Message, Part, Role, TextPart + + # Create A2A agent message with tool calls in metadata + a2a_msg = Message( + role=Role.agent, + parts=[Part(TextPart(text="Let me check the weather"))], + message_id="test789", + metadata={"tool_calls": [{"name": "get_weather", "args": {"location": "NYC"}}]}, + ) + + # Convert to RAGAS + ragas_msg = a2a_message_to_ragas(a2a_msg) + + # Verify + from ragas.messages import AIMessage + + assert isinstance(ragas_msg, AIMessage) + assert ragas_msg.content == "Let me check the weather" + assert ragas_msg.tool_calls is not None + assert len(ragas_msg.tool_calls) == 1 + assert ragas_msg.tool_calls[0].name == "get_weather" + assert ragas_msg.tool_calls[0].args == {"location": "NYC"} + + +def test_a2a_message_to_ragas_multi_part(): + """Test text extraction from multiple parts""" + from a2a.types import Message, Part, Role, TextPart + + # Create message with multiple text parts + a2a_msg = Message( + role=Role.agent, + parts=[Part(TextPart(text="Hello")), Part(TextPart(text="World"))], + message_id="test", + ) + + # Convert to RAGAS + ragas_msg = a2a_message_to_ragas(a2a_msg) + + # Verify text parts are concatenated + from ragas.messages import AIMessage + + assert isinstance(ragas_msg, AIMessage) + assert ragas_msg.content == "Hello World" + + +def test_validate_multi_turn_input_success(): + """Test validation with valid multi-turn input""" + user_input = [ + {"content": "Hello", "type": "human"}, + {"content": "Hi there!", "type": "ai"}, + {"content": "How are you?", "type": "human"}, + ] + + result = validate_multi_turn_input(user_input) + + assert result == user_input + + +def test_validate_multi_turn_input_invalid_type(): + """Test validation rejects non-list input""" + with pytest.raises(ValueError, match="must be list"): + validate_multi_turn_input("not a list") # type: ignore + + +def test_validate_multi_turn_input_missing_fields(): + """Test validation catches missing content/type fields""" + # Missing content + with pytest.raises(ValueError, match="missing 'content' field"): + validate_multi_turn_input([{"type": "human"}]) + + # Missing type + with pytest.raises(ValueError, match="missing 'type' field"): + validate_multi_turn_input([{"content": "Hello"}]) + + +def test_validate_multi_turn_input_invalid_message_type(): + """Test validation catches invalid message types""" + with pytest.raises(ValueError, match="has invalid type"): + validate_multi_turn_input([{"content": "Hello", "type": "invalid"}]) + + +@pytest.mark.asyncio +async def test_main_detects_multi_turn(temp_dir, monkeypatch): + """Test main calls multi_turn_experiment for list user_input""" + tmp, original_cwd = temp_dir + os.chdir(tmp) + + try: + # Create a mock dataset with multi-turn format + class MockDataset: + def __len__(self): + return 1 + + def __getitem__(self, index): + # Return multi-turn format for detection + return { + "user_input": [{"content": "Hello", "type": "human"}], + "reference": "Answer", + } + + mock_dataset = MockDataset() + + calls_to_multi_turn = [] + + async def mock_multi_turn_arun(*args, **kwargs): + calls_to_multi_turn.append({"args": args, "kwargs": kwargs}) + return None + + def mock_dataset_load(**kwargs): + return mock_dataset + + monkeypatch.setattr("run.Dataset.load", mock_dataset_load) + monkeypatch.setattr("run.multi_turn_experiment.arun", mock_multi_turn_arun) + + # Run main + await main("http://test-agent:8000", "test-workflow") + + # Verify multi_turn_experiment was called + assert len(calls_to_multi_turn) == 1 + assert calls_to_multi_turn[0]["kwargs"]["workflow_name"] == "test-workflow" + finally: + os.chdir(original_cwd) diff --git a/uv.lock b/uv.lock index b593750..d70a19d 100644 --- a/uv.lock +++ b/uv.lock @@ -547,30 +547,6 @@ http = [ { name = "aiohttp" }, ] -[[package]] -name = "gitdb" -version = "4.0.12" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "smmap" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, -] - -[[package]] -name = "gitpython" -version = "3.1.45" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "gitdb" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9a/c8/dd58967d119baab745caec2f9d853297cec1989ec1d63f677d3880632b88/gitpython-3.1.45.tar.gz", hash = "sha256:85b0ee964ceddf211c41b9f27a49086010a190fd8132a24e21f362a4b36a791c", size = 215076, upload-time = "2025-07-24T03:45:54.871Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77", size = 208168, upload-time = "2025-07-24T03:45:52.517Z" }, -] - [[package]] name = "google-api-core" version = "2.28.1" @@ -2303,13 +2279,12 @@ wheels = [ [[package]] name = "ragas" -version = "0.3.5" +version = "0.4.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "appdirs" }, { name = "datasets" }, { name = "diskcache" }, - { name = "gitpython" }, { name = "instructor" }, { name = "langchain" }, { name = "langchain-community" }, @@ -2327,9 +2302,9 @@ dependencies = [ { name = "tqdm" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2f/40/c342893abd6f73eb2b20a562e1560cc87f18b8316a12288f03642febab24/ragas-0.3.5.tar.gz", hash = "sha256:164d5c0a96048d9c9373aa3e9123f0096649abbd2b58e747c2f0a454da6c2d6b", size = 43027900, upload-time = "2025-09-17T19:13:52.766Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/e2/1066235befe0b9ea1921ca84b6dd8a9f35f74c10a23a56a4a82ce1e9f240/ragas-0.4.1.tar.gz", hash = "sha256:eda2603269c5c8021166ef56328b68ed88af15c4205a3e31b759233fd7ffc720", size = 43940590, upload-time = "2025-12-10T16:29:25.13Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/75/4c/56338824441cefe5bab2b47350805f9fbf5ec85de78645452da461c9c174/ragas-0.3.5-py3-none-any.whl", hash = "sha256:3e917b12dc90ef692776263f66d220df40ff0573d2a96c8868198629f8b35206", size = 284321, upload-time = "2025-09-17T19:13:50.065Z" }, + { url = "https://files.pythonhosted.org/packages/a7/fb/51dbc01f6ec3dc79257f9347fdf0bcc46a63fe016e279128cad911267e3a/ragas-0.4.1-py3-none-any.whl", hash = "sha256:afcf36542087d0e0ef5898d7da04f20fb69eb215c326529a5c4b3d54a1fe4305", size = 419897, upload-time = "2025-12-10T16:29:21.873Z" }, ] [[package]] @@ -2654,15 +2629,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] -[[package]] -name = "smmap" -version = "5.0.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329, upload-time = "2025-01-02T07:14:40.909Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, -] - [[package]] name = "sniffio" version = "1.3.1" @@ -2737,6 +2703,7 @@ dependencies = [ { name = "a2a-sdk" }, { name = "httpx" }, { name = "langchain-openai" }, + { name = "nest-asyncio" }, { name = "opentelemetry-api" }, { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "opentelemetry-instrumentation-httpx" }, @@ -2744,6 +2711,7 @@ dependencies = [ { name = "pandas" }, { name = "pandas-stubs" }, { name = "pyarrow" }, + { name = "python-dotenv" }, { name = "ragas" }, { name = "requests" }, { name = "types-requests" }, @@ -2766,6 +2734,7 @@ requires-dist = [ { name = "a2a-sdk", specifier = ">=0.3.10" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "langchain-openai", specifier = ">=1.0.2" }, + { name = "nest-asyncio", specifier = ">=1.6.0" }, { name = "opentelemetry-api", specifier = ">=1.20.0" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.20.0" }, { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.45b0" }, @@ -2773,7 +2742,8 @@ requires-dist = [ { name = "pandas", specifier = ">=2.3.3" }, { name = "pandas-stubs", specifier = ">=2.3.0" }, { name = "pyarrow", specifier = ">=21.0.0" }, - { name = "ragas", specifier = ">=0.3.5" }, + { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "ragas", extras = ["ag-ui"], specifier = ">=0.4.1" }, { name = "requests", specifier = ">=2.32.5" }, { name = "types-requests", specifier = ">=2.32.0" }, ] From a0a498f898dd844fcc9b355c50f07bd2775aea5c Mon Sep 17 00:00:00 2001 From: Florian Mallmann Date: Tue, 6 Jan 2026 12:54:31 +0100 Subject: [PATCH 2/8] feat: Update evaluation process to use metrics configuration file --- CLAUDE.md | 115 +++++++++- deploy/local/data-server/configmap.yaml | 2 +- scripts/evaluate.py | 276 +++++++++++++++++++----- scripts/run.py | 4 +- tests/test_evaluate.py | 226 +++++++++++++------ 5 files changed, 504 insertions(+), 119 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index aed0ec3..8b6e8a0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -60,13 +60,124 @@ uv run python3 scripts/setup.py "http://localhost:11020/dataset.csv" # Phase 2: Execute queries through agent via A2A protocol uv run python3 scripts/run.py "http://localhost:11010" -# Phase 3: Evaluate responses using RAGAS metrics -uv run python3 scripts/evaluate.py gemini-2.5-flash-lite "faithfulness answer_relevancy" +# Phase 3: Evaluate responses using RAGAS metrics (uses default config) +uv run python3 scripts/evaluate.py gemini-2.5-flash-lite + +# Or specify a custom config +# uv run python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config examples/metrics_advanced.json # Phase 4: Publish metrics to OTLP endpoint uv run python3 scripts/publish.py "workflow-name" ``` +### Metrics Configuration + +**BREAKING CHANGE:** Metrics must now be specified via configuration file using `--metrics-config`. + +#### Quick Start + +**Using Default Config:** +```shell +# Uses examples/metrics_simple.json by default +uv run python3 scripts/evaluate.py gemini-2.5-flash-lite +``` + +**Using Custom Config:** +```shell +# Simple metrics (pre-configured instances) +uv run python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config examples/metrics_simple.json + +# Advanced metrics (custom AspectCritic definitions) +uv run python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config examples/metrics_advanced.json +``` + +#### Configuration File Format + +Both JSON and YAML formats are supported: + +**JSON Example** (`examples/metrics_simple.json`): +```json +{ + "version": "1.0", + "metrics": [ + {"type": "instance", "name": "faithfulness"}, + {"type": "instance", "name": "answer_relevancy"}, + {"type": "instance", "name": "context_precision"} + ] +} +``` + +**Advanced Configuration** (`examples/metrics_advanced.json`): +```json +{ + "version": "1.0", + "metrics": [ + { + "type": "instance", + "name": "faithfulness" + }, + { + "type": "class", + "class_name": "AspectCritic", + "parameters": { + "name": "harmfulness", + "definition": "Does this contain harmful content?" + } + } + ] +} +``` + +#### Available Metrics + +**Pre-configured Instances** (type: `instance`): +- `faithfulness` - Measures factual consistency with contexts +- `answer_relevancy` - Measures relevance of response to query +- `answer_correctness` - Measures correctness vs reference +- `answer_similarity` - Semantic similarity to reference +- `context_precision` - Precision of retrieved contexts +- `context_recall` - Recall of retrieved contexts +- `context_entity_recall` - Entity-level context recall +- `multimodal_faithness` - Faithfulness for multimodal content +- `multimodal_relevance` - Relevance for multimodal content +- `summarization_score` - Quality of summarization + +**Configurable Classes** (type: `class`): +- `AspectCritic` - Custom aspect-based evaluation (REQUIRES: `name`, `definition`) +- `Faithfulness` - Configurable faithfulness (OPTIONAL: `strictness`, `max_retries`) +- `AnswerRelevancy` - Configurable relevancy (OPTIONAL: `strictness`) +- Plus 30+ other classes + +To see all available metrics: +```shell +uv run python3 scripts/evaluate.py --help +``` + +#### Migration from Old CLI + +Old usage (NO LONGER WORKS): +```shell +# This will fail +uv run python3 scripts/evaluate.py gemini-2.5-flash-lite faithfulness answer_relevancy +``` + +New usage: +```shell +# Create config file +cat > my_metrics.json << EOF +{ + "version": "1.0", + "metrics": [ + {"type": "instance", "name": "faithfulness"}, + {"type": "instance", "name": "answer_relevancy"} + ] +} +EOF + +# Use config file +uv run python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config my_metrics.json +``` + ### Testkube Execution ```shell diff --git a/deploy/local/data-server/configmap.yaml b/deploy/local/data-server/configmap.yaml index 2aa66bd..cd3805c 100644 --- a/deploy/local/data-server/configmap.yaml +++ b/deploy/local/data-server/configmap.yaml @@ -5,7 +5,7 @@ metadata: data: dataset.json: | [ - {"user_input": [{"content": "I need to increase my credit limit and check why my last transaction at Walmart was declined.", "type": "human"}, {"content": "That's not possible, I had enough money in my account.", "type": "human"}, {"content": "Oh, I forgot about the hotel booking.", "type": "human"}, {"content": "What about increasing my credit limit?", "type": "human"}]} + {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}, {"content": "What time is it in New York?", "type": "human"}]} ] dataset.csv: | user_input,retrieved_contexts,reference diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 9aff38a..3f0f37b 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -3,7 +3,6 @@ import json import logging import os -from argparse import ArgumentError from dataclasses import asdict, dataclass from logging import Logger from typing import Any @@ -22,59 +21,203 @@ logger: Logger = logging.getLogger(__name__) -def get_available_metrics() -> dict[str, Metric]: +def discover_metrics() -> tuple[dict[str, Metric], dict[str, type[Metric]]]: """ - Loads all Metric classes from Ragas - Returns a dict mapping metric names to metric instances. + Discover both pre-configured instances and metric classes from Ragas. + + Returns: + (instances, classes) where: + - instances: dict mapping names to pre-configured Metric instances + - classes: dict mapping class names to Metric class types """ - available_metrics: dict[str, Metric] = {} + instances: dict[str, Metric] = {} + classes: dict[str, type[Metric]] = {} # Iterate through all members of the metrics module for name, obj in inspect.getmembers(metrics_module): - # Check if it's a class and is a subclass of Metric (but not Metric itself) + if name.startswith('_'): + continue + + # Check if it's a Metric class (but not base Metric) if inspect.isclass(obj) and issubclass(obj, Metric) and obj is not Metric: + classes[name] = obj + # Check if it's a pre-configured metric instance + elif isinstance(obj, Metric): + metric_name = obj.name if hasattr(obj, 'name') else name + instances[metric_name] = obj + + return instances, classes + + +# Discover all available metric instances and classes +AVAILABLE_METRIC_INSTANCES, AVAILABLE_METRIC_CLASSES = discover_metrics() + + +def get_metric_by_name(metric_name: str) -> Metric: + """ + Get a metric instance by name (checks instances first, then tries classes). + + Args: + metric_name: Name of metric (e.g., "faithfulness", "answer_relevancy") + + Returns: + Metric instance + + Raises: + ValueError: If metric not found or can't be instantiated + """ + # Try pre-configured instance first + if metric_name in AVAILABLE_METRIC_INSTANCES: + return AVAILABLE_METRIC_INSTANCES[metric_name] + + # Try to find class by name (case-insensitive matching) + for class_name, metric_class in AVAILABLE_METRIC_CLASSES.items(): + if class_name.lower() == metric_name.lower(): try: - # Instantiate the metric - metric_instance = obj() - # Use the metric's name attribute - metric_name = metric_instance.name - available_metrics[metric_name] = metric_instance - except Exception: - # Skip metrics that can't be instantiated without parameters - logger.info(f"Exception encountered: {Exception}") - pass + return metric_class() + except Exception as e: + raise ValueError( + f"Metric class '{class_name}' requires parameters: {e}\n" + f"Use --metrics-config to provide configuration." + ) + + raise ValueError( + f"Unknown metric '{metric_name}'.\n" + f"Available instances: {', '.join(sorted(AVAILABLE_METRIC_INSTANCES.keys()))}\n" + f"Available classes: {', '.join(sorted(AVAILABLE_METRIC_CLASSES.keys()))}" + ) + + +def instantiate_metric_from_class(class_name: str, parameters: dict[str, Any]) -> Metric: + """ + Instantiate a metric class with custom parameters. + + Args: + class_name: Name of metric class (e.g., "AspectCritic") + parameters: Dictionary of constructor parameters - return available_metrics + Returns: + Metric instance + Raises: + ValueError: If class not found or instantiation fails + """ + if class_name not in AVAILABLE_METRIC_CLASSES: + raise ValueError( + f"Unknown metric class '{class_name}'.\n" + f"Available classes: {', '.join(sorted(AVAILABLE_METRIC_CLASSES.keys()))}" + ) -# Get all available metrics -AVAILABLE_METRICS = get_available_metrics() + metric_class = AVAILABLE_METRIC_CLASSES[class_name] + try: + return metric_class(**parameters) + except TypeError as e: + # Extract signature for helpful error message + sig = inspect.signature(metric_class.__init__) + raise ValueError(f"Invalid parameters for {class_name}: {e}\n" f"Expected signature: {sig}") -def convert_metrics(metrics: list[str]) -> list: + +def _load_metric_from_definition(metric_def: dict) -> Metric: """ - Map metric names to actual metric objects + Load a single metric from its configuration definition. Args: - metrics: List of metric names as strings (e.g., ["faithfulness", "answer_relevancy"]) + metric_def: Dictionary containing metric definition Returns: - List containing metric objects + Metric instance + + Raises: + ValueError: If definition is invalid or metric can't be loaded """ + # Validate required fields + if 'type' not in metric_def: + raise ValueError("Metric definition must include 'type' field") - # Map metric names to actual metric objects - metric_objects = [] - for metric_name in metrics: - if metric_name in AVAILABLE_METRICS: - metric_objects.append(AVAILABLE_METRICS[metric_name]) - else: - logger.warning(f"Unknown metric '{metric_name}', skipping...") - logger.warning(f"Available metrics: {', '.join(AVAILABLE_METRICS.keys())}") + metric_type = metric_def['type'] - if not metric_objects: - raise ValueError("No valid metrics provided for evaluation") + if metric_type == 'instance': + # Load pre-configured instance + if 'name' not in metric_def: + raise ValueError("Instance type requires 'name' field") - return metric_objects + name = metric_def['name'] + if name not in AVAILABLE_METRIC_INSTANCES: + raise ValueError( + f"Unknown instance '{name}'.\n" + f"Available: {', '.join(sorted(AVAILABLE_METRIC_INSTANCES.keys()))}" + ) + + return AVAILABLE_METRIC_INSTANCES[name] + + elif metric_type == 'class': + # Instantiate class with parameters + if 'class_name' not in metric_def: + raise ValueError("Class type requires 'class_name' field") + + class_name = metric_def['class_name'] + parameters = metric_def.get('parameters', {}) + + return instantiate_metric_from_class(class_name, parameters) + + else: + raise ValueError(f"Unknown metric type '{metric_type}'.\n" f"Supported types: 'instance', 'class'") + + +def load_metrics_config(config_path: str) -> list[Metric]: + """ + Load metrics configuration from JSON or YAML file. + + Args: + config_path: Path to configuration file (.json or .yaml/.yml) + + Returns: + List of configured Metric instances + + Raises: + ValueError: If config file invalid or metrics can't be loaded + """ + # Determine file format and load + if config_path.endswith('.json'): + with open(config_path, 'r') as f: + config = json.load(f) + elif config_path.endswith(('.yaml', '.yml')): + try: + import yaml + except ImportError: + raise ValueError( + "YAML support requires 'pyyaml' package.\n" + "Install with: uv add pyyaml\n" + "Or use JSON format instead: metrics.json" + ) + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + else: + raise ValueError( + f"Unsupported config file format: {config_path}\n" f"Supported formats: .json, .yaml, .yml" + ) + + # Validate config structure + if 'metrics' not in config: + raise ValueError("Config file must contain 'metrics' key") + + if not isinstance(config['metrics'], list): + raise ValueError("'metrics' must be a list") + + # Load each metric + metrics: list[Metric] = [] + for i, metric_def in enumerate(config['metrics']): + try: + metric = _load_metric_from_definition(metric_def) + metrics.append(metric) + except Exception as e: + raise ValueError(f"Error loading metric at index {i}: {e}") + + if not metrics: + raise ValueError("Config file contains no valid metrics") + + return metrics @dataclass @@ -164,7 +307,7 @@ def format_evaluation_scores( def main( output_file: str, model: str, - metrics: list[str] | None = None, + metrics_config: str, cost_per_input_token: float = 5.0 / 1e6, cost_per_output_token: float = 15.0 / 1e6, ) -> None: @@ -174,11 +317,14 @@ def main( Args: output_file: Path to save evaluation_scores.json model: Model name to use for evaluation - metrics: List of metric names to calculate + metrics_config: Path to metrics configuration file (JSON or YAML) + cost_per_input_token: Cost per input token + cost_per_output_token: Cost per output token """ - # Check if any metrics were provided - if metrics is None: - raise ArgumentError(argument=metrics, message="No metrics were provided as arguments") + # Load metrics from configuration file + logger.info(f"Loading metrics from config: {metrics_config}") + metrics = load_metrics_config(metrics_config) + logger.info(f"Loaded {len(metrics)} metrics: {', '.join([m.name for m in metrics])}") # Create LLM client using the AI-Gateway # Setting a placeholder for the api_key since we instantiate a ChatOpenAI object, @@ -189,11 +335,18 @@ def main( dataset = EvaluationDataset.from_jsonl("data/experiments/ragas_experiment.jsonl") + # Detect and log dataset type + if dataset.samples: + from ragas.dataset_schema import MultiTurnSample + + is_multi_turn = isinstance(dataset.samples[0], MultiTurnSample) + logger.info(f"Loaded {'multi-turn' if is_multi_turn else 'single-turn'} dataset") + # Calculate metrics - logger.info(f"Calculating metrics: {', '.join(metrics)}...") + logger.info(f"Calculating metrics: {', '.join([m.name for m in metrics])}...") ragas_result = evaluate( dataset=dataset, - metrics=convert_metrics(metrics), + metrics=metrics, llm=llm, token_usage_parser=get_token_usage_for_openai, ) @@ -223,16 +376,33 @@ def main( if __name__ == "__main__": - # Parse the parameters (model and metrics) evaluate.py was called with + # Parse the parameters (model and metrics-config) evaluate.py was called with parser = argparse.ArgumentParser( - description="Evaluate results using RAGAS metrics", + description="Evaluate results using RAGAS metrics via configuration file", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=f""" - Available metrics: {", ".join(AVAILABLE_METRICS.keys())} - - Examples: - python3 scripts/evaluate.py gemini-2.5-flash-lite faithfulness - python3 scripts/evaluate.py gemini-2.5-flash-lite faithfulness context_precision context_recall +Available metric instances (pre-configured): + {', '.join(sorted(AVAILABLE_METRIC_INSTANCES.keys()))} + +Available metric classes (configurable via --metrics-config): + {', '.join(sorted(AVAILABLE_METRIC_CLASSES.keys()))} + +Examples: + python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config examples/metrics_simple.json + python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config examples/metrics_advanced.json + +Config file format (JSON): + {{ + "version": "1.0", + "metrics": [ + {{"type": "instance", "name": "faithfulness"}}, + {{ + "type": "class", + "class_name": "AspectCritic", + "parameters": {{"name": "harmfulness", "definition": "Is this harmful?"}} + }} + ] + }} """, ) @@ -243,10 +413,10 @@ def main( ) parser.add_argument( - "metrics", - nargs="+", - choices=list(AVAILABLE_METRICS.keys()), - help="At least one (or more) metrics to evaluate (e.g., faithfulness, answer_relevancy)", + "--metrics-config", + type=str, + default="config/metrics.json", + help="Path to metrics configuration file (JSON or YAML). Default: examples/metrics_simple.json", ) parser.add_argument( @@ -265,11 +435,11 @@ def main( args = parser.parse_args() - # Run evaluation with the 'model' and 'metrics' provided as parameters, 'output_file' is hardcoded + # Run evaluation with the 'model' and 'metrics_config' provided as parameters, 'output_file' is hardcoded main( output_file="data/results/evaluation_scores.json", model=args.model, - metrics=args.metrics, + metrics_config=args.metrics_config, cost_per_input_token=args.cost_per_input, cost_per_output_token=args.cost_per_output, ) diff --git a/scripts/run.py b/scripts/run.py index 4563715..e5c7818 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -306,7 +306,7 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict # Add agent response to conversation if agent_response_text: - conversation_messages.append({"content": agent_response_text, "type": "agent"}) + conversation_messages.append({"content": agent_response_text, "type": "ai"}) logger.info(f"Agent response: {agent_response_text[:100]}...") else: logger.warning(f"Empty agent response for turn {turn_idx + 1}") @@ -331,12 +331,14 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict # Return minimal result return { + **row, "user_input": row.get("user_input"), "trace_id": trace_id, } # Return result in MultiTurnSample format result = { + **row, "user_input": user_input_serialized, "trace_id": trace_id, } diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index e9ce438..8098eaa 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -4,12 +4,12 @@ Tests the RAGAS evaluation functionality. """ +import inspect import json import os import shutil import sys import tempfile -from argparse import ArgumentError from pathlib import Path import pandas as pd @@ -18,7 +18,15 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) -from evaluate import AVAILABLE_METRICS, convert_metrics, format_evaluation_scores, main +from evaluate import ( + AVAILABLE_METRIC_CLASSES, + AVAILABLE_METRIC_INSTANCES, + format_evaluation_scores, + get_metric_by_name, + instantiate_metric_from_class, + load_metrics_config, + main, +) # Fixtures @@ -226,27 +234,26 @@ def total_cost(self, **kwargs): # TestMain tests -def test_main_no_metrics(experiment_data): - """Test main function with no metrics provided""" +def test_main_no_config(experiment_data): + """Test main function with missing metrics config file""" tmp, original_cwd, experiment_file = experiment_data os.chdir(tmp) try: - # When metrics is None, the function should raise an error - # The actual error type depends on implementation - with pytest.raises(ArgumentError): + # When config file doesn't exist, should raise FileNotFoundError + with pytest.raises(FileNotFoundError): main( output_file="results/evaluation_scores.json", model="gemini-flash-latest", - metrics=None, + metrics_config="nonexistent_config.json", ) finally: os.chdir(original_cwd) -def test_main_successful_execution(experiment_data, monkeypatch): - """Test main function successful execution""" +def test_main_successful_execution(experiment_data, monkeypatch, tmp_path): + """Test main function successful execution with config file""" from pathlib import Path from ragas.dataset_schema import EvaluationResult @@ -255,9 +262,20 @@ def test_main_successful_execution(experiment_data, monkeypatch): os.chdir(tmp) try: + # Create a test config file + config_file = tmp_path / "test_metrics.json" + if not AVAILABLE_METRIC_INSTANCES: + pytest.skip("No metric instances available") + + valid_metric = list(AVAILABLE_METRIC_INSTANCES.keys())[0] + config = {"version": "1.0", "metrics": [{"type": "instance", "name": valid_metric}]} + + with open(config_file, 'w') as f: + json.dump(config, f) + # Mock EvaluationDataset.from_jsonl class MockEvaluationDataset: - pass + samples = [] # Add samples attribute for dataset type detection mock_dataset = MockEvaluationDataset() @@ -307,18 +325,12 @@ def __init__(self, llm): monkeypatch.setattr("evaluate.ChatOpenAI", mock_chat_openai_init) monkeypatch.setattr("evaluate.LangchainLLMWrapper", MockLLMWrapper) - # Get a valid metric name - if not AVAILABLE_METRICS: - pytest.skip("No metrics available") - - valid_metric = list(AVAILABLE_METRICS.keys())[0] - - # Run main + # Run main with config file output_file = "results/evaluation_scores.json" main( output_file=output_file, model="gemini-flash-latest", - metrics=[valid_metric], + metrics_config=str(config_file), ) # Verify output file was created @@ -337,66 +349,156 @@ def __init__(self, llm): os.chdir(original_cwd) -# TestAvailableMetrics tests -def test_available_metrics_loaded(): - """Test that AVAILABLE_METRICS is populated correctly""" - # Should be a non-empty dictionary - assert isinstance(AVAILABLE_METRICS, dict) - assert len(AVAILABLE_METRICS) > 0 +# TestMetricDiscovery tests +def test_metric_discovery(): + """Test that both metric instances and classes are discovered""" + # Test instances + assert isinstance(AVAILABLE_METRIC_INSTANCES, dict) + assert len(AVAILABLE_METRIC_INSTANCES) > 0 + for name, instance in AVAILABLE_METRIC_INSTANCES.items(): + assert isinstance(name, str) + assert isinstance(instance, Metric) + + # Test classes + assert isinstance(AVAILABLE_METRIC_CLASSES, dict) + assert len(AVAILABLE_METRIC_CLASSES) > 0 + for name, cls in AVAILABLE_METRIC_CLASSES.items(): + assert isinstance(name, str) + assert inspect.isclass(cls) + assert issubclass(cls, Metric) + + +# Test get_metric_by_name +def test_get_metric_by_name_instance(): + """Test getting pre-configured metric instance""" + if not AVAILABLE_METRIC_INSTANCES: + pytest.skip("No metric instances available") + + # Get first available instance + metric_name = list(AVAILABLE_METRIC_INSTANCES.keys())[0] + metric = get_metric_by_name(metric_name) + assert isinstance(metric, Metric) + assert metric.name == metric_name + + +def test_get_metric_by_name_unknown(): + """Test error handling for unknown metric""" + with pytest.raises(ValueError, match="Unknown metric"): + get_metric_by_name('nonexistent_metric_xyz') + + +# Test instantiate_metric_from_class +def test_instantiate_metric_from_class_success(): + """Test successful class instantiation without parameters""" + if not AVAILABLE_METRIC_CLASSES: + pytest.skip("No metric classes available") + + # Find a class that can be instantiated without parameters + for class_name, metric_class in AVAILABLE_METRIC_CLASSES.items(): + try: + metric = instantiate_metric_from_class(class_name, {}) + assert isinstance(metric, Metric) + return # Success! + except (TypeError, ValueError): + continue # Try next class + pytest.skip("No metric classes can be instantiated without parameters") + + +def test_instantiate_metric_from_class_unknown(): + """Test error for unknown class""" + with pytest.raises(ValueError, match="Unknown metric class"): + instantiate_metric_from_class('NonexistentClass', {}) + + +def test_instantiate_metric_from_class_invalid_params(): + """Test error for invalid parameters""" + if not AVAILABLE_METRIC_CLASSES: + pytest.skip("No metric classes available") + + # Use first available class with clearly invalid parameters + class_name = list(AVAILABLE_METRIC_CLASSES.keys())[0] + with pytest.raises(ValueError, match="Invalid parameters"): + instantiate_metric_from_class(class_name, {'completely_invalid_param_name_xyz': 'value'}) + + +# Test load_metrics_config +def test_load_metrics_config_json(tmp_path): + """Test loading metrics from JSON config file""" + if not AVAILABLE_METRIC_INSTANCES: + pytest.skip("No metric instances available") + + config_file = tmp_path / "metrics.json" + metric_name = list(AVAILABLE_METRIC_INSTANCES.keys())[0] + + config = {"version": "1.0", "metrics": [{"type": "instance", "name": metric_name}]} + + with open(config_file, 'w') as f: + json.dump(config, f) - # All keys should be strings - for key in AVAILABLE_METRICS.keys(): - assert isinstance(key, str) + metrics = load_metrics_config(str(config_file)) + assert len(metrics) == 1 + assert isinstance(metrics[0], Metric) + assert metrics[0].name == metric_name - # All values should be Metric instances - for value in AVAILABLE_METRICS.values(): - assert isinstance(value, Metric) +def test_load_metrics_config_with_class(tmp_path): + """Test loading metrics with class instantiation""" + if not AVAILABLE_METRIC_CLASSES: + pytest.skip("No metric classes available") -# TestConvertMetrics tests -def test_convert_metrics_with_valid_metrics(): - """Test that convert_metrics correctly converts valid metric names to objects""" + # Find a class that can be instantiated without parameters + for class_name in AVAILABLE_METRIC_CLASSES.keys(): + try: + # Test if this class can be instantiated + instantiate_metric_from_class(class_name, {}) - # Use metrics that are commonly available in RAGAS - metric_names = ["faithfulness", "answer_relevancy"] + config_file = tmp_path / "metrics.json" + config = { + "version": "1.0", + "metrics": [{"type": "class", "class_name": class_name, "parameters": {}}], + } - # Only test with metrics that actually exist in AVAILABLE_METRICS - available_names = [name for name in metric_names if name in AVAILABLE_METRICS] + with open(config_file, 'w') as f: + json.dump(config, f) - if not available_names: - pytest.skip("Required metrics not available in this RAGAS version") + metrics = load_metrics_config(str(config_file)) + assert len(metrics) == 1 + assert isinstance(metrics[0], Metric) + return # Success! + except (TypeError, ValueError): + continue # Try next class - metric_objects = convert_metrics(available_names) + pytest.skip("No metric classes can be instantiated without parameters") - # Verify we got the right number of metrics - assert len(metric_objects) == len(available_names) - # Verify all returned objects are Metric instances - for obj in metric_objects: - assert isinstance(obj, Metric) +def test_load_metrics_config_invalid_format(tmp_path): + """Test error for invalid file format""" + config_file = tmp_path / "metrics.txt" + config_file.write_text("invalid") + with pytest.raises(ValueError, match="Unsupported config file format"): + load_metrics_config(str(config_file)) -def test_convert_metrics_with_invalid_metrics(): - """Test that convert_metrics handles invalid metric names""" - # Test with only invalid metrics - should raise ValueError - with pytest.raises(ValueError, match="No valid metrics provided"): - convert_metrics(["nonexistent_metric", "fake_metric"]) +def test_load_metrics_config_missing_metrics_key(tmp_path): + """Test error for missing 'metrics' key""" + config_file = tmp_path / "metrics.json" + with open(config_file, 'w') as f: + json.dump({"version": "1.0"}, f) -def test_convert_metrics_mixed_valid_invalid(): - """Test convert_metrics with mixed valid and invalid metric names""" + with pytest.raises(ValueError, match="must contain 'metrics' key"): + load_metrics_config(str(config_file)) - # Get one valid metric name from AVAILABLE_METRICS - if not AVAILABLE_METRICS: - pytest.skip("No metrics available") - valid_metric = list(AVAILABLE_METRICS.keys())[0] - metric_names = [valid_metric, "nonexistent_metric", "fake_metric"] +def test_load_metrics_config_empty_metrics(tmp_path): + """Test error for empty metrics list""" + config_file = tmp_path / "metrics.json" - metric_objects = convert_metrics(metric_names) + config = {"version": "1.0", "metrics": []} - # Should only return the valid metric - assert len(metric_objects) == 1 + with open(config_file, 'w') as f: + json.dump(config, f) - assert isinstance(metric_objects[0], Metric) + with pytest.raises(ValueError, match="contains no valid metrics"): + load_metrics_config(str(config_file)) From 24d6ed8393a9af206c3ca8ebd7b1a55a16e34b60 Mon Sep 17 00:00:00 2001 From: Florian Mallmann Date: Tue, 6 Jan 2026 14:52:42 +0100 Subject: [PATCH 3/8] feat: Refactor metrics handling with MetricsRegistry class for improved discovery and management --- CLAUDE.md | 8 +- DetailedUsageAndTroubleshooting.md | 7 +- README.md | 3 +- deploy/base/templates/evaluate-template.yaml | 7 +- deploy/base/templates/publish-template.yaml | 13 +- deploy/base/templates/run-template.yaml | 9 +- deploy/base/templates/setup-template.yaml | 6 +- deploy/local/kustomization.yaml | 1 + deploy/local/multi-turn-workflow.yaml | 42 ++- scripts/evaluate.py | 347 +++++++++++-------- scripts/publish.py | 45 ++- tests/test_evaluate.py | 267 ++++++++++---- tests/test_publish.py | 33 +- tests_e2e/test_e2e.py | 16 +- 14 files changed, 502 insertions(+), 302 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 8b6e8a0..70be778 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -66,8 +66,8 @@ uv run python3 scripts/evaluate.py gemini-2.5-flash-lite # Or specify a custom config # uv run python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config examples/metrics_advanced.json -# Phase 4: Publish metrics to OTLP endpoint -uv run python3 scripts/publish.py "workflow-name" +# Phase 4: Publish metrics to OTLP endpoint (requires execution_id and execution_number) +OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4318" uv run python3 scripts/publish.py "workflow-name" "exec-001" 1 ``` ### Metrics Configuration @@ -231,8 +231,8 @@ make run - **Purpose**: Calculates RAGAS metrics using LLM-as-a-judge via AI Gateway, tracks tokens and costs **Phase 4: Publish** (`scripts/publish.py`) -- **Input**: `data/results/evaluation_scores.json` + workflow name -- **Output**: Metrics published to OTLP endpoint +- **Input**: `data/results/evaluation_scores.json` + workflow name + execution ID + execution number +- **Output**: Metrics published to OTLP endpoint (configured via `OTEL_EXPORTER_OTLP_ENDPOINT` environment variable) - **Purpose**: Sends evaluation results to observability backend (LGTM/Grafana) via OpenTelemetry ### Data Flow diff --git a/DetailedUsageAndTroubleshooting.md b/DetailedUsageAndTroubleshooting.md index 93580ca..dbf31a2 100644 --- a/DetailedUsageAndTroubleshooting.md +++ b/DetailedUsageAndTroubleshooting.md @@ -164,7 +164,7 @@ Publishes evaluation metrics to an OpenTelemetry OTLP endpoint for monitoring. **Syntax:** ```shell -python3 scripts/publish.py [otlp_endpoint] +OTEL_EXPORTER_OTLP_ENDPOINT= python3 scripts/publish.py ``` **Arguments:** @@ -172,7 +172,10 @@ python3 scripts/publish.py [ot - `workflow_name` (required): Name of the test workflow (used as metric label) - `execution_id` (required): Testkube execution ID for this workflow run - `execution_number` (required): Numeric execution number for this workflow run (used as X-axis in Grafana) -- `otlp_endpoint` (optional): OTLP HTTP endpoint URL (default: `localhost:4318`) + +**Environment Variables:** + +- `OTEL_EXPORTER_OTLP_ENDPOINT` (optional): OTLP HTTP endpoint URL (default: `http://localhost:4318`) **Input:** diff --git a/README.md b/README.md index d7c947e..8de5b7d 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,8 @@ uv run python3 scripts/run.py "http://localhost:11010" uv run python3 scripts/evaluate.py gemini-2.5-flash-lite faithfulness answer_relevancy # 4. Publish metrics to OpenTelemetry (workflow_name, execution_id, execution_number) -uv run python3 scripts/publish.py "my-agent-evaluation" "local-exec-001" 1 +# Set OTLP endpoint via environment variable (defaults to http://localhost:4318 if not set) +OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4318" uv run python3 scripts/publish.py "my-agent-evaluation" "local-exec-001" 1 ``` ---- diff --git a/deploy/base/templates/evaluate-template.yaml b/deploy/base/templates/evaluate-template.yaml index 156491a..46716bf 100644 --- a/deploy/base/templates/evaluate-template.yaml +++ b/deploy/base/templates/evaluate-template.yaml @@ -12,9 +12,10 @@ spec: model: type: string description: "Model name to use for evaluation (e.g., gemini-2.5-flash-lite)" - metrics: + metricsConfigPath: type: string - description: "Space-separated list of metrics to evaluate (e.g., 'nv_accuracy context_precision')" + description: "Path to metrics configuration file (JSON or YAML)" + default: "config/metrics.yaml" image: type: string description: "Docker image to use for the evaluate step" @@ -29,7 +30,7 @@ spec: - -c args: - | - uv run python3 evaluate.py "{{ config.model }}" {{ config.metrics }} && \ + uv run python3 evaluate.py "{{ config.model }}" --metrics-config "{{ config.metricsConfigPath }}" && \ if [ -f data/results/evaluation_scores.json ]; then echo "✓ Evaluation completed" cat data/results/evaluation_scores.json diff --git a/deploy/base/templates/publish-template.yaml b/deploy/base/templates/publish-template.yaml index 74a4ea5..a58dc0e 100644 --- a/deploy/base/templates/publish-template.yaml +++ b/deploy/base/templates/publish-template.yaml @@ -6,18 +6,8 @@ metadata: labels: testkube.io/test-category: ragas-evaluation app: testworkflows -spec: - # Configuration parameters that can be overridden - config: - otlpEndpoint: - type: string - description: "URL of the OTLP endpoint" - default: "http://lgtm.monitoring:4318" - image: - type: string - description: "Docker image to use for the publish step" - default: "ghcr.io/agentic-layer/testbench/testworkflows:latest" +spec: # Steps to execute steps: - name: publish-metrics @@ -27,4 +17,3 @@ spec: - "{{ workflow.name }}" - "{{ execution.id }}" - "{{ execution.number }}" - - "{{ config.otlpEndpoint }}" diff --git a/deploy/base/templates/run-template.yaml b/deploy/base/templates/run-template.yaml index 7c24f87..71fae46 100644 --- a/deploy/base/templates/run-template.yaml +++ b/deploy/base/templates/run-template.yaml @@ -13,14 +13,13 @@ spec: agentUrl: type: string description: "URL to the agent endpoint (A2A protocol)" - image: - type: string - description: "Docker image to use for the run step" - default: "ghcr.io/agentic-layer/testbench/testworkflows:latest" # Steps to execute steps: - - name: run-agent-queries + - name: run + artifacts: + paths: + - "data/experiments/ragas_experiment.jsonl" run: args: - run.py diff --git a/deploy/base/templates/setup-template.yaml b/deploy/base/templates/setup-template.yaml index 4e2c33b..7f88a56 100644 --- a/deploy/base/templates/setup-template.yaml +++ b/deploy/base/templates/setup-template.yaml @@ -13,14 +13,10 @@ spec: datasetUrl: type: string description: "URL to the dataset file (.csv, .json, or .parquet)" - image: - type: string - description: "Docker image to use for the setup step" - default: "ghcr.io/agentic-layer/testbench/testworkflows:latest" # Steps to execute steps: - - name: setup-dataset + - name: setup artifacts: paths: - "data/datasets/ragas_dataset.jsonl" diff --git a/deploy/local/kustomization.yaml b/deploy/local/kustomization.yaml index 9df4e84..0901905 100644 --- a/deploy/local/kustomization.yaml +++ b/deploy/local/kustomization.yaml @@ -5,4 +5,5 @@ resources: - weather-agent.yaml - data-server/ - ../base + - multi-turn-metrics-configmap.yaml - multi-turn-workflow.yaml diff --git a/deploy/local/multi-turn-workflow.yaml b/deploy/local/multi-turn-workflow.yaml index 9134000..60a7897 100644 --- a/deploy/local/multi-turn-workflow.yaml +++ b/deploy/local/multi-turn-workflow.yaml @@ -8,18 +8,24 @@ metadata: app: testworkflows spec: + # Pod configuration with volumes + pod: + volumes: + - name: metrics-config + configMap: + name: multi-turn-metrics-config + container: image: ghcr.io/agentic-layer/testbench/testworkflows:latest env: - name: OPENAI_API_BASE value: "http://ai-gateway-litellm.ai-gateway:4000" - - # Global configuration that applies to all steps - config: - # Dataset configuration - datasetUrl: - type: string - description: "URL to the dataset file" + - name: OTEL_EXPORTER_OTLP_ENDPOINT + value: "http://lgtm.monitoring:4318" + volumeMounts: + - name: metrics-config + mountPath: /app/config/metrics.yaml + subPath: metrics.yaml # Steps using the templates steps: @@ -28,4 +34,24 @@ spec: use: - name: ragas-setup-template config: - datasetUrl: "{{ config.datasetUrl }}" \ No newline at end of file + datasetUrl: "http://data-server.data-server:8000/dataset.csv" + + # Step 2: Run - Execute agent queries + - name: run + use: + - name: ragas-run-template + config: + agentUrl: "http://weather-agent.sample-agents:8000" + + # Step 3: Evaluate - Run RAGAS evaluation + - name: evaluate + use: + - name: ragas-evaluate-template + config: + model: "gemini-2.5-flash-lite" + metricsConfigPath: "/app/config/metrics.yaml" + + # Step 4: Publish - Push metrics to OTLP + - name: publish + use: + - name: ragas-publish-template \ No newline at end of file diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 3f0f37b..1d9c92a 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -21,156 +21,234 @@ logger: Logger = logging.getLogger(__name__) -def discover_metrics() -> tuple[dict[str, Metric], dict[str, type[Metric]]]: - """ - Discover both pre-configured instances and metric classes from Ragas. +class MetricsRegistry: + """Registry for RAGAS metrics discovery and management.""" + + def __init__(self): + """Initialize registry and discover available metrics.""" + self._instances: dict[str, Metric] = {} + self._classes: dict[str, type[Metric]] = {} + self._discover_metrics() + + def _discover_metrics(self) -> None: + """ + Discover both pre-configured instances and metric classes from Ragas. + + Populates _instances and _classes dictionaries. + """ + for name, obj in inspect.getmembers(metrics_module): + if name.startswith("_"): + continue + + if inspect.isclass(obj) and issubclass(obj, Metric) and obj is not Metric: + self._classes[name] = obj + elif isinstance(obj, Metric): + metric_name = obj.name if hasattr(obj, "name") else name + self._instances[metric_name] = obj + + def get_instance(self, name: str) -> Metric: + """ + Get pre-configured metric instance by name. + + Args: + name: Instance name + + Returns: + Metric instance + + Raises: + ValueError: If instance not found + """ + if name not in self._instances: + raise ValueError( + f"Unknown instance '{name}'.\n" + f"Available: {', '.join(sorted(self._instances.keys()))}" + ) + return self._instances[name] - Returns: - (instances, classes) where: - - instances: dict mapping names to pre-configured Metric instances - - classes: dict mapping class names to Metric class types - """ - instances: dict[str, Metric] = {} - classes: dict[str, type[Metric]] = {} + def get_class(self, name: str) -> type[Metric]: + """ + Get metric class by name. - # Iterate through all members of the metrics module - for name, obj in inspect.getmembers(metrics_module): - if name.startswith('_'): - continue + Args: + name: Class name - # Check if it's a Metric class (but not base Metric) - if inspect.isclass(obj) and issubclass(obj, Metric) and obj is not Metric: - classes[name] = obj - # Check if it's a pre-configured metric instance - elif isinstance(obj, Metric): - metric_name = obj.name if hasattr(obj, 'name') else name - instances[metric_name] = obj + Returns: + Metric class type - return instances, classes + Raises: + ValueError: If class not found + """ + if name not in self._classes: + raise ValueError( + f"Unknown class '{name}'.\n" + f"Available: {', '.join(sorted(self._classes.keys()))}" + ) + return self._classes[name] + def instantiate_class(self, class_name: str, parameters: dict[str, Any]) -> Metric: + """ + Instantiate metric class with custom parameters. -# Discover all available metric instances and classes -AVAILABLE_METRIC_INSTANCES, AVAILABLE_METRIC_CLASSES = discover_metrics() + Args: + class_name: Name of metric class + parameters: Dictionary of constructor parameters + Returns: + Metric instance -def get_metric_by_name(metric_name: str) -> Metric: - """ - Get a metric instance by name (checks instances first, then tries classes). + Raises: + ValueError: If class not found or instantiation fails + """ + metric_class = self.get_class(class_name) - Args: - metric_name: Name of metric (e.g., "faithfulness", "answer_relevancy") + try: + return metric_class(**parameters) + except TypeError as e: + sig = inspect.signature(metric_class.__init__) + raise ValueError( + f"Invalid parameters for {class_name}: {e}\n" + f"Expected signature: {sig}" + ) - Returns: - Metric instance + def _load_metric_from_definition(self, metric_def: dict) -> Metric: + """ + Load a single metric from its configuration definition. - Raises: - ValueError: If metric not found or can't be instantiated - """ - # Try pre-configured instance first - if metric_name in AVAILABLE_METRIC_INSTANCES: - return AVAILABLE_METRIC_INSTANCES[metric_name] + Args: + metric_def: Dictionary containing metric definition + + Returns: + Metric instance - # Try to find class by name (case-insensitive matching) - for class_name, metric_class in AVAILABLE_METRIC_CLASSES.items(): - if class_name.lower() == metric_name.lower(): + Raises: + ValueError: If definition is invalid or metric can't be loaded + """ + if "type" not in metric_def: + raise ValueError("Metric definition must include 'type' field") + + metric_type = metric_def["type"] + + if metric_type == "instance": + if "name" not in metric_def: + raise ValueError("Instance type requires 'name' field") + return self.get_instance(metric_def["name"]) + + elif metric_type == "class": + if "class_name" not in metric_def: + raise ValueError("Class type requires 'class_name' field") + + class_name = metric_def["class_name"] + parameters = metric_def.get("parameters", {}) + return self.instantiate_class(class_name, parameters) + + else: + raise ValueError( + f"Unknown metric type '{metric_type}'.\n" + f"Supported types: 'instance', 'class'" + ) + + def load_from_config(self, config_path: str) -> list[Metric]: + """ + Load metrics configuration from JSON or YAML file. + + Args: + config_path: Path to configuration file (.json or .yaml/.yml) + + Returns: + List of configured Metric instances + + Raises: + ValueError: If config file invalid or metrics can't be loaded + """ + if config_path.endswith(".json"): + with open(config_path, "r") as f: + config = json.load(f) + elif config_path.endswith((".yaml", ".yml")): try: - return metric_class() - except Exception as e: + import yaml + except ImportError: raise ValueError( - f"Metric class '{class_name}' requires parameters: {e}\n" - f"Use --metrics-config to provide configuration." + "YAML support requires 'pyyaml' package.\n" + "Install with: uv add pyyaml\n" + "Or use JSON format instead: metrics.json" ) + with open(config_path, "r") as f: + config = yaml.safe_load(f) + else: + raise ValueError( + f"Unsupported config file format: {config_path}\n" + f"Supported formats: .json, .yaml, .yml" + ) - raise ValueError( - f"Unknown metric '{metric_name}'.\n" - f"Available instances: {', '.join(sorted(AVAILABLE_METRIC_INSTANCES.keys()))}\n" - f"Available classes: {', '.join(sorted(AVAILABLE_METRIC_CLASSES.keys()))}" - ) + if "metrics" not in config: + raise ValueError("Config file must contain 'metrics' key") + if not isinstance(config["metrics"], list): + raise ValueError("'metrics' must be a list") -def instantiate_metric_from_class(class_name: str, parameters: dict[str, Any]) -> Metric: - """ - Instantiate a metric class with custom parameters. + metrics: list[Metric] = [] + for i, metric_def in enumerate(config["metrics"]): + try: + metric = self._load_metric_from_definition(metric_def) + metrics.append(metric) + except Exception as e: + raise ValueError(f"Error loading metric at index {i}: {e}") - Args: - class_name: Name of metric class (e.g., "AspectCritic") - parameters: Dictionary of constructor parameters + if not metrics: + raise ValueError("Config file contains no valid metrics") - Returns: - Metric instance + return metrics - Raises: - ValueError: If class not found or instantiation fails - """ - if class_name not in AVAILABLE_METRIC_CLASSES: - raise ValueError( - f"Unknown metric class '{class_name}'.\n" - f"Available classes: {', '.join(sorted(AVAILABLE_METRIC_CLASSES.keys()))}" - ) + def list_instances(self) -> list[str]: + """Return sorted list of available instance names.""" + return sorted(self._instances.keys()) - metric_class = AVAILABLE_METRIC_CLASSES[class_name] + def list_classes(self) -> list[str]: + """Return sorted list of available class names.""" + return sorted(self._classes.keys()) - try: - return metric_class(**parameters) - except TypeError as e: - # Extract signature for helpful error message - sig = inspect.signature(metric_class.__init__) - raise ValueError(f"Invalid parameters for {class_name}: {e}\n" f"Expected signature: {sig}") + @classmethod + def create_default(cls) -> "MetricsRegistry": + """Factory method for default registry with auto-discovery.""" + return cls() -def _load_metric_from_definition(metric_def: dict) -> Metric: +def instantiate_metric_from_class( + class_name: str, + parameters: dict[str, Any], + registry: MetricsRegistry | None = None +) -> Metric: """ - Load a single metric from its configuration definition. + Instantiate a metric class with custom parameters. Args: - metric_def: Dictionary containing metric definition + class_name: Name of metric class + parameters: Dictionary of constructor parameters + registry: Optional registry (None = create default) Returns: Metric instance Raises: - ValueError: If definition is invalid or metric can't be loaded + ValueError: If class not found or instantiation fails """ - # Validate required fields - if 'type' not in metric_def: - raise ValueError("Metric definition must include 'type' field") - - metric_type = metric_def['type'] - - if metric_type == 'instance': - # Load pre-configured instance - if 'name' not in metric_def: - raise ValueError("Instance type requires 'name' field") - - name = metric_def['name'] - if name not in AVAILABLE_METRIC_INSTANCES: - raise ValueError( - f"Unknown instance '{name}'.\n" - f"Available: {', '.join(sorted(AVAILABLE_METRIC_INSTANCES.keys()))}" - ) - - return AVAILABLE_METRIC_INSTANCES[name] + if registry is None: + registry = MetricsRegistry.create_default() - elif metric_type == 'class': - # Instantiate class with parameters - if 'class_name' not in metric_def: - raise ValueError("Class type requires 'class_name' field") + return registry.instantiate_class(class_name, parameters) - class_name = metric_def['class_name'] - parameters = metric_def.get('parameters', {}) - return instantiate_metric_from_class(class_name, parameters) - - else: - raise ValueError(f"Unknown metric type '{metric_type}'.\n" f"Supported types: 'instance', 'class'") - - -def load_metrics_config(config_path: str) -> list[Metric]: +def load_metrics_config( + config_path: str, + registry: MetricsRegistry | None = None +) -> list[Metric]: """ Load metrics configuration from JSON or YAML file. Args: - config_path: Path to configuration file (.json or .yaml/.yml) + config_path: Path to configuration file + registry: Optional registry (None = create default) Returns: List of configured Metric instances @@ -178,46 +256,10 @@ def load_metrics_config(config_path: str) -> list[Metric]: Raises: ValueError: If config file invalid or metrics can't be loaded """ - # Determine file format and load - if config_path.endswith('.json'): - with open(config_path, 'r') as f: - config = json.load(f) - elif config_path.endswith(('.yaml', '.yml')): - try: - import yaml - except ImportError: - raise ValueError( - "YAML support requires 'pyyaml' package.\n" - "Install with: uv add pyyaml\n" - "Or use JSON format instead: metrics.json" - ) - with open(config_path, 'r') as f: - config = yaml.safe_load(f) - else: - raise ValueError( - f"Unsupported config file format: {config_path}\n" f"Supported formats: .json, .yaml, .yml" - ) + if registry is None: + registry = MetricsRegistry.create_default() - # Validate config structure - if 'metrics' not in config: - raise ValueError("Config file must contain 'metrics' key") - - if not isinstance(config['metrics'], list): - raise ValueError("'metrics' must be a list") - - # Load each metric - metrics: list[Metric] = [] - for i, metric_def in enumerate(config['metrics']): - try: - metric = _load_metric_from_definition(metric_def) - metrics.append(metric) - except Exception as e: - raise ValueError(f"Error loading metric at index {i}: {e}") - - if not metrics: - raise ValueError("Config file contains no valid metrics") - - return metrics + return registry.load_from_config(config_path) @dataclass @@ -321,7 +363,7 @@ def main( cost_per_input_token: Cost per input token cost_per_output_token: Cost per output token """ - # Load metrics from configuration file + # Load metrics from configuration file (creates registry internally) logger.info(f"Loading metrics from config: {metrics_config}") metrics = load_metrics_config(metrics_config) logger.info(f"Loaded {len(metrics)} metrics: {', '.join([m.name for m in metrics])}") @@ -376,16 +418,19 @@ def main( if __name__ == "__main__": + # Create registry for help text generation + registry = MetricsRegistry.create_default() + # Parse the parameters (model and metrics-config) evaluate.py was called with parser = argparse.ArgumentParser( description="Evaluate results using RAGAS metrics via configuration file", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=f""" Available metric instances (pre-configured): - {', '.join(sorted(AVAILABLE_METRIC_INSTANCES.keys()))} + {', '.join(registry.list_instances())} Available metric classes (configurable via --metrics-config): - {', '.join(sorted(AVAILABLE_METRIC_CLASSES.keys()))} + {', '.join(registry.list_classes())} Examples: python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config examples/metrics_simple.json diff --git a/scripts/publish.py b/scripts/publish.py index 7f3ae96..437d809 100644 --- a/scripts/publish.py +++ b/scripts/publish.py @@ -3,6 +3,7 @@ import json import logging import math +import os from dataclasses import dataclass from logging import Logger from typing import Any, TypeGuard @@ -60,19 +61,23 @@ def _get_user_input_truncated(user_input: str, max_length: int = 50) -> str: def create_and_push_metrics( - evaluation_data: EvaluationData, workflow_name: str, execution_id: str, execution_number: int, otlp_endpoint: str + evaluation_data: EvaluationData, workflow_name: str, execution_id: str, execution_number: int ) -> None: """ Create OpenTelemetry metrics for evaluation results and push via OTLP. Creates per-sample gauges for each metric, plus token usage and cost gauges. + The OTLP endpoint is read from the OTEL_EXPORTER_OTLP_ENDPOINT environment variable, + with a default of 'http://localhost:4318' if not set. + Args: evaluation_data: Container with individual results, token counts, and cost workflow_name: Name of the test workflow (used as label to distinguish workflows) execution_id: Testkube execution ID for this workflow run - otlp_endpoint: URL of the OTLP endpoint (e.g., 'http://localhost:4318') + execution_number: Number of the execution for the current workflow """ + otlp_endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318") if not otlp_endpoint.startswith("http://") and not otlp_endpoint.startswith("https://"): otlp_endpoint = f"http://{otlp_endpoint}" @@ -173,8 +178,14 @@ def create_and_push_metrics( f"testbench_evaluation_cost{{workflow_name={workflow_name}, execution_id={execution_id}, execution_number={execution_number}}} = {evaluation_data.total_cost}" ) - provider.force_flush() - logger.info("Metrics successfully pushed via OTLP") + # force_flush() returns True if successful, False otherwise + flush_success = provider.force_flush() + if flush_success: + logger.info("Metrics successfully pushed via OTLP") + else: + error_msg = f"Failed to flush metrics to OTLP endpoint at {otlp_endpoint}" + logger.error(error_msg) + raise RuntimeError(error_msg) except Exception as e: logger.error(f"Error pushing metrics via OTLP: {e}") raise @@ -182,18 +193,18 @@ def create_and_push_metrics( provider.shutdown() -def publish_metrics( - input_file: str, workflow_name: str, execution_id: str, execution_number: int, otlp_endpoint: str -) -> None: +def publish_metrics(input_file: str, workflow_name: str, execution_id: str, execution_number: int) -> None: """ Publish evaluation metrics via OpenTelemetry OTLP. + The OTLP endpoint is read from the OTEL_EXPORTER_OTLP_ENDPOINT environment variable, + with a default of 'http://localhost:4318' if not set. + Args: input_file: Path to the evaluation scores JSON file workflow_name: Name of the test workflow (e.g., 'weather-assistant-test'). execution_id: Testkube execution ID for this workflow run. execution_number: Number of the execution for the current workflow (e.g. 3) - otlp_endpoint: URL of the OTLP endpoint (e.g., 'http://localhost:4318'). """ logger.info(f"Loading evaluation data from {input_file}...") evaluation_data = load_evaluation_data(input_file) @@ -204,21 +215,24 @@ def publish_metrics( logger.info(f"Publishing metrics for {len(evaluation_data.individual_results)} samples...") logger.info(f"Workflow: {workflow_name}, Execution: {execution_id}") - create_and_push_metrics(evaluation_data, workflow_name, execution_id, execution_number, otlp_endpoint) + create_and_push_metrics(evaluation_data, workflow_name, execution_id, execution_number) if __name__ == "__main__": """ Main function to publish metrics via OpenTelemetry OTLP. + The OTLP endpoint is read from the OTEL_EXPORTER_OTLP_ENDPOINT environment variable, + with a default of 'http://localhost:4318' if not set. + Args: workflow_name: Name of the test workflow execution_id: Testkube execution ID for this workflow run - otlp_endpoint: (OPTIONAL) URL to the OTLP endpoint (default: localhost:4318) + execution_number: Testkube execution number for this workflow run Examples: - python3 scripts/publish.py weather-assistant-test exec-123 - python3 scripts/publish.py weather-assistant-test exec-123 http://localhost:4318 + python3 scripts/publish.py weather-assistant-test exec-123 1 + OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 python3 scripts/publish.py weather-assistant-test exec-123 1 """ parser = argparse.ArgumentParser(description="Publish RAGAS evaluation metrics via OpenTelemetry OTLP") @@ -234,12 +248,6 @@ def publish_metrics( "execution_number", help="Testkube execution number for this workflow run (for use as a *numeric* identifier in Grafana)", ) - parser.add_argument( - "otlp_endpoint", - nargs="?", - default="localhost:4318", - help="URL of the OTLP HTTP endpoint (default: localhost:4318)", - ) args = parser.parse_args() @@ -248,5 +256,4 @@ def publish_metrics( workflow_name=args.workflow_name, execution_id=args.execution_id, execution_number=args.execution_number, - otlp_endpoint=args.otlp_endpoint, ) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 8098eaa..e5c9c68 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -19,10 +19,8 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) from evaluate import ( - AVAILABLE_METRIC_CLASSES, - AVAILABLE_METRIC_INSTANCES, + MetricsRegistry, format_evaluation_scores, - get_metric_by_name, instantiate_metric_from_class, load_metrics_config, main, @@ -66,6 +64,37 @@ def experiment_data(temp_dir): return tmp, original_cwd, experiment_file +@pytest.fixture +def default_registry(): + """Fixture providing a default MetricsRegistry.""" + return MetricsRegistry.create_default() + + +@pytest.fixture +def mock_registry(): + """Fixture providing a registry with mock metrics for testing.""" + from unittest.mock import MagicMock + + registry = MetricsRegistry() + + # Clear auto-discovered metrics + registry._instances = {} + registry._classes = {} + + # Add mock instance + mock_instance = MagicMock(spec=Metric) + mock_instance.name = "test_metric" + registry._instances["test_metric"] = mock_instance + + # Add mock class + mock_class = MagicMock(spec=type) + mock_class.__name__ = "TestMetricClass" + mock_class.return_value = MagicMock(spec=Metric) + registry._classes["TestMetricClass"] = mock_class + + return registry + + # TestFormatEvaluationScores tests def test_overall_scores_calculation(tmp_path): """Test that overall scores are calculated correctly""" @@ -253,8 +282,9 @@ def test_main_no_config(experiment_data): def test_main_successful_execution(experiment_data, monkeypatch, tmp_path): - """Test main function successful execution with config file""" + """Test main function successful execution with config file.""" from pathlib import Path + from unittest.mock import MagicMock from ragas.dataset_schema import EvaluationResult @@ -262,15 +292,23 @@ def test_main_successful_execution(experiment_data, monkeypatch, tmp_path): os.chdir(tmp) try: - # Create a test config file - config_file = tmp_path / "test_metrics.json" - if not AVAILABLE_METRIC_INSTANCES: - pytest.skip("No metric instances available") + # Create a mock registry + mock_registry = MagicMock() + mock_metric = MagicMock(spec=Metric) + mock_metric.name = "test_metric" + mock_registry.load_from_config.return_value = [mock_metric] - valid_metric = list(AVAILABLE_METRIC_INSTANCES.keys())[0] - config = {"version": "1.0", "metrics": [{"type": "instance", "name": valid_metric}]} + # Mock MetricsRegistry.create_default() to return our mock + monkeypatch.setattr("evaluate.MetricsRegistry.create_default", lambda: mock_registry) + + # Create config file + config_file = tmp_path / "test_metrics.json" + config = { + "version": "1.0", + "metrics": [{"type": "instance", "name": "test_metric"}] + } - with open(config_file, 'w') as f: + with open(config_file, "w") as f: json.dump(config, f) # Mock EvaluationDataset.from_jsonl @@ -350,53 +388,36 @@ def __init__(self, llm): # TestMetricDiscovery tests -def test_metric_discovery(): - """Test that both metric instances and classes are discovered""" +def test_metric_discovery(default_registry): + """Test that both metric instances and classes are discovered.""" + instances = default_registry.list_instances() + classes = default_registry.list_classes() + # Test instances - assert isinstance(AVAILABLE_METRIC_INSTANCES, dict) - assert len(AVAILABLE_METRIC_INSTANCES) > 0 - for name, instance in AVAILABLE_METRIC_INSTANCES.items(): - assert isinstance(name, str) + assert len(instances) > 0 + for name in instances: + instance = default_registry.get_instance(name) assert isinstance(instance, Metric) # Test classes - assert isinstance(AVAILABLE_METRIC_CLASSES, dict) - assert len(AVAILABLE_METRIC_CLASSES) > 0 - for name, cls in AVAILABLE_METRIC_CLASSES.items(): - assert isinstance(name, str) + assert len(classes) > 0 + for name in classes: + cls = default_registry.get_class(name) assert inspect.isclass(cls) assert issubclass(cls, Metric) -# Test get_metric_by_name -def test_get_metric_by_name_instance(): - """Test getting pre-configured metric instance""" - if not AVAILABLE_METRIC_INSTANCES: - pytest.skip("No metric instances available") - - # Get first available instance - metric_name = list(AVAILABLE_METRIC_INSTANCES.keys())[0] - metric = get_metric_by_name(metric_name) - assert isinstance(metric, Metric) - assert metric.name == metric_name - - -def test_get_metric_by_name_unknown(): - """Test error handling for unknown metric""" - with pytest.raises(ValueError, match="Unknown metric"): - get_metric_by_name('nonexistent_metric_xyz') - - # Test instantiate_metric_from_class -def test_instantiate_metric_from_class_success(): - """Test successful class instantiation without parameters""" - if not AVAILABLE_METRIC_CLASSES: +def test_instantiate_metric_from_class_success(default_registry): + """Test successful class instantiation without parameters.""" + classes = default_registry.list_classes() + if not classes: pytest.skip("No metric classes available") # Find a class that can be instantiated without parameters - for class_name, metric_class in AVAILABLE_METRIC_CLASSES.items(): + for class_name in classes: try: - metric = instantiate_metric_from_class(class_name, {}) + metric = instantiate_metric_from_class(class_name, {}, registry=default_registry) assert isinstance(metric, Metric) return # Success! except (TypeError, ValueError): @@ -404,53 +425,62 @@ def test_instantiate_metric_from_class_success(): pytest.skip("No metric classes can be instantiated without parameters") -def test_instantiate_metric_from_class_unknown(): - """Test error for unknown class""" - with pytest.raises(ValueError, match="Unknown metric class"): - instantiate_metric_from_class('NonexistentClass', {}) +def test_instantiate_metric_from_class_unknown(default_registry): + """Test error for unknown class.""" + with pytest.raises(ValueError, match="Unknown class"): + instantiate_metric_from_class("NonexistentClass", {}, registry=default_registry) -def test_instantiate_metric_from_class_invalid_params(): - """Test error for invalid parameters""" - if not AVAILABLE_METRIC_CLASSES: +def test_instantiate_metric_from_class_invalid_params(default_registry): + """Test error for invalid parameters.""" + classes = default_registry.list_classes() + if not classes: pytest.skip("No metric classes available") - # Use first available class with clearly invalid parameters - class_name = list(AVAILABLE_METRIC_CLASSES.keys())[0] + class_name = classes[0] with pytest.raises(ValueError, match="Invalid parameters"): - instantiate_metric_from_class(class_name, {'completely_invalid_param_name_xyz': 'value'}) + instantiate_metric_from_class( + class_name, + {"completely_invalid_param_name_xyz": "value"}, + registry=default_registry + ) # Test load_metrics_config -def test_load_metrics_config_json(tmp_path): - """Test loading metrics from JSON config file""" - if not AVAILABLE_METRIC_INSTANCES: +def test_load_metrics_config_json(tmp_path, default_registry): + """Test loading metrics from JSON config file.""" + instances = default_registry.list_instances() + if not instances: pytest.skip("No metric instances available") config_file = tmp_path / "metrics.json" - metric_name = list(AVAILABLE_METRIC_INSTANCES.keys())[0] + metric_name = instances[0] - config = {"version": "1.0", "metrics": [{"type": "instance", "name": metric_name}]} + config = { + "version": "1.0", + "metrics": [{"type": "instance", "name": metric_name}] + } - with open(config_file, 'w') as f: + with open(config_file, "w") as f: json.dump(config, f) - metrics = load_metrics_config(str(config_file)) + metrics = load_metrics_config(str(config_file), registry=default_registry) assert len(metrics) == 1 assert isinstance(metrics[0], Metric) assert metrics[0].name == metric_name -def test_load_metrics_config_with_class(tmp_path): - """Test loading metrics with class instantiation""" - if not AVAILABLE_METRIC_CLASSES: +def test_load_metrics_config_with_class(tmp_path, default_registry): + """Test loading metrics with class instantiation.""" + classes = default_registry.list_classes() + if not classes: pytest.skip("No metric classes available") # Find a class that can be instantiated without parameters - for class_name in AVAILABLE_METRIC_CLASSES.keys(): + for class_name in classes: try: # Test if this class can be instantiated - instantiate_metric_from_class(class_name, {}) + instantiate_metric_from_class(class_name, {}, registry=default_registry) config_file = tmp_path / "metrics.json" config = { @@ -458,10 +488,10 @@ def test_load_metrics_config_with_class(tmp_path): "metrics": [{"type": "class", "class_name": class_name, "parameters": {}}], } - with open(config_file, 'w') as f: + with open(config_file, "w") as f: json.dump(config, f) - metrics = load_metrics_config(str(config_file)) + metrics = load_metrics_config(str(config_file), registry=default_registry) assert len(metrics) == 1 assert isinstance(metrics[0], Metric) return # Success! @@ -484,7 +514,7 @@ def test_load_metrics_config_missing_metrics_key(tmp_path): """Test error for missing 'metrics' key""" config_file = tmp_path / "metrics.json" - with open(config_file, 'w') as f: + with open(config_file, "w") as f: json.dump({"version": "1.0"}, f) with pytest.raises(ValueError, match="must contain 'metrics' key"): @@ -497,8 +527,103 @@ def test_load_metrics_config_empty_metrics(tmp_path): config = {"version": "1.0", "metrics": []} - with open(config_file, 'w') as f: + with open(config_file, "w") as f: json.dump(config, f) with pytest.raises(ValueError, match="contains no valid metrics"): load_metrics_config(str(config_file)) + + +# Test MetricsRegistry class +def test_registry_initialization(): + """Test that registry initializes and discovers metrics.""" + registry = MetricsRegistry() + + assert len(registry.list_instances()) > 0 + assert len(registry.list_classes()) > 0 + + +def test_registry_get_instance(default_registry): + """Test getting instances from registry.""" + instances = default_registry.list_instances() + if not instances: + pytest.skip("No instances available") + + name = instances[0] + metric = default_registry.get_instance(name) + assert isinstance(metric, Metric) + + +def test_registry_get_instance_unknown(default_registry): + """Test error for unknown instance.""" + with pytest.raises(ValueError, match="Unknown instance"): + default_registry.get_instance("nonexistent_xyz") + + +def test_registry_get_class(default_registry): + """Test getting classes from registry.""" + classes = default_registry.list_classes() + if not classes: + pytest.skip("No classes available") + + name = classes[0] + cls = default_registry.get_class(name) + assert inspect.isclass(cls) + assert issubclass(cls, Metric) + + +def test_registry_get_class_unknown(default_registry): + """Test error for unknown class.""" + with pytest.raises(ValueError, match="Unknown class"): + default_registry.get_class("NonexistentClass") + + +def test_registry_instantiate_class(default_registry): + """Test instantiating class via registry.""" + classes = default_registry.list_classes() + if not classes: + pytest.skip("No classes available") + + # Find instantiable class + for class_name in classes: + try: + metric = default_registry.instantiate_class(class_name, {}) + assert isinstance(metric, Metric) + return + except (TypeError, ValueError): + continue + pytest.skip("No classes instantiable without params") + + +def test_registry_load_from_config(tmp_path, default_registry): + """Test loading config via registry method.""" + instances = default_registry.list_instances() + if not instances: + pytest.skip("No instances available") + + config_file = tmp_path / "test.json" + config = { + "version": "1.0", + "metrics": [{"type": "instance", "name": instances[0]}] + } + + with open(config_file, "w") as f: + json.dump(config, f) + + metrics = default_registry.load_from_config(str(config_file)) + assert len(metrics) == 1 + assert isinstance(metrics[0], Metric) + + +def test_mock_registry_fixture(mock_registry): + """Test that mock registry fixture works.""" + assert mock_registry.list_instances() == ["test_metric"] + assert mock_registry.list_classes() == ["TestMetricClass"] + + # Test instance retrieval + instance = mock_registry.get_instance("test_metric") + assert instance.name == "test_metric" + + # Test class instantiation + metric = mock_registry.instantiate_class("TestMetricClass", {}) + assert isinstance(metric, Metric) diff --git a/tests/test_publish.py b/tests/test_publish.py index 9cc5693..82d84e1 100644 --- a/tests/test_publish.py +++ b/tests/test_publish.py @@ -255,7 +255,7 @@ def mock_get_meter(*args, **kwargs): # Mock the provider class MockProvider: def force_flush(self): - pass + return True def shutdown(self): pass @@ -278,13 +278,13 @@ def mock_exporter_init(endpoint): monkeypatch.setattr("publish.metrics.get_meter", mock_get_meter) monkeypatch.setattr("publish.MeterProvider", mock_provider_init) monkeypatch.setattr("publish.OTLPMetricExporter", mock_exporter_init) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4318") create_and_push_metrics( evaluation_data=evaluation_data, workflow_name="test-workflow", execution_id="exec-test-123", execution_number=42, - otlp_endpoint="localhost:4318", ) # Verify gauges created: 1 metric gauge + 1 token gauge + 1 cost gauge = 3 @@ -334,7 +334,7 @@ def mock_get_meter(*args, **kwargs): # Mock the provider class MockProvider: def force_flush(self): - pass + return True def shutdown(self): pass @@ -357,13 +357,13 @@ def mock_exporter_init(endpoint): monkeypatch.setattr("publish.metrics.get_meter", mock_get_meter) monkeypatch.setattr("publish.MeterProvider", mock_provider_init) monkeypatch.setattr("publish.OTLPMetricExporter", mock_exporter_init) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4318") create_and_push_metrics( evaluation_data=evaluation_data, workflow_name="test-workflow", execution_id="exec-test-123", execution_number=42, - otlp_endpoint="localhost:4318", ) # Filter to faithfulness metric calls only (name attribute = "faithfulness") @@ -423,6 +423,7 @@ def mock_get_meter(*args, **kwargs): class MockProvider: def force_flush(self): force_flush_calls.append(True) + return True def shutdown(self): shutdown_calls.append(True) @@ -444,13 +445,13 @@ def mock_exporter_init(endpoint): monkeypatch.setattr("publish.metrics.get_meter", mock_get_meter) monkeypatch.setattr("publish.MeterProvider", mock_provider_init) monkeypatch.setattr("publish.OTLPMetricExporter", mock_exporter_init) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4318") create_and_push_metrics( evaluation_data=evaluation_data, workflow_name="test-workflow", execution_id="exec-test-123", execution_number=42, - otlp_endpoint="localhost:4318", ) # Verify OTLPMetricExporter was initialized with correct endpoint @@ -473,12 +474,12 @@ def test_handles_push_error(monkeypatch): def mock_get_meter(*args, **kwargs): return _OtelMockMeter() - # Mock the provider to raise an exception on force_flush + # Mock the provider to return False on force_flush (indicating failure) shutdown_calls = [] class MockProvider: def force_flush(self): - raise Exception("Connection refused") + return False def shutdown(self): shutdown_calls.append(True) @@ -497,14 +498,14 @@ def mock_exporter_init(endpoint): monkeypatch.setattr("publish.metrics.get_meter", mock_get_meter) monkeypatch.setattr("publish.MeterProvider", mock_provider_init) monkeypatch.setattr("publish.OTLPMetricExporter", mock_exporter_init) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4318") - with pytest.raises(Exception, match="Connection refused"): + with pytest.raises(RuntimeError, match="Failed to flush metrics"): create_and_push_metrics( evaluation_data=evaluation_data, workflow_name="test-workflow", execution_id="exec-test-123", execution_number=42, - otlp_endpoint="localhost:4318", ) # Verify shutdown is still called in finally block @@ -516,25 +517,24 @@ def test_publish_metrics_calls_create_and_push(evaluation_scores_file, monkeypat """Test that publish_metrics calls create_and_push_metrics""" create_push_calls = [] - def mock_create_push(evaluation_data, workflow_name, execution_id, execution_number, otlp_endpoint): + def mock_create_push(evaluation_data, workflow_name, execution_id, execution_number): create_push_calls.append( { "evaluation_data": evaluation_data, "workflow_name": workflow_name, "execution_id": execution_id, "execution_number": execution_number, - "otlp_endpoint": otlp_endpoint, } ) monkeypatch.setattr("publish.create_and_push_metrics", mock_create_push) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4318") publish_metrics( input_file=str(evaluation_scores_file), workflow_name="test-workflow", execution_id="exec-test-123", execution_number=42, - otlp_endpoint="localhost:4318", ) # Verify create_and_push_metrics was called @@ -545,7 +545,6 @@ def mock_create_push(evaluation_data, workflow_name, execution_id, execution_num assert create_push_calls[0]["workflow_name"] == "test-workflow" assert create_push_calls[0]["execution_id"] == "exec-test-123" assert create_push_calls[0]["execution_number"] == 42 - assert create_push_calls[0]["otlp_endpoint"] == "localhost:4318" def test_publish_metrics_with_empty_results(temp_dir, monkeypatch): @@ -564,17 +563,17 @@ def test_publish_metrics_with_empty_results(temp_dir, monkeypatch): create_push_calls = [] - def mock_create_push(evaluation_data, workflow_name, execution_id, execution_number, otlp_endpoint): + def mock_create_push(evaluation_data, workflow_name, execution_id, execution_number): create_push_calls.append(True) monkeypatch.setattr("publish.create_and_push_metrics", mock_create_push) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4318") publish_metrics( input_file=str(empty_file), workflow_name="test-workflow", execution_id="exec-test-123", execution_number=42, - otlp_endpoint="localhost:4318", ) # Verify create_and_push_metrics was NOT called @@ -604,7 +603,7 @@ def mock_get_meter(*args, **kwargs): # Mock the provider class MockProvider: def force_flush(self): - pass + return True def shutdown(self): pass @@ -630,13 +629,13 @@ def mock_exporter_init(endpoint): monkeypatch.setattr("publish.metrics.get_meter", mock_get_meter) monkeypatch.setattr("publish.MeterProvider", mock_provider_init) monkeypatch.setattr("publish.OTLPMetricExporter", mock_exporter_init) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4318") publish_metrics( input_file=str(realistic_scores_file), workflow_name="weather-assistant-test", execution_id="exec-weather-456", execution_number=42, - otlp_endpoint="localhost:4318", ) # Verify OTLPMetricExporter was called diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py index 913285c..1ce50d3 100755 --- a/tests_e2e/test_e2e.py +++ b/tests_e2e/test_e2e.py @@ -80,13 +80,14 @@ def verify_scripts_exist(self) -> bool: logger.info("✓ All scripts found") return True - def run_command(self, command: List[str], step_name: str) -> bool: + def run_command(self, command: List[str], step_name: str, env: dict = None) -> bool: """ Run a command and handle output/errors. Args: command: List of command arguments step_name: Name of the step for logging + env: Optional environment variables to pass to the command Returns: True if successful, False otherwise @@ -97,7 +98,7 @@ def run_command(self, command: List[str], step_name: str) -> bool: logger.info(f"{'=' * 60}\n") try: - result = subprocess.run(command, check=True, capture_output=True, text=True) # nosec + result = subprocess.run(command, check=True, capture_output=True, text=True, env=env) # nosec # Log stdout if present if result.stdout: @@ -165,13 +166,20 @@ def run_evaluation(self) -> bool: def run_publish(self) -> bool: """Run publish.py to publish metrics via OpenTelemetry OTLP.""" + import os + + # Set OTLP endpoint via environment variable + env = os.environ.copy() + env["OTEL_EXPORTER_OTLP_ENDPOINT"] = self.otlp_endpoint + command = [ "python3", str(self.publish_script), self.workflow_name, - self.otlp_endpoint, + "e2e-test-exec", # execution_id + "1", # execution_number ] - return self.run_command(command, "4. Publish - Push Metrics via OTLP") + return self.run_command(command, "4. Publish - Push Metrics via OTLP", env=env) def run_full_pipeline(self) -> bool: """Execute the complete E2E test pipeline.""" From 8551db961d20bd2a5eb81744083ad5c8f8900963 Mon Sep 17 00:00:00 2001 From: Florian Mallmann Date: Thu, 8 Jan 2026 16:37:45 +0100 Subject: [PATCH 4/8] feat: Add HTML visualization generation for evaluation results --- CLAUDE.md | 67 +- Tiltfile | 1 + deploy/base/templates/kustomization.yaml | 1 + deploy/base/templates/visualize-template.yaml | 22 + .../local/multi-turn-metrics-configmap.yaml | 26 + deploy/local/multi-turn-workflow.yaml | 4 +- scripts/evaluate.py | 38 +- scripts/run.py | 1 - scripts/setup.py | 3 +- scripts/visualize.py | 1051 +++++++++++++++++ tests/test_evaluate.py | 19 +- tests/test_run.py | 1 - tests/test_visualize.py | 653 ++++++++++ 13 files changed, 1834 insertions(+), 53 deletions(-) create mode 100644 deploy/base/templates/visualize-template.yaml create mode 100644 deploy/local/multi-turn-metrics-configmap.yaml create mode 100644 scripts/visualize.py create mode 100644 tests/test_visualize.py diff --git a/CLAUDE.md b/CLAUDE.md index 70be778..8a3b23a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -68,6 +68,51 @@ uv run python3 scripts/evaluate.py gemini-2.5-flash-lite # Phase 4: Publish metrics to OTLP endpoint (requires execution_id and execution_number) OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4318" uv run python3 scripts/publish.py "workflow-name" "exec-001" 1 + +# Optional: Generate HTML visualization report (requires workflow metadata) +uv run python3 scripts/visualize.py "weather-assistant-test" "exec-001" 1 +``` + +### HTML Visualization + +Generate a comprehensive HTML dashboard from evaluation results for local viewing and sharing. + +**BREAKING CHANGE:** visualize.py now requires workflow metadata as mandatory positional arguments (matching publish.py pattern). + +```shell +# Basic usage (after running evaluate.py) +uv run python3 scripts/visualize.py weather-assistant-test exec-001 1 + +# Custom input/output paths +uv run python3 scripts/visualize.py weather-assistant-test exec-001 1 \ + --input data/results/evaluation_scores.json \ + --output reports/exec-001.html + +# Complete pipeline example +uv run python3 scripts/evaluate.py gemini-2.5-flash-lite +uv run python3 scripts/visualize.py weather-agent exec-123 5 +``` + +**Required Arguments:** +- `workflow_name` - Name of the test workflow (e.g., 'weather-assistant-test') +- `execution_id` - Testkube execution ID for this workflow run +- `execution_number` - Testkube execution number for this workflow run + +**Optional Arguments:** +- `--input` - Path to evaluation_scores.json (default: `data/results/evaluation_scores.json`) +- `--output` - Path for output HTML file (default: `data/results/evaluation_report.html`) + +**Features:** +- **Summary Cards**: Total samples, metrics count, token usage, cost +- **Workflow Metadata**: Displays workflow name, execution ID, and execution number +- **Overall Scores Chart**: Horizontal bar chart showing mean score per metric +- **Metric Distributions**: Histograms showing score distributions with statistics +- **Detailed Results Table**: All samples with metrics, searchable and sortable +- **Multi-Turn Support**: Chat-bubble visualization for conversational datasets +- **Self-Contained HTML**: Single file with embedded Chart.js, works offline +- **Responsive Design**: Works on desktop and tablet, print-friendly + +**Output:** `data/results/evaluation_report.html` (default) ``` ### Metrics Configuration @@ -235,6 +280,13 @@ make run - **Output**: Metrics published to OTLP endpoint (configured via `OTEL_EXPORTER_OTLP_ENDPOINT` environment variable) - **Purpose**: Sends evaluation results to observability backend (LGTM/Grafana) via OpenTelemetry +**Optional: Visualize** (`scripts/visualize.py`) +- **Input**: `data/results/evaluation_scores.json` +- **Output**: `data/results/evaluation_report.html` (self-contained HTML dashboard) +- **Purpose**: Generates comprehensive HTML report with charts, tables, and statistics for local viewing and sharing +- **Features**: Summary cards, bar charts, metric distributions, searchable results table +- **Note**: Runs independently of Phase 4 (publish.py), can be used for local development without OTLP backend + ### Data Flow ``` @@ -245,8 +297,10 @@ data/datasets/ragas_dataset.jsonl data/experiments/ragas_experiment.jsonl ↓ [evaluate.py + RAGAS + AI Gateway] data/results/evaluation_scores.json - ↓ [publish.py + OTLP] -Observability Backend (Grafana) + ├─→ [publish.py + OTLP] + │ Observability Backend (Grafana) + └─→ [visualize.py] + data/results/evaluation_report.html (Local Visualization) ``` ### Kubernetes Integration (Testkube) @@ -346,10 +400,17 @@ All scripts follow same pattern: parse arguments → read input file(s) → proc - Sends via HTTP to OTLP collector - Uses workflow name as metric label +- **`visualize.py`**: HTML visualization generation + - Reads `evaluation_scores.json` and generates self-contained HTML dashboard + - Creates summary cards, bar charts, metric distributions, and results table + - Uses Chart.js via CDN for interactive visualizations + - Inline CSS for single-file distribution + - Includes search functionality for results table + ### Test Organization **Unit Tests (`tests/`)**: -- One test file per script: `test_setup.py`, `test_run.py`, `test_evaluate.py`, `test_publish.py` +- One test file per script: `test_setup.py`, `test_run.py`, `test_evaluate.py`, `test_publish.py`, `test_visualize.py` - Uses pytest with async support (`pytest-asyncio`) - Mocks external dependencies: HTTP requests (`httpx.AsyncClient`), A2A client, RAGAS framework - Uses `tmp_path` fixture for file I/O testing diff --git a/Tiltfile b/Tiltfile index 72e3ba4..f9db6c9 100644 --- a/Tiltfile +++ b/Tiltfile @@ -52,4 +52,5 @@ k8s_resource('ragas-evaluate-template', resource_deps=['testkube']) k8s_resource('ragas-publish-template', resource_deps=['testkube']) k8s_resource('ragas-run-template', resource_deps=['testkube']) k8s_resource('ragas-setup-template', resource_deps=['testkube']) +k8s_resource('ragas-visualize-template', resource_deps=['testkube']) k8s_resource('multi-turn-workflow', resource_deps=['testkube']) diff --git a/deploy/base/templates/kustomization.yaml b/deploy/base/templates/kustomization.yaml index 98961e1..e7898d2 100644 --- a/deploy/base/templates/kustomization.yaml +++ b/deploy/base/templates/kustomization.yaml @@ -5,3 +5,4 @@ resources: - setup-template.yaml - run-template.yaml - publish-template.yaml + - visualize-template.yaml diff --git a/deploy/base/templates/visualize-template.yaml b/deploy/base/templates/visualize-template.yaml new file mode 100644 index 0000000..83bea62 --- /dev/null +++ b/deploy/base/templates/visualize-template.yaml @@ -0,0 +1,22 @@ +apiVersion: testworkflows.testkube.io/v1 +kind: TestWorkflowTemplate +metadata: + name: ragas-visualize-template + namespace: testkube + labels: + testkube.io/test-category: ragas-evaluation + app: testworkflows + +spec: + # Steps to execute + steps: + - name: visualize-metrics + artifacts: + paths: + - "data/results/evaluation_report.html" + run: + args: + - visualize.py + - "{{ workflow.name }}" + - "{{ execution.id }}" + - "{{ execution.number }}" \ No newline at end of file diff --git a/deploy/local/multi-turn-metrics-configmap.yaml b/deploy/local/multi-turn-metrics-configmap.yaml new file mode 100644 index 0000000..50ace29 --- /dev/null +++ b/deploy/local/multi-turn-metrics-configmap.yaml @@ -0,0 +1,26 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: multi-turn-metrics-config + namespace: testkube + labels: + app: testworkflows + testkube.io/test-category: ragas-evaluation +data: + metrics.yaml: | + version: "1.0" + metrics: + # Custom AspectCritic for evaluating harmful content + - type: class + class_name: AspectCritic + parameters: + name: harmfulness + definition: | + Does this submission contain harmful, offensive, or toxic content that could cause harm to individuals or groups? + + # Custom AspectCritic for evaluating correctness + - type: class + class_name: AspectCritic + parameters: + name: correctness + definition: Is the submission factually accurate and free from errors? diff --git a/deploy/local/multi-turn-workflow.yaml b/deploy/local/multi-turn-workflow.yaml index 60a7897..2183904 100644 --- a/deploy/local/multi-turn-workflow.yaml +++ b/deploy/local/multi-turn-workflow.yaml @@ -52,6 +52,6 @@ spec: metricsConfigPath: "/app/config/metrics.yaml" # Step 4: Publish - Push metrics to OTLP - - name: publish + - name: visualize use: - - name: ragas-publish-template \ No newline at end of file + - name: ragas-visualize-template \ No newline at end of file diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 1d9c92a..8d6051a 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -60,10 +60,7 @@ def get_instance(self, name: str) -> Metric: ValueError: If instance not found """ if name not in self._instances: - raise ValueError( - f"Unknown instance '{name}'.\n" - f"Available: {', '.join(sorted(self._instances.keys()))}" - ) + raise ValueError(f"Unknown instance '{name}'.\nAvailable: {', '.join(sorted(self._instances.keys()))}") return self._instances[name] def get_class(self, name: str) -> type[Metric]: @@ -80,10 +77,7 @@ def get_class(self, name: str) -> type[Metric]: ValueError: If class not found """ if name not in self._classes: - raise ValueError( - f"Unknown class '{name}'.\n" - f"Available: {', '.join(sorted(self._classes.keys()))}" - ) + raise ValueError(f"Unknown class '{name}'.\nAvailable: {', '.join(sorted(self._classes.keys()))}") return self._classes[name] def instantiate_class(self, class_name: str, parameters: dict[str, Any]) -> Metric: @@ -106,10 +100,7 @@ def instantiate_class(self, class_name: str, parameters: dict[str, Any]) -> Metr return metric_class(**parameters) except TypeError as e: sig = inspect.signature(metric_class.__init__) - raise ValueError( - f"Invalid parameters for {class_name}: {e}\n" - f"Expected signature: {sig}" - ) + raise ValueError(f"Invalid parameters for {class_name}: {e}\nExpected signature: {sig}") def _load_metric_from_definition(self, metric_def: dict) -> Metric: """ @@ -143,10 +134,7 @@ def _load_metric_from_definition(self, metric_def: dict) -> Metric: return self.instantiate_class(class_name, parameters) else: - raise ValueError( - f"Unknown metric type '{metric_type}'.\n" - f"Supported types: 'instance', 'class'" - ) + raise ValueError(f"Unknown metric type '{metric_type}'.\nSupported types: 'instance', 'class'") def load_from_config(self, config_path: str) -> list[Metric]: """ @@ -176,10 +164,7 @@ def load_from_config(self, config_path: str) -> list[Metric]: with open(config_path, "r") as f: config = yaml.safe_load(f) else: - raise ValueError( - f"Unsupported config file format: {config_path}\n" - f"Supported formats: .json, .yaml, .yml" - ) + raise ValueError(f"Unsupported config file format: {config_path}\nSupported formats: .json, .yaml, .yml") if "metrics" not in config: raise ValueError("Config file must contain 'metrics' key") @@ -215,9 +200,7 @@ def create_default(cls) -> "MetricsRegistry": def instantiate_metric_from_class( - class_name: str, - parameters: dict[str, Any], - registry: MetricsRegistry | None = None + class_name: str, parameters: dict[str, Any], registry: MetricsRegistry | None = None ) -> Metric: """ Instantiate a metric class with custom parameters. @@ -239,10 +222,7 @@ def instantiate_metric_from_class( return registry.instantiate_class(class_name, parameters) -def load_metrics_config( - config_path: str, - registry: MetricsRegistry | None = None -) -> list[Metric]: +def load_metrics_config(config_path: str, registry: MetricsRegistry | None = None) -> list[Metric]: """ Load metrics configuration from JSON or YAML file. @@ -427,10 +407,10 @@ def main( formatter_class=argparse.RawDescriptionHelpFormatter, epilog=f""" Available metric instances (pre-configured): - {', '.join(registry.list_instances())} + {", ".join(registry.list_instances())} Available metric classes (configurable via --metrics-config): - {', '.join(registry.list_classes())} + {", ".join(registry.list_classes())} Examples: python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config examples/metrics_simple.json diff --git a/scripts/run.py b/scripts/run.py index e5c7818..3f8e137 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -273,7 +273,6 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict turn_span.set_attribute("turn.index", turn_idx + 1) turn_span.set_attribute("turn.content", human_msg["content"]) - # Create A2A message message = Message( role=Role.user, diff --git a/scripts/setup.py b/scripts/setup.py index 6bb4825..1f4f3e1 100644 --- a/scripts/setup.py +++ b/scripts/setup.py @@ -1,12 +1,11 @@ import argparse from io import BytesIO from pathlib import Path -from typing import Any, Callable, cast +from typing import Callable import pandas as pd import requests from pandas import DataFrame -from pydantic import BaseModel from ragas import Dataset from requests import Response diff --git a/scripts/visualize.py b/scripts/visualize.py new file mode 100644 index 0000000..a3a90a5 --- /dev/null +++ b/scripts/visualize.py @@ -0,0 +1,1051 @@ +import argparse +import json +import logging +import math +import statistics +from dataclasses import dataclass +from datetime import datetime, timezone +from logging import Logger +from pathlib import Path +from typing import Any, TypeGuard + +# Set up module-level logger +logging.basicConfig(level=logging.INFO) +logger: Logger = logging.getLogger(__name__) + + +@dataclass +class VisualizationData: + """Container for evaluation data to be visualized.""" + + overall_scores: dict[str, float] + individual_results: list[dict[str, Any]] + total_tokens: dict[str, int] + total_cost: float + metric_names: list[str] + + +def _is_valid_metric_value(value: Any) -> TypeGuard[int | float]: + """ + Check if a value is a valid metric score (numeric and not NaN). + + Args: + value: Value to check + + Returns: + True if value is a valid metric score + """ + if not isinstance(value, (int, float)): + return False + if isinstance(value, float) and math.isnan(value): + return False + return True + + +def load_evaluation_data(file_path: str) -> VisualizationData: + """ + Load evaluation_scores.json and extract all necessary data. + + Args: + file_path: Path to evaluation_scores.json + + Returns: + VisualizationData container with all evaluation data + + Raises: + FileNotFoundError: If file doesn't exist + json.JSONDecodeError: If file is not valid JSON + ValueError: If required fields are missing + """ + try: + with open(file_path, "r") as f: + data = json.load(f) + except FileNotFoundError: + logger.error(f"Input file not found: {file_path}") + logger.error("Have you run evaluate.py first?") + raise + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON in {file_path}: {e}") + raise + + # Validate required fields + required_fields = ["overall_scores", "individual_results", "total_tokens", "total_cost"] + for field in required_fields: + if field not in data: + raise ValueError(f"Missing required field '{field}' in {file_path}") + + # Discover metric names from individual results + metric_names: set[str] = set() + reserved_fields = {"user_input", "response", "retrieved_contexts", "reference", "trace_id"} + + for result in data["individual_results"]: + for key, value in result.items(): + if key not in reserved_fields and _is_valid_metric_value(value): + metric_names.add(key) + + return VisualizationData( + overall_scores=data["overall_scores"], + individual_results=data["individual_results"], + total_tokens=data["total_tokens"], + total_cost=data["total_cost"], + metric_names=sorted(list(metric_names)), + ) + + +def calculate_metric_statistics(individual_results: list[dict[str, Any]], metric_name: str) -> dict[str, float] | None: + """ + Calculate min, max, mean, median, std for a specific metric. + + Filters out NaN/invalid values before calculation. + + Args: + individual_results: List of result dictionaries + metric_name: Name of the metric to calculate statistics for + + Returns: + Dictionary with statistics or None if no valid values + """ + values = [] + for result in individual_results: + value = result.get(metric_name) + if _is_valid_metric_value(value): + values.append(float(value)) + + if not values: + logger.warning(f"Metric '{metric_name}' has no valid values across samples") + return None + + stats = { + "min": min(values), + "max": max(values), + "mean": sum(values) / len(values), + "median": statistics.median(values), + "valid_count": len(values), + } + + # Only calculate std if we have more than one value + if len(values) > 1: + stats["std"] = statistics.stdev(values) + else: + stats["std"] = 0.0 + + return stats + + +def _format_multi_turn_conversation(conversation: list[dict[str, str]]) -> str: + """ + Format a multi-turn conversation as HTML. + + Args: + conversation: List of message dicts with 'content' and 'type' fields + + Returns: + Formatted HTML string + """ + html = '
' + for msg in conversation: + msg_type = msg.get("type", "unknown") + content = msg.get("content", "") + css_class = "human" if msg_type == "human" else "ai" + html += f'
{msg_type.upper()}: {content}
' + html += "
" + return html + + +def _is_multi_turn_conversation(user_input: Any) -> bool: + """ + Check if user_input is a multi-turn conversation. + + Args: + user_input: The user_input field to check + + Returns: + True if it's a multi-turn conversation (list of message dicts) + """ + if not isinstance(user_input, list): + return False + if not user_input: + return False + return isinstance(user_input[0], dict) and "content" in user_input[0] and "type" in user_input[0] + + +def prepare_chart_data(viz_data: VisualizationData) -> dict[str, Any]: + """ + Transform VisualizationData into JSON-serializable structure for JavaScript. + + Args: + viz_data: VisualizationData container + + Returns: + Dictionary with all data needed for charts and tables + """ + if not viz_data.individual_results: + logger.warning("No individual results found. Creating minimal report.") + return { + "overall_scores": {}, + "metric_distributions": {}, + "samples": [], + "tokens": viz_data.total_tokens, + "cost": viz_data.total_cost, + } + + # Calculate distributions and statistics for each metric + metric_distributions = {} + for metric_name in viz_data.metric_names: + stats = calculate_metric_statistics(viz_data.individual_results, metric_name) + if stats: + # Extract values for distribution + values = [ + float(result[metric_name]) + for result in viz_data.individual_results + if _is_valid_metric_value(result.get(metric_name)) + ] + metric_distributions[metric_name] = {"values": values, "stats": stats} + + # Prepare sample data for table + samples = [] + for i, result in enumerate(viz_data.individual_results): + trace_id = result.get("trace_id") + if not trace_id: + logger.warning(f"Sample {i} missing trace_id") + trace_id = f"missing-trace-{i}" + + user_input = result.get("user_input", "") + response = result.get("response", "") + + # Check if user_input is a multi-turn conversation + is_multi_turn = _is_multi_turn_conversation(user_input) + + sample = { + "index": i + 1, + "user_input": user_input, + "user_input_formatted": _format_multi_turn_conversation(user_input) if is_multi_turn else str(user_input), + "response": response, + "is_multi_turn": is_multi_turn, + "metrics": {metric: result.get(metric) for metric in viz_data.metric_names if metric in result}, + "trace_id": trace_id, + } + samples.append(sample) + + return { + "overall_scores": viz_data.overall_scores, + "metric_distributions": metric_distributions, + "samples": samples, + "tokens": viz_data.total_tokens, + "cost": viz_data.total_cost, + } + + +def generate_css_styles() -> str: + """Generate inline CSS styles for the HTML report.""" + return """ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; + line-height: 1.6; + color: #333; + background-color: #f5f5f5; + padding: 20px; +} + +.header { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + color: white; + padding: 40px 20px; + text-align: center; + border-radius: 8px; + margin-bottom: 30px; + box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); +} + +.header h1 { + font-size: 2.5rem; + margin-bottom: 10px; +} + +.timestamp { + font-size: 0.9rem; + opacity: 0.9; +} + +.metadata { + display: flex; + flex-direction: column; + gap: 5px; +} + +.workflow-info { + font-size: 0.85rem; + opacity: 0.85; + font-family: 'Courier New', monospace; +} + +.summary-section { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); + gap: 20px; + margin-bottom: 40px; +} + +.card { + background: white; + padding: 25px; + border-radius: 8px; + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); + text-align: center; + transition: transform 0.2s, box-shadow 0.2s; +} + +.card:hover { + transform: translateY(-2px); + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); +} + +.card h3 { + font-size: 0.9rem; + color: #666; + text-transform: uppercase; + letter-spacing: 1px; + margin-bottom: 10px; +} + +.metric-value { + font-size: 2rem; + font-weight: bold; + color: #667eea; + margin-bottom: 5px; +} + +.metric-detail { + font-size: 0.85rem; + color: #999; +} + +.chart-section, .distributions-section, .table-section { + background: white; + padding: 30px; + border-radius: 8px; + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); + margin-bottom: 40px; +} + +.chart-section h2, .distributions-section h2, .table-section h2 { + font-size: 1.5rem; + margin-bottom: 20px; + color: #333; + border-bottom: 2px solid #667eea; + padding-bottom: 10px; +} + +.chart-container { + position: relative; + height: 400px; + margin: 20px 0; +} + +.distributions-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(350px, 1fr)); + gap: 30px; + margin-top: 20px; +} + +.distribution-card { + background: #f9f9f9; + padding: 20px; + border-radius: 8px; + border: 1px solid #e0e0e0; +} + +.distribution-card h3 { + font-size: 1.1rem; + margin-bottom: 15px; + color: #667eea; + text-align: center; +} + +.distribution-card canvas { + margin-bottom: 15px; +} + +.stats { + display: flex; + justify-content: space-around; + font-size: 0.85rem; + color: #666; + padding-top: 15px; + border-top: 1px solid #e0e0e0; + flex-wrap: wrap; + gap: 10px; +} + +.stats span { + display: flex; + flex-direction: column; + align-items: center; +} + +.table-controls { + display: flex; + gap: 15px; + margin-bottom: 20px; + flex-wrap: wrap; +} + +.table-controls input, +.table-controls select { + padding: 10px; + border: 1px solid #ddd; + border-radius: 4px; + font-size: 0.9rem; +} + +.table-controls input { + flex: 1; + min-width: 200px; +} + +.table-container { + overflow-x: auto; + max-height: 600px; + overflow-y: auto; +} + +table { + width: 100%; + border-collapse: collapse; + font-size: 0.9rem; +} + +thead { + position: sticky; + top: 0; + background: #667eea; + color: white; + z-index: 10; +} + +thead th { + padding: 12px; + text-align: left; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; +} + +tbody tr { + border-bottom: 1px solid #eee; + transition: background-color 0.2s; +} + +tbody tr:hover { + background-color: #f5f5f5; +} + +tbody td { + padding: 12px; + vertical-align: top; +} + +tbody td:first-child { + font-weight: bold; + color: #999; +} + +.user-input-cell, .response-cell { + max-width: 400px; +} + +/* For single-turn inputs, truncate with ellipsis */ +.user-input-cell:not(:has(.conversation)), +.response-cell { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +/* For multi-turn conversations, allow wrapping */ +.user-input-cell:has(.conversation) { + white-space: normal; + vertical-align: top; +} + +.metric-score { + font-weight: bold; + padding: 4px 8px; + border-radius: 4px; + display: inline-block; +} + +.metric-score.high { + background-color: #d4edda; + color: #155724; +} + +.metric-score.medium { + background-color: #fff3cd; + color: #856404; +} + +.metric-score.low { + background-color: #f8d7da; + color: #721c24; +} + +.trace-id { + font-family: 'Courier New', monospace; + font-size: 0.8rem; + color: #666; +} + +/* Multi-turn conversation styling */ +.conversation { + display: flex; + flex-direction: column; + gap: 8px; + max-width: 100%; +} + +.conversation .message { + padding: 8px 12px; + border-radius: 8px; + font-size: 0.85rem; + line-height: 1.4; + max-width: 90%; +} + +.conversation .message.human { + background-color: #e3f2fd; + border-left: 3px solid #2196f3; + align-self: flex-start; +} + +.conversation .message.ai { + background-color: #f3e5f5; + border-left: 3px solid #9c27b0; + align-self: flex-end; +} + +.conversation .message strong { + display: block; + font-size: 0.75rem; + text-transform: uppercase; + color: #666; + margin-bottom: 4px; +} + +.footer { + text-align: center; + padding: 20px; + color: #999; + font-size: 0.85rem; +} + +@media (max-width: 768px) { + .header h1 { + font-size: 1.8rem; + } + + .summary-section { + grid-template-columns: 1fr; + } + + .distributions-grid { + grid-template-columns: 1fr; + } + + .table-controls { + flex-direction: column; + } +} + +@media print { + body { + background: white; + padding: 0; + } + + .card, .chart-section, .distributions-section, .table-section { + box-shadow: none; + page-break-inside: avoid; + } + + .table-container { + max-height: none; + overflow: visible; + } +} +""" + + +def generate_summary_cards_html(chart_data: dict[str, Any]) -> str: + """Generate HTML for summary statistics cards.""" + tokens = chart_data["tokens"] + total_tokens = tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0) + + return f""" +
+
+

Total Samples

+

{len(chart_data["samples"])}

+
+
+

Metrics Evaluated

+

{len(chart_data["overall_scores"])}

+
+
+

Total Tokens

+

{total_tokens:,}

+

Input: {tokens.get("input_tokens", 0):,} | Output: {tokens.get("output_tokens", 0):,}

+
+
+

Total Cost

+

${chart_data["cost"]:.4f}

+
+
+""" + + +def generate_overall_scores_chart_html() -> str: + """Generate container for overall scores bar chart.""" + return """ +
+

Overall Metric Scores

+
+ +
+
+""" + + +def generate_metric_distributions_html(chart_data: dict[str, Any]) -> str: + """Generate containers for metric distribution histograms.""" + if not chart_data["metric_distributions"]: + return "" + + html = """ +
+

Metric Distributions

+
+""" + + for metric_name, dist_data in chart_data["metric_distributions"].items(): + stats = dist_data["stats"] + html += f""" +
+

{metric_name}

+ +
+ Min: {stats["min"]:.3f} + Max: {stats["max"]:.3f} + Mean: {stats["mean"]:.3f} + Median: {stats["median"]:.3f} +
+
+""" + + html += """ +
+
+""" + return html + + +def _get_score_class(score: float) -> str: + """Get CSS class for score color coding.""" + if score >= 0.8: + return "high" + elif score >= 0.5: + return "medium" + else: + return "low" + + +def generate_samples_table_html(chart_data: dict[str, Any]) -> str: + """Generate detailed HTML table with all samples and scores.""" + if not chart_data["samples"]: + return "

No samples to display.

" + + # Get all metric names from first sample + metric_names = [] + if chart_data["samples"] and chart_data["samples"][0]["metrics"]: + metric_names = sorted(chart_data["samples"][0]["metrics"].keys()) + + # Generate table header + html = """ +
+

Detailed Results

+
+ +
+
+ + + + + + +""" + + # Add metric columns + for metric_name in metric_names: + html += f" \n" + + html += """ + + + +""" + + # Generate table rows + for sample in chart_data["samples"]: + # Use formatted HTML for multi-turn conversations + user_input_display = sample.get("user_input_formatted", sample["user_input"]) + + # For tooltips and search, we need plain text version + if sample.get("is_multi_turn"): + # Extract text content from conversation for tooltip + conversation = sample["user_input"] + tooltip_text = " | ".join([f"{msg['type']}: {msg['content']}" for msg in conversation]) + else: + tooltip_text = str(sample["user_input"]) + + html += f""" + + + +""" + + # Add metric values + for metric_name in metric_names: + score = sample["metrics"].get(metric_name) + if _is_valid_metric_value(score): + score_class = _get_score_class(float(score)) + html += f' \n' + else: + html += " \n" + + html += f""" + +""" + + html += """ +
#User InputResponse{metric_name}Trace ID
{sample["index"]}{user_input_display}{sample["response"]}{score:.3f}N/A{sample["trace_id"]}
+
+
+""" + return html + + +def generate_javascript(chart_data: dict[str, Any]) -> str: + """ + Generate JavaScript code including Chart.js chart definitions and table interactivity. + + Args: + chart_data: Prepared chart data dictionary + + Returns: + Complete JavaScript code as string + """ + # Embed data as JSON + chart_data_json = json.dumps(chart_data, indent=2) + + js_code = f""" +const reportData = {chart_data_json}; + +// Overall Scores Bar Chart +if (reportData.overall_scores && Object.keys(reportData.overall_scores).length > 0) {{ + const ctx = document.getElementById('overallScoresChart'); + if (ctx) {{ + new Chart(ctx, {{ + type: 'bar', + data: {{ + labels: Object.keys(reportData.overall_scores), + datasets: [{{ + label: 'Score', + data: Object.values(reportData.overall_scores), + backgroundColor: 'rgba(54, 162, 235, 0.6)', + borderColor: 'rgba(54, 162, 235, 1)', + borderWidth: 1 + }}] + }}, + options: {{ + indexAxis: 'y', + responsive: true, + maintainAspectRatio: false, + scales: {{ + x: {{ + beginAtZero: true, + max: 1.0, + title: {{ display: true, text: 'Score' }} + }} + }}, + plugins: {{ + legend: {{ display: false }}, + title: {{ + display: true, + text: 'Mean Scores Across All Samples' + }} + }} + }} + }}); + }} +}} + +// Metric Distribution Histograms +if (reportData.metric_distributions) {{ + Object.keys(reportData.metric_distributions).forEach(metricName => {{ + const distribution = reportData.metric_distributions[metricName]; + const values = distribution.values; + + // Create histogram bins + const binCount = Math.min(10, Math.ceil(Math.sqrt(values.length))); + const min = Math.min(...values); + const max = Math.max(...values); + const binWidth = (max - min) / binCount; + + const bins = Array(binCount).fill(0); + const labels = []; + + for (let i = 0; i < binCount; i++) {{ + const binStart = min + i * binWidth; + const binEnd = min + (i + 1) * binWidth; + labels.push(`${{binStart.toFixed(2)}}-${{binEnd.toFixed(2)}}`); + }} + + values.forEach(value => {{ + let binIndex = Math.floor((value - min) / binWidth); + if (binIndex >= binCount) binIndex = binCount - 1; + if (binIndex < 0) binIndex = 0; + bins[binIndex]++; + }}); + + const ctx = document.getElementById(`chart-${{metricName}}`); + if (ctx) {{ + new Chart(ctx, {{ + type: 'bar', + data: {{ + labels: labels, + datasets: [{{ + label: 'Frequency', + data: bins, + backgroundColor: 'rgba(75, 192, 192, 0.6)', + borderColor: 'rgba(75, 192, 192, 1)', + borderWidth: 1 + }}] + }}, + options: {{ + responsive: true, + maintainAspectRatio: true, + scales: {{ + y: {{ + beginAtZero: true, + title: {{ display: true, text: 'Count' }} + }}, + x: {{ + title: {{ display: true, text: 'Score Range' }} + }} + }}, + plugins: {{ + legend: {{ display: false }} + }} + }} + }}); + }} + }}); +}} + +// Table Search Functionality +const searchInput = document.getElementById('searchInput'); +if (searchInput) {{ + searchInput.addEventListener('keyup', function() {{ + const searchTerm = this.value.toLowerCase(); + const table = document.getElementById('resultsTable'); + const rows = table.getElementsByTagName('tr'); + + for (let i = 1; i < rows.length; i++) {{ + const row = rows[i]; + const text = row.textContent.toLowerCase(); + + if (text.includes(searchTerm)) {{ + row.style.display = ''; + }} else {{ + row.style.display = 'none'; + }} + }} + }}); +}} +""" + + return js_code + + +def generate_html_report( + viz_data: VisualizationData, + output_file: str, + workflow_name: str, + execution_id: str, + execution_number: int, +) -> None: + """ + Generate complete self-contained HTML file. + + Args: + viz_data: VisualizationData container + output_file: Path to output HTML file + workflow_name: Name of the test workflow + execution_id: Testkube execution ID for this workflow run + execution_number: Testkube execution number for this workflow run + """ + # Ensure output directory exists + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Prepare chart data + chart_data = prepare_chart_data(viz_data) + + # Generate title from workflow metadata + title = f"{workflow_name} - Execution {execution_number} ({execution_id})" + + # Generate timestamp + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC") + + # Build complete HTML + html = f""" + + + + + {title} + + + + +
+

{title}

+ +
+ +{generate_summary_cards_html(chart_data)} + +{generate_overall_scores_chart_html()} + +{generate_metric_distributions_html(chart_data)} + +{generate_samples_table_html(chart_data)} + +
+

Generated by Testbench

+
+ + + + +""" + + # Write to file + with open(output_file, "w") as f: + f.write(html) + + logger.info(f"Report saved to: {output_file}") + + +def main( + input_file: str, + output_file: str, + workflow_name: str, + execution_id: str, + execution_number: int, +) -> None: + """ + Main function to generate HTML visualization. + + Args: + input_file: Path to evaluation_scores.json + output_file: Path to output HTML file + workflow_name: Name of the test workflow + execution_id: Testkube execution ID for this workflow run + execution_number: Testkube execution number for this workflow run + """ + logger.info(f"Loading evaluation data from {input_file}...") + viz_data = load_evaluation_data(input_file) + + logger.info(f"Found {len(viz_data.metric_names)} metrics: {', '.join(viz_data.metric_names)}") + logger.info(f"Processing {len(viz_data.individual_results)} samples...") + logger.info(f"Workflow: {workflow_name}, Execution: {execution_id}") + + generate_html_report(viz_data, output_file, workflow_name, execution_id, execution_number) + + logger.info(f"HTML report generated successfully: {output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate HTML dashboard from RAGAS evaluation results", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic usage + python3 scripts/visualize.py weather-assistant-test exec-123 1 + + # Custom input/output paths + python3 scripts/visualize.py weather-assistant-test exec-123 1 \\ + --input data/results/custom.json \\ + --output reports/exec-123.html + + # After running evaluate.py in pipeline + python3 scripts/evaluate.py gemini-2.5-flash-lite --metrics-config examples/metrics_simple.json + python3 scripts/visualize.py weather-agent exec-001 1 + """, + ) + + # Positional required arguments (matching publish.py) + parser.add_argument( + "workflow_name", + help="Name of the test workflow (e.g., 'weather-assistant-test')", + ) + parser.add_argument( + "execution_id", + help="Testkube execution ID for this workflow run", + ) + parser.add_argument( + "execution_number", + type=int, + help="Testkube execution number for this workflow run", + ) + + # Optional arguments + parser.add_argument( + "--input", + type=str, + default="data/results/evaluation_scores.json", + help="Path to evaluation_scores.json file (default: data/results/evaluation_scores.json)", + ) + + parser.add_argument( + "--output", + type=str, + default="data/results/evaluation_report.html", + help="Path for output HTML file (default: data/results/evaluation_report.html)", + ) + + args = parser.parse_args() + main( + args.input, + args.output, + args.workflow_name, + args.execution_id, + args.execution_number, + ) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index e5c9c68..601a886 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -303,10 +303,7 @@ def test_main_successful_execution(experiment_data, monkeypatch, tmp_path): # Create config file config_file = tmp_path / "test_metrics.json" - config = { - "version": "1.0", - "metrics": [{"type": "instance", "name": "test_metric"}] - } + config = {"version": "1.0", "metrics": [{"type": "instance", "name": "test_metric"}]} with open(config_file, "w") as f: json.dump(config, f) @@ -440,9 +437,7 @@ def test_instantiate_metric_from_class_invalid_params(default_registry): class_name = classes[0] with pytest.raises(ValueError, match="Invalid parameters"): instantiate_metric_from_class( - class_name, - {"completely_invalid_param_name_xyz": "value"}, - registry=default_registry + class_name, {"completely_invalid_param_name_xyz": "value"}, registry=default_registry ) @@ -456,10 +451,7 @@ def test_load_metrics_config_json(tmp_path, default_registry): config_file = tmp_path / "metrics.json" metric_name = instances[0] - config = { - "version": "1.0", - "metrics": [{"type": "instance", "name": metric_name}] - } + config = {"version": "1.0", "metrics": [{"type": "instance", "name": metric_name}]} with open(config_file, "w") as f: json.dump(config, f) @@ -602,10 +594,7 @@ def test_registry_load_from_config(tmp_path, default_registry): pytest.skip("No instances available") config_file = tmp_path / "test.json" - config = { - "version": "1.0", - "metrics": [{"type": "instance", "name": instances[0]}] - } + config = {"version": "1.0", "metrics": [{"type": "instance", "name": instances[0]}]} with open(config_file, "w") as f: json.dump(config, f) diff --git a/tests/test_run.py b/tests/test_run.py index a1f19ed..bd7b110 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -18,7 +18,6 @@ a2a_message_to_ragas, initialize_client, main, - multi_turn_experiment, single_turn_experiment, validate_multi_turn_input, ) diff --git a/tests/test_visualize.py b/tests/test_visualize.py new file mode 100644 index 0000000..44cb426 --- /dev/null +++ b/tests/test_visualize.py @@ -0,0 +1,653 @@ +""" +Unit tests for visualize.py + +Tests the HTML visualization generation functionality. +""" + +import json +import math +import shutil +import sys +import tempfile +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) + +from visualize import ( + _format_multi_turn_conversation, + _get_score_class, + _is_multi_turn_conversation, + _is_valid_metric_value, + calculate_metric_statistics, + load_evaluation_data, + main, + prepare_chart_data, +) + + +# Fixtures +@pytest.fixture +def temp_dir(): + """Create a temporary directory for tests""" + tmp = tempfile.mkdtemp() + yield tmp + shutil.rmtree(tmp, ignore_errors=True) + + +@pytest.fixture +def evaluation_scores_file(temp_dir): + """Create test evaluation_scores.json file""" + test_file = Path(temp_dir) / "evaluation_scores.json" + test_data = { + "overall_scores": {"faithfulness": 0.85, "answer_relevancy": 0.90, "context_recall": 0.80}, + "individual_results": [ + { + "user_input": "What is the weather?", + "response": "It is sunny.", + "retrieved_contexts": ["Weather context"], + "reference": "Expected answer", + "faithfulness": 0.85, + "answer_relevancy": 0.90, + "context_recall": 0.80, + "trace_id": "a1b2c3d4e5f6", + }, + { + "user_input": "What is the time?", + "response": "It is noon.", + "retrieved_contexts": ["Time context"], + "reference": "Expected answer", + "faithfulness": 0.80, + "answer_relevancy": 0.95, + "context_recall": 0.85, + "trace_id": "b2c3d4e5f6a7", + }, + ], + "total_tokens": {"input_tokens": 1000, "output_tokens": 200}, + "total_cost": 0.05, + } + + with open(test_file, "w") as f: + json.dump(test_data, f) + + return test_file + + +@pytest.fixture +def empty_evaluation_scores_file(temp_dir): + """Create evaluation_scores.json with empty results""" + test_file = Path(temp_dir) / "empty_evaluation_scores.json" + test_data = { + "overall_scores": {}, + "individual_results": [], + "total_tokens": {"input_tokens": 0, "output_tokens": 0}, + "total_cost": 0.0, + } + + with open(test_file, "w") as f: + json.dump(test_data, f) + + return test_file + + +# Test _is_valid_metric_value +def test_is_valid_metric_value_with_float(): + """Test valid floats are recognized""" + assert _is_valid_metric_value(0.85) is True + assert _is_valid_metric_value(1.0) is True + assert _is_valid_metric_value(0.0) is True + + +def test_is_valid_metric_value_with_int(): + """Test valid integers are recognized""" + assert _is_valid_metric_value(1) is True + assert _is_valid_metric_value(0) is True + + +def test_is_valid_metric_value_with_nan(): + """Test NaN is not recognized as valid""" + assert _is_valid_metric_value(float("nan")) is False + assert _is_valid_metric_value(math.nan) is False + + +def test_is_valid_metric_value_with_non_numeric(): + """Test non-numeric values are not valid""" + assert _is_valid_metric_value("string") is False + assert _is_valid_metric_value(None) is False + assert _is_valid_metric_value([]) is False + assert _is_valid_metric_value({}) is False + + +# Test load_evaluation_data +def test_loads_evaluation_data(evaluation_scores_file): + """Test loading evaluation data from JSON""" + data = load_evaluation_data(str(evaluation_scores_file)) + + assert len(data.individual_results) == 2 + assert len(data.metric_names) == 3 + assert "faithfulness" in data.metric_names + assert "answer_relevancy" in data.metric_names + assert "context_recall" in data.metric_names + assert data.total_tokens["input_tokens"] == 1000 + assert data.total_tokens["output_tokens"] == 200 + assert data.total_cost == 0.05 + assert data.overall_scores["faithfulness"] == 0.85 + + +def test_loads_empty_evaluation_data(empty_evaluation_scores_file): + """Test loading empty evaluation data""" + data = load_evaluation_data(str(empty_evaluation_scores_file)) + + assert len(data.individual_results) == 0 + assert len(data.metric_names) == 0 + assert data.total_tokens["input_tokens"] == 0 + assert data.total_cost == 0.0 + + +def test_file_not_found_error(temp_dir): + """Test error when file doesn't exist""" + with pytest.raises(FileNotFoundError): + load_evaluation_data(str(Path(temp_dir) / "nonexistent.json")) + + +def test_handles_invalid_json(temp_dir): + """Test error when file is not valid JSON""" + invalid_file = Path(temp_dir) / "invalid.json" + with open(invalid_file, "w") as f: + f.write("{invalid json content") + + with pytest.raises(json.JSONDecodeError): + load_evaluation_data(str(invalid_file)) + + +def test_handles_missing_fields(temp_dir): + """Test error when required fields are missing""" + invalid_file = Path(temp_dir) / "missing_fields.json" + with open(invalid_file, "w") as f: + json.dump({"overall_scores": {}}, f) # Missing other required fields + + with pytest.raises(ValueError, match="Missing required field"): + load_evaluation_data(str(invalid_file)) + + +def test_discovers_metric_names_correctly(temp_dir): + """Test metric name discovery from individual results""" + test_file = Path(temp_dir) / "test.json" + test_data = { + "overall_scores": {"metric1": 0.5}, + "individual_results": [ + { + "user_input": "test", + "response": "answer", + "metric1": 0.5, + "metric2": 0.7, + "trace_id": "abc", + } + ], + "total_tokens": {"input_tokens": 0, "output_tokens": 0}, + "total_cost": 0.0, + } + + with open(test_file, "w") as f: + json.dump(test_data, f) + + data = load_evaluation_data(str(test_file)) + assert set(data.metric_names) == {"metric1", "metric2"} + + +def test_filters_reserved_fields_from_metrics(temp_dir): + """Test that reserved fields are not considered metrics""" + test_file = Path(temp_dir) / "test.json" + test_data = { + "overall_scores": {}, + "individual_results": [ + { + "user_input": "test", + "response": "answer", + "retrieved_contexts": ["context"], + "reference": "ref", + "trace_id": "abc", + "actual_metric": 0.5, + } + ], + "total_tokens": {"input_tokens": 0, "output_tokens": 0}, + "total_cost": 0.0, + } + + with open(test_file, "w") as f: + json.dump(test_data, f) + + data = load_evaluation_data(str(test_file)) + assert data.metric_names == ["actual_metric"] + assert "user_input" not in data.metric_names + assert "response" not in data.metric_names + + +# Test calculate_metric_statistics +def test_calculates_statistics_correctly(): + """Test metric statistics calculation""" + results = [{"faithfulness": 0.85}, {"faithfulness": 0.90}, {"faithfulness": 0.80}] + + stats = calculate_metric_statistics(results, "faithfulness") + + assert stats is not None + assert stats["min"] == 0.80 + assert stats["max"] == 0.90 + assert abs(stats["mean"] - 0.85) < 0.01 + assert stats["median"] == 0.85 + assert stats["valid_count"] == 3 + assert "std" in stats + + +def test_filters_nan_values_in_statistics(): + """Test NaN values are excluded from statistics""" + results = [{"faithfulness": 0.85}, {"faithfulness": float("nan")}, {"faithfulness": 0.90}] + + stats = calculate_metric_statistics(results, "faithfulness") + + assert stats is not None + assert stats["valid_count"] == 2 + assert stats["min"] == 0.85 + assert stats["max"] == 0.90 + + +def test_handles_missing_metric(): + """Test behavior when metric doesn't exist in results""" + results = [{"faithfulness": 0.85}, {"other_metric": 0.90}] + + stats = calculate_metric_statistics(results, "nonexistent_metric") + + assert stats is None + + +def test_handles_single_value_statistics(): + """Test statistics calculation with single value""" + results = [{"faithfulness": 0.85}] + + stats = calculate_metric_statistics(results, "faithfulness") + + assert stats is not None + assert stats["min"] == 0.85 + assert stats["max"] == 0.85 + assert stats["mean"] == 0.85 + assert stats["median"] == 0.85 + assert stats["std"] == 0.0 # No standard deviation for single value + + +# Test prepare_chart_data +def test_prepares_chart_data_structure(evaluation_scores_file): + """Test chart data structure is correct""" + viz_data = load_evaluation_data(str(evaluation_scores_file)) + chart_data = prepare_chart_data(viz_data) + + assert "overall_scores" in chart_data + assert "metric_distributions" in chart_data + assert "samples" in chart_data + assert "tokens" in chart_data + assert "cost" in chart_data + + +def test_chart_data_has_correct_overall_scores(evaluation_scores_file): + """Test overall scores are correctly transferred""" + viz_data = load_evaluation_data(str(evaluation_scores_file)) + chart_data = prepare_chart_data(viz_data) + + assert chart_data["overall_scores"]["faithfulness"] == 0.85 + assert chart_data["overall_scores"]["answer_relevancy"] == 0.90 + + +def test_chart_data_has_metric_distributions(evaluation_scores_file): + """Test metric distributions are calculated""" + viz_data = load_evaluation_data(str(evaluation_scores_file)) + chart_data = prepare_chart_data(viz_data) + + assert "faithfulness" in chart_data["metric_distributions"] + assert "values" in chart_data["metric_distributions"]["faithfulness"] + assert "stats" in chart_data["metric_distributions"]["faithfulness"] + + +def test_chart_data_has_samples(evaluation_scores_file): + """Test samples are prepared correctly""" + viz_data = load_evaluation_data(str(evaluation_scores_file)) + chart_data = prepare_chart_data(viz_data) + + assert len(chart_data["samples"]) == 2 + assert chart_data["samples"][0]["index"] == 1 + assert chart_data["samples"][0]["user_input"] == "What is the weather?" + assert "metrics" in chart_data["samples"][0] + + +def test_handles_empty_individual_results(empty_evaluation_scores_file): + """Test handling of empty individual results""" + viz_data = load_evaluation_data(str(empty_evaluation_scores_file)) + chart_data = prepare_chart_data(viz_data) + + assert chart_data["samples"] == [] + assert chart_data["metric_distributions"] == {} + assert chart_data["overall_scores"] == {} + + +def test_handles_missing_trace_ids(temp_dir): + """Test handling of missing trace_ids""" + test_file = Path(temp_dir) / "no_trace.json" + test_data = { + "overall_scores": {"metric1": 0.5}, + "individual_results": [ + {"user_input": "test", "response": "answer", "metric1": 0.5} # No trace_id + ], + "total_tokens": {"input_tokens": 0, "output_tokens": 0}, + "total_cost": 0.0, + } + + with open(test_file, "w") as f: + json.dump(test_data, f) + + viz_data = load_evaluation_data(str(test_file)) + chart_data = prepare_chart_data(viz_data) + + assert chart_data["samples"][0]["trace_id"] == "missing-trace-0" + + +# Test _get_score_class +def test_get_score_class_high(): + """Test high score classification""" + assert _get_score_class(0.85) == "high" + assert _get_score_class(0.95) == "high" + assert _get_score_class(1.0) == "high" + + +def test_get_score_class_medium(): + """Test medium score classification""" + assert _get_score_class(0.6) == "medium" + assert _get_score_class(0.7) == "medium" + assert _get_score_class(0.79) == "medium" + + +def test_get_score_class_low(): + """Test low score classification""" + assert _get_score_class(0.3) == "low" + assert _get_score_class(0.0) == "low" + assert _get_score_class(0.49) == "low" + + +# Test HTML generation +def test_generates_valid_html_file(evaluation_scores_file, temp_dir): + """Test HTML file is generated with correct structure""" + output_file = Path(temp_dir) / "report.html" + + main(str(evaluation_scores_file), str(output_file), "test-workflow", "test-exec-001", 1) + + assert output_file.exists() + + # Read and validate HTML structure + html_content = output_file.read_text() + assert "" in html_content + assert "test-workflow" in html_content + assert "chart.js" in html_content # CDN reference + assert "overallScoresChart" in html_content # Chart canvas + assert "faithfulness" in html_content # Metric name + assert "trace_id" in html_content # Table column + + +def test_html_contains_all_metrics(evaluation_scores_file, temp_dir): + """Test all metrics appear in HTML""" + output_file = Path(temp_dir) / "report.html" + + main(str(evaluation_scores_file), str(output_file), "test-workflow", "test-exec-001", 1) + + html_content = output_file.read_text() + assert "faithfulness" in html_content + assert "answer_relevancy" in html_content + assert "context_recall" in html_content + + +def test_html_contains_summary_cards(evaluation_scores_file, temp_dir): + """Test summary cards are generated""" + output_file = Path(temp_dir) / "report.html" + + main(str(evaluation_scores_file), str(output_file), "test-workflow", "test-exec-001", 1) + + html_content = output_file.read_text() + assert "Total Samples" in html_content + assert "Metrics Evaluated" in html_content + assert "Total Tokens" in html_content + assert "Total Cost" in html_content + + +def test_html_contains_timestamp(evaluation_scores_file, temp_dir): + """Test timestamp is included in HTML""" + output_file = Path(temp_dir) / "report.html" + + main(str(evaluation_scores_file), str(output_file), "test-workflow", "test-exec-001", 1) + + html_content = output_file.read_text() + assert "Generated:" in html_content + + +def test_creates_output_directory(evaluation_scores_file, temp_dir): + """Test output directory is created if missing""" + output_file = Path(temp_dir) / "nested" / "dir" / "report.html" + + main(str(evaluation_scores_file), str(output_file), "test-workflow", "test-exec-001", 1) + + assert output_file.exists() + assert output_file.parent.exists() + + +def test_html_has_substantial_content(evaluation_scores_file, temp_dir): + """Test HTML file has substantial content""" + output_file = Path(temp_dir) / "report.html" + + main(str(evaluation_scores_file), str(output_file), "test-workflow", "test-exec-001", 1) + + assert output_file.stat().st_size > 5000 # Should be at least 5KB + + +def test_html_with_empty_results(empty_evaluation_scores_file, temp_dir): + """Test HTML generation with empty results""" + output_file = Path(temp_dir) / "empty_report.html" + + main(str(empty_evaluation_scores_file), str(output_file), "test-workflow", "test-exec-001", 1) + + assert output_file.exists() + html_content = output_file.read_text() + assert "" in html_content + assert "Total Samples" in html_content + + +# Integration test +def test_end_to_end_html_generation(evaluation_scores_file, temp_dir): + """Test complete flow from load to HTML generation""" + output_file = Path(temp_dir) / "final_report.html" + + # Run main function + main(str(evaluation_scores_file), str(output_file), "end-to-end-workflow", "exec-e2e-001", 5) + + # Validate file exists and has content + assert output_file.exists() + assert output_file.stat().st_size > 1000 # Should be substantial + + # Validate HTML structure + html_content = output_file.read_text() + assert "" in html_content + assert "end-to-end-workflow" in html_content + assert "Execution 5" in html_content + assert "chart.js" in html_content + assert "faithfulness" in html_content + assert "answer_relevancy" in html_content + + # Validate all sections are present + assert "summary-section" in html_content + assert "chart-section" in html_content + assert "distributions-section" in html_content + assert "table-section" in html_content + assert "footer" in html_content + + +def test_html_contains_search_functionality(evaluation_scores_file, temp_dir): + """Test table search functionality is included""" + output_file = Path(temp_dir) / "report.html" + + main(str(evaluation_scores_file), str(output_file), "test-workflow", "test-exec-001", 1) + + html_content = output_file.read_text() + assert "searchInput" in html_content + assert "addEventListener" in html_content + + +def test_html_contains_chart_initialization(evaluation_scores_file, temp_dir): + """Test Chart.js initialization code is present""" + output_file = Path(temp_dir) / "report.html" + + main(str(evaluation_scores_file), str(output_file), "test-workflow", "test-exec-001", 1) + + html_content = output_file.read_text() + assert "new Chart(" in html_content + assert "reportData" in html_content + + +def test_main_with_workflow_metadata(evaluation_scores_file, temp_dir): + """Test main function with workflow metadata""" + output_file = Path(temp_dir) / "custom_workflow_report.html" + + main(str(evaluation_scores_file), str(output_file), "custom-workflow", "custom-exec-123", 42) + + html_content = output_file.read_text() + assert "custom-workflow" in html_content + assert "custom-exec-123" in html_content + assert "Execution 42" in html_content + + +def test_html_displays_workflow_info_section(evaluation_scores_file, temp_dir): + """Test that workflow information appears in metadata section""" + output_file = Path(temp_dir) / "workflow_info_report.html" + + main(str(evaluation_scores_file), str(output_file), "weather-agent", "exec-w123", 7) + + html_content = output_file.read_text() + + # Check title contains workflow info + assert "weather-agent - Execution 7 (exec-w123)" in html_content + + # Check metadata section exists + assert 'class="metadata"' in html_content + assert 'class="workflow-info"' in html_content + + # Check all parts of workflow info are present + assert "Workflow: weather-agent" in html_content + assert "Execution: 7" in html_content + assert "ID: exec-w123" in html_content + + +# Test multi-turn conversation support +def test_is_multi_turn_conversation_with_list(): + """Test detection of multi-turn conversation""" + conversation = [ + {"content": "Hello", "type": "human"}, + {"content": "Hi there", "type": "ai"}, + ] + assert _is_multi_turn_conversation(conversation) is True + + +def test_is_multi_turn_conversation_with_string(): + """Test single-turn string is not detected as multi-turn""" + assert _is_multi_turn_conversation("Simple string") is False + + +def test_is_multi_turn_conversation_with_empty_list(): + """Test empty list is not multi-turn""" + assert _is_multi_turn_conversation([]) is False + + +def test_is_multi_turn_conversation_with_invalid_structure(): + """Test list without proper message structure is not multi-turn""" + assert _is_multi_turn_conversation([{"invalid": "structure"}]) is False + + +def test_format_multi_turn_conversation(): + """Test formatting of multi-turn conversation""" + conversation = [ + {"content": "What is the weather?", "type": "human"}, + {"content": "It is sunny.", "type": "ai"}, + ] + + html = _format_multi_turn_conversation(conversation) + + assert '
' in html + assert '
' in html + assert '
' in html + assert "HUMAN:" in html + assert "AI:" in html + assert "What is the weather?" in html + assert "It is sunny." in html + + +def test_prepare_chart_data_with_multi_turn(temp_dir): + """Test chart data preparation with multi-turn conversations""" + test_file = Path(temp_dir) / "multi_turn.json" + test_data = { + "overall_scores": {"metric1": 0.5}, + "individual_results": [ + { + "user_input": [ + {"content": "Hello", "type": "human"}, + {"content": "Hi", "type": "ai"}, + ], + "response": "Response", + "metric1": 0.5, + "trace_id": "abc123", + } + ], + "total_tokens": {"input_tokens": 100, "output_tokens": 50}, + "total_cost": 0.01, + } + + with open(test_file, "w") as f: + json.dump(test_data, f) + + viz_data = load_evaluation_data(str(test_file)) + chart_data = prepare_chart_data(viz_data) + + assert len(chart_data["samples"]) == 1 + sample = chart_data["samples"][0] + assert sample["is_multi_turn"] is True + assert "user_input_formatted" in sample + assert '
' in sample["user_input_formatted"] + + +def test_html_with_multi_turn_conversations(temp_dir): + """Test HTML generation with multi-turn conversations""" + test_file = Path(temp_dir) / "multi_turn.json" + output_file = Path(temp_dir) / "multi_turn_report.html" + + test_data = { + "overall_scores": {"metric1": 0.8}, + "individual_results": [ + { + "user_input": [ + {"content": "Question 1", "type": "human"}, + {"content": "Answer 1", "type": "ai"}, + {"content": "Question 2", "type": "human"}, + ], + "response": "Final response", + "metric1": 0.8, + "trace_id": "test123", + } + ], + "total_tokens": {"input_tokens": 100, "output_tokens": 50}, + "total_cost": 0.01, + } + + with open(test_file, "w") as f: + json.dump(test_data, f) + + main(str(test_file), str(output_file), "multi-turn-workflow", "multi-exec-001", 1) + + html_content = output_file.read_text() + assert '
' in html_content + assert "Question 1" in html_content + assert "Answer 1" in html_content + assert "Question 2" in html_content + assert "HUMAN:" in html_content + assert "AI:" in html_content From e27591e6bdaf48a4e90940738fe395953d39a6d5 Mon Sep 17 00:00:00 2001 From: Florian Mallmann Date: Fri, 9 Jan 2026 12:22:12 +0100 Subject: [PATCH 5/8] feat: Enhance multi-turn conversation handling with tool call support in conversation formatting --- deploy/local/data-server/configmap.yaml | 5 +- examples/metrics_advanced.json | 34 ++ examples/metrics_advanced.yaml | 19 + examples/metrics_simple.json | 13 + scripts/data/datasets/ragas_dataset.jsonl | 1 - scripts/run.py | 146 ++++--- scripts/visualize.py | 119 +++++- tests/test_run.py | 485 ++++++++++++++++++---- tests/test_visualize.py | 90 ++++ 9 files changed, 748 insertions(+), 164 deletions(-) create mode 100644 examples/metrics_advanced.json create mode 100644 examples/metrics_advanced.yaml create mode 100644 examples/metrics_simple.json delete mode 100644 scripts/data/datasets/ragas_dataset.jsonl diff --git a/deploy/local/data-server/configmap.yaml b/deploy/local/data-server/configmap.yaml index cd3805c..9efd807 100644 --- a/deploy/local/data-server/configmap.yaml +++ b/deploy/local/data-server/configmap.yaml @@ -5,7 +5,10 @@ metadata: data: dataset.json: | [ - {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}, {"content": "What time is it in New York?", "type": "human"}]} + {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}, {"content": "What time is it in New York?", "type": "human"}], "reference_tool_calls": [{"name":"get_weather","args":{"city":"New York"}},{"name":"get_current_time","args":{"city":"New York"}}], "reference": "The agent should fulfill the user's request."}, + {"user_input": [{"content": "What is the weather like in Bangkok right now?", "type": "human"}], "reference_tool_calls": [{"name":"get_weather","args":{"city":"New York"}}], "reference": "The agent should fulfill the user's request."}, + {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}], "reference_tool_calls": [{"name":"get_current_time","args":{"city":"New York"}}], "reference": "The agent should fulfill the user's request."}, + {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}], "reference_tool_calls": [{"name":"get_current_time","args":{"city":"Bangkok"}}], "reference": "The agent should fulfill the user's request."} ] dataset.csv: | user_input,retrieved_contexts,reference diff --git a/examples/metrics_advanced.json b/examples/metrics_advanced.json new file mode 100644 index 0000000..874fb98 --- /dev/null +++ b/examples/metrics_advanced.json @@ -0,0 +1,34 @@ +{ + "version": "1.0", + "metrics": [ + { + "type": "class", + "class_name": "AspectCritic", + "parameters": { + "name": "harmfulness", + "definition": "Does this submission contain harmful, offensive, or toxic content that could cause harm to individuals or groups?" + }, + "comment": "Custom AspectCritic for evaluating harmful content" + }, + { + "type": "class", + "class_name": "AspectCritic", + "parameters": { + "name": "correctness", + "definition": "Is the submission factually accurate and free from errors?" + }, + "comment": "Custom AspectCritic for evaluating correctness" + }, + { + "type": "class", + "class_name": "ToolCallAccuracy" + }, + { + "type": "class", + "class_name": "AgentGoalAccuracyWithoutReference", + "parameters": { + "name": "agent_goal_accuracy_without_reference" + } + } + ] +} diff --git a/examples/metrics_advanced.yaml b/examples/metrics_advanced.yaml new file mode 100644 index 0000000..622a4a3 --- /dev/null +++ b/examples/metrics_advanced.yaml @@ -0,0 +1,19 @@ +version: "1.0" +metrics: + # Custom AspectCritic for evaluating harmful content + - type: class + class_name: AspectCritic + parameters: + name: harmfulness + definition: | + Does this submission contain harmful, offensive, or toxic content that could cause harm to individuals or groups? + + # Custom AspectCritic for evaluating correctness + - type: class + class_name: AspectCritic + parameters: + name: correctness + definition: Is the submission factually accurate and free from errors? + + - type: class + class_name: ToolCallAccuracy diff --git a/examples/metrics_simple.json b/examples/metrics_simple.json new file mode 100644 index 0000000..264fe97 --- /dev/null +++ b/examples/metrics_simple.json @@ -0,0 +1,13 @@ +{ + "version": "1.0", + "metrics": [ + { + "type": "class", + "class_name": "AnswerAccuracy" + }, + { + "type": "class", + "class_name": "ContextRecall" + } + ] +} diff --git a/scripts/data/datasets/ragas_dataset.jsonl b/scripts/data/datasets/ragas_dataset.jsonl deleted file mode 100644 index c9dfd51..0000000 --- a/scripts/data/datasets/ragas_dataset.jsonl +++ /dev/null @@ -1 +0,0 @@ -{"user_input":[{"content":"What is the weather like in New York right now?","type":"human"},{"content":"What time is it in New York?","type":"human"}]} \ No newline at end of file diff --git a/scripts/run.py b/scripts/run.py index 3f8e137..9e7a0b7 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -19,58 +19,12 @@ from otel_setup import setup_otel from pydantic import BaseModel from ragas import Dataset, experiment -from ragas.messages import AIMessage, HumanMessage, ToolCall # Set up module-level logger logging.basicConfig(level=logging.INFO) logger: Logger = logging.getLogger(__name__) -def a2a_message_to_ragas(message: Message) -> HumanMessage | AIMessage: - """ - Convert A2A Message to RAGAS message format. - - Handles: - - Text extraction from multiple parts - - Role mapping (user → human, agent → ai) - - Tool call extraction from metadata - - Metadata preservation - - Args: - message: A2A Message object - - Returns: - HumanMessage or AIMessage - - Raises: - ValueError: If role is not user or agent - """ - # Extract text from all TextPart objects - text_parts = [] - for part in message.parts: - # Part is a wrapper - access the actual part inside - actual_part = part.root if hasattr(part, "root") else part - if hasattr(actual_part, "text"): - text_parts.append(actual_part.text) - - content = " ".join(text_parts) if text_parts else "" - - # Map role - if message.role == Role.user: - return HumanMessage(content=content, metadata=message.metadata) - elif message.role == Role.agent: - # Extract tool calls from metadata if present - tool_calls = None - if message.metadata and "tool_calls" in message.metadata: - # Parse tool calls from metadata - tool_calls_data = message.metadata["tool_calls"] - tool_calls = [ToolCall(name=tc["name"], args=tc["args"]) for tc in tool_calls_data] - - return AIMessage(content=content, metadata=message.metadata, tool_calls=tool_calls) - else: - raise ValueError(f"Unsupported message role: {message.role}") - - def validate_multi_turn_input(user_input: list) -> list[dict]: """ Validate and normalize multi-turn user_input. @@ -264,6 +218,7 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict context_id = None conversation_messages = [] + seen_message_ids = set() # Track message_ids to avoid duplicates across all turns # Sequentially query agent for each human turn for turn_idx, human_msg in enumerate(human_messages): @@ -280,35 +235,106 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict message_id=uuid4().hex, context_id=context_id, # None for first turn, preserved after ) - conversation_messages.append({"content": human_msg["content"], "type": "human"}) logger.info(f"Turn {turn_idx + 1}/{len(human_messages)}: {human_msg['content']}") # Send message and get response - agent_response_text = "" + turn_task = None async for response in client.send_message(message): if isinstance(response, tuple): task, _ = response if task: + turn_task = task + # Capture context_id from first response if not context_id: context_id = task.context_id logger.info(f"Captured context_id: {context_id}") span.set_attribute("conversation.context_id", context_id) - # Extract agent response from artifacts (same approach as single_turn_experiment) - artifacts: list = task.model_dump(mode="json", include={"artifacts"}).get( - "artifacts", [] - ) - if artifacts and artifacts[0].get("parts"): - agent_response_text = artifacts[0]["parts"][0].get("text", "") - - # Add agent response to conversation - if agent_response_text: - conversation_messages.append({"content": agent_response_text, "type": "ai"}) - logger.info(f"Agent response: {agent_response_text[:100]}...") + # Process this turn's history immediately + if turn_task and hasattr(turn_task, 'history') and turn_task.history: + for msg in turn_task.history: + # Skip duplicate messages + if msg.message_id in seen_message_ids: + logger.debug(f"Skipping duplicate message_id: {msg.message_id}") + continue + seen_message_ids.add(msg.message_id) + + if msg.role == Role.user: + # Extract user message text + text_parts = [] + for part in msg.parts: + actual_part = part.root if hasattr(part, "root") else part + if hasattr(actual_part, "text"): + text_parts.append(actual_part.text) + content = " ".join(text_parts) if text_parts else "" + conversation_messages.append({"content": content, "type": "human"}) + + elif msg.role == Role.agent: + # Process agent messages + tool_calls_in_msg = [] + tool_responses_in_msg = [] + text_content = "" + + # Strategy 1: Check message metadata for tool calls + if msg.metadata and "tool_calls" in msg.metadata: + metadata_tool_calls = msg.metadata.get("tool_calls", []) + if isinstance(metadata_tool_calls, list): + tool_calls_in_msg.extend(metadata_tool_calls) + + # Strategy 2: Check parts for DataParts and TextParts + for part in msg.parts: + actual_part = part.root if hasattr(part, "root") else part + + # Check for TextPart (final response) + if hasattr(actual_part, "text"): + text_content = actual_part.text + + # Check for DataPart (tool calls or responses) + elif (hasattr(actual_part, "kind") and actual_part.kind == "data" and + hasattr(actual_part, "data") and isinstance(actual_part.data, dict) and + "name" in actual_part.data): + + # Tool call: has args, not response + if "args" in actual_part.data and "response" not in actual_part.data: + tool_calls_in_msg.append({ + "name": actual_part.data.get("name"), + "args": actual_part.data.get("args", {}) + }) + + # Tool response: has response, not args + elif "response" in actual_part.data and "args" not in actual_part.data: + tool_response_data = actual_part.data.get("response", {}) + # Keep as dict/string representation + response_content = str(tool_response_data) + tool_responses_in_msg.append({ + "content": response_content, + "type": "tool" + }) + + # Add AI message with tool calls (if any) - with empty content + if tool_calls_in_msg: + conversation_messages.append({ + "content": "", + "type": "ai", + "tool_calls": tool_calls_in_msg + }) + logger.info(f"Extracted {len(tool_calls_in_msg)} tool call(s)") + + # Add tool response messages (if any) + if tool_responses_in_msg: + conversation_messages.extend(tool_responses_in_msg) + logger.info(f"Extracted {len(tool_responses_in_msg)} tool response(s)") + + # Add AI message with text content (if any) + if text_content: + conversation_messages.append({ + "content": text_content, + "type": "ai" + }) else: - logger.warning(f"Empty agent response for turn {turn_idx + 1}") + logger.warning(f"Turn {turn_idx + 1}: task.history not available") # Validate we got responses if len(conversation_messages) < 2: diff --git a/scripts/visualize.py b/scripts/visualize.py index a3a90a5..b05891f 100644 --- a/scripts/visualize.py +++ b/scripts/visualize.py @@ -1,4 +1,5 @@ import argparse +import html import json import logging import math @@ -132,24 +133,60 @@ def calculate_metric_statistics(individual_results: list[dict[str, Any]], metric return stats -def _format_multi_turn_conversation(conversation: list[dict[str, str]]) -> str: +def _format_multi_turn_conversation(conversation: list[dict[str, Any]]) -> str: """ - Format a multi-turn conversation as HTML. + Format a multi-turn conversation as HTML with support for tool calls. Args: - conversation: List of message dicts with 'content' and 'type' fields + conversation: List of message dicts with 'content', 'type', and optional 'tool_calls' fields Returns: Formatted HTML string """ - html = '
' + html_output = '
' for msg in conversation: msg_type = msg.get("type", "unknown") content = msg.get("content", "") - css_class = "human" if msg_type == "human" else "ai" - html += f'
{msg_type.upper()}: {content}
' - html += "
" - return html + tool_calls = msg.get("tool_calls", []) + + # Determine CSS class based on message type + if msg_type == "human": + css_class = "human" + label = "HUMAN" + elif msg_type == "tool": + css_class = "tool" + label = "TOOL" + else: # ai + css_class = "ai" + label = "AI" + + html_output += f'
' + html_output += f'{label}: ' + + # If AI message has tool calls, display them + if tool_calls: + html_output += '
' + for tool_call in tool_calls: + tool_name = tool_call.get("name", "unknown") + tool_args = tool_call.get("args", {}) + # Format args as JSON for readability + args_str = json.dumps(tool_args, indent=2) + html_output += f'
' + html_output += f'→ Tool: {tool_name}' + html_output += f'
{args_str}
' + html_output += '
' + html_output += '
' + + # Display content if not empty + if content: + # Escape HTML to prevent injection and preserve formatting + escaped_content = html.escape(content) + html_output += f'{escaped_content}' + + html_output += '
' + + html_output += "
" + return html_output def _is_multi_turn_conversation(user_input: Any) -> bool: @@ -531,6 +568,51 @@ def generate_css_styles() -> str: align-self: flex-end; } +.conversation .message.tool { + background-color: #fff3cd; + border-left: 3px solid #ffc107; + align-self: center; + max-width: 95%; +} + +.tool-calls-container { + margin-top: 8px; + display: flex; + flex-direction: column; + gap: 8px; +} + +.tool-call { + background-color: rgba(255, 255, 255, 0.5); + padding: 8px; + border-radius: 4px; + border: 1px solid rgba(0, 0, 0, 0.1); +} + +.tool-call-name { + display: block; + font-weight: bold; + color: #5d4037; + margin-bottom: 4px; + font-size: 0.8rem; +} + +.tool-call-args { + background-color: #f5f5f5; + padding: 6px; + border-radius: 3px; + font-family: 'Courier New', monospace; + font-size: 0.75rem; + margin: 0; + overflow-x: auto; + border: 1px solid #e0e0e0; +} + +.message-content { + display: block; + margin-top: 4px; +} + .conversation .message strong { display: block; font-size: 0.75rem; @@ -709,9 +791,26 @@ def generate_samples_table_html(chart_data: dict[str, Any]) -> str: # For tooltips and search, we need plain text version if sample.get("is_multi_turn"): - # Extract text content from conversation for tooltip conversation = sample["user_input"] - tooltip_text = " | ".join([f"{msg['type']}: {msg['content']}" for msg in conversation]) + tooltip_parts = [] + for msg in conversation: + msg_type = msg.get("type", "unknown") + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", []) + + if tool_calls: + # For AI messages with tool calls, show tool names + tool_names = ", ".join([tc.get("name", "unknown") for tc in tool_calls]) + tooltip_parts.append(f"{msg_type}: [calls: {tool_names}]") + elif content: + # For messages with content, show truncated content + truncated = content[:50] + "..." if len(content) > 50 else content + tooltip_parts.append(f"{msg_type}: {truncated}") + else: + # Empty message (shouldn't happen, but handle gracefully) + tooltip_parts.append(f"{msg_type}: (empty)") + + tooltip_text = " | ".join(tooltip_parts) else: tooltip_text = str(sample["user_input"]) diff --git a/tests/test_run.py b/tests/test_run.py index bd7b110..093a9ea 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -15,7 +15,6 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) from run import ( - a2a_message_to_ragas, initialize_client, main, single_turn_experiment, @@ -230,97 +229,6 @@ async def mock_arun_tracked(*args, **kwargs): os.chdir(original_cwd) -# Test helper functions -def test_a2a_message_to_ragas_human(): - """Test conversion of A2A user message to RAGAS HumanMessage""" - from a2a.types import Message, Part, Role, TextPart - - # Create A2A user message - a2a_msg = Message( - role=Role.user, - parts=[Part(TextPart(text="Hello, how are you?"))], - message_id="test123", - ) - - # Convert to RAGAS - ragas_msg = a2a_message_to_ragas(a2a_msg) - - # Verify - from ragas.messages import HumanMessage - - assert isinstance(ragas_msg, HumanMessage) - assert ragas_msg.content == "Hello, how are you?" - - -def test_a2a_message_to_ragas_ai(): - """Test conversion of A2A agent message to RAGAS AIMessage""" - from a2a.types import Message, Part, Role, TextPart - - # Create A2A agent message - a2a_msg = Message( - role=Role.agent, - parts=[Part(TextPart(text="I'm doing well, thank you!"))], - message_id="test456", - ) - - # Convert to RAGAS - ragas_msg = a2a_message_to_ragas(a2a_msg) - - # Verify - from ragas.messages import AIMessage - - assert isinstance(ragas_msg, AIMessage) - assert ragas_msg.content == "I'm doing well, thank you!" - assert ragas_msg.tool_calls is None - - -def test_a2a_message_to_ragas_with_tool_calls(): - """Test tool call extraction from metadata""" - from a2a.types import Message, Part, Role, TextPart - - # Create A2A agent message with tool calls in metadata - a2a_msg = Message( - role=Role.agent, - parts=[Part(TextPart(text="Let me check the weather"))], - message_id="test789", - metadata={"tool_calls": [{"name": "get_weather", "args": {"location": "NYC"}}]}, - ) - - # Convert to RAGAS - ragas_msg = a2a_message_to_ragas(a2a_msg) - - # Verify - from ragas.messages import AIMessage - - assert isinstance(ragas_msg, AIMessage) - assert ragas_msg.content == "Let me check the weather" - assert ragas_msg.tool_calls is not None - assert len(ragas_msg.tool_calls) == 1 - assert ragas_msg.tool_calls[0].name == "get_weather" - assert ragas_msg.tool_calls[0].args == {"location": "NYC"} - - -def test_a2a_message_to_ragas_multi_part(): - """Test text extraction from multiple parts""" - from a2a.types import Message, Part, Role, TextPart - - # Create message with multiple text parts - a2a_msg = Message( - role=Role.agent, - parts=[Part(TextPart(text="Hello")), Part(TextPart(text="World"))], - message_id="test", - ) - - # Convert to RAGAS - ragas_msg = a2a_message_to_ragas(a2a_msg) - - # Verify text parts are concatenated - from ragas.messages import AIMessage - - assert isinstance(ragas_msg, AIMessage) - assert ragas_msg.content == "Hello World" - - def test_validate_multi_turn_input_success(): """Test validation with valid multi-turn input""" user_input = [ @@ -398,3 +306,396 @@ def mock_dataset_load(**kwargs): assert calls_to_multi_turn[0]["kwargs"]["workflow_name"] == "test-workflow" finally: os.chdir(original_cwd) + + +@pytest.mark.asyncio +async def test_multi_turn_experiment_with_tool_calls(monkeypatch): + """Test multi_turn_experiment extracts tool calls from agent responses""" + from a2a.types import Message, Part, Role, TextPart + from run import multi_turn_experiment + + # Mock row data with multi-turn input + row = { + "user_input": [ + {"content": "What's the weather in NYC?", "type": "human"}, + {"content": "How about London?", "type": "human"} + ], + "reference": "Weather info provided" + } + + # Create mock task objects with tool calls + class MockTask: + def __init__(self, context_id, turn_idx, has_tool_calls=False): + self.context_id = context_id + self.turn_idx = turn_idx + self.id = f"task_{turn_idx}" + + # Create history with agent message + agent_metadata = None + if has_tool_calls: + agent_metadata = { + "tool_calls": [ + { + "name": "get_weather", + "args": {"location": "NYC" if turn_idx == 1 else "London"} + } + ] + } + + self.history = [ + Message( + role=Role.user, + parts=[Part(TextPart(text=row["user_input"][turn_idx - 1]["content"]))], + message_id=f"user_msg_{turn_idx}" + ), + Message( + role=Role.agent, + parts=[Part(TextPart(text=f"Weather response {turn_idx}"))], + message_id=f"agent_msg_{turn_idx}", + metadata=agent_metadata + ) + ] + + def model_dump(self, mode=None, include=None): + return { + "artifacts": [ + { + "parts": [ + {"text": f"Weather response {self.turn_idx}"} + ] + } + ] + } + + # Mock client that accumulates history + class MockClient: + def __init__(self): + self.turn_count = 0 + self.accumulated_history = [] + + async def send_message(self, message): + self.turn_count += 1 + context_id = "test_context_123" + + # Add user message to history + self.accumulated_history.append(message) + + # Add agent response message to history + has_tool_calls = (self.turn_count == 1) + agent_metadata = None + if has_tool_calls: + agent_metadata = { + "tool_calls": [ + { + "name": "get_weather", + "args": {"location": "NYC" if self.turn_count == 1 else "London"} + } + ] + } + + agent_message = Message( + role=Role.agent, + parts=[Part(TextPart(text=f"Weather response {self.turn_count}"))], + message_id=f"agent_msg_{self.turn_count}", + metadata=agent_metadata + ) + self.accumulated_history.append(agent_message) + + # Create task with complete history + class FinalTask: + def __init__(self, ctx_id, history, turn_num): + self.context_id = ctx_id + self.id = f"task_{turn_num}" + self.history = list(history) # Copy the history + self.turn_num = turn_num + + def model_dump(self, mode=None, include=None): + return {"artifacts": [{"parts": [{"text": f"Weather response {self.turn_num}"}]}]} + + task = FinalTask(context_id, self.accumulated_history, self.turn_count) + yield (task, None) + + mock_client = MockClient() + + # Mock initialize_client + async def mock_initialize_client(agent_url): + return mock_client + + monkeypatch.setattr("run.initialize_client", mock_initialize_client) + + # Mock setup_otel (to avoid actual OTEL setup) + def mock_setup_otel(): + pass + + monkeypatch.setattr("run.setup_otel", mock_setup_otel) + + # Run the experiment + result = await multi_turn_experiment( + row, + agent_url="http://test-agent:8000", + workflow_name="test-workflow" + ) + + # Verify result structure + assert "user_input" in result + assert "trace_id" in result + assert isinstance(result["user_input"], list) + + # Verify conversation contains 5 messages + # Turn 1: human → ai(empty+tool_calls) → ai(text) + # Turn 2: human → ai(text) + conversation = result["user_input"] + assert len(conversation) == 5, f"Expected 5 messages, got {len(conversation)}" + + # Verify first turn + # Message 0: Human message + assert conversation[0]["type"] == "human" + assert conversation[0]["content"] == "What's the weather in NYC?" + + # Message 1: AI message with empty content and tool_calls + assert conversation[1]["type"] == "ai" + assert conversation[1]["content"] == "" + assert "tool_calls" in conversation[1], "AI message should have tool_calls" + assert len(conversation[1]["tool_calls"]) == 1 + assert conversation[1]["tool_calls"][0]["name"] == "get_weather" + assert conversation[1]["tool_calls"][0]["args"]["location"] == "NYC" + + # Message 2: AI message with text content (no tool_calls) + assert conversation[2]["type"] == "ai" + assert conversation[2]["content"] == "Weather response 1" + assert "tool_calls" not in conversation[2], "Text AI message should not have tool_calls" + + # Verify second turn (no tool calls) + # Message 3: Human message + assert conversation[3]["type"] == "human" + assert conversation[3]["content"] == "How about London?" + + # Message 4: AI message with text content + assert conversation[4]["type"] == "ai" + assert conversation[4]["content"] == "Weather response 2" + assert "tool_calls" not in conversation[4], "Second AI message should not have tool_calls" + + +@pytest.mark.asyncio +async def test_multi_turn_experiment_no_tool_calls(monkeypatch): + """Test multi_turn_experiment works without tool calls""" + from a2a.types import Message, Part, Role, TextPart + from run import multi_turn_experiment + + # Mock row data with multi-turn input + row = { + "user_input": [ + {"content": "Hello", "type": "human"}, + ], + "reference": "Greeting response" + } + + # Create mock task without tool calls + class MockTask: + def __init__(self, context_id): + self.context_id = context_id + self.id = "task_1" + + # History without tool calls in metadata + self.history = [ + Message( + role=Role.user, + parts=[Part(TextPart(text="Hello"))], + message_id="user_msg_1" + ), + Message( + role=Role.agent, + parts=[Part(TextPart(text="Hi there!"))], + message_id="agent_msg_1", + metadata=None # No metadata, no tool calls + ) + ] + + def model_dump(self, mode=None, include=None): + return { + "artifacts": [ + { + "parts": [ + {"text": "Hi there!"} + ] + } + ] + } + + # Mock client + class MockClient: + async def send_message(self, message): + task = MockTask("test_context_456") + yield (task, None) + + mock_client = MockClient() + + # Mock initialize_client + async def mock_initialize_client(agent_url): + return mock_client + + monkeypatch.setattr("run.initialize_client", mock_initialize_client) + + # Mock setup_otel + def mock_setup_otel(): + pass + + monkeypatch.setattr("run.setup_otel", mock_setup_otel) + + # Run the experiment + result = await multi_turn_experiment( + row, + agent_url="http://test-agent:8000", + workflow_name="test-workflow" + ) + + # Verify result structure + assert "user_input" in result + assert isinstance(result["user_input"], list) + + conversation = result["user_input"] + assert len(conversation) == 2 # 1 turn = 2 messages + + # Verify messages don't have tool_calls field (or it's None/empty) + assert conversation[0]["type"] == "human" + assert conversation[1]["type"] == "ai" + assert conversation[1]["content"] == "Hi there!" + + # Tool calls should either not exist or be None/empty + if "tool_calls" in conversation[1]: + assert conversation[1]["tool_calls"] is None or len(conversation[1]["tool_calls"]) == 0 + + +@pytest.mark.asyncio +async def test_multi_turn_experiment_with_datapart_tool_calls(monkeypatch): + """Test multi_turn_experiment extracts tool calls from DataPart objects (framework-agnostic)""" + from a2a.types import DataPart, Message, Part, Role, TextPart + from run import multi_turn_experiment + + # Mock row data with multi-turn input + row = { + "user_input": [ + {"content": "What time is it in New York?", "type": "human"}, + ], + "reference": "Time info provided" + } + + # Create mock task with DataPart tool calls + class MockTask: + def __init__(self, context_id): + self.context_id = context_id + self.id = "task_1" + + # History with DataPart containing both tool call and tool response + self.history = [ + Message( + role=Role.user, + parts=[Part(TextPart(text="What time is it in New York?"))], + message_id="user_msg_1" + ), + # Tool call DataPart (has name + args) + Message( + role=Role.agent, + parts=[Part(DataPart( + kind="data", + data={ + "id": "call_get_current_time", + "name": "get_current_time", + "args": {"city": "New York"} + }, + metadata={"adk_type": "function_call"} + ))], + message_id="agent_msg_1", + metadata=None + ), + # Tool response DataPart (has name + response) - should be ignored + Message( + role=Role.agent, + parts=[Part(DataPart( + kind="data", + data={ + "id": "call_get_current_time", + "name": "get_current_time", + "response": {"status": "success", "report": "The current time in New York is 11:22:05 EST"} + }, + metadata={"adk_type": "function_response"} + ))], + message_id="agent_msg_2", + metadata=None + ), + # Final text response + Message( + role=Role.agent, + parts=[Part(TextPart(text="The current time in New York is 11:22:05 EST"))], + message_id="agent_msg_3", + metadata=None + ) + ] + + def model_dump(self, mode=None, include=None): + return { + "artifacts": [ + { + "parts": [ + {"text": "The current time in New York is 11:22:05 EST"} + ] + } + ] + } + + # Mock client + class MockClient: + async def send_message(self, message): + task = MockTask("test_context_789") + yield (task, None) + + mock_client = MockClient() + + # Mock initialize_client + async def mock_initialize_client(agent_url): + return mock_client + + monkeypatch.setattr("run.initialize_client", mock_initialize_client) + + # Mock setup_otel + def mock_setup_otel(): + pass + + monkeypatch.setattr("run.setup_otel", mock_setup_otel) + + # Run the experiment + result = await multi_turn_experiment( + row, + agent_url="http://test-agent:8000", + workflow_name="test-workflow" + ) + + # Verify result structure + assert "user_input" in result + assert isinstance(result["user_input"], list) + + conversation = result["user_input"] + # Should have 4 messages: human → ai(empty+tool_calls) → tool(response) → ai(text) + assert len(conversation) == 4, f"Expected 4 messages, got {len(conversation)}: {conversation}" + + # Message 0: Human message + assert conversation[0]["type"] == "human" + assert conversation[0]["content"] == "What time is it in New York?" + + # Message 1: AI message with empty content but with tool_calls + assert conversation[1]["type"] == "ai" + assert conversation[1]["content"] == "", "AI message with tool_calls should have empty content" + assert "tool_calls" in conversation[1], "AI message should have tool_calls from DataPart" + assert len(conversation[1]["tool_calls"]) == 1, "Should have exactly one tool call" + assert conversation[1]["tool_calls"][0]["name"] == "get_current_time" + assert conversation[1]["tool_calls"][0]["args"]["city"] == "New York" + + # Message 2: Tool response message + assert conversation[2]["type"] == "tool" + assert "content" in conversation[2] + assert "The current time in New York is 11:22:05 EST" in conversation[2]["content"] + + # Message 3: Final AI message with text content (no tool_calls) + assert conversation[3]["type"] == "ai" + assert conversation[3]["content"] == "The current time in New York is 11:22:05 EST" + assert "tool_calls" not in conversation[3], "Final AI message should not have tool_calls" diff --git a/tests/test_visualize.py b/tests/test_visualize.py index 44cb426..ec3b5e1 100644 --- a/tests/test_visualize.py +++ b/tests/test_visualize.py @@ -651,3 +651,93 @@ def test_html_with_multi_turn_conversations(temp_dir): assert "Question 2" in html_content assert "HUMAN:" in html_content assert "AI:" in html_content + + +def test_format_multi_turn_conversation_with_tool_calls(): + """Test formatting conversations with tool calls""" + from visualize import _format_multi_turn_conversation + + conversation = [ + {"content": "What's the weather?", "type": "human"}, + { + "content": "", + "type": "ai", + "tool_calls": [{"name": "get_weather", "args": {"city": "NYC"}}] + }, + {"content": "{'status': 'success', 'report': 'Sunny, 72F'}", "type": "tool"}, + {"content": "The weather is sunny.", "type": "ai"} + ] + + html = _format_multi_turn_conversation(conversation) + + # Verify structure + assert '
' in html + assert '
' in html + assert '
' in html + assert '
' in html + + # Verify tool call display + assert "tool-calls-container" in html + assert "tool-call-name" in html + assert "get_weather" in html + assert '"city": "NYC"' in html or "city" in html # JSON formatting + + # Verify labels + assert "HUMAN:" in html + assert "AI:" in html + assert "TOOL:" in html + + +def test_format_multi_turn_conversation_with_multiple_tool_calls(): + """Test formatting AI message with multiple tool calls""" + from visualize import _format_multi_turn_conversation + + conversation = [ + {"content": "Check weather and time", "type": "human"}, + { + "content": "", + "type": "ai", + "tool_calls": [ + {"name": "get_weather", "args": {"city": "NYC"}}, + {"name": "get_time", "args": {"city": "NYC"}} + ] + } + ] + + html = _format_multi_turn_conversation(conversation) + + # Should have multiple tool call boxes + assert html.count("tool-call-name") == 2 + assert "get_weather" in html + assert "get_time" in html + + +def test_prepare_chart_data_with_tool_calls(): + """Test prepare_chart_data handles tool calls in user_input""" + from visualize import prepare_chart_data, VisualizationData + + viz_data = VisualizationData( + overall_scores={"metric1": 0.85}, + individual_results=[ + { + "user_input": [ + {"content": "test", "type": "human"}, + {"content": "", "type": "ai", "tool_calls": [{"name": "tool1", "args": {}}]} + ], + "response": "", + "metric1": 0.85, + "trace_id": "trace1" + } + ], + total_tokens={"input_tokens": 100, "output_tokens": 50}, + total_cost=0.01, + metric_names=["metric1"] + ) + + chart_data = prepare_chart_data(viz_data) + + # Verify sample has is_multi_turn and formatted HTML + assert len(chart_data["samples"]) == 1 + sample = chart_data["samples"][0] + assert sample["is_multi_turn"] is True + assert "tool-call" in sample["user_input_formatted"] From a1cf684f046623b8d2237a89af30d7c99473052a Mon Sep 17 00:00:00 2001 From: Florian Mallmann Date: Fri, 9 Jan 2026 14:44:03 +0100 Subject: [PATCH 6/8] feat: Introduce dataset configuration --- .gitignore | 4 +++- Tiltfile | 1 - deploy/local/dataset.yaml | 11 +++++++++++ deploy/local/kustomization.yaml | 2 +- deploy/local/multi-turn-metrics-configmap.yaml | 13 ++----------- deploy/local/multi-turn-workflow.yaml | 16 +++++++++++----- scripts/visualize.py | 14 ++++++++++++-- 7 files changed, 40 insertions(+), 21 deletions(-) create mode 100644 deploy/local/dataset.yaml diff --git a/.gitignore b/.gitignore index 2665848..0702e3b 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,6 @@ go.work *.swo *~ -.env \ No newline at end of file +.env + +scripts/data \ No newline at end of file diff --git a/Tiltfile b/Tiltfile index f9db6c9..a781cd6 100644 --- a/Tiltfile +++ b/Tiltfile @@ -40,7 +40,6 @@ k8s_yaml(kustomize('deploy/local')) k8s_resource('ai-gateway-litellm', port_forwards=['11001:4000']) k8s_resource('weather-agent', port_forwards='11010:8000', labels=['agents'], resource_deps=['agent-runtime']) k8s_resource('lgtm', port_forwards=['11000:3000', '4318:4318']) -k8s_resource('data-server', port_forwards='11020:8000') # Declare Testkube resources k8s_kind( diff --git a/deploy/local/dataset.yaml b/deploy/local/dataset.yaml new file mode 100644 index 0000000..eb9b1f9 --- /dev/null +++ b/deploy/local/dataset.yaml @@ -0,0 +1,11 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: dataset + namespace: testkube +data: + dataset.jsonl: | + {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}, {"content": "What time is it in New York?", "type": "human"}], "reference_tool_calls": [{"name":"get_weather","args":{"city":"New York"}},{"name":"get_current_time","args":{"city":"New York"}}], "reference": "The agent should fulfill the user's request."} + {"user_input": [{"content": "What is the weather like in Bangkok right now?", "type": "human"}], "reference_tool_calls": [{"name":"get_weather","args":{"city":"New York"}}], "reference": "The agent should fulfill the user's request."} + {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}], "reference_tool_calls": [{"name":"get_current_time","args":{"city":"New York"}}], "reference": "The agent should fulfill the user's request."} + {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}], "reference_tool_calls": [{"name":"get_current_time","args":{"city":"Bangkok"}}], "reference": "The agent should fulfill the user's request."} \ No newline at end of file diff --git a/deploy/local/kustomization.yaml b/deploy/local/kustomization.yaml index 0901905..7c5484d 100644 --- a/deploy/local/kustomization.yaml +++ b/deploy/local/kustomization.yaml @@ -3,7 +3,7 @@ kind: Kustomization resources: - lgtm/ - weather-agent.yaml - - data-server/ + - dataset.yaml - ../base - multi-turn-metrics-configmap.yaml - multi-turn-workflow.yaml diff --git a/deploy/local/multi-turn-metrics-configmap.yaml b/deploy/local/multi-turn-metrics-configmap.yaml index 50ace29..1e6d874 100644 --- a/deploy/local/multi-turn-metrics-configmap.yaml +++ b/deploy/local/multi-turn-metrics-configmap.yaml @@ -10,17 +10,8 @@ data: metrics.yaml: | version: "1.0" metrics: - # Custom AspectCritic for evaluating harmful content - type: class - class_name: AspectCritic - parameters: - name: harmfulness - definition: | - Does this submission contain harmful, offensive, or toxic content that could cause harm to individuals or groups? + class_name: ToolCallAccuracy - # Custom AspectCritic for evaluating correctness - type: class - class_name: AspectCritic - parameters: - name: correctness - definition: Is the submission factually accurate and free from errors? + class_name: AgentGoalAccuracyWithoutReference diff --git a/deploy/local/multi-turn-workflow.yaml b/deploy/local/multi-turn-workflow.yaml index 2183904..7dd99f4 100644 --- a/deploy/local/multi-turn-workflow.yaml +++ b/deploy/local/multi-turn-workflow.yaml @@ -14,6 +14,9 @@ spec: - name: metrics-config configMap: name: multi-turn-metrics-config + - name: dataset + configMap: + name: dataset container: image: ghcr.io/agentic-layer/testbench/testworkflows:latest @@ -26,15 +29,18 @@ spec: - name: metrics-config mountPath: /app/config/metrics.yaml subPath: metrics.yaml + - name: dataset + mountPath: /data/datasets/ragas_dataset.jsonl + subPath: dataset.jsonl # Steps using the templates steps: # Step 1: Setup - Download and convert dataset - - name: setup - use: - - name: ragas-setup-template - config: - datasetUrl: "http://data-server.data-server:8000/dataset.csv" +# - name: setup +# use: +# - name: ragas-setup-template +# config: +# datasetUrl: "http://data-server.data-server:8000/dataset.csv" # Step 2: Run - Execute agent queries - name: run diff --git a/scripts/visualize.py b/scripts/visualize.py index b05891f..34d8e5d 100644 --- a/scripts/visualize.py +++ b/scripts/visualize.py @@ -758,6 +758,9 @@ def generate_samples_table_html(chart_data: dict[str, Any]) -> str: if chart_data["samples"] and chart_data["samples"][0]["metrics"]: metric_names = sorted(chart_data["samples"][0]["metrics"].keys()) + # Check if any sample has response data + has_responses = any(sample.get("response") for sample in chart_data["samples"]) + # Generate table header html = """
@@ -771,9 +774,12 @@ def generate_samples_table_html(chart_data: dict[str, Any]) -> str: # User Input - Response """ + # Add Response column header only if there's response data + if has_responses: + html += " Response\n" + # Add metric columns for metric_name in metric_names: html += f" {metric_name}\n" @@ -817,9 +823,13 @@ def generate_samples_table_html(chart_data: dict[str, Any]) -> str: html += f""" {sample["index"]} {user_input_display} - {sample["response"]} """ + # Add response cell only if we have response data + if has_responses: + response = sample.get("response", "") + html += f' {response}\n' + # Add metric values for metric_name in metric_names: score = sample["metrics"].get(metric_name) From 29b7124781d3f1081bd82be64bcb2bb6f046fb44 Mon Sep 17 00:00:00 2001 From: Florian Mallmann Date: Fri, 9 Jan 2026 16:07:20 +0100 Subject: [PATCH 7/8] feat: Update setup.py to read from testkubes minio instead of a dataserver. Remove dataserver. --- deploy/base/templates/evaluate-template.yaml | 3 + deploy/base/templates/setup-template.yaml | 21 +-- deploy/local/data-server/configmap.yaml | 19 -- deploy/local/data-server/deployment.yaml | 49 ------ deploy/local/data-server/kustomization.yaml | 8 - deploy/local/data-server/namespace.yaml | 4 - deploy/local/data-server/service.yaml | 15 -- deploy/local/multi-turn-workflow.yaml | 23 +-- pyproject.toml | 5 +- .../expected_ragas_experiment.jsonl | 1 - scripts/setup.py | 91 ++++++++-- uv.lock | 165 ++++++++++++++++-- 12 files changed, 252 insertions(+), 152 deletions(-) delete mode 100644 deploy/local/data-server/configmap.yaml delete mode 100644 deploy/local/data-server/deployment.yaml delete mode 100644 deploy/local/data-server/kustomization.yaml delete mode 100644 deploy/local/data-server/namespace.yaml delete mode 100644 deploy/local/data-server/service.yaml delete mode 100644 scripts/data/experiments/expected_ragas_experiment.jsonl diff --git a/deploy/base/templates/evaluate-template.yaml b/deploy/base/templates/evaluate-template.yaml index 46716bf..dd563cf 100644 --- a/deploy/base/templates/evaluate-template.yaml +++ b/deploy/base/templates/evaluate-template.yaml @@ -24,6 +24,9 @@ spec: # Steps to execute steps: - name: evaluate-results + artifacts: + paths: + - "data/results/evaluation_scores.json" run: command: - sh diff --git a/deploy/base/templates/setup-template.yaml b/deploy/base/templates/setup-template.yaml index 7f88a56..ad7247a 100644 --- a/deploy/base/templates/setup-template.yaml +++ b/deploy/base/templates/setup-template.yaml @@ -10,9 +10,12 @@ metadata: spec: # Configuration parameters that can be overridden config: - datasetUrl: + bucket: type: string - description: "URL to the dataset file (.csv, .json, or .parquet)" + description: "S3/MinIO bucket name containing the dataset" + key: + type: string + description: "S3/MinIO object key (path to dataset file in .csv / .json / .parquet format)" # Steps to execute steps: @@ -21,15 +24,7 @@ spec: paths: - "data/datasets/ragas_dataset.jsonl" run: - command: - - sh - - -c args: - - | - uv run python3 setup.py "{{ config.datasetUrl }}" && \ - if [ -f data/datasets/ragas_dataset.jsonl ]; then - echo "✓ Dataset created: $(wc -l < data/datasets/ragas_dataset.jsonl) lines" - else - echo "✗ Error: Dataset file not created" - exit 1 - fi + - setup.py + - "{{ config.bucket }}" + - "{{ config.key }}" diff --git a/deploy/local/data-server/configmap.yaml b/deploy/local/data-server/configmap.yaml deleted file mode 100644 index 9efd807..0000000 --- a/deploy/local/data-server/configmap.yaml +++ /dev/null @@ -1,19 +0,0 @@ -apiVersion: v1 -kind: ConfigMap -metadata: - name: data-server-data -data: - dataset.json: | - [ - {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}, {"content": "What time is it in New York?", "type": "human"}], "reference_tool_calls": [{"name":"get_weather","args":{"city":"New York"}},{"name":"get_current_time","args":{"city":"New York"}}], "reference": "The agent should fulfill the user's request."}, - {"user_input": [{"content": "What is the weather like in Bangkok right now?", "type": "human"}], "reference_tool_calls": [{"name":"get_weather","args":{"city":"New York"}}], "reference": "The agent should fulfill the user's request."}, - {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}], "reference_tool_calls": [{"name":"get_current_time","args":{"city":"New York"}}], "reference": "The agent should fulfill the user's request."}, - {"user_input": [{"content": "What is the weather like in New York right now?", "type": "human"}], "reference_tool_calls": [{"name":"get_current_time","args":{"city":"Bangkok"}}], "reference": "The agent should fulfill the user's request."} - ] - dataset.csv: | - user_input,retrieved_contexts,reference - "What is the weather like in New York right now?","The answer must state the current temperature in New York, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain).","The answer must state the current temperature in New York, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain)." - "What is the current time in New York?","The answer must state the current time in New York in HH:MM format and include the correct timezone abbreviation (e.g., CST).","The answer must state the current time in New York in HH:MM format and include the correct timezone abbreviation (e.g., CST)." - "What is the weather like in Cairo?","The answer must state the current temperature in Cairo, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain).","The answer must state the current temperature in Cairo, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain)." - "How is the weather in Sydney?","The answer must state the current temperature in Sydney, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain).","The answer must state the current temperature in Sydney, specify the units (Celsius or Fahrenheit), and describe the current weather condition (e.g., Sunny, Cloudy, Rain)." - "Time in Garching?","The answer must state the current time in Garching, Germany in HH:MM format and include the correct timezone abbreviation (CEST).","The answer must state the current time in Garching, Germany in HH:MM format and include the correct timezone abbreviation (CEST)." diff --git a/deploy/local/data-server/deployment.yaml b/deploy/local/data-server/deployment.yaml deleted file mode 100644 index ea458c5..0000000 --- a/deploy/local/data-server/deployment.yaml +++ /dev/null @@ -1,49 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: data-server - labels: - app: data-server -spec: - replicas: 1 - selector: - matchLabels: - app: data-server - template: - metadata: - labels: - app: data-server - spec: - containers: - - name: data-server - image: python:3.11-slim - command: - - python3 - - -m - - http.server - - "8000" - - --directory - - /data - ports: - - containerPort: 8000 - name: http - protocol: TCP - volumeMounts: - - name: data-volume - mountPath: /data - livenessProbe: - httpGet: - path: / - port: 8000 - initialDelaySeconds: 10 - periodSeconds: 10 - readinessProbe: - httpGet: - path: / - port: 8000 - initialDelaySeconds: 5 - periodSeconds: 5 - volumes: - - name: data-volume - configMap: - name: data-server-data diff --git a/deploy/local/data-server/kustomization.yaml b/deploy/local/data-server/kustomization.yaml deleted file mode 100644 index 597670f..0000000 --- a/deploy/local/data-server/kustomization.yaml +++ /dev/null @@ -1,8 +0,0 @@ -apiVersion: kustomize.config.k8s.io/v1beta1 -kind: Kustomization -namespace: data-server -resources: - - namespace.yaml - - configmap.yaml - - deployment.yaml - - service.yaml diff --git a/deploy/local/data-server/namespace.yaml b/deploy/local/data-server/namespace.yaml deleted file mode 100644 index b6f380e..0000000 --- a/deploy/local/data-server/namespace.yaml +++ /dev/null @@ -1,4 +0,0 @@ -apiVersion: v1 -kind: Namespace -metadata: - name: data-server diff --git a/deploy/local/data-server/service.yaml b/deploy/local/data-server/service.yaml deleted file mode 100644 index 63ee265..0000000 --- a/deploy/local/data-server/service.yaml +++ /dev/null @@ -1,15 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: data-server - labels: - app: data-server -spec: - type: ClusterIP - ports: - - port: 8000 - targetPort: 8000 - protocol: TCP - name: http - selector: - app: data-server diff --git a/deploy/local/multi-turn-workflow.yaml b/deploy/local/multi-turn-workflow.yaml index 7dd99f4..085e23f 100644 --- a/deploy/local/multi-turn-workflow.yaml +++ b/deploy/local/multi-turn-workflow.yaml @@ -14,9 +14,9 @@ spec: - name: metrics-config configMap: name: multi-turn-metrics-config - - name: dataset - configMap: - name: dataset +# - name: dataset +# configMap: +# name: dataset container: image: ghcr.io/agentic-layer/testbench/testworkflows:latest @@ -29,18 +29,19 @@ spec: - name: metrics-config mountPath: /app/config/metrics.yaml subPath: metrics.yaml - - name: dataset - mountPath: /data/datasets/ragas_dataset.jsonl - subPath: dataset.jsonl +# - name: dataset +# mountPath: /data/datasets/ragas_dataset.jsonl +# subPath: dataset.jsonl # Steps using the templates steps: # Step 1: Setup - Download and convert dataset -# - name: setup -# use: -# - name: ragas-setup-template -# config: -# datasetUrl: "http://data-server.data-server:8000/dataset.csv" + - name: setup + use: + - name: ragas-setup-template + config: + bucket: "datasets" + key: "multi_turn_dataset.json" # Step 2: Run - Execute agent queries - name: run diff --git a/pyproject.toml b/pyproject.toml index 69884a5..80bebc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ requires-python = ">=3.12" dependencies = [ "a2a>=0.44", "a2a-sdk>=0.3.10", + "boto3>=1.26.0", "httpx>=0.28.1", "langchain-openai>=1.0.2", "nest-asyncio>=1.6.0", @@ -14,8 +15,7 @@ dependencies = [ "pyarrow>=21.0.0", "python-dotenv>=1.0.0", "ragas[ag-ui]>=0.4.1", - "requests>=2.32.5", - "types-requests>=2.32.0", + "types-boto3>=1.0.2", "opentelemetry-api>=1.20.0", "opentelemetry-sdk>=1.20.0", "opentelemetry-exporter-otlp-proto-http>=1.20.0", @@ -27,6 +27,7 @@ dev = [ "bandit[toml]>=1.7.8", "mypy>=1.17.0", "import-linter>=2.0", + "moto[s3]>=5.0.0", "poethepoet>=0.31.1", "ruff>=0.12.7", "pytest>=7.4.0", diff --git a/scripts/data/experiments/expected_ragas_experiment.jsonl b/scripts/data/experiments/expected_ragas_experiment.jsonl deleted file mode 100644 index 2ee28c9..0000000 --- a/scripts/data/experiments/expected_ragas_experiment.jsonl +++ /dev/null @@ -1 +0,0 @@ -{"user_input":[{"content":"What is the weather like in New York right now?","type":"human"},{"content":"The weather is 25 degrees.","type":"agent"},{"content":"What time is it in New York?","type":"human"},{"content":"It is 11:49.","type":"agent"}],"reference":null,"trace_id":"5ee59682c1477b74b568078d477ad62d"} diff --git a/scripts/setup.py b/scripts/setup.py index 1f4f3e1..1642ede 100644 --- a/scripts/setup.py +++ b/scripts/setup.py @@ -1,15 +1,22 @@ import argparse +import logging +import os from io import BytesIO from pathlib import Path from typing import Callable +import boto3 +from botocore.client import Config import pandas as pd -import requests +from botocore.exceptions import ClientError, NoCredentialsError, PartialCredentialsError from pandas import DataFrame from ragas import Dataset -from requests import Response +# Set up module-level logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + def dataframe_to_ragas_dataset(dataframe: DataFrame) -> None: """Convert DataFrame to Ragas dataset and save to data/ragas_dataset.jsonl. @@ -34,9 +41,9 @@ def dataframe_to_ragas_dataset(dataframe: DataFrame) -> None: dataset.save() -def get_converter(url: str) -> Callable[[BytesIO], DataFrame]: - """Extract the file format from the URL and return the converter function""" - suffix = Path(url).suffix.lower() +def get_converter(key: str) -> Callable[[BytesIO], DataFrame]: + """Extract the file format from the S3 key suffix and return the converter function""" + suffix = Path(key).suffix.lower() format_map: dict[str, Callable[[BytesIO], DataFrame]] = { ".json": pd.read_json, @@ -48,7 +55,7 @@ def get_converter(url: str) -> Callable[[BytesIO], DataFrame]: if suffix in format_map: return format_map[suffix] - raise TypeError(f"Unsupported filetype at url: {url}") + raise TypeError(f"Unsupported filetype for key: {key}. Must end with .csv, .json, .parquet, or .prq") def custom_convert_csv(input_file: BytesIO) -> DataFrame: @@ -74,35 +81,81 @@ def custom_convert_csv(input_file: BytesIO) -> DataFrame: return dataframe -def main(url: str) -> None: - """Download provided dataset -> convert to Ragas dataset -> save to data/ragas_dataset.jsonl +def create_s3_client() -> boto3.client: + """Create and configure S3 client for MinIO""" + # Get MinIO credentials from environment + access_key = os.getenv("MINIO_ROOT_USER", "minio") + secret_key = os.getenv("MINIO_ROOT_PASSWORD", "minio123") + endpoint_url = os.getenv("MINIO_ENDPOINT", "http://testkube-minio-service-testkube.testkube:9000") + + logger.info(f"Connecting to MinIO at {endpoint_url}") + + # Create S3 client with MinIO configuration + s3_client = boto3.client( + "s3", + endpoint_url=endpoint_url, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + config=Config(signature_version="s3v4"), + region_name="us-east-1", # MinIO doesn't care about region, but boto3 requires it + ) + + return s3_client + + +def main(bucket: str, key: str) -> None: + """Download dataset from S3/MinIO -> convert to Ragas dataset -> save to data/datasets/ragas_dataset.jsonl Source dataset must contain columns: user_input, retrieved_contexts, reference + + Args: + bucket: S3 bucket name + key: S3 object key (path to dataset file) """ - converter = get_converter(url) + converter = get_converter(key) + + # Create S3 client + s3_client = create_s3_client() - # Download file from URL and raise HTTP error if it occurs - file: Response = requests.get(url, timeout=20) - file.raise_for_status() + # Download file from S3 + logger.info(f"Downloading from bucket '{bucket}', key '{key}'...") + try: + response = s3_client.get_object(Bucket=bucket, Key=key) + file_content = response["Body"].read() + logger.info(f"Downloaded {len(file_content)} bytes") + except Exception as e: + logger.error(f"Failed to download from S3: {e}") + raise # Load into DataFrame by using the correct converter - buffer = BytesIO(file.content) + logger.info("Converting to DataFrame...") + buffer = BytesIO(file_content) + dataframe = converter(buffer) + logger.info(f"Loaded {len(dataframe)} rows") # Convert DataFrame to Ragas dataset and save it + logger.info("Converting to Ragas dataset...") dataframe_to_ragas_dataset(dataframe) + logger.info("✓ Dataset saved successfully to data/ragas_dataset.jsonl") if __name__ == "__main__": - # Parse parameter the script was called with (URL) + # Parse parameters: bucket and key parser = argparse.ArgumentParser( - description="Download provided dataset -> convert to Ragas dataset -> save to data/datasets/ragas_dataset.jsonl" + description="Download dataset from S3/MinIO -> convert to Ragas dataset -> save to data/datasets/ragas_dataset.jsonl" + ) + parser.add_argument( + "bucket", + type=str, + help="S3/MinIO bucket name containing the dataset", ) parser.add_argument( - "url", - help="URL to the dataset in .csv / .json / .parquet format (must have user_input, retrieved_contexts, and reference columns)", + "key", + type=str, + help="S3/MinIO object key (path to dataset file in .csv / .json / .parquet format)", ) args = parser.parse_args() - # Call main using the parsed URL - main(args.url) + # Call main using the parsed bucket and key + main(args.bucket, args.key) diff --git a/uv.lock b/uv.lock index d70a19d..20bb3bd 100644 --- a/uv.lock +++ b/uv.lock @@ -170,6 +170,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/ca/ba5f909b40ea12ec542d5d7bdd13ee31c4d65f3beed20211ef81c18fa1f3/bandit-1.8.6-py3-none-any.whl", hash = "sha256:3348e934d736fcdb68b6aa4030487097e23a501adf3e7827b63658df464dddd0", size = 133808, upload-time = "2025-07-06T03:10:49.134Z" }, ] +[[package]] +name = "boto3" +version = "1.42.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ee/21/8be0e3685c3a4868be48d8d2f6e5b4641727e1d8a5d396b8b401d2b5f06e/boto3-1.42.24.tar.gz", hash = "sha256:c47a2f40df933e3861fc66fd8d6b87ee36d4361663a7e7ba39a87f5a78b2eae1", size = 112788, upload-time = "2026-01-07T20:30:51.019Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/75/bbfccb268f9faa4f59030888e859dca9797a980b77d6a074113af73bd4bf/boto3-1.42.24-py3-none-any.whl", hash = "sha256:8ed6ad670a5a2d7f66c1b0d3362791b48392c7a08f78479f5d8ab319a4d9118f", size = 140572, upload-time = "2026-01-07T20:30:49.431Z" }, +] + +[[package]] +name = "botocore" +version = "1.42.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jmespath" }, + { name = "python-dateutil" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/12/d7/bb4a4e839b238ffb67b002d7326b328ebe5eb23ed5180f2ca10399a802de/botocore-1.42.24.tar.gz", hash = "sha256:be8d1bea64fb91eea08254a1e5fea057e4428d08e61f4e11083a02cafc1f8cc6", size = 14878455, upload-time = "2026-01-07T20:30:40.379Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/d4/f2655d777eed8b069ecab3761454cb83f830f8be8b5b0d292e4b3a980d00/botocore-1.42.24-py3-none-any.whl", hash = "sha256:8fca9781d7c84f7ad070fceffaff7179c4aa7a5ffb27b43df9d1d957801e0a8d", size = 14551806, upload-time = "2026-01-07T20:30:38.103Z" }, +] + +[[package]] +name = "botocore-stubs" +version = "1.42.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "types-awscrt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/61/5715ec25b3dcb2a08133811f6a18a9ca9be54567452ab3e92cadcaec746e/botocore_stubs-1.42.24.tar.gz", hash = "sha256:f5fbe240267b27036b1217a304de34bf2bf993087e049a300d17d6f52d77988b", size = 42415, upload-time = "2026-01-07T21:27:03.862Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/6b/cffb62a7872ba32e08c22c9c918a4d5d1d39ed6d74195bf50a3ae75a22f3/botocore_stubs-1.42.24-py3-none-any.whl", hash = "sha256:025999e68f419472cc8dfb7bcc2964fa0a06b447f43e7fc309012ff4c665b3db", size = 66762, upload-time = "2026-01-07T21:27:02.249Z" }, +] + [[package]] name = "cachetools" version = "6.2.1" @@ -1284,6 +1324,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] +[[package]] +name = "moto" +version = "5.1.19" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "boto3" }, + { name = "botocore" }, + { name = "cryptography" }, + { name = "jinja2" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "responses" }, + { name = "werkzeug" }, + { name = "xmltodict" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/45/eb/100a04d1b49859d05a9c701815117cd31bc436c3d9e959d399d9d2ff7e9c/moto-5.1.19.tar.gz", hash = "sha256:a13423e402366b6affab07ed28e1df5f3fcc54ef68fc8d83dc9f824da7a4024e", size = 8361592, upload-time = "2025-12-28T20:14:57.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/07/5ca7ba79615b88ee2325224894667f263b992d266a52b83d215c4b3caa39/moto-5.1.19-py3-none-any.whl", hash = "sha256:7adb0caacf0e2d0dbb09550bcb49a7f158ee7c460a09cb54d4599a9a94cfef70", size = 6451569, upload-time = "2025-12-28T20:14:54.701Z" }, +] + +[package.optional-dependencies] +s3 = [ + { name = "py-partiql-parser" }, + { name = "pyyaml" }, +] + [[package]] name = "multidict" version = "6.6.4" @@ -1996,6 +2062,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/97/b7/15cc7d93443d6c6a84626ae3258a91f4c6ac8c0edd5df35ea7658f71b79c/protobuf-6.32.1-py3-none-any.whl", hash = "sha256:2601b779fc7d32a866c6b4404f9d42a3f67c5b9f3f15b4db3cccabe06b95c346", size = 169289, upload-time = "2025-09-11T21:38:41.234Z" }, ] +[[package]] +name = "py-partiql-parser" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/7a/a0f6bda783eb4df8e3dfd55973a1ac6d368a89178c300e1b5b91cd181e5e/py_partiql_parser-0.6.3.tar.gz", hash = "sha256:09cecf916ce6e3da2c050f0cb6106166de42c33d34a078ec2eb19377ea70389a", size = 17456, upload-time = "2025-10-18T13:56:13.441Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/33/a7cbfccc39056a5cf8126b7aab4c8bafbedd4f0ca68ae40ecb627a2d2cd3/py_partiql_parser-0.6.3-py2.py3-none-any.whl", hash = "sha256:deb0769c3346179d2f590dcbde556f708cdb929059fb654bad75f4cf6e07f582", size = 23752, upload-time = "2025-10-18T13:56:12.256Z" }, +] + [[package]] name = "pyarrow" version = "21.0.0" @@ -2424,6 +2499,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" }, ] +[[package]] +name = "responses" +version = "0.25.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyyaml" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/95/89c054ad70bfef6da605338b009b2e283485835351a9935c7bfbfaca7ffc/responses-0.25.8.tar.gz", hash = "sha256:9374d047a575c8f781b94454db5cab590b6029505f488d12899ddb10a4af1cf4", size = 79320, upload-time = "2025-08-08T19:01:46.709Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/4c/cc276ce57e572c102d9542d383b2cfd551276581dc60004cb94fe8774c11/responses-0.25.8-py3-none-any.whl", hash = "sha256:0c710af92def29c8352ceadff0c3fe340ace27cf5af1bbe46fb71275bcd2831c", size = 34769, upload-time = "2025-08-08T19:01:45.018Z" }, +] + [[package]] name = "rich" version = "14.1.0" @@ -2475,6 +2564,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/73/4de6579bac8e979fca0a77e54dec1f1e011a0d268165eb8a9bc0982a6564/ruff-0.14.3-py3-none-win_arm64.whl", hash = "sha256:26eb477ede6d399d898791d01961e16b86f02bc2486d0d1a7a9bb2379d055dc1", size = 12590017, upload-time = "2025-10-31T00:26:24.52Z" }, ] +[[package]] +name = "s3transfer" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, +] + [[package]] name = "scikit-network" version = "0.33.3" @@ -2701,6 +2802,7 @@ source = { virtual = "." } dependencies = [ { name = "a2a" }, { name = "a2a-sdk" }, + { name = "boto3" }, { name = "httpx" }, { name = "langchain-openai" }, { name = "nest-asyncio" }, @@ -2713,14 +2815,14 @@ dependencies = [ { name = "pyarrow" }, { name = "python-dotenv" }, { name = "ragas" }, - { name = "requests" }, - { name = "types-requests" }, + { name = "types-boto3" }, ] [package.dev-dependencies] dev = [ { name = "bandit" }, { name = "import-linter" }, + { name = "moto", extra = ["s3"] }, { name = "mypy" }, { name = "poethepoet" }, { name = "pytest" }, @@ -2732,6 +2834,7 @@ dev = [ requires-dist = [ { name = "a2a", specifier = ">=0.44" }, { name = "a2a-sdk", specifier = ">=0.3.10" }, + { name = "boto3", specifier = ">=1.26.0" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "langchain-openai", specifier = ">=1.0.2" }, { name = "nest-asyncio", specifier = ">=1.6.0" }, @@ -2744,14 +2847,14 @@ requires-dist = [ { name = "pyarrow", specifier = ">=21.0.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "ragas", extras = ["ag-ui"], specifier = ">=0.4.1" }, - { name = "requests", specifier = ">=2.32.5" }, - { name = "types-requests", specifier = ">=2.32.0" }, + { name = "types-boto3", specifier = ">=1.0.2" }, ] [package.metadata.requires-dev] dev = [ { name = "bandit", extras = ["toml"], specifier = ">=1.7.8" }, { name = "import-linter", specifier = ">=2.0" }, + { name = "moto", extras = ["s3"], specifier = ">=5.0.0" }, { name = "mypy", specifier = ">=1.17.0" }, { name = "poethepoet", specifier = ">=0.31.1" }, { name = "pytest", specifier = ">=7.4.0" }, @@ -2843,6 +2946,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/22/35617eee79080a5d071d0f14ad698d325ee6b3bf824fc0467c03b30e7fa8/typer-0.19.2-py3-none-any.whl", hash = "sha256:755e7e19670ffad8283db353267cb81ef252f595aa6834a0d1ca9312d9326cb9", size = 46748, upload-time = "2025-09-23T09:47:46.777Z" }, ] +[[package]] +name = "types-awscrt" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/1f/febd2df22e24f77b759db0dd9ecdd7f07f055e6a4dbbb699c5eb34b617ef/types_awscrt-0.30.0.tar.gz", hash = "sha256:362fd8f5eaebcfcd922cb9fd8274fb375df550319f78031ee3779eac0b9ecc79", size = 17761, upload-time = "2025-12-12T01:55:59.626Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/5f/15999051fca2949a67562c3f80fae2dd5d3404a3f97b326b614533843281/types_awscrt-0.30.0-py3-none-any.whl", hash = "sha256:8204126e01a00eaa4a746e7a0076538ca0e4e3f52408adec0ab9b471bb0bb64b", size = 42392, upload-time = "2025-12-12T01:55:58.194Z" }, +] + +[[package]] +name = "types-boto3" +version = "1.42.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore-stubs" }, + { name = "types-s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/8d/a0052576bab66a0632e2f7f26ebbcd98eeaf17d8b8fc15f4c19b7ec3df82/types_boto3-1.42.24.tar.gz", hash = "sha256:7b982a7ddbe1cfb153c5bd5442c5b394562adcac4dd6d1df3bc6f68f3f11f1d6", size = 101257, upload-time = "2026-01-07T20:41:13.892Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/b7/71eddff9cd4191e0a23a8dcd2538a99f5aa6e8562f050f76ab0a6e9bbc4c/types_boto3-1.42.24-py3-none-any.whl", hash = "sha256:0f1edc99c9cc7b5e6a7dc0003d1d9831dec81fd9f66a550d0979c577742e8956", size = 69676, upload-time = "2026-01-07T20:41:10.734Z" }, +] + [[package]] name = "types-pytz" version = "2025.2.0.20250809" @@ -2853,15 +2978,12 @@ wheels = [ ] [[package]] -name = "types-requests" -version = "2.32.4.20250913" +name = "types-s3transfer" +version = "0.16.0" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "urllib3" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/36/27/489922f4505975b11de2b5ad07b4fe1dca0bca9be81a703f26c5f3acfce5/types_requests-2.32.4.20250913.tar.gz", hash = "sha256:abd6d4f9ce3a9383f269775a9835a4c24e5cd6b9f647d64f88aa4613c33def5d", size = 23113, upload-time = "2025-09-13T02:40:02.309Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/64/42689150509eb3e6e82b33ee3d89045de1592488842ddf23c56957786d05/types_s3transfer-0.16.0.tar.gz", hash = "sha256:b4636472024c5e2b62278c5b759661efeb52a81851cde5f092f24100b1ecb443", size = 13557, upload-time = "2025-12-08T08:13:09.928Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/20/9a227ea57c1285986c4cf78400d0a91615d25b24e257fd9e2969606bdfae/types_requests-2.32.4.20250913-py3-none-any.whl", hash = "sha256:78c9c1fffebbe0fa487a418e0fa5252017e9c60d1a2da394077f1780f655d7e1", size = 20658, upload-time = "2025-09-13T02:40:01.115Z" }, + { url = "https://files.pythonhosted.org/packages/98/27/e88220fe6274eccd3bdf95d9382918716d312f6f6cef6a46332d1ee2feff/types_s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:1c0cd111ecf6e21437cb410f5cddb631bfb2263b77ad973e79b9c6d0cb24e0ef", size = 19247, upload-time = "2025-12-08T08:13:08.426Z" }, ] [[package]] @@ -2925,6 +3047,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/58/dd/56f0d8af71e475ed194d702f8b4cf9cea812c95e82ad823d239023c6558c/w3lib-2.3.1-py3-none-any.whl", hash = "sha256:9ccd2ae10c8c41c7279cd8ad4fe65f834be894fe7bfdd7304b991fd69325847b", size = 21751, upload-time = "2025-01-27T14:22:09.421Z" }, ] +[[package]] +name = "werkzeug" +version = "3.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" }, +] + [[package]] name = "wrapt" version = "1.17.3" @@ -2974,6 +3108,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/f6/a933bd70f98e9cf3e08167fc5cd7aaaca49147e48411c0bd5ae701bb2194/wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22", size = 23591, upload-time = "2025-08-12T05:53:20.674Z" }, ] +[[package]] +name = "xmltodict" +version = "1.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/aa/917ceeed4dbb80d2f04dbd0c784b7ee7bba8ae5a54837ef0e5e062cd3cfb/xmltodict-1.0.2.tar.gz", hash = "sha256:54306780b7c2175a3967cad1db92f218207e5bc1aba697d887807c0fb68b7649", size = 25725, upload-time = "2025-09-17T21:59:26.459Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/20/69a0e6058bc5ea74892d089d64dfc3a62ba78917ec5e2cfa70f7c92ba3a5/xmltodict-1.0.2-py3-none-any.whl", hash = "sha256:62d0fddb0dcbc9f642745d8bbf4d81fd17d6dfaec5a15b5c1876300aad92af0d", size = 13893, upload-time = "2025-09-17T21:59:24.859Z" }, +] + [[package]] name = "xxhash" version = "3.5.0" From e119b4716a2b7552c51e43015619f1e0b5f6f9fa Mon Sep 17 00:00:00 2001 From: Florian Mallmann Date: Mon, 12 Jan 2026 14:29:50 +0100 Subject: [PATCH 8/8] feat: Improve type hinting and refactor conversation message handling for multi-turn interactions --- .github/workflows/ci.yml | 8 +- Tiltfile | 8 ++ deploy/local/multi-turn-workflow.yaml | 24 ++--- scripts/evaluate.py | 4 +- scripts/run.py | 49 ++++----- scripts/setup.py | 11 +- scripts/visualize.py | 12 +-- tests/test_run.py | 147 ++++++++++---------------- tests/test_setup.py | 76 ++++++------- tests/test_visualize.py | 22 ++-- tests_e2e/test_e2e.py | 2 +- 11 files changed, 154 insertions(+), 209 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 22cb5bb..5579dc1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -136,10 +136,4 @@ jobs: - name: Run Test Workflow run: | - testkube run testworkflow ragas-evaluation-workflow \ - --config datasetUrl="http://data-server.data-server:8000/dataset.csv" \ - --config agentUrl="http://agent-gateway-krakend.agent-gateway-krakend:10000/weather-agent" \ - --config metrics="nv_accuracy context_recall" \ - --config image="${{ steps.extract-tag.outputs.image-tag }}" \ - -n testkube \ - --watch + testkube run testworkflow multi-turn-workflow --watch diff --git a/Tiltfile b/Tiltfile index a781cd6..e5f4245 100644 --- a/Tiltfile +++ b/Tiltfile @@ -25,6 +25,14 @@ v1alpha1.extension(name='agent-gateway-krakend', repo_name='agentic-layer', repo load('ext://agent-gateway-krakend', 'agent_gateway_krakend_install') agent_gateway_krakend_install(version='0.4.1') +# Pre-create testkube namespace to avoid race condition with kustomize resources +k8s_yaml(blob(''' +apiVersion: v1 +kind: Namespace +metadata: + name: testkube +''')) + load('ext://helm_resource', 'helm_resource') helm_resource( 'testkube', diff --git a/deploy/local/multi-turn-workflow.yaml b/deploy/local/multi-turn-workflow.yaml index 085e23f..1467974 100644 --- a/deploy/local/multi-turn-workflow.yaml +++ b/deploy/local/multi-turn-workflow.yaml @@ -14,9 +14,9 @@ spec: - name: metrics-config configMap: name: multi-turn-metrics-config -# - name: dataset -# configMap: -# name: dataset + - name: dataset + configMap: + name: dataset container: image: ghcr.io/agentic-layer/testbench/testworkflows:latest @@ -29,19 +29,19 @@ spec: - name: metrics-config mountPath: /app/config/metrics.yaml subPath: metrics.yaml -# - name: dataset -# mountPath: /data/datasets/ragas_dataset.jsonl -# subPath: dataset.jsonl + - name: dataset + mountPath: /data/datasets/ragas_dataset.jsonl + subPath: dataset.jsonl # Steps using the templates steps: # Step 1: Setup - Download and convert dataset - - name: setup - use: - - name: ragas-setup-template - config: - bucket: "datasets" - key: "multi_turn_dataset.json" +# - name: setup +# use: +# - name: ragas-setup-template +# config: +# bucket: "datasets" +# key: "multi_turn_dataset.json" # Step 2: Run - Execute agent queries - name: run diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 8d6051a..ac24711 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -154,7 +154,7 @@ def load_from_config(self, config_path: str) -> list[Metric]: config = json.load(f) elif config_path.endswith((".yaml", ".yml")): try: - import yaml + import yaml # type: ignore[import-untyped] except ImportError: raise ValueError( "YAML support requires 'pyyaml' package.\n" @@ -441,7 +441,7 @@ def main( "--metrics-config", type=str, default="config/metrics.json", - help="Path to metrics configuration file (JSON or YAML). Default: examples/metrics_simple.json", + help="Path to metrics configuration file (JSON or YAML). Default: config/metrics.json", ) parser.add_argument( diff --git a/scripts/run.py b/scripts/run.py index 9e7a0b7..238968a 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -2,6 +2,7 @@ import asyncio import logging from logging import Logger +from typing import Any, cast from uuid import uuid4 import httpx @@ -217,7 +218,7 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict raise ValueError("No human messages found in user_input") context_id = None - conversation_messages = [] + conversation_messages: list[dict[str, Any]] = [] seen_message_ids = set() # Track message_ids to avoid duplicates across all turns # Sequentially query agent for each human turn @@ -229,9 +230,10 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict turn_span.set_attribute("turn.content", human_msg["content"]) # Create A2A message + parts: list[Part] = [Part(root=TextPart(text=human_msg["content"]))] message = Message( role=Role.user, - parts=[TextPart(text=human_msg["content"])], + parts=parts, message_id=uuid4().hex, context_id=context_id, # None for first turn, preserved after ) @@ -253,7 +255,7 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict span.set_attribute("conversation.context_id", context_id) # Process this turn's history immediately - if turn_task and hasattr(turn_task, 'history') and turn_task.history: + if turn_task and hasattr(turn_task, "history") and turn_task.history: for msg in turn_task.history: # Skip duplicate messages if msg.message_id in seen_message_ids: @@ -292,34 +294,36 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict text_content = actual_part.text # Check for DataPart (tool calls or responses) - elif (hasattr(actual_part, "kind") and actual_part.kind == "data" and - hasattr(actual_part, "data") and isinstance(actual_part.data, dict) and - "name" in actual_part.data): - + elif ( + hasattr(actual_part, "kind") + and actual_part.kind == "data" + and hasattr(actual_part, "data") + and isinstance(actual_part.data, dict) + and "name" in actual_part.data + ): # Tool call: has args, not response if "args" in actual_part.data and "response" not in actual_part.data: - tool_calls_in_msg.append({ - "name": actual_part.data.get("name"), - "args": actual_part.data.get("args", {}) - }) + tool_calls_in_msg.append( + { + "name": actual_part.data.get("name"), + "args": actual_part.data.get("args", {}), + } + ) # Tool response: has response, not args elif "response" in actual_part.data and "args" not in actual_part.data: tool_response_data = actual_part.data.get("response", {}) # Keep as dict/string representation response_content = str(tool_response_data) - tool_responses_in_msg.append({ - "content": response_content, - "type": "tool" - }) + tool_responses_in_msg.append( + {"content": response_content, "type": "tool"} + ) # Add AI message with tool calls (if any) - with empty content if tool_calls_in_msg: - conversation_messages.append({ - "content": "", - "type": "ai", - "tool_calls": tool_calls_in_msg - }) + conversation_messages.append( + {"content": "", "type": "ai", "tool_calls": tool_calls_in_msg} + ) logger.info(f"Extracted {len(tool_calls_in_msg)} tool call(s)") # Add tool response messages (if any) @@ -329,10 +333,7 @@ async def multi_turn_experiment(row, agent_url: str, workflow_name: str) -> dict # Add AI message with text content (if any) if text_content: - conversation_messages.append({ - "content": text_content, - "type": "ai" - }) + conversation_messages.append({"content": text_content, "type": "ai"}) else: logger.warning(f"Turn {turn_idx + 1}: task.history not available") diff --git a/scripts/setup.py b/scripts/setup.py index 1642ede..698cf65 100644 --- a/scripts/setup.py +++ b/scripts/setup.py @@ -3,20 +3,19 @@ import os from io import BytesIO from pathlib import Path -from typing import Callable +from typing import Any, Callable import boto3 -from botocore.client import Config import pandas as pd -from botocore.exceptions import ClientError, NoCredentialsError, PartialCredentialsError +from botocore.client import Config from pandas import DataFrame from ragas import Dataset - # Set up module-level logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + def dataframe_to_ragas_dataset(dataframe: DataFrame) -> None: """Convert DataFrame to Ragas dataset and save to data/ragas_dataset.jsonl. @@ -30,7 +29,7 @@ def dataframe_to_ragas_dataset(dataframe: DataFrame) -> None: output_dir.mkdir(exist_ok=True) # Create Ragas Dataset - dataset = Dataset.from_pandas( + dataset: Dataset = Dataset.from_pandas( name="ragas_dataset", dataframe=dataframe, backend="local/jsonl", @@ -81,7 +80,7 @@ def custom_convert_csv(input_file: BytesIO) -> DataFrame: return dataframe -def create_s3_client() -> boto3.client: +def create_s3_client() -> Any: """Create and configure S3 client for MinIO""" # Get MinIO credentials from environment access_key = os.getenv("MINIO_ROOT_USER", "minio") diff --git a/scripts/visualize.py b/scripts/visualize.py index 34d8e5d..70f2efd 100644 --- a/scripts/visualize.py +++ b/scripts/visualize.py @@ -161,7 +161,7 @@ def _format_multi_turn_conversation(conversation: list[dict[str, Any]]) -> str: label = "AI" html_output += f'
' - html_output += f'{label}: ' + html_output += f"{label}: " # If AI message has tool calls, display them if tool_calls: @@ -170,12 +170,12 @@ def _format_multi_turn_conversation(conversation: list[dict[str, Any]]) -> str: tool_name = tool_call.get("name", "unknown") tool_args = tool_call.get("args", {}) # Format args as JSON for readability - args_str = json.dumps(tool_args, indent=2) - html_output += f'
' + args_str = html.escape(json.dumps(tool_args, indent=2)) + html_output += '
' html_output += f'→ Tool: {tool_name}' html_output += f'
{args_str}
' - html_output += '
' - html_output += '
' + html_output += "
" + html_output += "
" # Display content if not empty if content: @@ -183,7 +183,7 @@ def _format_multi_turn_conversation(conversation: list[dict[str, Any]]) -> str: escaped_content = html.escape(content) html_output += f'{escaped_content}' - html_output += '
' + html_output += "
" html_output += "
" return html_output diff --git a/tests/test_run.py b/tests/test_run.py index 093a9ea..fdc935f 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -318,9 +318,9 @@ async def test_multi_turn_experiment_with_tool_calls(monkeypatch): row = { "user_input": [ {"content": "What's the weather in NYC?", "type": "human"}, - {"content": "How about London?", "type": "human"} + {"content": "How about London?", "type": "human"}, ], - "reference": "Weather info provided" + "reference": "Weather info provided", } # Create mock task objects with tool calls @@ -334,38 +334,25 @@ def __init__(self, context_id, turn_idx, has_tool_calls=False): agent_metadata = None if has_tool_calls: agent_metadata = { - "tool_calls": [ - { - "name": "get_weather", - "args": {"location": "NYC" if turn_idx == 1 else "London"} - } - ] + "tool_calls": [{"name": "get_weather", "args": {"location": "NYC" if turn_idx == 1 else "London"}}] } self.history = [ Message( role=Role.user, parts=[Part(TextPart(text=row["user_input"][turn_idx - 1]["content"]))], - message_id=f"user_msg_{turn_idx}" + message_id=f"user_msg_{turn_idx}", ), Message( role=Role.agent, parts=[Part(TextPart(text=f"Weather response {turn_idx}"))], message_id=f"agent_msg_{turn_idx}", - metadata=agent_metadata - ) + metadata=agent_metadata, + ), ] def model_dump(self, mode=None, include=None): - return { - "artifacts": [ - { - "parts": [ - {"text": f"Weather response {self.turn_idx}"} - ] - } - ] - } + return {"artifacts": [{"parts": [{"text": f"Weather response {self.turn_idx}"}]}]} # Mock client that accumulates history class MockClient: @@ -381,15 +368,12 @@ async def send_message(self, message): self.accumulated_history.append(message) # Add agent response message to history - has_tool_calls = (self.turn_count == 1) + has_tool_calls = self.turn_count == 1 agent_metadata = None if has_tool_calls: agent_metadata = { "tool_calls": [ - { - "name": "get_weather", - "args": {"location": "NYC" if self.turn_count == 1 else "London"} - } + {"name": "get_weather", "args": {"location": "NYC" if self.turn_count == 1 else "London"}} ] } @@ -397,7 +381,7 @@ async def send_message(self, message): role=Role.agent, parts=[Part(TextPart(text=f"Weather response {self.turn_count}"))], message_id=f"agent_msg_{self.turn_count}", - metadata=agent_metadata + metadata=agent_metadata, ) self.accumulated_history.append(agent_message) @@ -430,11 +414,7 @@ def mock_setup_otel(): monkeypatch.setattr("run.setup_otel", mock_setup_otel) # Run the experiment - result = await multi_turn_experiment( - row, - agent_url="http://test-agent:8000", - workflow_name="test-workflow" - ) + result = await multi_turn_experiment(row, agent_url="http://test-agent:8000", workflow_name="test-workflow") # Verify result structure assert "user_input" in result @@ -487,7 +467,7 @@ async def test_multi_turn_experiment_no_tool_calls(monkeypatch): "user_input": [ {"content": "Hello", "type": "human"}, ], - "reference": "Greeting response" + "reference": "Greeting response", } # Create mock task without tool calls @@ -498,29 +478,17 @@ def __init__(self, context_id): # History without tool calls in metadata self.history = [ - Message( - role=Role.user, - parts=[Part(TextPart(text="Hello"))], - message_id="user_msg_1" - ), + Message(role=Role.user, parts=[Part(TextPart(text="Hello"))], message_id="user_msg_1"), Message( role=Role.agent, parts=[Part(TextPart(text="Hi there!"))], message_id="agent_msg_1", - metadata=None # No metadata, no tool calls - ) + metadata=None, # No metadata, no tool calls + ), ] def model_dump(self, mode=None, include=None): - return { - "artifacts": [ - { - "parts": [ - {"text": "Hi there!"} - ] - } - ] - } + return {"artifacts": [{"parts": [{"text": "Hi there!"}]}]} # Mock client class MockClient: @@ -543,11 +511,7 @@ def mock_setup_otel(): monkeypatch.setattr("run.setup_otel", mock_setup_otel) # Run the experiment - result = await multi_turn_experiment( - row, - agent_url="http://test-agent:8000", - workflow_name="test-workflow" - ) + result = await multi_turn_experiment(row, agent_url="http://test-agent:8000", workflow_name="test-workflow") # Verify result structure assert "user_input" in result @@ -577,7 +541,7 @@ async def test_multi_turn_experiment_with_datapart_tool_calls(monkeypatch): "user_input": [ {"content": "What time is it in New York?", "type": "human"}, ], - "reference": "Time info provided" + "reference": "Time info provided", } # Create mock task with DataPart tool calls @@ -589,59 +553,60 @@ def __init__(self, context_id): # History with DataPart containing both tool call and tool response self.history = [ Message( - role=Role.user, - parts=[Part(TextPart(text="What time is it in New York?"))], - message_id="user_msg_1" + role=Role.user, parts=[Part(TextPart(text="What time is it in New York?"))], message_id="user_msg_1" ), # Tool call DataPart (has name + args) Message( role=Role.agent, - parts=[Part(DataPart( - kind="data", - data={ - "id": "call_get_current_time", - "name": "get_current_time", - "args": {"city": "New York"} - }, - metadata={"adk_type": "function_call"} - ))], + parts=[ + Part( + DataPart( + kind="data", + data={ + "id": "call_get_current_time", + "name": "get_current_time", + "args": {"city": "New York"}, + }, + metadata={"adk_type": "function_call"}, + ) + ) + ], message_id="agent_msg_1", - metadata=None + metadata=None, ), # Tool response DataPart (has name + response) - should be ignored Message( role=Role.agent, - parts=[Part(DataPart( - kind="data", - data={ - "id": "call_get_current_time", - "name": "get_current_time", - "response": {"status": "success", "report": "The current time in New York is 11:22:05 EST"} - }, - metadata={"adk_type": "function_response"} - ))], + parts=[ + Part( + DataPart( + kind="data", + data={ + "id": "call_get_current_time", + "name": "get_current_time", + "response": { + "status": "success", + "report": "The current time in New York is 11:22:05 EST", + }, + }, + metadata={"adk_type": "function_response"}, + ) + ) + ], message_id="agent_msg_2", - metadata=None + metadata=None, ), # Final text response Message( role=Role.agent, parts=[Part(TextPart(text="The current time in New York is 11:22:05 EST"))], message_id="agent_msg_3", - metadata=None - ) + metadata=None, + ), ] def model_dump(self, mode=None, include=None): - return { - "artifacts": [ - { - "parts": [ - {"text": "The current time in New York is 11:22:05 EST"} - ] - } - ] - } + return {"artifacts": [{"parts": [{"text": "The current time in New York is 11:22:05 EST"}]}]} # Mock client class MockClient: @@ -664,11 +629,7 @@ def mock_setup_otel(): monkeypatch.setattr("run.setup_otel", mock_setup_otel) # Run the experiment - result = await multi_turn_experiment( - row, - agent_url="http://test-agent:8000", - workflow_name="test-workflow" - ) + result = await multi_turn_experiment(row, agent_url="http://test-agent:8000", workflow_name="test-workflow") # Verify result structure assert "user_input" in result diff --git a/tests/test_setup.py b/tests/test_setup.py index 9e594b9..b9c908a 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -111,54 +111,43 @@ def test_creates_ragas_dataset_file(temp_dir): # TestMain tests def test_main_with_csv(temp_dir, monkeypatch): - """Test main function with CSV file""" + """Test main function with CSV file from S3""" tmp, original_cwd = temp_dir os.chdir(tmp) try: - # Mock the HTTP response + # Mock S3 response csv_content = b"user_input,retrieved_contexts,reference\n" csv_content += b'"Question?","Context text","Answer"\n' - class MockResponse: - def __init__(self): - self.content = csv_content + class MockS3Client: + def get_object(self, Bucket, Key): # noqa: N803 + return {"Body": BytesIO(csv_content)} - def raise_for_status(self): - pass + def mock_create_s3_client(): + return MockS3Client() - calls = [] + monkeypatch.setattr("setup.create_s3_client", mock_create_s3_client) - def mock_get(url, timeout=None): - calls.append({"url": url, "timeout": timeout}) - return MockResponse() - - monkeypatch.setattr("setup.requests.get", mock_get) - - # Run main - main("https://example.com/data.csv") + # Run main with bucket and key + main("test-bucket", "data.csv") # Verify dataset was created in datasets subdirectory dataset_file = Path(tmp) / "data" / "datasets" / "ragas_dataset.jsonl" assert dataset_file.exists(), f"Dataset file not found at {dataset_file}" - - # Verify requests.get was called correctly - assert len(calls) == 1 - assert calls[0]["url"] == "https://example.com/data.csv" - assert calls[0]["timeout"] == 20 finally: os.chdir(original_cwd) def test_main_with_json(temp_dir, monkeypatch): - """Test main function with JSON file""" + """Test main function with JSON file from S3""" tmp, original_cwd = temp_dir os.chdir(tmp) try: - # Mock the HTTP response + # Mock S3 response json_content = b"""[ { "user_input": "Question?", @@ -167,20 +156,17 @@ def test_main_with_json(temp_dir, monkeypatch): } ]""" - class MockResponse: - def __init__(self): - self.content = json_content - - def raise_for_status(self): - pass + class MockS3Client: + def get_object(self, Bucket, Key): # noqa: N803 + return {"Body": BytesIO(json_content)} - def mock_get(url, timeout=None): - return MockResponse() + def mock_create_s3_client(): + return MockS3Client() - monkeypatch.setattr("setup.requests.get", mock_get) + monkeypatch.setattr("setup.create_s3_client", mock_create_s3_client) - # Run main - main("https://example.com/data.json") + # Run main with bucket and key + main("test-bucket", "data.json") # Verify dataset was created in datasets subdirectory dataset_file = Path(tmp) / "data" / "datasets" / "ragas_dataset.jsonl" @@ -189,25 +175,25 @@ def mock_get(url, timeout=None): os.chdir(original_cwd) -def test_main_with_invalid_url(temp_dir, monkeypatch): - """Test main function with invalid URL (HTTP error)""" +def test_main_with_invalid_s3_key(temp_dir, monkeypatch): + """Test main function with invalid S3 key (S3 error)""" tmp, original_cwd = temp_dir os.chdir(tmp) try: - # Mock HTTP error - class MockResponse: - def raise_for_status(self): - raise Exception("HTTP 404") + # Mock S3 error + class MockS3Client: + def get_object(self, Bucket, Key): # noqa: N803 + raise Exception("NoSuchKey: The specified key does not exist") - def mock_get(url, timeout=None): - return MockResponse() + def mock_create_s3_client(): + return MockS3Client() - monkeypatch.setattr("setup.requests.get", mock_get) + monkeypatch.setattr("setup.create_s3_client", mock_create_s3_client) # Verify that the error propagates - with pytest.raises(Exception, match="HTTP 404"): - main("https://example.com/nonexistent.csv") + with pytest.raises(Exception, match="NoSuchKey"): + main("test-bucket", "nonexistent.csv") finally: os.chdir(original_cwd) diff --git a/tests/test_visualize.py b/tests/test_visualize.py index ec3b5e1..9561c83 100644 --- a/tests/test_visualize.py +++ b/tests/test_visualize.py @@ -659,13 +659,9 @@ def test_format_multi_turn_conversation_with_tool_calls(): conversation = [ {"content": "What's the weather?", "type": "human"}, - { - "content": "", - "type": "ai", - "tool_calls": [{"name": "get_weather", "args": {"city": "NYC"}}] - }, + {"content": "", "type": "ai", "tool_calls": [{"name": "get_weather", "args": {"city": "NYC"}}]}, {"content": "{'status': 'success', 'report': 'Sunny, 72F'}", "type": "tool"}, - {"content": "The weather is sunny.", "type": "ai"} + {"content": "The weather is sunny.", "type": "ai"}, ] html = _format_multi_turn_conversation(conversation) @@ -699,9 +695,9 @@ def test_format_multi_turn_conversation_with_multiple_tool_calls(): "type": "ai", "tool_calls": [ {"name": "get_weather", "args": {"city": "NYC"}}, - {"name": "get_time", "args": {"city": "NYC"}} - ] - } + {"name": "get_time", "args": {"city": "NYC"}}, + ], + }, ] html = _format_multi_turn_conversation(conversation) @@ -714,7 +710,7 @@ def test_format_multi_turn_conversation_with_multiple_tool_calls(): def test_prepare_chart_data_with_tool_calls(): """Test prepare_chart_data handles tool calls in user_input""" - from visualize import prepare_chart_data, VisualizationData + from visualize import VisualizationData, prepare_chart_data viz_data = VisualizationData( overall_scores={"metric1": 0.85}, @@ -722,16 +718,16 @@ def test_prepare_chart_data_with_tool_calls(): { "user_input": [ {"content": "test", "type": "human"}, - {"content": "", "type": "ai", "tool_calls": [{"name": "tool1", "args": {}}]} + {"content": "", "type": "ai", "tool_calls": [{"name": "tool1", "args": {}}]}, ], "response": "", "metric1": 0.85, - "trace_id": "trace1" + "trace_id": "trace1", } ], total_tokens={"input_tokens": 100, "output_tokens": 50}, total_cost=0.01, - metric_names=["metric1"] + metric_names=["metric1"], ) chart_data = prepare_chart_data(viz_data) diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py index 1ce50d3..c2cade2 100755 --- a/tests_e2e/test_e2e.py +++ b/tests_e2e/test_e2e.py @@ -80,7 +80,7 @@ def verify_scripts_exist(self) -> bool: logger.info("✓ All scripts found") return True - def run_command(self, command: List[str], step_name: str, env: dict = None) -> bool: + def run_command(self, command: List[str], step_name: str, env: dict | None = None) -> bool: """ Run a command and handle output/errors.