Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion src/mjlab/scripts/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mjlab.utils.torch import configure_torch_backends
from mjlab.utils.wrappers import VideoRecorder
from mjlab.viewer import NativeMujocoViewer, ViserPlayViewer
from mjlab.viewer.viser.viewer import CheckpointManager


@dataclass(frozen=True)
Expand Down Expand Up @@ -194,6 +195,63 @@ def __call__(self, obs) -> torch.Tensor:
)
policy = runner.get_inference_policy(device=device)

ckpt_manager = None
if TRAINED_MODE and cfg.wandb_run_path is not None:
from datetime import datetime, timezone

import wandb

def parse_wandb_dt(value: str | datetime) -> datetime:
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return value

api = wandb.Api()
run_path = str(cfg.wandb_run_path)
wandb_run = api.run(run_path)

def fetch_available() -> list[tuple[str, str]]:
run = api.run(run_path)
now = datetime.now(tz=timezone.utc)
entries: list[tuple[str, str, int]] = []
for f in run.files():
if not f.name.endswith(".pt"):
continue
step = int(f.name.split("_")[1].split(".")[0])
s = int((now - parse_wandb_dt(f.updated_at)).total_seconds())
for div, unit in ((86400, "d"), (3600, "h"), (60, "m")):
if s >= div:
t = f"{s // div}{unit} ago"
break
else:
t = f"{s}s ago"
entries.append((f.name, t, step))
entries.sort(key=lambda x: x[2])
return [(name, t) for name, t, _ in entries]

_log_root = log_root_path # type: ignore[possibly-undefined]
_runner = runner # type: ignore[possibly-undefined]

def load_checkpoint(name: str):
path, _ = get_wandb_checkpoint_path(_log_root, Path(run_path), name)
_runner.load(
str(path),
load_cfg={"actor": True},
strict=True,
map_location=device,
)
return _runner.get_inference_policy(device=device)

assert resume_path is not None
ckpt_manager = CheckpointManager(
run_name=parse_wandb_dt(wandb_run.created_at).strftime("%Y-%m-%d_%H-%M-%S"),
run_url=wandb_run.url,
run_status=wandb_run.state,
current_name=resume_path.name,
fetch_available=fetch_available,
load_checkpoint=load_checkpoint,
)

# Handle "auto" viewer selection.
if cfg.viewer == "auto":
has_display = bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"))
Expand All @@ -205,7 +263,7 @@ def __call__(self, obs) -> torch.Tensor:
if resolved_viewer == "native":
NativeMujocoViewer(env, policy).run()
elif resolved_viewer == "viser":
ViserPlayViewer(env, policy).run()
ViserPlayViewer(env, policy, checkpoint_manager=ckpt_manager).run()
else:
raise RuntimeError(f"Unsupported viewer backend: {resolved_viewer}")

Expand Down
1 change: 1 addition & 0 deletions src/mjlab/viewer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class ViewerAction(Enum):
TOGGLE_PLOTS = "toggle_plots"
TOGGLE_DEBUG_VIS = "toggle_debug_vis"
TOGGLE_SHOW_ALL_ENVS = "toggle_show_all_envs"
FETCH_CHECKPOINT = "fetch_checkpoint"
CUSTOM = "custom"


Expand Down
95 changes: 95 additions & 0 deletions src/mjlab/viewer/viser/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
from __future__ import annotations

import time
import webbrowser
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from threading import Lock
from typing import Any, Optional

