diff --git a/src/mjlab/scripts/play.py b/src/mjlab/scripts/play.py index 3fd8602d9..f12528e06 100644 --- a/src/mjlab/scripts/play.py +++ b/src/mjlab/scripts/play.py @@ -17,6 +17,7 @@ from mjlab.utils.torch import configure_torch_backends from mjlab.utils.wrappers import VideoRecorder from mjlab.viewer import NativeMujocoViewer, ViserPlayViewer +from mjlab.viewer.viser.viewer import CheckpointManager @dataclass(frozen=True) @@ -194,6 +195,63 @@ def __call__(self, obs) -> torch.Tensor: ) policy = runner.get_inference_policy(device=device) + ckpt_manager = None + if TRAINED_MODE and cfg.wandb_run_path is not None: + from datetime import datetime, timezone + + import wandb + + def parse_wandb_dt(value: str | datetime) -> datetime: + if isinstance(value, str): + return datetime.fromisoformat(value.replace("Z", "+00:00")) + return value + + api = wandb.Api() + run_path = str(cfg.wandb_run_path) + wandb_run = api.run(run_path) + + def fetch_available() -> list[tuple[str, str]]: + run = api.run(run_path) + now = datetime.now(tz=timezone.utc) + entries: list[tuple[str, str, int]] = [] + for f in run.files(): + if not f.name.endswith(".pt"): + continue + step = int(f.name.split("_")[1].split(".")[0]) + s = int((now - parse_wandb_dt(f.updated_at)).total_seconds()) + for div, unit in ((86400, "d"), (3600, "h"), (60, "m")): + if s >= div: + t = f"{s // div}{unit} ago" + break + else: + t = f"{s}s ago" + entries.append((f.name, t, step)) + entries.sort(key=lambda x: x[2]) + return [(name, t) for name, t, _ in entries] + + _log_root = log_root_path # type: ignore[possibly-undefined] + _runner = runner # type: ignore[possibly-undefined] + + def load_checkpoint(name: str): + path, _ = get_wandb_checkpoint_path(_log_root, Path(run_path), name) + _runner.load( + str(path), + load_cfg={"actor": True}, + strict=True, + map_location=device, + ) + return _runner.get_inference_policy(device=device) + + assert resume_path is not None + ckpt_manager = CheckpointManager( + run_name=parse_wandb_dt(wandb_run.created_at).strftime("%Y-%m-%d_%H-%M-%S"), + run_url=wandb_run.url, + run_status=wandb_run.state, + current_name=resume_path.name, + fetch_available=fetch_available, + load_checkpoint=load_checkpoint, + ) + # Handle "auto" viewer selection. if cfg.viewer == "auto": has_display = bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")) @@ -205,7 +263,7 @@ def __call__(self, obs) -> torch.Tensor: if resolved_viewer == "native": NativeMujocoViewer(env, policy).run() elif resolved_viewer == "viser": - ViserPlayViewer(env, policy).run() + ViserPlayViewer(env, policy, checkpoint_manager=ckpt_manager).run() else: raise RuntimeError(f"Unsupported viewer backend: {resolved_viewer}") diff --git a/src/mjlab/viewer/base.py b/src/mjlab/viewer/base.py index 4abde2453..eb1f11b8f 100644 --- a/src/mjlab/viewer/base.py +++ b/src/mjlab/viewer/base.py @@ -141,6 +141,7 @@ class ViewerAction(Enum): TOGGLE_PLOTS = "toggle_plots" TOGGLE_DEBUG_VIS = "toggle_debug_vis" TOGGLE_SHOW_ALL_ENVS = "toggle_show_all_envs" + FETCH_CHECKPOINT = "fetch_checkpoint" CUSTOM = "custom" diff --git a/src/mjlab/viewer/viser/viewer.py b/src/mjlab/viewer/viser/viewer.py index e5d906b3a..85c3355a9 100644 --- a/src/mjlab/viewer/viser/viewer.py +++ b/src/mjlab/viewer/viser/viewer.py @@ -6,9 +6,13 @@ from __future__ import annotations import time +import webbrowser +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from enum import Enum, auto from threading import Lock +from typing import Any, Optional import viser from typing_extensions import override @@ -19,6 +23,7 @@ EnvProtocol, PolicyProtocol, VerbosityLevel, + ViewerAction, ) from mjlab.viewer.viser.overlays import ( ViserCameraOverlays, @@ -29,6 +34,16 @@ from mjlab.viewer.viser.scene import ViserMujocoScene +@dataclass +class CheckpointManager: + run_name: str + run_url: str + run_status: str + current_name: str + fetch_available: Callable[[], list[tuple[str, str]]] + load_checkpoint: Callable[[str], PolicyProtocol] + + class UpdateReason(Enum): ACTION = auto() ENV_SWITCH = auto() @@ -45,8 +60,10 @@ def __init__( frame_rate: float = 60.0, verbosity: VerbosityLevel = VerbosityLevel.SILENT, viser_server: viser.ViserServer | None = None, + checkpoint_manager: CheckpointManager | None = None, ) -> None: super().__init__(env, policy, frame_rate, verbosity) + self._ckpt_mgr = checkpoint_manager self._term_overlays: ViserTermOverlays | None = None self._camera_overlays: ViserCameraOverlays | None = None self._debug_overlays: ViserDebugOverlays | None = None @@ -168,6 +185,84 @@ def _debug_viz_extra() -> None: # Groups tab (geoms and sites). self._scene.create_groups_gui(tabs) + if self._ckpt_mgr is not None: + with tabs.add_tab("W&B Run", icon=viser.Icon.CLOUD): + self._server.gui.add_html( + f'