Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
)
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
from otel_setup import setup_otel
from pydantic import BaseModel
from ragas import Dataset, experiment

from otel_setup import setup_otel

# Set up module-level logger
logging.basicConfig(level=logging.INFO)
logger: Logger = logging.getLogger(__name__)
Expand All @@ -43,14 +42,15 @@ async def initialize_client(agent_url: str) -> Client:


@experiment()
async def run_agent_experiment(row, agent_url: str, workflow_name: str) -> dict[str, str | list]:
async def run_agent_experiment(row, agent_url: str, workflow_name: str, timeout: int = 300) -> dict[str, str | list]:
"""
Experiment function that processes each row from the dataset.

Args:
row: A dictionary containing 'user_input', 'retrieved_contexts', and 'reference' fields
agent_url: The URL of the agent to query
workflow_name: Name of the test workflow for span labeling
timeout: Request timeout in seconds (default: 300)

Returns:
Dictionary with original row data plus 'response' and 'trace_id'
Expand All @@ -76,7 +76,7 @@ async def run_agent_experiment(row, agent_url: str, workflow_name: str) -> dict[
span.set_attribute("workflow.name", workflow_name)

try:
async with httpx.AsyncClient():
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)):
client = await initialize_client(agent_url)

# Get the input from the row
Expand Down Expand Up @@ -126,7 +126,7 @@ async def run_agent_experiment(row, agent_url: str, workflow_name: str) -> dict[
return result


async def main(agent_url: str, workflow_name: str) -> None:
async def main(agent_url: str, workflow_name: str, timeout: int) -> None:
"""Main function to load Ragas Dataset and run Experiment."""

# Initialize OpenTelemetry tracing
Expand All @@ -139,7 +139,9 @@ async def main(agent_url: str, workflow_name: str) -> None:

# Run the experiment
logger.info("Starting experiment...")
await run_agent_experiment.arun(dataset, name="ragas_experiment", agent_url=agent_url, workflow_name=workflow_name)
await run_agent_experiment.arun(
dataset, name="ragas_experiment", agent_url=agent_url, workflow_name=workflow_name, timeout=timeout
)

logger.info("Experiment completed successfully")
logger.info("Results saved to data/experiments/ragas_experiment.jsonl")
Expand All @@ -157,7 +159,13 @@ async def main(agent_url: str, workflow_name: str) -> None:
default="local-test",
help="Name of the test workflow (e.g., 'weather-assistant-test'). Default: 'local-test'",
)
parser.add_argument(
"--timeout",
type=float,
default=300.0,
help="Request timeout in seconds (default: 300)",
)
args = parser.parse_args()

# Call main with parsed arguments
asyncio.run(main(args.url, args.workflow_name))
asyncio.run(main(args.url, args.workflow_name, args.timeout))
14 changes: 5 additions & 9 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass

def mock_httpx_client():
def mock_httpx_client(**kwargs):
return MockAsyncClient()

monkeypatch.setattr("run.initialize_client", mock_init_client)
Expand All @@ -111,9 +111,7 @@ def mock_httpx_client():

# Call the function
result = await run_agent_experiment.func(
test_row,
agent_url="http://test-agent:8000",
workflow_name="test-workflow"
test_row, agent_url="http://test-agent:8000", workflow_name="test-workflow"
)

# Verify result structure
Expand Down Expand Up @@ -141,7 +139,7 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass

def mock_httpx_client():
def mock_httpx_client(**kwargs):
return MockAsyncClient()

monkeypatch.setattr("run.initialize_client", mock_init_client)
Expand All @@ -156,9 +154,7 @@ def mock_httpx_client():

# Call the function
result = await run_agent_experiment.func(
test_row,
agent_url="http://test-agent:8000",
workflow_name="test-workflow"
test_row, agent_url="http://test-agent:8000", workflow_name="test-workflow"
)

# Verify error is captured in response
Expand Down Expand Up @@ -210,7 +206,7 @@ async def mock_arun_tracked(*args, **kwargs):
monkeypatch.setattr("run.run_agent_experiment.arun", mock_arun_tracked)

# Run main
await main("http://test-agent:8000", "test-workflow")
await main("http://test-agent:8000", "test-workflow", 300)

# Verify Dataset.load was called
assert len(calls_to_load) == 1
Expand Down