From e97fb4b914e9c87a92c39ad505845d2d8fff37c2 Mon Sep 17 00:00:00 2001 From: MagellaX Date: Tue, 17 Feb 2026 22:00:31 +0530 Subject: [PATCH 1/6] feat(rl): add rollout collector and rollout CLI --- hud/cli/__init__.py | 2 + hud/cli/rollout.py | 203 ++++++++++++++++++++++++++++++++ hud/cli/tests/test_rollout.py | 64 +++++++++++ hud/rl/__init__.py | 13 +++ hud/rl/collector.py | 204 +++++++++++++++++++++++++++++++++ hud/rl/schema.py | 37 ++++++ hud/rl/tests/test_collector.py | 103 +++++++++++++++++ 7 files changed, 626 insertions(+) create mode 100644 hud/cli/rollout.py create mode 100644 hud/cli/tests/test_rollout.py create mode 100644 hud/rl/collector.py create mode 100644 hud/rl/schema.py create mode 100644 hud/rl/tests/test_collector.py diff --git a/hud/cli/__init__.py b/hud/cli/__init__.py index 02b788d78..89ddebcb3 100644 --- a/hud/cli/__init__.py +++ b/hud/cli/__init__.py @@ -30,6 +30,7 @@ from .pull import pull_command from .push import push_command from .remove import remove_command +from .rollout import rollout_app from .utils.config import set_env_values from .utils.cursor import get_cursor_config_path, list_cursor_servers, parse_cursor_config from .utils.logging import CaptureLogger @@ -42,6 +43,7 @@ rich_markup_mode="rich", pretty_exceptions_enable=False, # Disable Rich's verbose tracebacks ) +app.add_typer(rollout_app, name="rollout") console = Console() diff --git a/hud/cli/rollout.py b/hud/cli/rollout.py new file mode 100644 index 000000000..6b1ed4197 --- /dev/null +++ b/hud/cli/rollout.py @@ -0,0 +1,203 @@ +"""CLI commands for collecting rollout trajectories.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import typer + +from hud.rl.collector import collect_rollouts, write_rollouts_jsonl +from hud.types import AgentType +from hud.utils.hud_console import HUDConsole + +if TYPE_CHECKING: + from hud.agents import MCPAgent + +rollout_app = typer.Typer(help="Collect rollout trajectories for RL/RFT workflows.") +hud_console = HUDConsole() + + +def _resolve_agent( + *, + agent: AgentType, + model: str | None, + allowed_tools: list[str] | None, + verbose: bool, + vllm_base_url: str | None, +) -> tuple[type[MCPAgent], dict[str, Any]]: + if agent == AgentType.INTEGRATION_TEST: + from hud.agents.misc.integration_test_agent import IntegrationTestRunner + + config: dict[str, Any] = {"verbose": verbose} + if allowed_tools: + config["allowed_tools"] = allowed_tools + return IntegrationTestRunner, config + + if agent == AgentType.VLLM: + from hud.agents.openai_chat_generic import GenericOpenAIChatAgent + + from .eval import _build_vllm_config + + return GenericOpenAIChatAgent, _build_vllm_config( + vllm_base_url=vllm_base_url, + model=model, + allowed_tools=allowed_tools, + verbose=verbose, + ) + + if agent == AgentType.OPENAI: + from hud.agents import OperatorAgent + + config = {"verbose": verbose, "validate_api_key": False} + if allowed_tools: + config["allowed_tools"] = allowed_tools + return OperatorAgent, config + + if agent == AgentType.GEMINI: + from hud.agents import GeminiAgent + + config = { + "model": model or "gemini-2.5-computer-use-preview-10-2025", + "verbose": verbose, + "validate_api_key": False, + } + if allowed_tools: + config["allowed_tools"] = allowed_tools + return GeminiAgent, config + + if agent == AgentType.LITELLM: + from hud.agents.lite_llm import LiteAgent + + config = {"model_name": model or "gpt-4o-mini", "verbose": verbose} + if allowed_tools: + config["allowed_tools"] = allowed_tools + return LiteAgent, config + + from hud.agents import ClaudeAgent + + config = { + "model": model or "claude-sonnet-4-20250514", + "verbose": verbose, + "validate_api_key": False, + } + if allowed_tools: + config["allowed_tools"] = allowed_tools + return ClaudeAgent, config + + +@rollout_app.command("collect") +def collect_command( + source: str | None = typer.Argument( + None, + help=( + "HuggingFace dataset name (e.g. hud-evals/SheetBench-50) or path to " + "a local JSON/JSONL tasks file." + ), + ), + output: Path = typer.Option( # noqa: B008 + Path("rollouts.jsonl"), + "--output", + "-o", + help="Output JSONL file for collected trajectories.", + ), + agent: AgentType = typer.Option( # noqa: B008 + AgentType.CLAUDE, + "--agent", + help="Agent backend to use for rollout collection.", + ), + model: str | None = typer.Option( + None, + "--model", + help="Model name for the selected agent backend.", + ), + allowed_tools: str | None = typer.Option( + None, + "--allowed-tools", + help="Comma-separated list of allowed tools.", + ), + max_concurrent: int = typer.Option( + 30, + "--max-concurrent", + min=1, + help="Maximum number of concurrent tasks.", + ), + max_steps: int = typer.Option( + 50, + "--max-steps", + min=1, + help="Maximum steps per rollout.", + ), + group_size: int = typer.Option( + 1, + "--group-size", + min=1, + help="Number of rollouts to collect per task.", + ), + split: str = typer.Option( + "train", + "--split", + help="Dataset split when source is a HuggingFace dataset.", + ), + verbose: bool = typer.Option( + False, + "--verbose", + "-v", + help="Enable verbose agent logs.", + ), + vllm_base_url: str | None = typer.Option( + None, + "--vllm-base-url", + help="Base URL for vLLM server (used with --agent vllm).", + ), +) -> None: + """Collect and export rollout trajectories.""" + if source is None: + from .utils.tasks import find_tasks_file + + try: + source = find_tasks_file(None, msg="Select a tasks file to collect rollouts from") + hud_console.success(f"Selected: {source}") + except (FileNotFoundError, Exception): + hud_console.error( + "No source provided and no task/eval JSON files found in current directory" + ) + raise typer.Exit(1) from None + + allowed_tools_list = ( + [tool.strip() for tool in allowed_tools.split(",") if tool.strip()] + if allowed_tools + else None + ) + agent_class, agent_config = _resolve_agent( + agent=agent, + model=model, + allowed_tools=allowed_tools_list, + verbose=verbose, + vllm_base_url=vllm_base_url, + ) + + run_name = f"Rollout Collection: {Path(source).name if Path(source).exists() else source}" + records = asyncio.run( + collect_rollouts( + name=run_name, + source=source, + agent_class=agent_class, + agent_config=agent_config, + max_concurrent=max_concurrent, + max_steps=max_steps, + split=split, + group_size=group_size, + metadata={"source": source, "group_size": group_size}, + auto_respond=True, + ) + ) + + if not records: + hud_console.warning("No rollouts were collected.") + return + + output_path = write_rollouts_jsonl(records, output) + hud_console.success(f"Collected {len(records)} rollouts") + hud_console.info(f"Saved to: {output_path}") diff --git a/hud/cli/tests/test_rollout.py b/hud/cli/tests/test_rollout.py new file mode 100644 index 000000000..9c580a970 --- /dev/null +++ b/hud/cli/tests/test_rollout.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, patch + +import pytest +import typer + +import hud.cli.rollout as rollout_cli +from hud.rl.schema import RolloutRecord +from hud.types import AgentType + +if TYPE_CHECKING: + from pathlib import Path + + +@patch("hud.cli.rollout.write_rollouts_jsonl") +@patch("hud.cli.rollout.collect_rollouts", new_callable=AsyncMock) +@patch("hud.cli.rollout._resolve_agent") +def test_collect_command_happy_path( + mock_resolve_agent, + mock_collect_rollouts, + mock_write_rollouts, + tmp_path: Path, +) -> None: + mock_resolve_agent.return_value = (object, {"model": "x"}) + mock_collect_rollouts.return_value = [ + RolloutRecord( + rollout_id="rollout_1", + source="tasks.json", + task_index=0, + repeat_index=0, + prompt="Prompt", + ) + ] + mock_write_rollouts.return_value = tmp_path / "rollouts.jsonl" + + rollout_cli.collect_command( + source="tasks.json", + output=tmp_path / "rollouts.jsonl", + agent=AgentType.CLAUDE, + model=None, + allowed_tools="act,observe", + max_concurrent=5, + max_steps=7, + group_size=2, + split="train", + verbose=False, + vllm_base_url=None, + ) + + mock_resolve_agent.assert_called_once() + mock_collect_rollouts.assert_awaited_once() + assert mock_collect_rollouts.call_args.kwargs["source"] == "tasks.json" + assert mock_collect_rollouts.call_args.kwargs["group_size"] == 2 + mock_write_rollouts.assert_called_once() + + +@patch("hud.cli.utils.tasks.find_tasks_file", side_effect=FileNotFoundError()) +def test_collect_command_exits_without_source(mock_find_tasks_file) -> None: + with pytest.raises(typer.Exit): + rollout_cli.collect_command(source=None) + + mock_find_tasks_file.assert_called_once() diff --git a/hud/rl/__init__.py b/hud/rl/__init__.py index 604974ce0..fd581c933 100644 --- a/hud/rl/__init__.py +++ b/hud/rl/__init__.py @@ -1 +1,14 @@ """RL module for HUD.""" + +from __future__ import annotations + +from .collector import build_rollout_records, collect_rollouts, write_rollouts_jsonl +from .schema import RolloutRecord, make_rollout_id + +__all__ = [ + "RolloutRecord", + "build_rollout_records", + "collect_rollouts", + "make_rollout_id", + "write_rollouts_jsonl", +] diff --git a/hud/rl/collector.py b/hud/rl/collector.py new file mode 100644 index 000000000..1d5a8d5c5 --- /dev/null +++ b/hud/rl/collector.py @@ -0,0 +1,204 @@ +"""Rollout collection utilities built on top of run_dataset.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import TYPE_CHECKING, Any, cast + +from pydantic import BaseModel + +from hud.datasets import run_dataset +from hud.types import Trace +from hud.utils.tasks import load_tasks + +from .schema import RolloutRecord, make_rollout_id + +if TYPE_CHECKING: + from collections.abc import Sequence + + from hud.agents import MCPAgent + + +def _to_jsonable(value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + if isinstance(value, dict): + return {str(k): _to_jsonable(v) for k, v in value.items()} + if isinstance(value, list): + return [_to_jsonable(item) for item in value] + if isinstance(value, tuple): + return [_to_jsonable(item) for item in value] + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return str(value) + + +def _source_with_split(source: str, split: str) -> str: + path = Path(source) + if path.exists() or ":" in source or "/" not in source: + return source + if split != "train": + return f"{source}:{split}" + return source + + +def _load_raw_tasks(source: str | Sequence[dict[str, Any]], split: str) -> list[dict[str, Any]]: + if isinstance(source, str): + loaded = load_tasks(_source_with_split(source, split), raw=True) + return cast("list[dict[str, Any]]", loaded) + + tasks: list[dict[str, Any]] = [] + for item in source: + if not isinstance(item, dict): + raise TypeError(f"Expected task dict, got {type(item)}") + tasks.append(cast("dict[str, Any]", _to_jsonable(item))) + return tasks + + +def _coerce_trace(result: Any) -> Trace: + if isinstance(result, Trace): + return result + if result is None: + return Trace(isError=True, content="No trace returned", reward=0.0, done=True) + if isinstance(result, dict): + try: + return Trace.model_validate(result) + except Exception: + payload = json.dumps(_to_jsonable(result), ensure_ascii=False) + return Trace(isError=True, content=payload, reward=0.0, done=True) + + if isinstance(result, BaseModel): + try: + return Trace.model_validate(result.model_dump(mode="json")) + except Exception: + return Trace(isError=True, content=str(result), reward=0.0, done=True) + + return Trace(isError=True, content=str(result), reward=0.0, done=True) + + +def _expand_tasks(tasks: Sequence[dict[str, Any]], group_size: int) -> list[dict[str, Any]]: + expanded: list[dict[str, Any]] = [] + for task in tasks: + expanded.extend(dict(task) for _ in range(group_size)) + return expanded + + +def build_rollout_records( + *, + source: str, + tasks: Sequence[dict[str, Any]], + results: Sequence[Any], + group_size: int = 1, +) -> list[RolloutRecord]: + """Convert run_dataset results into rollout records.""" + if group_size < 1: + raise ValueError("group_size must be >= 1") + + records: list[RolloutRecord] = [] + expected = len(tasks) * group_size + bounded = min(len(results), expected) + + for result_index in range(bounded): + task_index = result_index // group_size + repeat_index = result_index % group_size + task_dict = dict(tasks[task_index]) + prompt = str(task_dict.get("prompt") or f"Task {task_index}") + trace = _coerce_trace(results[result_index]) + task_id_raw = task_dict.get("id") + + records.append( + RolloutRecord( + rollout_id=make_rollout_id(source, task_index, repeat_index, prompt), + source=source, + task_index=task_index, + repeat_index=repeat_index, + task_id=str(task_id_raw) if task_id_raw is not None else None, + prompt=prompt, + reward=trace.reward, + done=trace.done, + is_error=trace.isError, + content=trace.content, + info=cast("dict[str, Any]", _to_jsonable(trace.info)), + task=cast("dict[str, Any]", _to_jsonable(task_dict)), + trace=[_to_jsonable(step) for step in trace.trace], + messages=_to_jsonable(trace.messages), + ) + ) + + for result_index in range(bounded, len(results)): + trace = _coerce_trace(results[result_index]) + fallback_prompt = "Unknown task" + records.append( + RolloutRecord( + rollout_id=make_rollout_id(source, -1, result_index - bounded, fallback_prompt), + source=source, + task_index=-1, + repeat_index=result_index - bounded, + prompt=fallback_prompt, + reward=trace.reward, + done=trace.done, + is_error=trace.isError, + content=trace.content, + info=cast("dict[str, Any]", _to_jsonable(trace.info)), + task={}, + trace=[_to_jsonable(step) for step in trace.trace], + messages=_to_jsonable(trace.messages), + ) + ) + + return records + + +async def collect_rollouts( + *, + name: str, + source: str | Sequence[dict[str, Any]], + agent_class: type[MCPAgent], + agent_config: dict[str, Any] | None = None, + max_concurrent: int = 30, + metadata: dict[str, Any] | None = None, + max_steps: int = 10, + split: str = "train", + group_size: int = 1, + auto_respond: bool = True, +) -> list[RolloutRecord]: + """Collect rollouts by executing tasks with run_dataset.""" + if group_size < 1: + raise ValueError("group_size must be >= 1") + + raw_tasks = _load_raw_tasks(source, split=split) + if not raw_tasks: + return [] + + expanded_tasks = _expand_tasks(raw_tasks, group_size=group_size) + source_name = source if isinstance(source, str) else name + results = await run_dataset( + name=name, + dataset=expanded_tasks, + agent_class=agent_class, + agent_config=agent_config, + max_concurrent=max_concurrent, + metadata=metadata, + max_steps=max_steps, + split=split, + auto_respond=auto_respond, + ) + + return build_rollout_records( + source=source_name, + tasks=raw_tasks, + results=results, + group_size=group_size, + ) + + +def write_rollouts_jsonl(records: Sequence[RolloutRecord], output_path: str | Path) -> Path: + """Write rollout records to JSONL.""" + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as fh: + for record in records: + fh.write(record.model_dump_json()) + fh.write("\n") + return path diff --git a/hud/rl/schema.py b/hud/rl/schema.py new file mode 100644 index 000000000..ff450bb67 --- /dev/null +++ b/hud/rl/schema.py @@ -0,0 +1,37 @@ +"""Schema models for rollout collection.""" + +from __future__ import annotations + +import hashlib +from typing import Any + +from pydantic import BaseModel, Field + +SCHEMA_VERSION = "hud.rollout.v1" + + +def make_rollout_id(source: str, task_index: int, repeat_index: int, prompt: str) -> str: + """Build a stable rollout identifier from task identity.""" + seed = f"{source}|{task_index}|{repeat_index}|{prompt}" + digest = hashlib.sha256(seed.encode("utf-8")).hexdigest() + return f"rollout_{digest[:16]}" + + +class RolloutRecord(BaseModel): + """Serialized rollout record for offline RL/RFT pipelines.""" + + schema_version: str = Field(default=SCHEMA_VERSION) + rollout_id: str + source: str + task_index: int + repeat_index: int + task_id: str | None = None + prompt: str + reward: float = 0.0 + done: bool = True + is_error: bool = False + content: str | None = None + info: dict[str, Any] = Field(default_factory=dict) + task: dict[str, Any] = Field(default_factory=dict) + trace: list[dict[str, Any]] = Field(default_factory=list) + messages: list[Any] = Field(default_factory=list) diff --git a/hud/rl/tests/test_collector.py b/hud/rl/tests/test_collector.py new file mode 100644 index 000000000..3d7315361 --- /dev/null +++ b/hud/rl/tests/test_collector.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import json +from typing import Any, cast +from unittest.mock import AsyncMock, patch + +import pytest + +from hud.rl.collector import ( + build_rollout_records, + collect_rollouts, + write_rollouts_jsonl, +) +from hud.rl.schema import make_rollout_id +from hud.types import Trace, TraceStep + + +def test_make_rollout_id_is_stable() -> None: + first = make_rollout_id("source", 1, 2, "prompt") + second = make_rollout_id("source", 1, 2, "prompt") + changed = make_rollout_id("source", 1, 3, "prompt") + + assert first == second + assert first != changed + assert first.startswith("rollout_") + + +def test_build_rollout_records_with_group_size_and_errors() -> None: + task = {"id": "task-1", "prompt": "Solve 2+2", "mcp_config": {"local": {"url": "x"}}} + trace = Trace( + reward=1.0, + done=True, + content="4", + trace=[TraceStep(category="agent", request={"prompt": "Solve 2+2"}, result={"text": "4"})], + ) + + records = build_rollout_records( + source="tasks.json", + tasks=[task], + results=[trace, None], + group_size=2, + ) + + assert len(records) == 2 + assert records[0].task_index == 0 + assert records[0].repeat_index == 0 + assert records[0].prompt == "Solve 2+2" + assert records[0].trace[0]["category"] == "agent" + + assert records[1].task_index == 0 + assert records[1].repeat_index == 1 + assert records[1].is_error is True + assert records[1].content == "No trace returned" + + +@pytest.mark.asyncio +async def test_collect_rollouts_repeats_tasks_and_uses_split() -> None: + raw_task = {"prompt": "Task prompt", "mcp_config": {"local": {"url": "http://localhost"}}} + trace = Trace(reward=0.5, done=True, content="done") + + with ( + patch("hud.rl.collector.load_tasks", return_value=[raw_task]) as mock_load_tasks, + patch("hud.rl.collector.run_dataset", new_callable=AsyncMock) as mock_run_dataset, + ): + mock_run_dataset.return_value = [trace, trace] + + records = await collect_rollouts( + name="test-rollout", + source="hud-evals/demo", + agent_class=cast("Any", object), + group_size=2, + split="test", + max_steps=3, + max_concurrent=7, + metadata={"suite": "demo"}, + auto_respond=True, + ) + + mock_load_tasks.assert_called_once_with("hud-evals/demo:test", raw=True) + called_dataset = mock_run_dataset.call_args.kwargs["dataset"] + assert len(called_dataset) == 2 + assert called_dataset[0]["prompt"] == "Task prompt" + assert len(records) == 2 + assert records[0].reward == 0.5 + + +def test_write_rollouts_jsonl(tmp_path) -> None: + record = build_rollout_records( + source="tasks.json", + tasks=[{"prompt": "A", "mcp_config": {"x": 1}}], + results=[Trace(reward=1.0, done=True)], + group_size=1, + )[0] + output_path = tmp_path / "rollouts" / "data.jsonl" + + written = write_rollouts_jsonl([record], output_path) + + assert written == output_path + lines = output_path.read_text(encoding="utf-8").splitlines() + assert len(lines) == 1 + payload = cast("dict[str, Any]", json.loads(lines[0])) + assert payload["schema_version"] == "hud.rollout.v1" + assert payload["task_index"] == 0 From bfb6e6b132221ca052a42821bd9596ccbd30f51d Mon Sep 17 00:00:00 2001 From: MagellaX Date: Wed, 18 Feb 2026 01:28:08 +0530 Subject: [PATCH 2/6] fix(rollout): enforce strict task parsing and wire auto_respond --- hud/rl/collector.py | 63 +++++++++++++++++++++++++++++++++- hud/rl/tests/test_collector.py | 39 +++++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/hud/rl/collector.py b/hud/rl/collector.py index b384b87d6..2738fa765 100644 --- a/hud/rl/collector.py +++ b/hud/rl/collector.py @@ -42,9 +42,67 @@ def _source_with_split(source: str, split: str) -> str: return source +def _load_raw_from_file_strict(path: Path) -> list[dict[str, Any]]: + raw_items: list[dict[str, Any]] = [] + + if path.suffix.lower() == ".jsonl": + with open(path, encoding="utf-8") as f: + for line_no, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + + try: + value = json.loads(line) + except json.JSONDecodeError as e: + raise ValueError(f"line {line_no}: invalid JSON ({e.msg})") from e + + if isinstance(value, dict): + raw_items.append(value) + continue + + if isinstance(value, list): + for idx, entry in enumerate(value): + if isinstance(entry, dict): + raw_items.append(entry) + else: + raise ValueError( + f"line {line_no} item {idx}: expected object, got " + f"{type(entry).__name__}" + ) + continue + + raise ValueError( + f"line {line_no}: expected object or list[object], got {type(value).__name__}" + ) + return raw_items + + with open(path, encoding="utf-8") as f: + value = json.load(f) + + if isinstance(value, dict): + return [value] + + if isinstance(value, list): + for idx, entry in enumerate(value): + if isinstance(entry, dict): + raw_items.append(entry) + else: + raise ValueError(f"item {idx}: expected object, got {type(entry).__name__}") + return raw_items + + raise ValueError(f"{path.name}: expected top-level object or array, got {type(value).__name__}") + + def _load_raw_tasks(source: str | Sequence[dict[str, Any]], split: str) -> list[dict[str, Any]]: if isinstance(source, str): + path = Path(source) + if path.exists() and path.suffix.lower() in {".json", ".jsonl"}: + return _load_raw_from_file_strict(path) + loaded = load_tasks(_source_with_split(source, split), raw=True) + if not all(isinstance(item, dict) for item in loaded): + raise ValueError("Loaded tasks must be objects") return cast("list[dict[str, Any]]", loaded) tasks: list[dict[str, Any]] = [] @@ -190,10 +248,13 @@ async def collect_rollouts( expanded_tasks = _expand_tasks(raw_tasks, group_size=group_size) source_name = source if isinstance(source, str) else name + final_agent_params = dict(agent_params or {}) + final_agent_params.setdefault("auto_respond", auto_respond) + results = await run_dataset( expanded_tasks, agent_type, - agent_params=agent_params, + agent_params=final_agent_params, group_size=1, quiet=True, max_concurrent=max_concurrent, diff --git a/hud/rl/tests/test_collector.py b/hud/rl/tests/test_collector.py index 20ec9f047..1982ba0b2 100644 --- a/hud/rl/tests/test_collector.py +++ b/hud/rl/tests/test_collector.py @@ -84,6 +84,45 @@ async def test_collect_rollouts_repeats_tasks_and_uses_split() -> None: assert records[0].reward == 0.5 +@pytest.mark.asyncio +async def test_collect_rollouts_passes_auto_respond_into_agent_params() -> None: + raw_task = {"prompt": "Task prompt", "mcp_config": {"local": {"url": "http://localhost"}}} + + with ( + patch("hud.rl.collector.load_tasks", return_value=[raw_task]), + patch("hud.rl.collector.run_dataset", new_callable=AsyncMock) as mock_run_dataset, + ): + mock_run_dataset.return_value = [Trace(reward=1.0, done=True)] + + await collect_rollouts( + name="test-rollout", + source="hud-evals/demo", + agent_type="claude", + agent_params={"model": "claude-sonnet-4-5"}, + auto_respond=False, + ) + + params = mock_run_dataset.call_args.kwargs["agent_params"] + assert params["model"] == "claude-sonnet-4-5" + assert params["auto_respond"] is False + + +@pytest.mark.asyncio +async def test_collect_rollouts_raises_on_non_dict_entries_in_local_json(tmp_path) -> None: + tasks_path = tmp_path / "tasks.json" + tasks_path.write_text( + json.dumps([{"prompt": "ok", "mcp_config": {"local": {"url": "x"}}}, "not-an-object"]), + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="item 1: expected object"): + await collect_rollouts( + name="test-rollout", + source=str(tasks_path), + agent_type="claude", + ) + + def test_write_rollouts_jsonl(tmp_path) -> None: record = build_rollout_records( source="tasks.json", From 27b4347ed9ed0bbcad638450ff146f3217879495 Mon Sep 17 00:00:00 2001 From: MagellaX Date: Wed, 18 Feb 2026 02:10:48 +0530 Subject: [PATCH 3/6] fix(rollout): address split/source and shared loader feedback --- hud/cli/rollout.py | 4 +-- hud/datasets/loader.py | 21 +++++++++--- hud/rl/collector.py | 59 +++------------------------------- hud/rl/tests/test_collector.py | 16 ++++++++- 4 files changed, 38 insertions(+), 62 deletions(-) diff --git a/hud/cli/rollout.py b/hud/cli/rollout.py index 92eecfc73..5c874c38a 100644 --- a/hud/cli/rollout.py +++ b/hud/cli/rollout.py @@ -27,10 +27,10 @@ def _resolve_agent( config: dict[str, Any] = {"verbose": verbose, "validate_api_key": False} model_defaults: dict[AgentType, str] = { AgentType.CLAUDE: "claude-sonnet-4-5", - AgentType.OPENAI: "gpt-5.1", + AgentType.OPENAI: "gpt-5", AgentType.OPERATOR: "computer-use-preview", AgentType.GEMINI: "gemini-3-pro-preview", - AgentType.GEMINI_CUA: "gemini-2.5-computer-use-preview-10-2025", + AgentType.GEMINI_CUA: "gemini-2.5-computer-use-preview", AgentType.OPENAI_COMPATIBLE: "gpt-5-mini", AgentType.INTEGRATION_TEST: "integration-test", } diff --git a/hud/datasets/loader.py b/hud/datasets/loader.py index 996402c95..12cd3f9c9 100644 --- a/hud/datasets/loader.py +++ b/hud/datasets/loader.py @@ -26,24 +26,36 @@ __all__ = ["load_dataset", "load_tasks", "save_tasks"] -def _load_raw_from_file(path: Path) -> list[dict[str, Any]]: +def _load_raw_from_file(path: Path, *, strict: bool = False) -> list[dict[str, Any]]: """Load raw task dicts from a local JSON or JSONL file.""" raw_items: list[dict[str, Any]] = [] + def _append_if_dict(item: Any, *, context: str) -> None: + if isinstance(item, dict): + raw_items.append(item) + elif strict: + raise ValueError(f"{context}: expected object, got {type(item).__name__}") + if path.suffix == ".jsonl": # JSONL: one task per line with open(path, encoding="utf-8") as f: - for line in f: + for line_no, line in enumerate(f, 1): line = line.strip() if not line: continue item = json.loads(line) # Handle case where line contains a list if isinstance(item, list): - raw_items.extend(i for i in item if isinstance(i, dict)) + for idx, entry in enumerate(item): + _append_if_dict(entry, context=f"line {line_no} item {idx}") elif isinstance(item, dict): raw_items.append(item) else: + if strict: + raise ValueError( + f"line {line_no}: expected object or list[object], " + f"got {type(item).__name__}" + ) raise ValueError( f"Invalid JSONL format: expected dict or list, got {type(item)}" ) @@ -53,7 +65,8 @@ def _load_raw_from_file(path: Path) -> list[dict[str, Any]]: data = json.load(f) if isinstance(data, list): - raw_items = [item for item in data if isinstance(item, dict)] + for idx, entry in enumerate(data): + _append_if_dict(entry, context=f"item {idx}") elif isinstance(data, dict): raw_items = [data] else: diff --git a/hud/rl/collector.py b/hud/rl/collector.py index 2738fa765..c69ea4078 100644 --- a/hud/rl/collector.py +++ b/hud/rl/collector.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from hud.datasets import load_tasks, run_dataset +from hud.datasets.loader import _load_raw_from_file from hud.types import Trace from .schema import RolloutRecord, make_rollout_id @@ -35,70 +36,18 @@ def _to_jsonable(value: Any) -> Any: def _source_with_split(source: str, split: str) -> str: path = Path(source) - if path.exists() or ":" in source or "/" not in source: + if path.is_file() or ":" in source or "/" not in source: return source if split != "train": return f"{source}:{split}" return source -def _load_raw_from_file_strict(path: Path) -> list[dict[str, Any]]: - raw_items: list[dict[str, Any]] = [] - - if path.suffix.lower() == ".jsonl": - with open(path, encoding="utf-8") as f: - for line_no, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - - try: - value = json.loads(line) - except json.JSONDecodeError as e: - raise ValueError(f"line {line_no}: invalid JSON ({e.msg})") from e - - if isinstance(value, dict): - raw_items.append(value) - continue - - if isinstance(value, list): - for idx, entry in enumerate(value): - if isinstance(entry, dict): - raw_items.append(entry) - else: - raise ValueError( - f"line {line_no} item {idx}: expected object, got " - f"{type(entry).__name__}" - ) - continue - - raise ValueError( - f"line {line_no}: expected object or list[object], got {type(value).__name__}" - ) - return raw_items - - with open(path, encoding="utf-8") as f: - value = json.load(f) - - if isinstance(value, dict): - return [value] - - if isinstance(value, list): - for idx, entry in enumerate(value): - if isinstance(entry, dict): - raw_items.append(entry) - else: - raise ValueError(f"item {idx}: expected object, got {type(entry).__name__}") - return raw_items - - raise ValueError(f"{path.name}: expected top-level object or array, got {type(value).__name__}") - - def _load_raw_tasks(source: str | Sequence[dict[str, Any]], split: str) -> list[dict[str, Any]]: if isinstance(source, str): path = Path(source) if path.exists() and path.suffix.lower() in {".json", ".jsonl"}: - return _load_raw_from_file_strict(path) + return _load_raw_from_file(path, strict=True) loaded = load_tasks(_source_with_split(source, split), raw=True) if not all(isinstance(item, dict) for item in loaded): @@ -247,7 +196,7 @@ async def collect_rollouts( return [] expanded_tasks = _expand_tasks(raw_tasks, group_size=group_size) - source_name = source if isinstance(source, str) else name + source_name = _source_with_split(source, split) if isinstance(source, str) else name final_agent_params = dict(agent_params or {}) final_agent_params.setdefault("auto_respond", auto_respond) diff --git a/hud/rl/tests/test_collector.py b/hud/rl/tests/test_collector.py index 1982ba0b2..f30f5d328 100644 --- a/hud/rl/tests/test_collector.py +++ b/hud/rl/tests/test_collector.py @@ -1,12 +1,13 @@ from __future__ import annotations import json -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from unittest.mock import AsyncMock, patch import pytest from hud.rl.collector import ( + _source_with_split, build_rollout_records, collect_rollouts, write_rollouts_jsonl, @@ -14,6 +15,9 @@ from hud.rl.schema import make_rollout_id from hud.types import Trace, TraceStep +if TYPE_CHECKING: + from pathlib import Path + def test_make_rollout_id_is_stable() -> None: first = make_rollout_id("source", 1, 2, "prompt") @@ -81,6 +85,8 @@ async def test_collect_rollouts_repeats_tasks_and_uses_split() -> None: assert len(called_dataset) == 2 assert called_dataset[0]["prompt"] == "Task prompt" assert len(records) == 2 + assert records[0].source == "hud-evals/demo:test" + assert records[0].rollout_id == make_rollout_id("hud-evals/demo:test", 0, 0, "Task prompt") assert records[0].reward == 0.5 @@ -107,6 +113,14 @@ async def test_collect_rollouts_passes_auto_respond_into_agent_params() -> None: assert params["auto_respond"] is False +def test_source_with_split_keeps_hf_split_when_matching_directory_exists( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / "hud-evals" / "demo").mkdir(parents=True) + assert _source_with_split("hud-evals/demo", "test") == "hud-evals/demo:test" + + @pytest.mark.asyncio async def test_collect_rollouts_raises_on_non_dict_entries_in_local_json(tmp_path) -> None: tasks_path = tmp_path / "tasks.json" From 5f232de5596c068ca83ff6e9cfe3ea920045a3ba Mon Sep 17 00:00:00 2001 From: MagellaX Date: Wed, 18 Feb 2026 12:06:56 +0530 Subject: [PATCH 4/6] fix(rollout): preserve agent-native model defaults --- hud/cli/rollout.py | 13 ++----------- hud/cli/tests/test_rollout.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/hud/cli/rollout.py b/hud/cli/rollout.py index 5c874c38a..8357e5967 100644 --- a/hud/cli/rollout.py +++ b/hud/cli/rollout.py @@ -25,17 +25,8 @@ def _resolve_agent( verbose: bool, ) -> tuple[AgentType, dict[str, Any]]: config: dict[str, Any] = {"verbose": verbose, "validate_api_key": False} - model_defaults: dict[AgentType, str] = { - AgentType.CLAUDE: "claude-sonnet-4-5", - AgentType.OPENAI: "gpt-5", - AgentType.OPERATOR: "computer-use-preview", - AgentType.GEMINI: "gemini-3-pro-preview", - AgentType.GEMINI_CUA: "gemini-2.5-computer-use-preview", - AgentType.OPENAI_COMPATIBLE: "gpt-5-mini", - AgentType.INTEGRATION_TEST: "integration-test", - } - if agent != AgentType.INTEGRATION_TEST: - config["model"] = model or model_defaults[agent] + if model is not None and agent != AgentType.INTEGRATION_TEST: + config["model"] = model if allowed_tools: config["allowed_tools"] = allowed_tools if disallowed_tools: diff --git a/hud/cli/tests/test_rollout.py b/hud/cli/tests/test_rollout.py index e20d28c0d..b81239a24 100644 --- a/hud/cli/tests/test_rollout.py +++ b/hud/cli/tests/test_rollout.py @@ -62,3 +62,34 @@ def test_collect_command_exits_without_source(mock_find_tasks_file) -> None: rollout_cli.collect_command(source=None) mock_find_tasks_file.assert_called_once() + + +def test_resolve_agent_only_sets_model_when_explicit() -> None: + _, config_without_model = rollout_cli._resolve_agent( + agent=AgentType.OPENAI, + model=None, + allowed_tools=None, + disallowed_tools=None, + verbose=False, + ) + assert "model" not in config_without_model + + _, config_with_model = rollout_cli._resolve_agent( + agent=AgentType.OPENAI, + model="gpt-5-mini", + allowed_tools=None, + disallowed_tools=None, + verbose=False, + ) + assert config_with_model["model"] == "gpt-5-mini" + + +def test_resolve_agent_never_sets_model_for_integration_test() -> None: + _, config = rollout_cli._resolve_agent( + agent=AgentType.INTEGRATION_TEST, + model="integration-test", + allowed_tools=None, + disallowed_tools=None, + verbose=False, + ) + assert "model" not in config From b46a3589ea53aeb3941cc11ed22ac822aab1b4ec Mon Sep 17 00:00:00 2001 From: MagellaX Date: Wed, 18 Feb 2026 12:35:26 +0530 Subject: [PATCH 5/6] fix(rollout): preserve split for slash-free hf datasets --- hud/rl/collector.py | 2 +- hud/rl/tests/test_collector.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/hud/rl/collector.py b/hud/rl/collector.py index c69ea4078..11a922aef 100644 --- a/hud/rl/collector.py +++ b/hud/rl/collector.py @@ -36,7 +36,7 @@ def _to_jsonable(value: Any) -> Any: def _source_with_split(source: str, split: str) -> str: path = Path(source) - if path.is_file() or ":" in source or "/" not in source: + if path.is_file() or ":" in source: return source if split != "train": return f"{source}:{split}" diff --git a/hud/rl/tests/test_collector.py b/hud/rl/tests/test_collector.py index f30f5d328..9533e9754 100644 --- a/hud/rl/tests/test_collector.py +++ b/hud/rl/tests/test_collector.py @@ -121,6 +121,10 @@ def test_source_with_split_keeps_hf_split_when_matching_directory_exists( assert _source_with_split("hud-evals/demo", "test") == "hud-evals/demo:test" +def test_source_with_split_keeps_hf_split_for_slash_free_dataset() -> None: + assert _source_with_split("imdb", "test") == "imdb:test" + + @pytest.mark.asyncio async def test_collect_rollouts_raises_on_non_dict_entries_in_local_json(tmp_path) -> None: tasks_path = tmp_path / "tasks.json" From 92658ed13ae89c991e32a4a155b6162bc634c1e1 Mon Sep 17 00:00:00 2001 From: MagellaX Date: Wed, 18 Feb 2026 12:51:31 +0530 Subject: [PATCH 6/6] chore(rollout): remove redundant exception tuple --- hud/cli/rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hud/cli/rollout.py b/hud/cli/rollout.py index 8357e5967..8fea320a5 100644 --- a/hud/cli/rollout.py +++ b/hud/cli/rollout.py @@ -106,7 +106,7 @@ def collect_command( try: source = find_tasks_file(None, msg="Select a tasks file to collect rollouts from") hud_console.success(f"Selected: {source}") - except (FileNotFoundError, Exception): + except Exception: hud_console.error( "No source provided and no task/eval JSON files found in current directory" )