Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hud/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .remove import remove_command
from .rft import rft_command
from .rft_status import rft_status_command
from .rollout import rollout_app
from .utils.config import set_env_values
from .utils.cursor import get_cursor_config_path, list_cursor_servers, parse_cursor_config
from .utils.logging import CaptureLogger
Expand All @@ -40,6 +41,7 @@
rich_markup_mode="rich",
pretty_exceptions_enable=False, # Disable Rich's verbose tracebacks
)
app.add_typer(rollout_app, name="rollout")

console = Console()

Expand Down
155 changes: 155 additions & 0 deletions hud/cli/rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""CLI commands for collecting rollout trajectories."""

from __future__ import annotations

import asyncio
from pathlib import Path
from typing import Any

import typer

from hud.rl.collector import collect_rollouts, write_rollouts_jsonl
from hud.types import AgentType
from hud.utils.hud_console import HUDConsole

rollout_app = typer.Typer(help="Collect rollout trajectories for RL/RFT workflows.")
hud_console = HUDConsole()


def _resolve_agent(
*,
agent: AgentType,
model: str | None,
allowed_tools: list[str] | None,
disallowed_tools: list[str] | None,
verbose: bool,
) -> tuple[AgentType, dict[str, Any]]:
config: dict[str, Any] = {"verbose": verbose, "validate_api_key": False}
if model is not None and agent != AgentType.INTEGRATION_TEST:
config["model"] = model
if allowed_tools:
config["allowed_tools"] = allowed_tools
if disallowed_tools:
config["disallowed_tools"] = disallowed_tools
return agent, config


@rollout_app.command("collect")
def collect_command(
source: str | None = typer.Argument(
None,
help=(
"HuggingFace dataset name (e.g. hud-evals/SheetBench-50) or path to "
"a local JSON/JSONL tasks file."
),
),
output: Path = typer.Option( # noqa: B008
Path("rollouts.jsonl"),
"--output",
"-o",
help="Output JSONL file for collected trajectories.",
),
agent: AgentType = typer.Option( # noqa: B008
AgentType.CLAUDE,
"--agent",
help="Agent backend to use for rollout collection.",
),
model: str | None = typer.Option(
None,
"--model",
help="Model name for the selected agent backend.",
),
allowed_tools: str | None = typer.Option(
None,
"--allowed-tools",
help="Comma-separated list of allowed tools.",
),
disallowed_tools: str | None = typer.Option(
None,
"--disallowed-tools",
help="Comma-separated list of disallowed tools.",
),
max_concurrent: int = typer.Option(
30,
"--max-concurrent",
min=1,
help="Maximum number of concurrent tasks.",
),
max_steps: int = typer.Option(
50,
"--max-steps",
min=1,
help="Maximum steps per rollout.",
),
group_size: int = typer.Option(
1,
"--group-size",
min=1,
help="Number of rollouts to collect per task.",
),
split: str = typer.Option(
"train",
"--split",
help="Dataset split when source is a HuggingFace dataset.",
),
verbose: bool = typer.Option(
False,
"--verbose",
"-v",
help="Enable verbose agent logs.",
),
) -> None:
"""Collect and export rollout trajectories."""
if source is None:
from .utils.tasks import find_tasks_file

try:
source = find_tasks_file(None, msg="Select a tasks file to collect rollouts from")
hud_console.success(f"Selected: {source}")
except Exception:
hud_console.error(
"No source provided and no task/eval JSON files found in current directory"
)
raise typer.Exit(1) from None

allowed_tools_list = (
[tool.strip() for tool in allowed_tools.split(",") if tool.strip()]
if allowed_tools
else None
)
disallowed_tools_list = (
[tool.strip() for tool in disallowed_tools.split(",") if tool.strip()]
if disallowed_tools
else None
)
agent_type, agent_params = _resolve_agent(
agent=agent,
model=model,
allowed_tools=allowed_tools_list,
disallowed_tools=disallowed_tools_list,
verbose=verbose,
)

run_name = f"Rollout Collection: {Path(source).name if Path(source).exists() else source}"
records = asyncio.run(
collect_rollouts(
name=run_name,
source=source,
agent_type=agent_type,
agent_params=agent_params,
max_concurrent=max_concurrent,
max_steps=max_steps,
split=split,
group_size=group_size,
metadata={"source": source, "group_size": group_size},
auto_respond=True,
)
)

if not records:
hud_console.warning("No rollouts were collected.")
return

output_path = write_rollouts_jsonl(records, output)
hud_console.success(f"Collected {len(records)} rollouts")
hud_console.info(f"Saved to: {output_path}")
95 changes: 95 additions & 0 deletions hud/cli/tests/test_rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, patch

import pytest
import typer

import hud.cli.rollout as rollout_cli
from hud.rl.schema import RolloutRecord
from hud.types import AgentType

if TYPE_CHECKING:
from pathlib import Path


@patch("hud.cli.rollout.write_rollouts_jsonl")
@patch("hud.cli.rollout.collect_rollouts", new_callable=AsyncMock)
@patch("hud.cli.rollout._resolve_agent")
def test_collect_command_happy_path(
mock_resolve_agent,
mock_collect_rollouts,
mock_write_rollouts,
tmp_path: Path,
) -> None:
mock_resolve_agent.return_value = (AgentType.CLAUDE, {"model": "x"})
mock_collect_rollouts.return_value = [
RolloutRecord(
rollout_id="rollout_1",
source="tasks.json",
task_index=0,
repeat_index=0,
prompt="Prompt",
)
]
mock_write_rollouts.return_value = tmp_path / "rollouts.jsonl"

rollout_cli.collect_command(
source="tasks.json",
output=tmp_path / "rollouts.jsonl",
agent=AgentType.CLAUDE,
model=None,
allowed_tools="act,observe",
disallowed_tools=None,
max_concurrent=5,
max_steps=7,
group_size=2,
split="train",
verbose=False,
)

mock_resolve_agent.assert_called_once()
mock_collect_rollouts.assert_awaited_once()
assert mock_collect_rollouts.call_args.kwargs["source"] == "tasks.json"
assert mock_collect_rollouts.call_args.kwargs["group_size"] == 2
mock_write_rollouts.assert_called_once()


@patch("hud.cli.utils.tasks.find_tasks_file", side_effect=FileNotFoundError())
def test_collect_command_exits_without_source(mock_find_tasks_file) -> None:
with pytest.raises(typer.Exit):
rollout_cli.collect_command(source=None)

mock_find_tasks_file.assert_called_once()


def test_resolve_agent_only_sets_model_when_explicit() -> None:
_, config_without_model = rollout_cli._resolve_agent(
agent=AgentType.OPENAI,
model=None,
allowed_tools=None,
disallowed_tools=None,
verbose=False,
)
assert "model" not in config_without_model

_, config_with_model = rollout_cli._resolve_agent(
agent=AgentType.OPENAI,
model="gpt-5-mini",
allowed_tools=None,
disallowed_tools=None,
verbose=False,
)
assert config_with_model["model"] == "gpt-5-mini"


def test_resolve_agent_never_sets_model_for_integration_test() -> None:
_, config = rollout_cli._resolve_agent(
agent=AgentType.INTEGRATION_TEST,
model="integration-test",
allowed_tools=None,
disallowed_tools=None,
verbose=False,
)
assert "model" not in config
21 changes: 17 additions & 4 deletions hud/datasets/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,36 @@
__all__ = ["load_dataset", "load_tasks", "save_tasks"]


def _load_raw_from_file(path: Path) -> list[dict[str, Any]]:
def _load_raw_from_file(path: Path, *, strict: bool = False) -> list[dict[str, Any]]:
"""Load raw task dicts from a local JSON or JSONL file."""
raw_items: list[dict[str, Any]] = []

def _append_if_dict(item: Any, *, context: str) -> None:
if isinstance(item, dict):
raw_items.append(item)
elif strict:
raise ValueError(f"{context}: expected object, got {type(item).__name__}")

if path.suffix == ".jsonl":
# JSONL: one task per line
with open(path, encoding="utf-8") as f:
for line in f:
for line_no, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
item = json.loads(line)
# Handle case where line contains a list
if isinstance(item, list):
raw_items.extend(i for i in item if isinstance(i, dict))
for idx, entry in enumerate(item):
_append_if_dict(entry, context=f"line {line_no} item {idx}")
elif isinstance(item, dict):
raw_items.append(item)
else:
if strict:
raise ValueError(
f"line {line_no}: expected object or list[object], "
f"got {type(item).__name__}"
)
raise ValueError(
f"Invalid JSONL format: expected dict or list, got {type(item)}"
)
Expand All @@ -53,7 +65,8 @@ def _load_raw_from_file(path: Path) -> list[dict[str, Any]]:
data = json.load(f)

if isinstance(data, list):
raw_items = [item for item in data if isinstance(item, dict)]
for idx, entry in enumerate(data):
_append_if_dict(entry, context=f"item {idx}")
elif isinstance(data, dict):
raw_items = [data]
else:
Expand Down
14 changes: 14 additions & 0 deletions hud/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""RL module for HUD."""

from __future__ import annotations

from .collector import build_rollout_records, collect_rollouts, write_rollouts_jsonl
from .schema import RolloutRecord, make_rollout_id

__all__ = [
"RolloutRecord",
"build_rollout_records",
"collect_rollouts",
"make_rollout_id",
"write_rollouts_jsonl",
]
Loading
Loading