From ff8e230654a340fb0c703b6c903cb92fe6c59570 Mon Sep 17 00:00:00 2001 From: MiaAppel Date: Tue, 16 Dec 2025 16:55:01 +0100 Subject: [PATCH] FEAT: add variable timout for run.py --- scripts/run.py | 22 +++++++++++++++------- tests/test_run.py | 14 +++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/scripts/run.py b/scripts/run.py index b93f592..fcd8c91 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -16,11 +16,10 @@ ) from opentelemetry import trace from opentelemetry.trace import Status, StatusCode +from otel_setup import setup_otel from pydantic import BaseModel from ragas import Dataset, experiment -from otel_setup import setup_otel - # Set up module-level logger logging.basicConfig(level=logging.INFO) logger: Logger = logging.getLogger(__name__) @@ -43,7 +42,7 @@ async def initialize_client(agent_url: str) -> Client: @experiment() -async def run_agent_experiment(row, agent_url: str, workflow_name: str) -> dict[str, str | list]: +async def run_agent_experiment(row, agent_url: str, workflow_name: str, timeout: int = 300) -> dict[str, str | list]: """ Experiment function that processes each row from the dataset. @@ -51,6 +50,7 @@ async def run_agent_experiment(row, agent_url: str, workflow_name: str) -> dict[ row: A dictionary containing 'user_input', 'retrieved_contexts', and 'reference' fields agent_url: The URL of the agent to query workflow_name: Name of the test workflow for span labeling + timeout: Request timeout in seconds (default: 300) Returns: Dictionary with original row data plus 'response' and 'trace_id' @@ -76,7 +76,7 @@ async def run_agent_experiment(row, agent_url: str, workflow_name: str) -> dict[ span.set_attribute("workflow.name", workflow_name) try: - async with httpx.AsyncClient(): + async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)): client = await initialize_client(agent_url) # Get the input from the row @@ -126,7 +126,7 @@ async def run_agent_experiment(row, agent_url: str, workflow_name: str) -> dict[ return result -async def main(agent_url: str, workflow_name: str) -> None: +async def main(agent_url: str, workflow_name: str, timeout: int) -> None: """Main function to load Ragas Dataset and run Experiment.""" # Initialize OpenTelemetry tracing @@ -139,7 +139,9 @@ async def main(agent_url: str, workflow_name: str) -> None: # Run the experiment logger.info("Starting experiment...") - await run_agent_experiment.arun(dataset, name="ragas_experiment", agent_url=agent_url, workflow_name=workflow_name) + await run_agent_experiment.arun( + dataset, name="ragas_experiment", agent_url=agent_url, workflow_name=workflow_name, timeout=timeout + ) logger.info("Experiment completed successfully") logger.info("Results saved to data/experiments/ragas_experiment.jsonl") @@ -157,7 +159,13 @@ async def main(agent_url: str, workflow_name: str) -> None: default="local-test", help="Name of the test workflow (e.g., 'weather-assistant-test'). Default: 'local-test'", ) + parser.add_argument( + "--timeout", + type=float, + default=300.0, + help="Request timeout in seconds (default: 300)", + ) args = parser.parse_args() # Call main with parsed arguments - asyncio.run(main(args.url, args.workflow_name)) + asyncio.run(main(args.url, args.workflow_name, args.timeout)) diff --git a/tests/test_run.py b/tests/test_run.py index bfbbe60..7aeeb06 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -96,7 +96,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): pass - def mock_httpx_client(): + def mock_httpx_client(**kwargs): return MockAsyncClient() monkeypatch.setattr("run.initialize_client", mock_init_client) @@ -111,9 +111,7 @@ def mock_httpx_client(): # Call the function result = await run_agent_experiment.func( - test_row, - agent_url="http://test-agent:8000", - workflow_name="test-workflow" + test_row, agent_url="http://test-agent:8000", workflow_name="test-workflow" ) # Verify result structure @@ -141,7 +139,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): pass - def mock_httpx_client(): + def mock_httpx_client(**kwargs): return MockAsyncClient() monkeypatch.setattr("run.initialize_client", mock_init_client) @@ -156,9 +154,7 @@ def mock_httpx_client(): # Call the function result = await run_agent_experiment.func( - test_row, - agent_url="http://test-agent:8000", - workflow_name="test-workflow" + test_row, agent_url="http://test-agent:8000", workflow_name="test-workflow" ) # Verify error is captured in response @@ -210,7 +206,7 @@ async def mock_arun_tracked(*args, **kwargs): monkeypatch.setattr("run.run_agent_experiment.arun", mock_arun_tracked) # Run main - await main("http://test-agent:8000", "test-workflow") + await main("http://test-agent:8000", "test-workflow", 300) # Verify Dataset.load was called assert len(calls_to_load) == 1