From 8ac396e65e8cebae333e28114f58002c9bbd9d5e Mon Sep 17 00:00:00 2001 From: mark Date: Mon, 15 Dec 2025 16:24:44 +0000 Subject: [PATCH 1/5] added policy rollout examples --- examples/4_rollout_neuracore_policy.py | 832 ++++++++++++++++++ .../5_rollout_neuracore_policy_minimal.py | 294 +++++++ 2 files changed, 1126 insertions(+) create mode 100644 examples/4_rollout_neuracore_policy.py create mode 100644 examples/5_rollout_neuracore_policy_minimal.py diff --git a/examples/4_rollout_neuracore_policy.py b/examples/4_rollout_neuracore_policy.py new file mode 100644 index 0000000..590854d --- /dev/null +++ b/examples/4_rollout_neuracore_policy.py @@ -0,0 +1,832 @@ +#!/usr/bin/env python3 +"""Piper Robot Test with NeuraCore policy. + +This script loads a trained NeuraCore policy, reads status from the piper robot +controlled by the Meta Quest controller, and replays the prediction horizon virtually +on Viser to test the stability of the policy output. +""" + +import argparse +import sys +import threading +import time +import traceback +from pathlib import Path + +import neuracore as nc +import numpy as np +from neuracore_types import ( + CameraData, + JointData, + ParallelGripperOpenAmountData, + SyncPoint, +) + +# Add parent directory to path to import pink_ik_solver and piper_controller +sys.path.insert(0, str(Path(__file__).parent.parent)) + +# Add meta_quest_teleop to path +sys.path.insert(0, str(Path(__file__).parent.parent / "meta_quest_teleop")) + +from common.configs import ( + CAMERA_FRAME_STREAMING_RATE, + CONTROLLER_BETA, + CONTROLLER_D_CUTOFF, + CONTROLLER_DATA_RATE, + CONTROLLER_MIN_CUTOFF, + DAMPING_COST, + FRAME_TASK_GAIN, + GRIPPER_FRAME_NAME, + IK_SOLVER_RATE, + JOINT_STATE_STREAMING_RATE, + LM_DAMPING, + NEUTRAL_JOINT_ANGLES, + ORIENTATION_COST, + POSITION_COST, + POSTURE_COST_VECTOR, + ROBOT_RATE, + SOLVER_DAMPING_VALUE, + SOLVER_NAME, + URDF_PATH, + VISUALIZATION_RATE, +) +from common.data_manager import DataManager, RobotActivityState +from common.policy_state import PolicyState +from common.robot_visualizer import RobotVisualizer +from common.threads.camera import camera_thread +from common.threads.ik_solver import ik_solver_thread +from common.threads.joint_state import joint_state_thread +from common.threads.quest_reader import quest_reader_thread + +from meta_quest_teleop.reader import MetaQuestReader +from pink_ik_solver import PinkIKSolver +from piper_controller import PiperController + +POLICY_EXECUTION_RATE = 100.0 # Hz +PREDICTION_HORIZON_EXECUTION_RATIO = ( + 0.8 # percentage of the prediction horizon that is executed +) +MAX_SAFETY_THRESHOLD = 20.0 # degrees +MAX_ACTION_ERROR_THRESHOLD = 3.0 # degrees +TARGET_MODE = ( + PolicyState.ExecutionMode.TARGETING_TIME +) # "targeting_time" or "targeting_pose" +TARGETING_POSE_TIME_THRESHOLD = 1.0 # seconds + +GRIPPER_NAME = "gripper" +JOINT_NAMES = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"] + + +def toggle_robot_enabled_status( + data_manager: DataManager, + robot_controller: PiperController, + visualizer: RobotVisualizer, +) -> None: + """Handle Button A press to toggle robot enable/disable state.""" + robot_activity_state = data_manager.get_robot_activity_state() + if robot_activity_state == RobotActivityState.ENABLED: + # Disable robot + data_manager.set_robot_activity_state(RobotActivityState.DISABLED) + robot_controller.graceful_stop() + # Reset teleop state when disabling robot + data_manager.set_teleop_state(False, None, None) + visualizer.update_toggle_robot_enabled_status(False) + print("โœ“ ๐Ÿ”ด Robot disabled (Button A)") + elif robot_activity_state == RobotActivityState.DISABLED: + if robot_controller.resume_robot(): + data_manager.set_robot_activity_state(RobotActivityState.ENABLED) + visualizer.update_toggle_robot_enabled_status(True) + print("โœ“ ๐ŸŸข Robot enabled (Button A)") + else: + print("โœ— Failed to enable robot") + + +def home_robot(data_manager: DataManager, robot_controller: PiperController) -> None: + """Handle Button B press to move robot to home position.""" + robot_activity_state = data_manager.get_robot_activity_state() + if robot_activity_state == RobotActivityState.ENABLED: + print("๐Ÿ  Button B pressed - Moving to home position...") + # Set state to HOMING to prevent IK thread from sending robot commands + data_manager.set_robot_activity_state(RobotActivityState.HOMING) + # Disable teleop during homing + data_manager.set_teleop_state(False, None, None) + ok = robot_controller.move_to_home() + if not ok: + print("โœ— Failed to initiate home move") + # Revert to ENABLED on failure + data_manager.set_robot_activity_state(RobotActivityState.ENABLED) + else: + print("โš ๏ธ Button B pressed but robot is not enabled") + + +def run_policy( + data_manager: DataManager, + policy: nc.policy, + policy_state: PolicyState, + visualizer: RobotVisualizer, +) -> bool: + """Handle Run Policy button press to capture state and get policy prediction.""" + # Get current joint positions + current_joint_angles = data_manager.get_current_joint_angles() + if current_joint_angles is None: + print("โš ๏ธ No current joint angles available") + return False + + # Get current gripper open value + gripper_open_value = data_manager.get_current_gripper_open_value() + if gripper_open_value is None: + print("โš ๏ธ No gripper open value available") + return False + + # Get current RGB image + rgb_image = data_manager.get_rgb_image() + if rgb_image is None: + print("โš ๏ธ No RGB image available") + return False + + # Prepare data for NeuraCore logging + joint_angles_rad = np.radians(current_joint_angles) + joint_positions_dict = { + JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad) + } + gripper_open_amounts_dict = {GRIPPER_NAME: gripper_open_value} + + # Log joint positions parallel gripper open amounts and RGB image to NeuraCore + try: + # nc.log_joint_positions(joint_positions_dict) + # nc.log_gripper_data(open_amounts=gripper_open_amounts_dict) + # nc.log_rgb("camera", rgb_image) + + # create neuracore sync point + timestamp = time.time() + joint_positions = JointData( + values=joint_positions_dict, + timestamp=timestamp, + ) + parallel_gripper_open_amounts = ParallelGripperOpenAmountData( + open_amounts=gripper_open_amounts_dict, + timestamp=timestamp, + ) + cam_data = CameraData(frame=rgb_image, timestamp=timestamp) + sync_point = SyncPoint( + joint_positions=joint_positions, + parallel_gripper_open_amounts=parallel_gripper_open_amounts, + rgb_images={"camera": cam_data}, + timestamp=timestamp, + ) + + # Get policy prediction + start_time = time.time() + # predicted_sync_points = policy.predict(timeout=5) + predicted_sync_points = policy.predict(sync_point, timeout=5) + end_time = time.time() + print( + f" โœ“ Got {len(predicted_sync_points)} actions in {end_time - start_time:.3f} seconds" + ) + + prediction_ratio = visualizer.get_prediction_ratio() + policy_state.set_execution_ratio(prediction_ratio) + + # Save full prediction horizon (clipping happens when locking for execution) + prediction_horizon_sync_points = predicted_sync_points + + # Set policy inputs + policy_state.set_policy_rgb_image_input(rgb_image) + policy_state.set_policy_state_input(current_joint_angles) + + # Store prediction horizon actions in policy state + policy_state.set_prediction_horizon_sync_points(prediction_horizon_sync_points) + visualizer.update_ghost_robot_visibility(True) + policy_state.set_ghost_robot_playing(True) + policy_state.reset_ghost_action_index() + + except Exception as e: + print(f"โœ— Failed to get policy prediction: {e}") + traceback.print_exc() + return False + + return True + + +def start_policy_execution( + data_manager: DataManager, policy_state: PolicyState +) -> bool: + """Handle Execute Policy button press to start policy execution.""" + # Check if policy execution is already active + if ( + data_manager.get_robot_activity_state() == RobotActivityState.POLICY_CONTROLLED + and not policy_state.get_continuous_play_active() + ): + print("โš ๏ธ Policy execution already in progress") + return False + # Check if robot is enabled + elif data_manager.get_robot_activity_state() == RobotActivityState.DISABLED: + print("โš ๏ธ Cannot execute policy: Robot is disabled") + return False + + # Get prediction horizon + prediction_horizon_sync_points = policy_state.get_prediction_horizon_sync_points() + prediction_horizon_length = len(prediction_horizon_sync_points) + if prediction_horizon_length == 0: + print("โš ๏ธ No prediction horizon available. Make sure policy was run first.") + return False + first_sync_point = prediction_horizon_sync_points[0] + if first_sync_point.joint_target_positions is None: + print("โš ๏ธ First prediction in horizon has no joint targets") + return False + + # Safety check: verify robot is close enough to first action + current_joint_angles = data_manager.get_current_joint_angles() + if current_joint_angles is None: + print("โš ๏ธ Cannot execute policy: No current joint angles available") + return False + + current_joint_target_positions_rad = first_sync_point.joint_target_positions.numpy( + order=JOINT_NAMES + ) + + joint_differences = np.abs( + current_joint_angles - np.degrees(current_joint_target_positions_rad) + ) + + if np.any(joint_differences > MAX_SAFETY_THRESHOLD): + print("โš ๏ธ Cannot execute policy: Robot too far from first action") + print(f" Differences: {[f'{d:.3f}' for d in joint_differences]}") + print(f" Threshold: {MAX_SAFETY_THRESHOLD}ยฐ") + return False + + # All checks passed - start execution + + # Stop ghost visualization + policy_state.set_ghost_robot_playing(False) + + # Deactivate teleop + data_manager.set_teleop_state(False, None, None) + + # Lock policy inputs and start execution + policy_state.start_policy_execution() + + # Change robot state to POLICY_CONTROLLED + data_manager.set_robot_activity_state(RobotActivityState.POLICY_CONTROLLED) + + return True + + +def run_and_start_policy_execution( + data_manager: DataManager, + policy: nc.policy, + policy_state: PolicyState, + visualizer: RobotVisualizer, +) -> None: + """Handle Run and Execute Policy button press to capture state, get policy prediction, and immediately execute it.""" + print("Run and Execute Policy for one prediction horizon") + run_policy(data_manager, policy, policy_state, visualizer) + start_policy_execution(data_manager, policy_state) + + +def end_policy_play( + data_manager: DataManager, + policy_state: PolicyState, + visualizer: RobotVisualizer, + policy_status_message: str, +) -> None: + """End continuous play and set robot activity state to ENABLED and update policy status.""" + if policy_state.get_continuous_play_active(): + policy_state.set_continuous_play_active(False) + visualizer.update_play_policy_button_status(False) + policy_state.end_policy_execution() + data_manager.set_robot_activity_state(RobotActivityState.ENABLED) + data_manager.set_teleop_state(False, None, None) + visualizer.update_policy_status(policy_status_message) + + +def play_policy( + data_manager: DataManager, + policy: nc.policy, + policy_state: PolicyState, + visualizer: RobotVisualizer, +) -> None: + """Handle Play Policy button press to start/stop continuous policy execution.""" + if not policy_state.get_continuous_play_active(): + # Start continuous play + print("โ–ถ๏ธ Play Policy button pressed - Starting continuous policy execution...") + + # Run policy to get prediction horizon + success = run_policy(data_manager, policy, policy_state, visualizer) + if not success: + print("โš ๏ธ Failed to run policy") + end_policy_play( + data_manager, + policy_state, + visualizer, + "Continuous play stopped - prediction failed", + ) + return + + # Execute policy + success = start_policy_execution(data_manager, policy_state) + if not success: + print("โš ๏ธ Failed to execute policy") + end_policy_play( + data_manager, + policy_state, + visualizer, + "Continuous play stopped - execution failed", + ) + return + + policy_state.set_continuous_play_active(True) + visualizer.update_play_policy_button_status(True) + + else: + # Stop continuous play + print("โน๏ธ Stop Policy button pressed - Stopping continuous policy execution...") + policy_state.set_continuous_play_active(False) + end_policy_play( + data_manager, policy_state, visualizer, "Policy execution stopped " + ) + + print("โœ“ Policy execution stopped and robot enabled") + + +def policy_execution_thread( + policy: nc.policy, + data_manager: DataManager, + policy_state: PolicyState, + robot_controller: PiperController, + visualizer: RobotVisualizer, +) -> None: + """Policy execution thread.""" + dt_execution = 1.0 / POLICY_EXECUTION_RATE + while True: + start_time = time.time() + + if ( + data_manager.get_robot_activity_state() + == RobotActivityState.POLICY_CONTROLLED + ): + locked_horizon_sync_points = ( + policy_state.get_locked_prediction_horizon_sync_points() + ) + execution_index = policy_state.get_execution_action_index() + locked_horizon_length = policy_state.get_locked_prediction_horizon_length() + if execution_index < locked_horizon_length: + current_sync_point = locked_horizon_sync_points[execution_index] + # Check if previous goal was achieved, if any + current_joint_angles = data_manager.get_current_joint_angles() + if ( + execution_index > 0 + and current_joint_angles is not None + and policy_state.get_execution_mode() + == PolicyState.ExecutionMode.TARGETING_POSE + ): + targeting_pose_start_time = time.time() + while ( + time.time() - targeting_pose_start_time + < TARGETING_POSE_TIME_THRESHOLD + ): + previous_sync_point = locked_horizon_sync_points[ + execution_index - 1 + ] + if previous_sync_point.joint_target_positions is None: + break + previous_joint_target_positions_rad = ( + previous_sync_point.joint_target_positions.numpy( + order=JOINT_NAMES + ) + ) + previous_joint_target_positions_deg = np.degrees( + previous_joint_target_positions_rad + ) + joint_errors = np.abs( + current_joint_angles - previous_joint_target_positions_deg + ) + if np.any(joint_errors <= MAX_ACTION_ERROR_THRESHOLD): + break + time.sleep(0.001) + + # Send current action to robot (if available) + if current_sync_point.joint_target_positions is not None: + current_joint_target_positions_rad = ( + current_sync_point.joint_target_positions.numpy( + order=JOINT_NAMES + ) + ) + current_joint_target_positions_deg = np.degrees( + current_joint_target_positions_rad + ) + robot_controller.set_target_joint_angles( + current_joint_target_positions_deg + ) + + # Send current gripper open value to robot (if available) + if ( + current_sync_point.parallel_gripper_open_amounts is not None + and GRIPPER_NAME + in current_sync_point.parallel_gripper_open_amounts.open_amounts + ): + current_gripper_open_value = ( + current_sync_point.parallel_gripper_open_amounts.open_amounts[ + GRIPPER_NAME + ] + ) + robot_controller.set_gripper_open_value(current_gripper_open_value) + + # Update execution index + policy_state.increment_execution_action_index() + + # Update status + visualizer.update_policy_status( + f"Executing policy: {execution_index + 1}/{locked_horizon_length}" + ) + # Check if continuous play is active + elif policy_state.get_continuous_play_active(): + # Automatically get new prediction and execute + try: + # End policy execution to clear input lock + policy_state.end_policy_execution() + # Run policy to get prediction horizon + success = run_policy(data_manager, policy, policy_state, visualizer) + if not success: + print("โš ๏ธ Failed to run policy") + end_policy_play( + data_manager, + policy_state, + visualizer, + "Continuous play stopped - prediction failed", + ) + continue + + # Execute policy + success = start_policy_execution(data_manager, policy_state) + if not success: + print("โš ๏ธ Failed to execute policy") + end_policy_play( + data_manager, + policy_state, + visualizer, + "Continuous play stopped - execution failed", + ) + continue + + except Exception as e: + print(f"โœ— Failed to get next policy prediction: {e}") + traceback.print_exc() + end_policy_play( + data_manager, + policy_state, + visualizer, + "Continuous play stopped - prediction failed", + ) + else: + # Execution complete + print("โœ“ Policy execution completed") + end_policy_play( + data_manager, policy_state, visualizer, "Policy execution completed" + ) + + # NOTE: this was added here to prevent OpenGL in visualization from blocking CUDA for policy execution + update_visualization(data_manager, policy_state, visualizer) + + dt_execution = 1.0 / visualizer.get_policy_execution_rate() + elapsed = time.time() - start_time + if elapsed < dt_execution: + time.sleep(dt_execution - elapsed) + + +def update_visualization( + data_manager: DataManager, + policy_state: PolicyState, + visualizer: RobotVisualizer, +) -> None: + """Update visualization.""" + # Update actual robot visualization + current_joint_angles = data_manager.get_current_joint_angles() + if current_joint_angles is not None: + joint_config_rad = np.radians(current_joint_angles) + visualizer.update_robot_pose(joint_config_rad) + + # Get policy state for ghost robot + prediction_horizon_sync_points = policy_state.get_prediction_horizon_sync_points() + prediction_horizon_length = len(prediction_horizon_sync_points) + ghost_robot_playing = policy_state.get_ghost_robot_playing() + ghost_action_index = policy_state.get_ghost_action_index() + + # Update ghost robot based on current state + robot_activity_state = data_manager.get_robot_activity_state() + if robot_activity_state == RobotActivityState.POLICY_CONTROLLED: + # During policy execution, make ghost robot show target joint angles + visualizer.update_ghost_robot_visibility(True) + target_joint_angles = data_manager.get_target_joint_angles() + if target_joint_angles is not None: + joint_config_rad = np.radians(target_joint_angles) + visualizer.update_ghost_robot_pose(joint_config_rad) + # Disable buttons during execution + visualizer.set_start_policy_execution_button_disabled(True) + visualizer.set_run_policy_button_disabled(True) + visualizer.set_run_and_start_policy_execution_button_disabled(True) + # Play/Stop button is enabled during execution so we can stop if needed + visualizer.set_play_policy_button_disabled(False) + + elif ( + robot_activity_state == RobotActivityState.ENABLED + and data_manager.get_teleop_active() + ): + # During teleoperation, make ghost robot show target joint angles + visualizer.update_ghost_robot_visibility(True) + target_joint_angles = data_manager.get_target_joint_angles() + if target_joint_angles is not None: + joint_config_rad = np.radians(target_joint_angles) + visualizer.update_ghost_robot_pose(joint_config_rad) + + elif ghost_robot_playing and prediction_horizon_length > 0: + # Enable execute policy button + visualizer.set_start_policy_execution_button_disabled(False) + # show ghost robot + visualizer.update_ghost_robot_visibility(True) + # Update ghost robot with prediction horizon actions (preview mode) + if ghost_action_index < prediction_horizon_length: + ghost_sync_point = prediction_horizon_sync_points[ghost_action_index] + if ghost_sync_point.joint_target_positions is not None: + ghost_joint_config = ghost_sync_point.joint_target_positions.numpy( + order=JOINT_NAMES + ) + visualizer.update_ghost_robot_pose(ghost_joint_config) + next_index = (ghost_action_index + 1) % prediction_horizon_length + policy_state.set_ghost_action_index(next_index) + else: + policy_state.reset_ghost_action_index() + + else: + # When not playing, hide the ghost robot + visualizer.update_ghost_robot_visibility(False) + + # Update button state and policy status when not policy controlled + robot_enabled = robot_activity_state == RobotActivityState.ENABLED + has_horizon = prediction_horizon_length > 0 + + # Update button enabled state + visualizer.set_start_policy_execution_button_disabled( + not (robot_enabled and has_horizon) + ) + visualizer.set_run_policy_button_disabled(not robot_enabled) + visualizer.set_run_and_start_policy_execution_button_disabled(not robot_enabled) + visualizer.set_play_policy_button_disabled(not robot_enabled) + + # Update policy status + if not has_horizon: + visualizer.update_policy_status( + "Ready - Press Right Joystick or 'Run Policy' button to get prediction" + ) + elif not robot_enabled: + visualizer.update_policy_status("Robot not enabled") + else: + visualizer.update_policy_status( + f"Ready - {prediction_horizon_length} actions in horizon" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Piper Robot Test with NeuraCore Policy - REAL ROBOT CONTROL" + ) + parser.add_argument( + "--ip-address", + type=str, + default=None, + help="IP address of Meta Quest device (optional, defaults to None for auto-discovery)", + ) + parser.add_argument( + "--train-run-name", + type=str, + default=None, + help="Name of the training run to load policy from (for cloud training). Mutually exclusive with --model-path.", + ) + parser.add_argument( + "--model-path", + type=str, + default=None, + help="Path to local model file to load policy from. Mutually exclusive with --train-run-name.", + ) + args = parser.parse_args() + + # Validate that exactly one of train-run-name or model-path is provided + if (args.train_run_name is None) == (args.model_path is None): + parser.error( + "Exactly one of --train-run-name or --model-path must be provided (not both, not neither)" + ) + + print("=" * 60) + print("PIPER ROBOT TEST WITH NEURACORE POLICY") + print("=" * 60) + print("Thread frequencies:") + print(f" ๐ŸŽฎ Quest Controller: {CONTROLLER_DATA_RATE} Hz") + print(f" ๐Ÿงฎ IK Solver: {IK_SOLVER_RATE} Hz") + print(f" ๐Ÿค– Robot Controller: {ROBOT_RATE} Hz") + print(f" ๐Ÿ“ธ Camera Frame: {CAMERA_FRAME_STREAMING_RATE} Hz") + print(f" ๐Ÿ“Š Joint State: {JOINT_STATE_STREAMING_RATE} Hz") + print(f" ๐Ÿ–ฅ๏ธ Visualization: {VISUALIZATION_RATE} Hz") + + # Connect to NeuraCore + print("\n๐Ÿ”ง Initializing NeuraCore...") + nc.login() + nc.connect_robot( + robot_name="AgileX PiPER", + urdf_path=str(URDF_PATH), + overwrite=False, + ) + + # Load policy from either train run name or model path + if args.train_run_name is not None: + print(f"\n๐Ÿค– Loading policy from training run: {args.train_run_name}...") + policy = nc.policy(train_run_name=args.train_run_name) + else: + print(f"\n๐Ÿค– Loading policy from model file: {args.model_path}...") + policy = nc.policy(model_file=args.model_path, device="cuda") + print(" โœ“ Policy loaded successfully") + + # Initialize policy state + policy_state = PolicyState() + policy_state.set_execution_mode(TARGET_MODE) + + # Initialize shared state + data_manager = DataManager() + data_manager.set_controller_filter_params( + CONTROLLER_MIN_CUTOFF, + CONTROLLER_BETA, + CONTROLLER_D_CUTOFF, + ) + + # Initialize robot controller + print("\n๐Ÿค– Initializing Piper robot controller...") + robot_controller = PiperController( + can_interface="can0", + robot_rate=ROBOT_RATE, + control_mode=PiperController.ControlMode.JOINT_SPACE, + neutral_joint_angles=NEUTRAL_JOINT_ANGLES, + debug_mode=False, + ) + + # Start robot control loop + print("\n๐Ÿš€ Starting robot control loop...") + robot_controller.start_control_loop() + + # Start joint state thread + print("\n๐Ÿ“Š Starting joint state thread...") + joint_state_thread_obj = threading.Thread( + target=joint_state_thread, args=(data_manager, robot_controller), daemon=True + ) + joint_state_thread_obj.start() + + # Initialize Meta Quest reader + print("\n๐ŸŽฎ Initializing Meta Quest reader...") + quest_reader = MetaQuestReader(ip_address=args.ip_address, port=5555, run=True) + + # Start data collection thread + print("\n๐ŸŽฎ Starting quest reader thread...") + quest_thread = threading.Thread( + target=quest_reader_thread, args=(data_manager, quest_reader), daemon=True + ) + quest_thread.start() + + # set initial configuration to current joint angles + current_joint_angles = data_manager.get_current_joint_angles() + if current_joint_angles is not None: + initial_joint_angles = np.radians(current_joint_angles) + else: + initial_joint_angles = np.radians(NEUTRAL_JOINT_ANGLES) + + # Create Pink IK solver + print("\n๐Ÿ”ง Creating Pink IK solver...") + ik_solver = PinkIKSolver( + urdf_path=URDF_PATH, + end_effector_frame=GRIPPER_FRAME_NAME, + solver_name=SOLVER_NAME, + position_cost=POSITION_COST, + orientation_cost=ORIENTATION_COST, + frame_task_gain=FRAME_TASK_GAIN, + lm_damping=LM_DAMPING, + damping_cost=DAMPING_COST, + solver_damping_value=SOLVER_DAMPING_VALUE, + integration_time_step=1 / IK_SOLVER_RATE, + initial_configuration=initial_joint_angles, + posture_cost_vector=np.array(POSTURE_COST_VECTOR), + ) + + # Start IK solver thread + print("\n๐Ÿงฎ Starting IK solver thread...") + ik_thread = threading.Thread( + target=ik_solver_thread, args=(data_manager, ik_solver), daemon=True + ) + ik_thread.start() + + # Start camera thread + print("\n๐Ÿ“ท Starting camera thread...") + camera_thread_obj = threading.Thread( + target=camera_thread, args=(data_manager,), daemon=True + ) + camera_thread_obj.start() + + # Set up visualization + print("\n๐Ÿ–ฅ๏ธ Starting Viser visualization...") + visualizer = RobotVisualizer(str(URDF_PATH)) + visualizer.add_policy_controls( + initial_prediction_ratio=PREDICTION_HORIZON_EXECUTION_RATIO, + initial_policy_rate=POLICY_EXECUTION_RATE, + initial_robot_rate=ROBOT_RATE, + initial_execution_mode=TARGET_MODE.value, + ) + visualizer.add_toggle_robot_enabled_status_button() + visualizer.add_homing_controls() + visualizer.add_policy_buttons() + + # Set up button callbacks + visualizer.set_toggle_robot_enabled_status_callback( + lambda: toggle_robot_enabled_status(data_manager, robot_controller, visualizer) + ) + visualizer.set_go_home_callback(lambda: home_robot(data_manager, robot_controller)) + visualizer.set_run_policy_callback( + lambda: (run_policy(data_manager, policy, policy_state, visualizer), None)[1] + ) + visualizer.set_start_policy_execution_callback( + lambda: (start_policy_execution(data_manager, policy_state), None)[1] + ) + visualizer.set_run_and_start_policy_execution_callback( + lambda: run_and_start_policy_execution( + data_manager, policy, policy_state, visualizer + ) + ) + visualizer.set_play_policy_callback( + lambda: play_policy(data_manager, policy, policy_state, visualizer) + ) + # Set up execution mode dropdown callback to sync with PolicyState + visualizer.set_execution_mode_callback( + lambda: policy_state.set_execution_mode( + PolicyState.ExecutionMode(visualizer.get_execution_mode()) + ) + ) + + # Register Quest reader button callbacks (after visualizer is created) + quest_reader.on( + "button_a_pressed", + lambda: toggle_robot_enabled_status(data_manager, robot_controller, visualizer), + ) + quest_reader.on( + "button_b_pressed", lambda: home_robot(data_manager, robot_controller) + ) + + # Start policy execution thread + print("\n๐Ÿค– Starting policy execution thread...") + policy_execution_thread_obj = threading.Thread( + target=policy_execution_thread, + args=(policy, data_manager, policy_state, robot_controller, visualizer), + daemon=True, + ) + policy_execution_thread_obj.start() + + print() + print("๐Ÿš€ Starting teleoperation with policy testing...") + print("๐ŸŽฎ CONTROLS:") + print(" 1. Press BUTTON A or Enable Robot button to enable/disable robot") + print(" 2. You have same control over the robot as in teleoperation.") + print(" - Hold RIGHT GRIP to activate teleoperation") + print(" - Move controller - robot follows!") + print(" - Hold RIGHT TRIGGER to close gripper") + print(" - Press BUTTON A or Enable Robot button to enable/disable robot") + print(" - Press BUTTON B or Home Robot button to send robot home") + print(" 3. Click 'Run Policy' button to run policy (without executing)") + print(" 4. Click 'Execute Policy' button to execute prediction horizon") + print(" 5. Click 'Run and Execute Policy' button to run and execute policy") + print(" 6. Click 'Play Policy' button to play policy") + print("โš ๏ธ Press Ctrl+C to exit") + print() + print("๐ŸŒ Open browser: http://localhost:8080") + + try: + while True: + time.sleep(1) + + except KeyboardInterrupt: + print("\n๐Ÿ‘‹ Interrupt received - shutting down gracefully...") + except Exception as e: + print(f"\nโŒ Demo error: {e}") + traceback.print_exc() + + # Cleanup + print("\n๐Ÿงน Cleaning up...") + + # Disconnect policy + policy.disconnect() + + # shutdown threads + data_manager.request_shutdown() + data_manager.set_robot_activity_state(RobotActivityState.DISABLED) + quest_thread.join() + quest_reader.stop() + ik_thread.join() + camera_thread_obj.join() + robot_controller.cleanup() + + nc.logout() + + print("\n๐Ÿ‘‹ Demo stopped.") diff --git a/examples/5_rollout_neuracore_policy_minimal.py b/examples/5_rollout_neuracore_policy_minimal.py new file mode 100644 index 0000000..ae9d884 --- /dev/null +++ b/examples/5_rollout_neuracore_policy_minimal.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +"""Minimal Piper Robot Policy Test - Terminal only, no GUI. + +Simple script that: +1. Enables robot +2. Sends robot home +3. Runs policy in continuous loop (get image, run policy, execute horizon, repeat) +4. On cancellation: sends robot home and exits +""" + +import argparse +import sys +import threading +import time +import traceback +from pathlib import Path + +import neuracore as nc +import numpy as np + +# Add parent directory to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.configs import ( + NEUTRAL_JOINT_ANGLES, + ROBOT_RATE, + URDF_PATH, +) +from common.data_manager import DataManager, RobotActivityState +from common.policy_state import PolicyState +from common.threads.camera import camera_thread +from common.threads.joint_state import joint_state_thread + +from piper_controller import PiperController + +# Constants +POLICY_EXECUTION_RATE = 100.0 # Hz +PREDICTION_HORIZON_EXECUTION_RATIO = 0.5 + +GRIPPER_NAME = "gripper" +JOINT_NAMES = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"] + + +def run_policy( + data_manager: DataManager, + policy: nc.policy, + policy_state: PolicyState, +) -> bool: + """Run policy and get prediction horizon.""" + # Get current state + current_joint_angles = data_manager.get_current_joint_angles() + if current_joint_angles is None: + print("โš ๏ธ No joint angles available") + return False + + gripper_open_value = data_manager.get_current_gripper_open_value() + if gripper_open_value is None: + print("โš ๏ธ No gripper value available") + return False + + rgb_image = data_manager.get_rgb_image() + if rgb_image is None: + print("โš ๏ธ No RGB image available") + return False + + # Get current gripper open value + gripper_open_value = data_manager.get_current_gripper_open_value() + if gripper_open_value is None: + print("โš ๏ธ No gripper open value available") + return False + + # Get current RGB image + rgb_image = data_manager.get_rgb_image() + if rgb_image is None: + print("โš ๏ธ No RGB image available") + return False + + # Prepare data for NeuraCore logging + joint_angles_rad = np.radians(current_joint_angles) + joint_positions_dict = { + JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad) + } + gripper_open_amounts_dict = {GRIPPER_NAME: gripper_open_value} + + # Log joint positions parallel gripper open amounts and RGB image to NeuraCore + nc.log_joint_positions(joint_positions_dict) + nc.log_gripper_data(open_amounts=gripper_open_amounts_dict) + nc.log_rgb("camera", rgb_image) + + # timestamp = time.time() + # sync_point = SyncPoint( + # joint_positions=JointData(values=joint_positions_dict, timestamp=timestamp), + # parallel_gripper_open_amounts=ParallelGripperOpenAmountData( + # open_amounts=gripper_open_amounts_dict, timestamp=timestamp + # ), + # rgb_images={"camera": CameraData(frame=rgb_image, timestamp=timestamp)}, + # timestamp=timestamp, + # ) + + # Get policy prediction + try: + start_time = time.time() + predicted_sync_points = policy.predict(timeout=5) + elapsed = time.time() - start_time + print(f"โœ“ Got {len(predicted_sync_points)} actions in {elapsed:.3f}s") + + # Save full horizon and set execution ratio (clipping occurs on lock) + policy_state.set_execution_ratio(PREDICTION_HORIZON_EXECUTION_RATIO) + policy_state.set_prediction_horizon_sync_points(predicted_sync_points) + return True + + except Exception as e: + print(f"โœ— Policy prediction failed: {e}") + traceback.print_exc() + return False + + +def execute_horizon( + data_manager: DataManager, + policy_state: PolicyState, + robot_controller: PiperController, +) -> None: + """Execute prediction horizon.""" + policy_state.start_policy_execution() + data_manager.set_robot_activity_state(RobotActivityState.POLICY_CONTROLLED) + + locked_horizon_sync_points = ( + policy_state.get_locked_prediction_horizon_sync_points() + ) + horizon_length = policy_state.get_locked_prediction_horizon_length() + dt = 1.0 / POLICY_EXECUTION_RATE + + for i in range(horizon_length): + sync_point = locked_horizon_sync_points[i] + if sync_point.joint_target_positions is not None: + joint_targets_rad = sync_point.joint_target_positions.numpy( + order=JOINT_NAMES + ) + joint_targets_deg = np.degrees(joint_targets_rad) + robot_controller.set_target_joint_angles(joint_targets_deg) + + if ( + sync_point.parallel_gripper_open_amounts is not None + and GRIPPER_NAME in sync_point.parallel_gripper_open_amounts.open_amounts + ): + gripper_value = sync_point.parallel_gripper_open_amounts.open_amounts[ + GRIPPER_NAME + ] + robot_controller.set_gripper_open_value(gripper_value) + + # Sleep to maintain rate + time.sleep(dt) + + # End execution + policy_state.end_policy_execution() + data_manager.set_robot_activity_state(RobotActivityState.ENABLED) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Minimal Piper Policy Test") + parser.add_argument( + "--model-file", + type=str, + required=True, + help="Path to model file (.nc.zip)", + ) + args = parser.parse_args() + + print("=" * 60) + print("PIPER POLICY ROLLOUT") + print("=" * 60) + + # Initialize NeuraCore + print("\n๐Ÿ”ง Initializing NeuraCore...") + nc.login() + nc.connect_robot( + robot_name="AgileX PiPER", + urdf_path=str(URDF_PATH), + overwrite=False, + ) + + # Load policy + print(f"\n๐Ÿค– Loading policy from: {args.model_file}...") + policy = nc.policy(model_file=args.model_file, device="cuda") + print("โœ“ Policy loaded") + + # Initialize state + data_manager = DataManager() + policy_state = PolicyState() + + # Initialize robot controller + print("\n๐Ÿค– Initializing robot controller...") + robot_controller = PiperController( + can_interface="can0", + robot_rate=ROBOT_RATE, + control_mode=PiperController.ControlMode.JOINT_SPACE, + neutral_joint_angles=NEUTRAL_JOINT_ANGLES, + debug_mode=False, + ) + robot_controller.start_control_loop() + + # Start joint state thread + print("\n๐Ÿ“Š Starting joint state thread...") + joint_state_thread_obj = threading.Thread( + target=joint_state_thread, args=(data_manager, robot_controller), daemon=True + ) + joint_state_thread_obj.start() + + # Start camera thread + print("\n๐Ÿ“ท Starting camera thread...") + camera_thread_obj = threading.Thread( + target=camera_thread, args=(data_manager,), daemon=True + ) + camera_thread_obj.start() + + # Wait for threads to initialize + print("\nโณ Waiting for initialization...") + time.sleep(2.0) + + try: + # Enable robot + print("\n๐ŸŸข Enabling robot...") + robot_controller.resume_robot() + data_manager.set_robot_activity_state(RobotActivityState.ENABLED) + print("โœ“ Robot enabled") + + # Home robot + print("\n๐Ÿ  Moving to home position...") + robot_controller.move_to_home() + data_manager.set_robot_activity_state(RobotActivityState.HOMING) + + # Wait for homing to complete + start_time = time.time() + while ( + data_manager.get_robot_activity_state() == RobotActivityState.HOMING + and not robot_controller.is_robot_homed() + and time.time() - start_time < 5.0 + ): + time.sleep(0.1) + print("โœ“ Robot homed") + + # Policy execution loop + print("\n๐Ÿš€ Starting policy execution loop...") + print("Press Ctrl+C to stop\n") + + while True: + # Run policy + if not run_policy(data_manager, policy, policy_state): + print("โš ๏ธ Policy run failed, retrying...") + time.sleep(0.5) + continue + + # Execute horizon + execute_horizon(data_manager, policy_state, robot_controller) + + except KeyboardInterrupt: + print("\n\n๐Ÿ‘‹ Interrupt received - shutting down...") + + except Exception as e: + print(f"\nโŒ Error: {e}") + traceback.print_exc() + + finally: + # Cleanup + print("\n๐Ÿงน Cleaning up...") + + # Home robot + print("\n๐Ÿ  Moving to home position...") + data_manager.set_robot_activity_state(RobotActivityState.HOMING) + robot_controller.move_to_home() + + # Wait for homing to complete + start_time = time.time() + while ( + data_manager.get_robot_activity_state() == RobotActivityState.HOMING + and not robot_controller.is_robot_homed() + and time.time() - start_time < 5.0 + ): + time.sleep(0.1) + print("โœ“ Robot homed") + + # Shutdown + policy.disconnect() + data_manager.set_robot_activity_state(RobotActivityState.DISABLED) + data_manager.request_shutdown() + joint_state_thread_obj.join() + camera_thread_obj.join() + time.sleep(0.5) # Give threads time to stop + + robot_controller.cleanup() + nc.logout() + + print("โœ“ Cleanup complete") + print("\n๐Ÿ‘‹ Done.") From e48fc7d00ada67dc947fab78bed0dd545faa3f88 Mon Sep 17 00:00:00 2001 From: mark Date: Mon, 15 Dec 2025 17:24:34 +0000 Subject: [PATCH 2/5] move constants to config --- examples/4_rollout_neuracore_policy.py | 33 ++++++++----------- .../5_rollout_neuracore_policy_minimal.py | 18 +++++----- examples/common/configs.py | 12 +++++++ 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/examples/4_rollout_neuracore_policy.py b/examples/4_rollout_neuracore_policy.py index 590854d..9fbf882 100644 --- a/examples/4_rollout_neuracore_policy.py +++ b/examples/4_rollout_neuracore_policy.py @@ -37,16 +37,23 @@ DAMPING_COST, FRAME_TASK_GAIN, GRIPPER_FRAME_NAME, + GRIPPER_LOGGING_NAME, IK_SOLVER_RATE, + JOINT_NAMES, JOINT_STATE_STREAMING_RATE, LM_DAMPING, + MAX_ACTION_ERROR_THRESHOLD, + MAX_SAFETY_THRESHOLD, NEUTRAL_JOINT_ANGLES, ORIENTATION_COST, + POLICY_EXECUTION_RATE, POSITION_COST, POSTURE_COST_VECTOR, + PREDICTION_HORIZON_EXECUTION_RATIO, ROBOT_RATE, SOLVER_DAMPING_VALUE, SOLVER_NAME, + TARGETING_POSE_TIME_THRESHOLD, URDF_PATH, VISUALIZATION_RATE, ) @@ -62,20 +69,6 @@ from pink_ik_solver import PinkIKSolver from piper_controller import PiperController -POLICY_EXECUTION_RATE = 100.0 # Hz -PREDICTION_HORIZON_EXECUTION_RATIO = ( - 0.8 # percentage of the prediction horizon that is executed -) -MAX_SAFETY_THRESHOLD = 20.0 # degrees -MAX_ACTION_ERROR_THRESHOLD = 3.0 # degrees -TARGET_MODE = ( - PolicyState.ExecutionMode.TARGETING_TIME -) # "targeting_time" or "targeting_pose" -TARGETING_POSE_TIME_THRESHOLD = 1.0 # seconds - -GRIPPER_NAME = "gripper" -JOINT_NAMES = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"] - def toggle_robot_enabled_status( data_manager: DataManager, @@ -149,7 +142,7 @@ def run_policy( joint_positions_dict = { JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad) } - gripper_open_amounts_dict = {GRIPPER_NAME: gripper_open_value} + gripper_open_amounts_dict = {GRIPPER_LOGGING_NAME: gripper_open_value} # Log joint positions parallel gripper open amounts and RGB image to NeuraCore try: @@ -421,13 +414,13 @@ def policy_execution_thread( # Send current gripper open value to robot (if available) if ( - current_sync_point.parallel_gripper_open_amounts is not None - and GRIPPER_NAME + current_sync_point.parallel_gripper_open_amounts + is not None in current_sync_point.parallel_gripper_open_amounts.open_amounts ): current_gripper_open_value = ( current_sync_point.parallel_gripper_open_amounts.open_amounts[ - GRIPPER_NAME + GRIPPER_LOGGING_NAME ] ) robot_controller.set_gripper_open_value(current_gripper_open_value) @@ -647,7 +640,7 @@ def update_visualization( # Initialize policy state policy_state = PolicyState() - policy_state.set_execution_mode(TARGET_MODE) + policy_state.set_execution_mode(PolicyState.ExecutionMode.TARGETING_TIME) # Initialize shared state data_manager = DataManager() @@ -734,7 +727,7 @@ def update_visualization( initial_prediction_ratio=PREDICTION_HORIZON_EXECUTION_RATIO, initial_policy_rate=POLICY_EXECUTION_RATE, initial_robot_rate=ROBOT_RATE, - initial_execution_mode=TARGET_MODE.value, + initial_execution_mode=PolicyState.ExecutionMode.TARGETING_TIME.value, ) visualizer.add_toggle_robot_enabled_status_button() visualizer.add_homing_controls() diff --git a/examples/5_rollout_neuracore_policy_minimal.py b/examples/5_rollout_neuracore_policy_minimal.py index ae9d884..ca0f263 100644 --- a/examples/5_rollout_neuracore_policy_minimal.py +++ b/examples/5_rollout_neuracore_policy_minimal.py @@ -22,7 +22,11 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) from common.configs import ( + GRIPPER_LOGGING_NAME, + JOINT_NAMES, NEUTRAL_JOINT_ANGLES, + POLICY_EXECUTION_RATE, + PREDICTION_HORIZON_EXECUTION_RATIO, ROBOT_RATE, URDF_PATH, ) @@ -33,13 +37,6 @@ from piper_controller import PiperController -# Constants -POLICY_EXECUTION_RATE = 100.0 # Hz -PREDICTION_HORIZON_EXECUTION_RATIO = 0.5 - -GRIPPER_NAME = "gripper" -JOINT_NAMES = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"] - def run_policy( data_manager: DataManager, @@ -80,7 +77,7 @@ def run_policy( joint_positions_dict = { JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad) } - gripper_open_amounts_dict = {GRIPPER_NAME: gripper_open_value} + gripper_open_amounts_dict = {GRIPPER_LOGGING_NAME: gripper_open_value} # Log joint positions parallel gripper open amounts and RGB image to NeuraCore nc.log_joint_positions(joint_positions_dict) @@ -141,10 +138,11 @@ def execute_horizon( if ( sync_point.parallel_gripper_open_amounts is not None - and GRIPPER_NAME in sync_point.parallel_gripper_open_amounts.open_amounts + and GRIPPER_LOGGING_NAME + in sync_point.parallel_gripper_open_amounts.open_amounts ): gripper_value = sync_point.parallel_gripper_open_amounts.open_amounts[ - GRIPPER_NAME + GRIPPER_LOGGING_NAME ] robot_controller.set_gripper_open_value(gripper_value) diff --git a/examples/common/configs.py b/examples/common/configs.py index 2bcfefe..32050d5 100644 --- a/examples/common/configs.py +++ b/examples/common/configs.py @@ -48,3 +48,15 @@ # Posture task cost vector (one weight per joint) POSTURE_COST_VECTOR = [0.0, 0.0, 0.0, 0.05, 0.0, 0.0] + + +POLICY_EXECUTION_RATE = 100.0 # Hz +PREDICTION_HORIZON_EXECUTION_RATIO = ( + 0.8 # percentage of the prediction horizon that is executed +) +MAX_SAFETY_THRESHOLD = 20.0 # degrees +MAX_ACTION_ERROR_THRESHOLD = 3.0 # degrees +TARGETING_POSE_TIME_THRESHOLD = 1.0 # seconds + +GRIPPER_LOGGING_NAME = "gripper" +JOINT_NAMES = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"] From 74b0b68cce47ae22b1ab843c0330e446d0f671cc Mon Sep 17 00:00:00 2001 From: mark Date: Tue, 16 Dec 2025 12:19:49 +0000 Subject: [PATCH 3/5] new neuracore changes --- examples/4_rollout_neuracore_policy.py | 56 +++++++++---------- .../5_rollout_neuracore_policy_minimal.py | 10 ---- examples/common/configs.py | 1 + 3 files changed, 26 insertions(+), 41 deletions(-) diff --git a/examples/4_rollout_neuracore_policy.py b/examples/4_rollout_neuracore_policy.py index 9fbf882..07244e8 100644 --- a/examples/4_rollout_neuracore_policy.py +++ b/examples/4_rollout_neuracore_policy.py @@ -15,12 +15,6 @@ import neuracore as nc import numpy as np -from neuracore_types import ( - CameraData, - JointData, - ParallelGripperOpenAmountData, - SyncPoint, -) # Add parent directory to path to import pink_ik_solver and piper_controller sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -30,6 +24,7 @@ from common.configs import ( CAMERA_FRAME_STREAMING_RATE, + CAMERA_LOGGING_NAME, CONTROLLER_BETA, CONTROLLER_D_CUTOFF, CONTROLLER_DATA_RATE, @@ -146,32 +141,13 @@ def run_policy( # Log joint positions parallel gripper open amounts and RGB image to NeuraCore try: - # nc.log_joint_positions(joint_positions_dict) - # nc.log_gripper_data(open_amounts=gripper_open_amounts_dict) - # nc.log_rgb("camera", rgb_image) - - # create neuracore sync point - timestamp = time.time() - joint_positions = JointData( - values=joint_positions_dict, - timestamp=timestamp, - ) - parallel_gripper_open_amounts = ParallelGripperOpenAmountData( - open_amounts=gripper_open_amounts_dict, - timestamp=timestamp, - ) - cam_data = CameraData(frame=rgb_image, timestamp=timestamp) - sync_point = SyncPoint( - joint_positions=joint_positions, - parallel_gripper_open_amounts=parallel_gripper_open_amounts, - rgb_images={"camera": cam_data}, - timestamp=timestamp, - ) + nc.log_joint_positions(joint_positions_dict) + nc.log_gripper_data(open_amounts=gripper_open_amounts_dict) + nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image) # Get policy prediction start_time = time.time() - # predicted_sync_points = policy.predict(timeout=5) - predicted_sync_points = policy.predict(sync_point, timeout=5) + predicted_sync_points = policy.predict(timeout=5) end_time = time.time() print( f" โœ“ Got {len(predicted_sync_points)} actions in {end_time - start_time:.3f} seconds" @@ -630,12 +606,30 @@ def update_visualization( ) # Load policy from either train run name or model path + model_input_order = { + "JOINT_POSITIONS": JOINT_NAMES, + "PARALLEL_GRIPPER_OPEN_AMOUNTS": [GRIPPER_LOGGING_NAME], + "RGB_IMAGES": [CAMERA_LOGGING_NAME], + } + model_output_order = { + "JOINT_TARGET_POSITIONS": JOINT_NAMES, + "PARALLEL_GRIPPER_OPEN_AMOUNTS": [GRIPPER_LOGGING_NAME], + } if args.train_run_name is not None: print(f"\n๐Ÿค– Loading policy from training run: {args.train_run_name}...") - policy = nc.policy(train_run_name=args.train_run_name) + policy = nc.policy( + train_run_name=args.train_run_name, + model_input_order=model_input_order, + model_output_order=model_output_order, + ) else: print(f"\n๐Ÿค– Loading policy from model file: {args.model_path}...") - policy = nc.policy(model_file=args.model_path, device="cuda") + policy = nc.policy( + model_file=args.model_path, + device="cuda", + model_input_order=model_input_order, + model_output_order=model_output_order, + ) print(" โœ“ Policy loaded successfully") # Initialize policy state diff --git a/examples/5_rollout_neuracore_policy_minimal.py b/examples/5_rollout_neuracore_policy_minimal.py index ca0f263..114f477 100644 --- a/examples/5_rollout_neuracore_policy_minimal.py +++ b/examples/5_rollout_neuracore_policy_minimal.py @@ -84,16 +84,6 @@ def run_policy( nc.log_gripper_data(open_amounts=gripper_open_amounts_dict) nc.log_rgb("camera", rgb_image) - # timestamp = time.time() - # sync_point = SyncPoint( - # joint_positions=JointData(values=joint_positions_dict, timestamp=timestamp), - # parallel_gripper_open_amounts=ParallelGripperOpenAmountData( - # open_amounts=gripper_open_amounts_dict, timestamp=timestamp - # ), - # rgb_images={"camera": CameraData(frame=rgb_image, timestamp=timestamp)}, - # timestamp=timestamp, - # ) - # Get policy prediction try: start_time = time.time() diff --git a/examples/common/configs.py b/examples/common/configs.py index 32050d5..06057a5 100644 --- a/examples/common/configs.py +++ b/examples/common/configs.py @@ -59,4 +59,5 @@ TARGETING_POSE_TIME_THRESHOLD = 1.0 # seconds GRIPPER_LOGGING_NAME = "gripper" +CAMERA_LOGGING_NAME = "camera" JOINT_NAMES = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"] From 0cbe687120d0e9072b136150eb426f526b7d8e50 Mon Sep 17 00:00:00 2001 From: mark Date: Fri, 19 Dec 2025 13:11:55 +0000 Subject: [PATCH 4/5] changed script to use new dayatypes --- examples/4_rollout_neuracore_policy.py | 152 +++++++++++++++---------- examples/common/configs.py | 2 +- examples/common/policy_state.py | 69 +++++++---- 3 files changed, 144 insertions(+), 79 deletions(-) diff --git a/examples/4_rollout_neuracore_policy.py b/examples/4_rollout_neuracore_policy.py index 07244e8..7c66ca6 100644 --- a/examples/4_rollout_neuracore_policy.py +++ b/examples/4_rollout_neuracore_policy.py @@ -15,6 +15,11 @@ import neuracore as nc import numpy as np +from neuracore_types import ( + BatchedJointData, + BatchedParallelGripperOpenAmountData, + DataType, +) # Add parent directory to path to import pink_ik_solver and piper_controller sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -65,6 +70,34 @@ from piper_controller import PiperController +def convert_predictions_to_horizon_dict(predictions: dict) -> dict[str, list[float]]: + """Convert predictions dict to horizon dict format.""" + horizon: dict[str, list[float]] = {} + + # Extract joint target positions + if DataType.JOINT_TARGET_POSITIONS in predictions: + joint_data = predictions[DataType.JOINT_TARGET_POSITIONS] + for joint_name in JOINT_NAMES: + if joint_name in joint_data: + batched = joint_data[joint_name] + if isinstance(batched, BatchedJointData): + # Extract values: (B, T, 1) -> list[float], taking B=0 + values = batched.value[0, :, 0].cpu().numpy().tolist() + horizon[joint_name] = values + + # Extract gripper open amounts + if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions: + gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS] + if GRIPPER_LOGGING_NAME in gripper_data: + batched = gripper_data[GRIPPER_LOGGING_NAME] + if isinstance(batched, BatchedParallelGripperOpenAmountData): + # Extract values: (B, T, 1) -> list[float], taking B=0 + values = batched.open_amount[0, :, 0].cpu().numpy().tolist() + horizon[GRIPPER_LOGGING_NAME] = values + + return horizon + + def toggle_robot_enabled_status( data_manager: DataManager, robot_controller: PiperController, @@ -120,8 +153,8 @@ def run_policy( print("โš ๏ธ No current joint angles available") return False - # Get current gripper open value - gripper_open_value = data_manager.get_current_gripper_open_value() + # Get target gripper open value because this is how the policy was trained + gripper_open_value = data_manager.get_target_gripper_open_value() if gripper_open_value is None: print("โš ๏ธ No gripper open value available") return False @@ -137,34 +170,33 @@ def run_policy( joint_positions_dict = { JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad) } - gripper_open_amounts_dict = {GRIPPER_LOGGING_NAME: gripper_open_value} # Log joint positions parallel gripper open amounts and RGB image to NeuraCore try: nc.log_joint_positions(joint_positions_dict) - nc.log_gripper_data(open_amounts=gripper_open_amounts_dict) + nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_open_value) nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image) # Get policy prediction start_time = time.time() - predicted_sync_points = policy.predict(timeout=5) + predictions = policy.predict(timeout=5) + prediction_horizon = convert_predictions_to_horizon_dict(predictions) end_time = time.time() + horizon_length = policy_state.get_prediction_horizon_length() print( - f" โœ“ Got {len(predicted_sync_points)} actions in {end_time - start_time:.3f} seconds" + f" โœ“ Got {horizon_length} actions in {end_time - start_time:.3f} seconds" ) prediction_ratio = visualizer.get_prediction_ratio() policy_state.set_execution_ratio(prediction_ratio) - # Save full prediction horizon (clipping happens when locking for execution) - prediction_horizon_sync_points = predicted_sync_points - # Set policy inputs policy_state.set_policy_rgb_image_input(rgb_image) policy_state.set_policy_state_input(current_joint_angles) # Store prediction horizon actions in policy state - policy_state.set_prediction_horizon_sync_points(prediction_horizon_sync_points) + policy_state.set_prediction_horizon(prediction_horizon) + visualizer.update_ghost_robot_visibility(True) policy_state.set_ghost_robot_playing(True) policy_state.reset_ghost_action_index() @@ -194,13 +226,14 @@ def start_policy_execution( return False # Get prediction horizon - prediction_horizon_sync_points = policy_state.get_prediction_horizon_sync_points() - prediction_horizon_length = len(prediction_horizon_sync_points) + prediction_horizon = policy_state.get_prediction_horizon() + prediction_horizon_length = policy_state.get_prediction_horizon_length() if prediction_horizon_length == 0: print("โš ๏ธ No prediction horizon available. Make sure policy was run first.") return False - first_sync_point = prediction_horizon_sync_points[0] - if first_sync_point.joint_target_positions is None: + + # Check that we have joint data for all joints + if not all(joint_name in prediction_horizon for joint_name in JOINT_NAMES): print("โš ๏ธ First prediction in horizon has no joint targets") return False @@ -209,15 +242,13 @@ def start_policy_execution( if current_joint_angles is None: print("โš ๏ธ Cannot execute policy: No current joint angles available") return False - - current_joint_target_positions_rad = first_sync_point.joint_target_positions.numpy( - order=JOINT_NAMES + # Get first action from horizon (index 0 for each joint) + current_joint_target_positions_rad = np.array( + [prediction_horizon[joint_name][0] for joint_name in JOINT_NAMES] ) - joint_differences = np.abs( current_joint_angles - np.degrees(current_joint_target_positions_rad) ) - if np.any(joint_differences > MAX_SAFETY_THRESHOLD): print("โš ๏ธ Cannot execute policy: Robot too far from first action") print(f" Differences: {[f'{d:.3f}' for d in joint_differences]}") @@ -334,13 +365,10 @@ def policy_execution_thread( data_manager.get_robot_activity_state() == RobotActivityState.POLICY_CONTROLLED ): - locked_horizon_sync_points = ( - policy_state.get_locked_prediction_horizon_sync_points() - ) + locked_horizon = policy_state.get_locked_prediction_horizon() execution_index = policy_state.get_execution_action_index() locked_horizon_length = policy_state.get_locked_prediction_horizon_length() if execution_index < locked_horizon_length: - current_sync_point = locked_horizon_sync_points[execution_index] # Check if previous goal was achieved, if any current_joint_angles = data_manager.get_current_joint_angles() if ( @@ -354,15 +382,16 @@ def policy_execution_thread( time.time() - targeting_pose_start_time < TARGETING_POSE_TIME_THRESHOLD ): - previous_sync_point = locked_horizon_sync_points[ - execution_index - 1 - ] - if previous_sync_point.joint_target_positions is None: + # Get previous action from horizon + if not all( + joint_name in locked_horizon for joint_name in JOINT_NAMES + ): break - previous_joint_target_positions_rad = ( - previous_sync_point.joint_target_positions.numpy( - order=JOINT_NAMES - ) + previous_joint_target_positions_rad = np.array( + [ + locked_horizon[joint_name][execution_index - 1] + for joint_name in JOINT_NAMES + ] ) previous_joint_target_positions_deg = np.degrees( previous_joint_target_positions_rad @@ -375,11 +404,12 @@ def policy_execution_thread( time.sleep(0.001) # Send current action to robot (if available) - if current_sync_point.joint_target_positions is not None: - current_joint_target_positions_rad = ( - current_sync_point.joint_target_positions.numpy( - order=JOINT_NAMES - ) + if all(joint_name in locked_horizon for joint_name in JOINT_NAMES): + current_joint_target_positions_rad = np.array( + [ + locked_horizon[joint_name][execution_index] + for joint_name in JOINT_NAMES + ] ) current_joint_target_positions_deg = np.degrees( current_joint_target_positions_rad @@ -389,16 +419,10 @@ def policy_execution_thread( ) # Send current gripper open value to robot (if available) - if ( - current_sync_point.parallel_gripper_open_amounts - is not None - in current_sync_point.parallel_gripper_open_amounts.open_amounts - ): - current_gripper_open_value = ( - current_sync_point.parallel_gripper_open_amounts.open_amounts[ - GRIPPER_LOGGING_NAME - ] - ) + if GRIPPER_LOGGING_NAME in locked_horizon: + current_gripper_open_value = locked_horizon[GRIPPER_LOGGING_NAME][ + execution_index + ] robot_controller.set_gripper_open_value(current_gripper_open_value) # Update execution index @@ -476,8 +500,8 @@ def update_visualization( visualizer.update_robot_pose(joint_config_rad) # Get policy state for ghost robot - prediction_horizon_sync_points = policy_state.get_prediction_horizon_sync_points() - prediction_horizon_length = len(prediction_horizon_sync_points) + prediction_horizon = policy_state.get_prediction_horizon() + prediction_horizon_length = policy_state.get_prediction_horizon_length() ghost_robot_playing = policy_state.get_ghost_robot_playing() ghost_action_index = policy_state.get_ghost_action_index() @@ -515,10 +539,13 @@ def update_visualization( visualizer.update_ghost_robot_visibility(True) # Update ghost robot with prediction horizon actions (preview mode) if ghost_action_index < prediction_horizon_length: - ghost_sync_point = prediction_horizon_sync_points[ghost_action_index] - if ghost_sync_point.joint_target_positions is not None: - ghost_joint_config = ghost_sync_point.joint_target_positions.numpy( - order=JOINT_NAMES + # Get ghost action from horizon + if all(joint_name in prediction_horizon for joint_name in JOINT_NAMES): + ghost_joint_config = np.array( + [ + prediction_horizon[joint_name][ghost_action_index] + for joint_name in JOINT_NAMES + ] ) visualizer.update_ghost_robot_pose(ghost_joint_config) next_index = (ghost_action_index + 1) % prediction_horizon_length @@ -606,15 +633,26 @@ def update_visualization( ) # Load policy from either train run name or model path + # NOTE: The model_output_order MUST match the exact order used during training + # This order is determined by the output_robot_data_spec in the training config. + # The order here should match the order in your training config's output_robot_data_spec. model_input_order = { - "JOINT_POSITIONS": JOINT_NAMES, - "PARALLEL_GRIPPER_OPEN_AMOUNTS": [GRIPPER_LOGGING_NAME], - "RGB_IMAGES": [CAMERA_LOGGING_NAME], + DataType.JOINT_POSITIONS: JOINT_NAMES, + DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], + DataType.RGB_IMAGES: [CAMERA_LOGGING_NAME], } model_output_order = { - "JOINT_TARGET_POSITIONS": JOINT_NAMES, - "PARALLEL_GRIPPER_OPEN_AMOUNTS": [GRIPPER_LOGGING_NAME], + DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES, + DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], } + + print("\n๐Ÿ“‹ Model input order:") + for data_type, names in model_input_order.items(): + print(f" {data_type.name}: {names}") + print("\n๐Ÿ“‹ Model output order:") + for data_type, names in model_output_order.items(): + print(f" {data_type.name}: {names}") + if args.train_run_name is not None: print(f"\n๐Ÿค– Loading policy from training run: {args.train_run_name}...") policy = nc.policy( diff --git a/examples/common/configs.py b/examples/common/configs.py index 06057a5..836d0dc 100644 --- a/examples/common/configs.py +++ b/examples/common/configs.py @@ -59,5 +59,5 @@ TARGETING_POSE_TIME_THRESHOLD = 1.0 # seconds GRIPPER_LOGGING_NAME = "gripper" -CAMERA_LOGGING_NAME = "camera" +CAMERA_LOGGING_NAME = "rgb" JOINT_NAMES = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"] diff --git a/examples/common/policy_state.py b/examples/common/policy_state.py index e9a36db..565430a 100644 --- a/examples/common/policy_state.py +++ b/examples/common/policy_state.py @@ -2,7 +2,6 @@ import threading from enum import Enum -from typing import Any import numpy as np @@ -18,7 +17,8 @@ class ExecutionMode(Enum): def __init__(self) -> None: """Initialize PolicyState with default values.""" - self._prediction_horizon_sync_points: list[Any] = [] + # Prediction horizon stored as dict[str, list[float]] where keys are joint/gripper names + self._prediction_horizon: dict[str, list[float]] = {} self._prediction_horizon_lock = threading.Lock() self._execution_ratio: float = 1.0 @@ -33,7 +33,7 @@ def __init__(self) -> None: # Policy execution state self._policy_inputs_locked: bool = False - self._locked_prediction_horizon_sync_points: list[Any] = [] + self._locked_prediction_horizon: dict[str, list[float]] = {} self._execution_action_index: int = 0 self._execution_lock = threading.Lock() @@ -46,17 +46,27 @@ def __init__(self) -> None: def get_prediction_horizon_length(self) -> int: """Get prediction horizon length (thread-safe).""" with self._prediction_horizon_lock: - return len(self._prediction_horizon_sync_points) - - def get_prediction_horizon_sync_points(self) -> list[Any]: - """Get prediction horizon sync points (thread-safe).""" + if not self._prediction_horizon: + return 0 + # Get length from first list (all should have same length) + first_key = next(iter(self._prediction_horizon.keys())) + return len(self._prediction_horizon[first_key]) + + def get_prediction_horizon(self) -> dict[str, list[float]]: + """Get prediction horizon (thread-safe).""" with self._prediction_horizon_lock: - return list(self._prediction_horizon_sync_points) + # Return a deep copy to prevent external modifications + return { + key: list(values) for key, values in self._prediction_horizon.items() + } - def set_prediction_horizon_sync_points(self, sync_points: list[Any]) -> None: - """Set prediction horizon sync points (thread-safe).""" + def set_prediction_horizon(self, horizon: dict[str, list[float]]) -> None: + """Set prediction horizon (thread-safe).""" with self._prediction_horizon_lock: - self._prediction_horizon_sync_points = list(sync_points) + # Store a deep copy to prevent external modifications + self._prediction_horizon = { + key: list(values) for key, values in horizon.items() + } def set_execution_ratio(self, ratio: float) -> None: """Set execution ratio used when locking prediction horizon.""" @@ -136,35 +146,52 @@ def reset_ghost_action_index(self) -> None: def start_policy_execution(self) -> None: """Start policy execution by locking inputs and storing horizon (thread-safe).""" with self._prediction_horizon_lock: - source_sync_points = list(self._prediction_horizon_sync_points) - total = len(source_sync_points) + source_horizon = { + key: list(values) for key, values in self._prediction_horizon.items() + } + total = self.get_prediction_horizon_length() if total == 0: - locked_sync_points = [] + locked_horizon = {} else: num_actions = int(total * self._execution_ratio) num_actions = max(1, min(num_actions, total)) - locked_sync_points = source_sync_points[:num_actions] + # Slice each list in the horizon + locked_horizon = { + key: values[:num_actions] for key, values in source_horizon.items() + } with self._execution_lock: self._policy_inputs_locked = True self._execution_action_index = 0 - self._locked_prediction_horizon_sync_points = locked_sync_points + self._locked_prediction_horizon = locked_horizon def end_policy_execution(self) -> None: """Stop policy execution and unlock inputs (thread-safe).""" with self._execution_lock: self._policy_inputs_locked = False - self._locked_prediction_horizon_sync_points = [] + self._locked_prediction_horizon = {} self._execution_action_index = 0 - def get_locked_prediction_horizon_sync_points(self) -> list[Any]: - """Get locked prediction horizon sync points (thread-safe).""" + def get_locked_prediction_horizon(self) -> dict[str, list[float]]: + """Get locked prediction horizon (thread-safe).""" with self._execution_lock: - return list(self._locked_prediction_horizon_sync_points) + # Return a deep copy to prevent external modifications + return { + key: list(values) + for key, values in self._locked_prediction_horizon.items() + } def get_locked_prediction_horizon_length(self) -> int: """Get locked prediction horizon length (thread-safe).""" with self._execution_lock: - return len(self._locked_prediction_horizon_sync_points) + if not self._locked_prediction_horizon: + return 0 + # Get length from first list (all should have same length) + first_key = next(iter(self._locked_prediction_horizon.keys())) + return len(self._locked_prediction_horizon[first_key]) + + def get_locked_prediction_horizon_sync_points(self) -> dict[str, list[float]]: + """Get locked prediction horizon (legacy method name, calls get_locked_prediction_horizon).""" + return self.get_locked_prediction_horizon() def get_execution_action_index(self) -> int: """Get current execution action index (thread-safe).""" From 25063991f0ffda66e0c0ebe7fb6037b9cb2757f8 Mon Sep 17 00:00:00 2001 From: mark Date: Fri, 19 Dec 2025 16:20:36 +0000 Subject: [PATCH 5/5] changed minimal scripts to latest data types --- examples/4_rollout_neuracore_policy.py | 3 + .../5_rollout_neuracore_policy_minimal.py | 155 +++++++++++++----- 2 files changed, 114 insertions(+), 44 deletions(-) diff --git a/examples/4_rollout_neuracore_policy.py b/examples/4_rollout_neuracore_policy.py index 7c66ca6..a0a755a 100644 --- a/examples/4_rollout_neuracore_policy.py +++ b/examples/4_rollout_neuracore_policy.py @@ -657,6 +657,7 @@ def update_visualization( print(f"\n๐Ÿค– Loading policy from training run: {args.train_run_name}...") policy = nc.policy( train_run_name=args.train_run_name, + device="cuda", model_input_order=model_input_order, model_output_order=model_output_order, ) @@ -681,6 +682,8 @@ def update_visualization( CONTROLLER_BETA, CONTROLLER_D_CUTOFF, ) + # Setting the target gripper so policy doesn't crash first time it runs + data_manager.set_target_gripper_open_value(1.0) # Initialize robot controller print("\n๐Ÿค– Initializing Piper robot controller...") diff --git a/examples/5_rollout_neuracore_policy_minimal.py b/examples/5_rollout_neuracore_policy_minimal.py index 114f477..66d415f 100644 --- a/examples/5_rollout_neuracore_policy_minimal.py +++ b/examples/5_rollout_neuracore_policy_minimal.py @@ -17,11 +17,17 @@ import neuracore as nc import numpy as np +from neuracore_types import ( + BatchedJointData, + BatchedParallelGripperOpenAmountData, + DataType, +) # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) from common.configs import ( + CAMERA_LOGGING_NAME, GRIPPER_LOGGING_NAME, JOINT_NAMES, NEUTRAL_JOINT_ANGLES, @@ -38,6 +44,34 @@ from piper_controller import PiperController +def convert_predictions_to_horizon_dict(predictions: dict) -> dict[str, list[float]]: + """Convert predictions dict to horizon dict format.""" + horizon: dict[str, list[float]] = {} + + # Extract joint target positions + if DataType.JOINT_TARGET_POSITIONS in predictions: + joint_data = predictions[DataType.JOINT_TARGET_POSITIONS] + for joint_name in JOINT_NAMES: + if joint_name in joint_data: + batched = joint_data[joint_name] + if isinstance(batched, BatchedJointData): + # Extract values: (B, T, 1) -> list[float], taking B=0 + values = batched.value[0, :, 0].cpu().numpy().tolist() + horizon[joint_name] = values + + # Extract gripper open amounts + if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions: + gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS] + if GRIPPER_LOGGING_NAME in gripper_data: + batched = gripper_data[GRIPPER_LOGGING_NAME] + if isinstance(batched, BatchedParallelGripperOpenAmountData): + # Extract values: (B, T, 1) -> list[float], taking B=0 + values = batched.open_amount[0, :, 0].cpu().numpy().tolist() + horizon[GRIPPER_LOGGING_NAME] = values + + return horizon + + def run_policy( data_manager: DataManager, policy: nc.policy, @@ -50,18 +84,8 @@ def run_policy( print("โš ๏ธ No joint angles available") return False - gripper_open_value = data_manager.get_current_gripper_open_value() - if gripper_open_value is None: - print("โš ๏ธ No gripper value available") - return False - - rgb_image = data_manager.get_rgb_image() - if rgb_image is None: - print("โš ๏ธ No RGB image available") - return False - - # Get current gripper open value - gripper_open_value = data_manager.get_current_gripper_open_value() + # Get target gripper open value because this is how the policy was trained + gripper_open_value = data_manager.get_target_gripper_open_value() if gripper_open_value is None: print("โš ๏ธ No gripper open value available") return False @@ -77,23 +101,26 @@ def run_policy( joint_positions_dict = { JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad) } - gripper_open_amounts_dict = {GRIPPER_LOGGING_NAME: gripper_open_value} - # Log joint positions parallel gripper open amounts and RGB image to NeuraCore + # Log joint positions, parallel gripper open amounts, and RGB image to NeuraCore nc.log_joint_positions(joint_positions_dict) - nc.log_gripper_data(open_amounts=gripper_open_amounts_dict) - nc.log_rgb("camera", rgb_image) + nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_open_value) + nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image) # Get policy prediction try: start_time = time.time() - predicted_sync_points = policy.predict(timeout=5) + predictions = policy.predict(timeout=5) + prediction_horizon = convert_predictions_to_horizon_dict(predictions) elapsed = time.time() - start_time - print(f"โœ“ Got {len(predicted_sync_points)} actions in {elapsed:.3f}s") - # Save full horizon and set execution ratio (clipping occurs on lock) + # Get horizon length from the first joint (all should have same length) + horizon_length = policy_state.get_prediction_horizon_length() + print(f"โœ“ Got {horizon_length} actions in {elapsed:.3f}s") + + # Set execution ratio and save prediction horizon policy_state.set_execution_ratio(PREDICTION_HORIZON_EXECUTION_RATIO) - policy_state.set_prediction_horizon_sync_points(predicted_sync_points) + policy_state.set_prediction_horizon(prediction_horizon) return True except Exception as e: @@ -111,30 +138,25 @@ def execute_horizon( policy_state.start_policy_execution() data_manager.set_robot_activity_state(RobotActivityState.POLICY_CONTROLLED) - locked_horizon_sync_points = ( - policy_state.get_locked_prediction_horizon_sync_points() - ) + locked_horizon = policy_state.get_locked_prediction_horizon() horizon_length = policy_state.get_locked_prediction_horizon_length() dt = 1.0 / POLICY_EXECUTION_RATE for i in range(horizon_length): - sync_point = locked_horizon_sync_points[i] - if sync_point.joint_target_positions is not None: - joint_targets_rad = sync_point.joint_target_positions.numpy( - order=JOINT_NAMES + # Send current action to robot (if available) + if all(joint_name in locked_horizon for joint_name in JOINT_NAMES): + current_joint_target_positions_rad = np.array( + [locked_horizon[joint_name][i] for joint_name in JOINT_NAMES] ) - joint_targets_deg = np.degrees(joint_targets_rad) - robot_controller.set_target_joint_angles(joint_targets_deg) + current_joint_target_positions_deg = np.degrees( + current_joint_target_positions_rad + ) + robot_controller.set_target_joint_angles(current_joint_target_positions_deg) - if ( - sync_point.parallel_gripper_open_amounts is not None - and GRIPPER_LOGGING_NAME - in sync_point.parallel_gripper_open_amounts.open_amounts - ): - gripper_value = sync_point.parallel_gripper_open_amounts.open_amounts[ - GRIPPER_LOGGING_NAME - ] - robot_controller.set_gripper_open_value(gripper_value) + # Send current gripper open value to robot (if available) + if GRIPPER_LOGGING_NAME in locked_horizon: + current_gripper_open_value = locked_horizon[GRIPPER_LOGGING_NAME][i] + robot_controller.set_gripper_open_value(current_gripper_open_value) # Sleep to maintain rate time.sleep(dt) @@ -147,13 +169,25 @@ def execute_horizon( if __name__ == "__main__": parser = argparse.ArgumentParser(description="Minimal Piper Policy Test") parser.add_argument( - "--model-file", + "--train-run-name", + type=str, + default=None, + help="Name of the training run to load policy from (for cloud training). Mutually exclusive with --model-path.", + ) + parser.add_argument( + "--model-path", type=str, - required=True, - help="Path to model file (.nc.zip)", + default=None, + help="Path to local model file to load policy from. Mutually exclusive with --train-run-name.", ) args = parser.parse_args() + # Validate that exactly one of train-run-name or model-path is provided + if (args.train_run_name is None) == (args.model_path is None): + parser.error( + "Exactly one of --train-run-name or --model-path must be provided (not both, not neither)" + ) + print("=" * 60) print("PIPER POLICY ROLLOUT") print("=" * 60) @@ -168,9 +202,42 @@ def execute_horizon( ) # Load policy - print(f"\n๐Ÿค– Loading policy from: {args.model_file}...") - policy = nc.policy(model_file=args.model_file, device="cuda") - print("โœ“ Policy loaded") + # NOTE: The model_output_order MUST match the exact order used during training + # This order is determined by the output_robot_data_spec in the training config. + # The order here should match the order in your training config's output_robot_data_spec. + model_input_order = { + DataType.JOINT_POSITIONS: JOINT_NAMES, + DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], + DataType.RGB_IMAGES: [CAMERA_LOGGING_NAME], + } + model_output_order = { + DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES, + DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], + } + + print("\n๐Ÿ“‹ Model input order:") + for data_type, names in model_input_order.items(): + print(f" {data_type.name}: {names}") + print("\n๐Ÿ“‹ Model output order:") + for data_type, names in model_output_order.items(): + print(f" {data_type.name}: {names}") + + if args.train_run_name is not None: + print(f"\n๐Ÿค– Loading policy from training run: {args.train_run_name}...") + policy = nc.policy( + train_run_name=args.train_run_name, + model_input_order=model_input_order, + model_output_order=model_output_order, + ) + else: + print(f"\n๐Ÿค– Loading policy from model file: {args.model_path}...") + policy = nc.policy( + model_file=args.model_path, + device="cuda", + model_input_order=model_input_order, + model_output_order=model_output_order, + ) + print(" โœ“ Policy loaded successfully") # Initialize state data_manager = DataManager()