import viser
from typing_extensions import override
Expand All @@ -19,6 +23,7 @@
EnvProtocol,
PolicyProtocol,
VerbosityLevel,
ViewerAction,
)
from mjlab.viewer.viser.overlays import (
ViserCameraOverlays,
Expand All @@ -29,6 +34,16 @@
from mjlab.viewer.viser.scene import ViserMujocoScene


@dataclass
class CheckpointManager:
run_name: str
run_url: str
run_status: str
current_name: str
fetch_available: Callable[[], list[tuple[str, str]]]
load_checkpoint: Callable[[str], PolicyProtocol]


class UpdateReason(Enum):
ACTION = auto()
ENV_SWITCH = auto()
Expand All @@ -45,8 +60,10 @@ def __init__(
frame_rate: float = 60.0,
verbosity: VerbosityLevel = VerbosityLevel.SILENT,
viser_server: viser.ViserServer | None = None,
checkpoint_manager: CheckpointManager | None = None,
) -> None:
super().__init__(env, policy, frame_rate, verbosity)
self._ckpt_mgr = checkpoint_manager
self._term_overlays: ViserTermOverlays | None = None
self._camera_overlays: ViserCameraOverlays | None = None
self._debug_overlays: ViserDebugOverlays | None = None
Expand Down Expand Up @@ -168,6 +185,84 @@ def _debug_viz_extra() -> None:
# Groups tab (geoms and sites).
self._scene.create_groups_gui(tabs)

if self._ckpt_mgr is not None:
with tabs.add_tab("W&B Run", icon=viser.Icon.CLOUD):
self._server.gui.add_html(
f'<div style="font-size: 0.85em; line-height: 1.25;'
f' padding: 0 1em 0.5em 1em;">'
f"<strong>Name:</strong> {self._ckpt_mgr.run_name}<br/>"
f"<strong>Status:</strong> {self._ckpt_mgr.run_status}"
f"</div>"
)

open_button = self._server.gui.add_button(
"Open Run",
icon=viser.Icon.EXTERNAL_LINK,
)

@open_button.on_click
def _(_) -> None:
assert self._ckpt_mgr is not None
webbrowser.open(self._ckpt_mgr.run_url)

self._ckpt_dropdown = self._server.gui.add_dropdown(
"Checkpoint",
options=[self._ckpt_mgr.current_name],
initial_value=self._ckpt_mgr.current_name,
)

self._ckpt_updating = False

@self._ckpt_dropdown.on_update
def _(_) -> None:
if not self._ckpt_updating:
self._actions.append((ViewerAction.FETCH_CHECKPOINT, "selected"))

ckpt_buttons = self._server.gui.add_button_group(
"",
options=["Refresh", "Use Latest"],
)

@ckpt_buttons.on_click
def _(event) -> None:
if event.target.value == "Refresh":
self._actions.append((ViewerAction.FETCH_CHECKPOINT, "refresh"))
else:
self._actions.append((ViewerAction.FETCH_CHECKPOINT, "latest"))

self._actions.append((ViewerAction.FETCH_CHECKPOINT, "refresh"))

@override
def _handle_custom_action(self, action: ViewerAction, payload: Optional[Any]) -> bool:
if action != ViewerAction.FETCH_CHECKPOINT or self._ckpt_mgr is None:
return action == ViewerAction.FETCH_CHECKPOINT

if payload in ("refresh", "latest"):
entries = self._ckpt_mgr.fetch_available()
labels = [f"{n} ({t})" if t else n for n, t in entries]
self._ckpt_updating = True
self._ckpt_dropdown.options = labels
cur = next(
(lbl for lbl in labels if lbl.startswith(self._ckpt_mgr.current_name)),
self._ckpt_mgr.current_name,
)
self._ckpt_dropdown.value = cur
self._ckpt_updating = False
if payload == "refresh":
return True
payload = entries[-1][0]
else:
payload = self._ckpt_dropdown.value.split(" (")[0]

name = payload
if name != self._ckpt_mgr.current_name:
print(f"[INFO]: Loading {name}...")
self.policy = self._ckpt_mgr.load_checkpoint(name)
self._ckpt_mgr.current_name = name
self.reset_environment()
print(f"[INFO]: Loaded {name}")
return True

@override
def _process_actions(self) -> None:
"""Process queued actions and sync UI state."""
Expand Down