From aad68c41e0e9d97a5118b69a61b195992d9134ac Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Sun, 16 Feb 2025 17:10:31 -0500 Subject: [PATCH 1/9] Expose optional s_noise parameter --- pixi.lock | 2 +- spring2025-course/localization-tutorial.py | 26 +++++++++++++--------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/pixi.lock b/pixi.lock index 610d259..3525588 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2038,7 +2038,7 @@ packages: - pypi: . name: localization-tutorial version: 0.1.0 - sha256: 33e938ca55162ae9ce3a35f7ce2bb1cf5f91aabf0dad8d020256eb42b1fdcb28 + sha256: 079b995864eef3e1bb9b86ec0c6b7ff059bda536c7d85fed88207b7c1ca7cac5 requires_dist: - genstudio>=2025.2.2,<2026 - genjax==0.9.1 diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index 8f01c8a..bd83a1d 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -10,7 +10,7 @@ # format_version: '1.3' # jupytext_version: 1.11.2 # kernelspec: -# display_name: .venv +# display_name: default # language: python # name: python3 # --- @@ -484,11 +484,11 @@ def update_ideal_sensors(widget, _, label="pose"): # %% @genjax.gen -def sensor_model_one(pose, angle): +def sensor_model_one(pose, angle, s_noise): return ( genjax.normal( sensor_distance(pose.rotate(angle), world["walls"], sensor_settings["box_size"]), - sensor_settings["s_noise"], + s_noise, ) @ "distance" ) @@ -503,7 +503,7 @@ def sensor_model_one(pose, angle): sensor_settings["s_noise"] = 0.10 key, sub_key = jax.random.split(key) -cm, score, retval = sensor_model_one.propose(sub_key, (some_pose, sensor_angles[1])) +cm, score, retval = sensor_model_one.propose(sub_key, (some_pose, sensor_angles[1], sensor_settings["s_noise"])) retval # %% [markdown] @@ -516,11 +516,11 @@ def sensor_model_one(pose, angle): # We are interested in the related model whose *single* draw consists of a *vector* of the sensor distances computed across the vector of sensor angles. This is exactly what we get using the GenJAX `vmap` combinator on GFs. # %% -sensor_model = sensor_model_one.vmap(in_axes=(None, 0)) +sensor_model = sensor_model_one.vmap(in_axes=(None, 0, None)) # %% key, sub_key = jax.random.split(key) -cm, score, retval = sensor_model.propose(sub_key, (some_pose, sensor_angles)) +cm, score, retval = sensor_model.propose(sub_key, (some_pose, sensor_angles, sensor_settings["s_noise"])) retval # %% [markdown] @@ -585,7 +585,7 @@ def update_noisy_sensors(widget, _, label="pose"): # The construction of a log density function is automated by the `assess` semantics for generative functions. This method is passed a choice map and a tuple of arguments, and it returns the log score plus the return value. # %% -score, retval = sensor_model.assess(cm, (some_pose, sensor_angles)) +score, retval = sensor_model.assess(cm, (some_pose, sensor_angles, sensor_settings["s_noise"])) jnp.exp(score) @@ -610,8 +610,10 @@ def update_noisy_sensors(widget, _, label="pose"): guess_pose = Pose(jnp.array([2.0, 16]), jnp.array(0.0)) target_pose = Pose(jnp.array([15.0, 4.0]), jnp.array(-1.6)) -def likelihood_function(cm, pose): - return sensor_model.assess(cm, (pose, sensor_angles))[0] +def likelihood_function(cm, pose, s_noise=None): + if s_noise is None: + s_noise = sensor_settings["s_noise"] + return sensor_model.assess(cm, (pose, sensor_angles, s_noise))[0] def on_guess_pose_chage(widget, _): update_ideal_sensors(widget, None, label="guess") @@ -1363,9 +1365,11 @@ def plot_path_with_confidence(path, step): # %% @genjax.gen -def full_model_kernel(motion_settings, state, control): +def full_model_kernel(motion_settings, state, control, s_noise=None): + if s_noise is None: + s_noise = sensor_settings["s_noise"] pose = step_model(motion_settings, state, control) @ "pose" - sensor_model(pose, sensor_angles) @ "sensor" + sensor_model(pose, sensor_angles, s_noise) @ "sensor" return pose @genjax.gen From 4b13efb5f77f66f2d062709a48aafb0e09ba3861 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Sun, 16 Feb 2025 17:10:52 -0500 Subject: [PATCH 2/9] First try at sensor noise slider --- spring2025-course/localization-tutorial.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index bd83a1d..4612d1f 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -536,14 +536,16 @@ def sensor_model_one(pose, angle, s_noise): # With a little wrapping, one gets a function of the same type as `ideal_sensor`, ignoring the PRNG key. # %% -def noisy_sensor(key, pose): - return sensor_model.propose(key, (pose, sensor_angles))[2] +def noisy_sensor(key, pose, s_noise=None): + if s_noise is None: + s_noise = sensor_settings["s_noise"] + return sensor_model.propose(key, (pose, sensor_angles, s_noise))[2] # %% def update_noisy_sensors(widget, _, label="pose"): k1, k2 = jax.random.split(jax.random.wrap_key_data(widget.state.k)) - readings = noisy_sensor(k1, pose_at(widget.state, label)) + readings = noisy_sensor(k1, pose_at(widget.state, label), s_noise=widget.state.noise_slider) widget.state.update({ "k": jax.random.key_data(k2), (label + "_readings"): readings @@ -557,12 +559,20 @@ def update_noisy_sensors(widget, _, label="pose"): + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles) + pose_widget("pose", some_pose, color="blue") ) + | Plot.Slider( + key="noise_slider", + label="Sensor noise:", + showValue=True, + range=[0.01, 2.5], + step=0.01, + init=0.1, + ) | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) | Plot.initialState({ "k": jax.random.key_data(k1), "pose_readings": noisy_sensor(k2, some_pose) - }, sync={"k"}) - | Plot.onChange({"pose": update_noisy_sensors}) + }, sync={"k", "noise_slider"}) + | Plot.onChange({"pose": update_noisy_sensors, "noise_slider": update_noisy_sensors}) ) # %% [markdown] From 7c16be41a3332ab66e1d6b15acb26ab19b39f796 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 17 Feb 2025 10:48:55 -0500 Subject: [PATCH 3/9] Avoid Python optional arguments --- spring2025-course/localization-tutorial.py | 108 +++++++++++---------- 1 file changed, 56 insertions(+), 52 deletions(-) diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index 4612d1f..b42134f 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -483,8 +483,14 @@ def update_ideal_sensors(widget, _, label="pose"): # Its declarative model in `Gen` starts with the case of just one sensor reading: # %% +sensor_settings["s_noise"] = 0.10 + @genjax.gen def sensor_model_one(pose, angle, s_noise): + # Handle optional/default values in this admittedly hacky way, + # so as not to confuse JAX's JIT. + if s_noise is None: + s_noise = sensor_settings["s_noise"] return ( genjax.normal( sensor_distance(pose.rotate(angle), world["walls"], sensor_settings["box_size"]), @@ -500,10 +506,8 @@ def sensor_model_one(pose, angle, s_noise): # We draw samples from `sensor_model_one` with `propose` semantics. Since this operation is stochastic, the method is called with a PRNG key in addition to a tuple of model arguments. The code is then run, performing the required draws from the sampling operations. The random draws get organized according to their addresses, forming a *choice map* data structure. This choice map, a score (to be discussed below), and the return value are all returned by `propose`. # %% -sensor_settings["s_noise"] = 0.10 - key, sub_key = jax.random.split(key) -cm, score, retval = sensor_model_one.propose(sub_key, (some_pose, sensor_angles[1], sensor_settings["s_noise"])) +cm, score, retval = sensor_model_one.propose(sub_key, (some_pose, sensor_angles[0], None)) retval # %% [markdown] @@ -520,7 +524,7 @@ def sensor_model_one(pose, angle, s_noise): # %% key, sub_key = jax.random.split(key) -cm, score, retval = sensor_model.propose(sub_key, (some_pose, sensor_angles, sensor_settings["s_noise"])) +cm, score, retval = sensor_model.propose(sub_key, (some_pose, sensor_angles, None)) retval # %% [markdown] @@ -536,16 +540,14 @@ def sensor_model_one(pose, angle, s_noise): # With a little wrapping, one gets a function of the same type as `ideal_sensor`, ignoring the PRNG key. # %% -def noisy_sensor(key, pose, s_noise=None): - if s_noise is None: - s_noise = sensor_settings["s_noise"] +def noisy_sensor(key, pose, s_noise): return sensor_model.propose(key, (pose, sensor_angles, s_noise))[2] # %% def update_noisy_sensors(widget, _, label="pose"): k1, k2 = jax.random.split(jax.random.wrap_key_data(widget.state.k)) - readings = noisy_sensor(k1, pose_at(widget.state, label), s_noise=widget.state.noise_slider) + readings = noisy_sensor(k1, pose_at(widget.state, label), widget.state.noise_slider) widget.state.update({ "k": jax.random.key_data(k2), (label + "_readings"): readings @@ -570,7 +572,7 @@ def update_noisy_sensors(widget, _, label="pose"): | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) | Plot.initialState({ "k": jax.random.key_data(k1), - "pose_readings": noisy_sensor(k2, some_pose) + "pose_readings": noisy_sensor(k2, some_pose, None) }, sync={"k", "noise_slider"}) | Plot.onChange({"pose": update_noisy_sensors, "noise_slider": update_noisy_sensors}) ) @@ -595,7 +597,7 @@ def update_noisy_sensors(widget, _, label="pose"): # The construction of a log density function is automated by the `assess` semantics for generative functions. This method is passed a choice map and a tuple of arguments, and it returns the log score plus the return value. # %% -score, retval = sensor_model.assess(cm, (some_pose, sensor_angles, sensor_settings["s_noise"])) +score, retval = sensor_model.assess(cm, (some_pose, sensor_angles, None)) jnp.exp(score) @@ -620,9 +622,7 @@ def update_noisy_sensors(widget, _, label="pose"): guess_pose = Pose(jnp.array([2.0, 16]), jnp.array(0.0)) target_pose = Pose(jnp.array([15.0, 4.0]), jnp.array(-1.6)) -def likelihood_function(cm, pose, s_noise=None): - if s_noise is None: - s_noise = sensor_settings["s_noise"] +def likelihood_function(cm, pose, s_noise): return sensor_model.assess(cm, (pose, sensor_angles, s_noise))[0] def on_guess_pose_chage(widget, _): @@ -630,7 +630,8 @@ def on_guess_pose_chage(widget, _): widget.state.update({"likelihood": likelihood_function( C["distance"].set(widget.state.target_readings), - pose_at(widget.state, "guess") + pose_at(widget.state, "guess"), + None ) }) @@ -639,7 +640,8 @@ def on_target_pose_chage(widget, _): widget.state.update({"likelihood": likelihood_function( C["distance"].set(widget.state.target_readings), - pose_at(widget.state, "guess") + pose_at(widget.state, "guess"), + None ) }) @@ -716,8 +718,8 @@ def on_target_pose_chage(widget, _): { "k": jax.random.key_data(k1), "guess_readings": ideal_sensor(sensor_angles, guess_pose), - "target_readings": (initial_target_readings := noisy_sensor(k3, target_pose)), - "likelihood": likelihood_function(C["distance"].set(initial_target_readings), guess_pose), + "target_readings": (initial_target_readings := noisy_sensor(k3, target_pose, None)), + "likelihood": likelihood_function(C["distance"].set(initial_target_readings), guess_pose, None), "show_target_pose": False, }, sync={"k", "target_readings"}) | Plot.onChange({ @@ -1375,17 +1377,18 @@ def plot_path_with_confidence(path, step): # %% @genjax.gen -def full_model_kernel(motion_settings, state, control, s_noise=None): - if s_noise is None: - s_noise = sensor_settings["s_noise"] +def full_model_kernel(motion_settings, s_noise, state, control): pose = step_model(motion_settings, state, control) @ "pose" sensor_model(pose, sensor_angles, s_noise) @ "sensor" return pose +# Recall that supplying `None` for `s_noise` ultimately leads to the value `sensor_settings["s_niose"]` being used. @genjax.gen -def full_model(motion_settings): +def full_model(motion_settings, s_noise): return ( - full_model_kernel.partial_apply(motion_settings) + full_model_kernel + .partial_apply(motion_settings) + .partial_apply(s_noise) .map(diag) .scan()(robot_inputs["start"], robot_inputs["controls"]) @ "steps" @@ -1404,7 +1407,7 @@ def full_model(motion_settings): # %% key, sub_key = jax.random.split(key) -cm, score, retval = full_model.propose(sub_key, (default_motion_settings,)) +cm, score, retval = full_model.propose(sub_key, (default_motion_settings, None)) cm @@ -1566,7 +1569,7 @@ def animate_full_trace(trace, frame_key=None): ) key, sub_key = jax.random.split(key) -tr = full_model.simulate(sub_key, (default_motion_settings,)) +tr = full_model.simulate(sub_key, (default_motion_settings, None)) animate_full_trace(tr) @@ -1585,8 +1588,8 @@ def animate_full_trace(trace, frame_key=None): } key, k_low, k_high = jax.random.split(key, 3) -trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation,)) -trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation,)) +trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation, None)) +trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation, None)) animate_full_trace(trace_low_deviation) # %% @@ -1656,13 +1659,13 @@ def plt(readings): key, sub_key = jax.random.split(key) sample, log_weight = model_importance( - sub_key, constraints_low_deviation, (motion_settings_low_deviation,) + sub_key, constraints_low_deviation, (motion_settings_low_deviation, None) ) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% key, sub_key = jax.random.split(key) sample, log_weight = model_importance( - sub_key, constraints_high_deviation, (motion_settings_high_deviation,) + sub_key, constraints_high_deviation, (motion_settings_high_deviation, None) ) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% [markdown] @@ -1726,13 +1729,13 @@ def constraint_from_path(path): trace_path_integrated_observations_low_deviation, w_low = model_importance( sub_key, constraints_path_integrated_observations_low_deviation, - (motion_settings_low_deviation,), + (motion_settings_low_deviation, None), ) key, sub_key = jax.random.split(key) trace_path_integrated_observations_high_deviation, w_high = model_importance( sub_key, constraints_path_integrated_observations_high_deviation, - (motion_settings_high_deviation,), + (motion_settings_high_deviation, None), ) Plot.Row(*[ @@ -1770,7 +1773,7 @@ def constraint_from_path(path): )( jax.random.split(sub_key, N_samples), constraints_low_deviation, - (motion_settings_low_deviation,), + (motion_settings_low_deviation, None), ) traces_generated_high_deviation, high_weights = jax.vmap( @@ -1778,7 +1781,7 @@ def constraint_from_path(path): )( jax.random.split(sub_key, N_samples), constraints_high_deviation, - (motion_settings_high_deviation,), + (motion_settings_high_deviation, None), ) low_deviation_paths = jax.vmap(get_path)(traces_generated_low_deviation) @@ -1880,14 +1883,14 @@ def constraint_from_path(path): N_samples = 20 def importance_sample( - key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int + key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, s_noise, N: int, K: int ): """Produce K importance samples of depth N from the model. That is, K times, we generate N importance samples conditioned by the constraints, and categorically select one of them.""" key1, key2 = jax.random.split(key) samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))( - jax.random.split(key1, N * K), constraints, (motion_settings,) + jax.random.split(key1, N * K), constraints, (motion_settings, s_noise) ) winners = jax.vmap(genjax.categorical.sampler)( jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) @@ -1899,15 +1902,15 @@ def importance_sample( return selected -jit_resample = jax.jit(importance_sample, static_argnums=(3, 4)) +jit_resample = jax.jit(importance_sample, static_argnums=(4, 5)) key, sub_key = jax.random.split(key) low_posterior = jit_resample( - sub_key, constraints_low_deviation, motion_settings_low_deviation, N_presamples, N_samples + sub_key, constraints_low_deviation, motion_settings_low_deviation, None, N_presamples, N_samples ) key, sub_key = jax.random.split(key) high_posterior = jit_resample( - sub_key, constraints_high_deviation, motion_settings_high_deviation, N_presamples, N_samples + sub_key, constraints_high_deviation, motion_settings_high_deviation, None, N_presamples, N_samples ) @@ -2122,7 +2125,7 @@ def step(state, update): return SISwithRejuvenation.Result(N, end, samples, indices) # %% -def localization_sis(motion_settings, observations): +def localization_sis(motion_settings, s_noise, observations): return SISwithRejuvenation( robot_inputs["start"], robot_inputs["controls"], @@ -2130,7 +2133,7 @@ def localization_sis(motion_settings, observations): lambda key, pose, control, observation: full_model_kernel.importance( key, C["sensor", "distance"].set(observation), - (motion_settings, pose, control), + (motion_settings, s_noise, pose, control), ), ) @@ -2140,8 +2143,8 @@ def localization_sis(motion_settings, observations): # Rerun it to try out the SMC examples on a fresh instance of the problem. key, k_low, k_high = jax.random.split(key, 3) -trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation,)) -trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation,)) +trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation, None)) +trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation, None)) path_low_deviation = get_path(trace_low_deviation) path_high_deviation = get_path(trace_high_deviation) observations_low_deviation = get_sensors(trace_low_deviation) @@ -2158,7 +2161,7 @@ def pose_list_to_plural_pose(pl: list[Pose]) -> Pose: key, sub_key = jax.random.split(key) smc_result = localization_sis( - motion_settings_high_deviation, observations_high_deviation + motion_settings_high_deviation, None, observations_high_deviation ).run(sub_key, N_particles) def plot_sis_result(ground_truth, smc_result): @@ -2177,7 +2180,7 @@ def plot_sis_result(ground_truth, smc_result): key, sub_key = jax.random.split(key) low_smc_result = localization_sis( - motion_settings_low_deviation, observations_low_deviation + motion_settings_low_deviation, None, observations_low_deviation ).run(sub_key, N_particles) plot_sis_result(path_low_deviation, low_smc_result) @@ -2208,7 +2211,7 @@ def run_SMCP3_step(fwd_proposal, bwd_proposal, key, sample, proposal_args): # the unnormalized posterior density over steps. @genjax.gen def grid_fwd_proposal(sample, args): - base_grid, observation, model_args = args + base_grid, observation, full_model_args = args observation_cm = C["sensor", "distance"].set(observation) log_weights = jax.vmap( @@ -2217,7 +2220,7 @@ def grid_fwd_proposal(sample, args): observation_cm | C["pose", "p"].set(p + sample.get_retval().p) | C["pose", "hd"].set(hd + sample.get_retval().hd), - model_args + full_model_args )[0] )(*base_grid) fwd_index = genjax.categorical(log_weights) @ "fwd_index" @@ -2233,14 +2236,15 @@ def grid_fwd_proposal(sample, args): # Backwards proposal simply guesses according to the prior over steps, nothing fancier. @genjax.gen def grid_bwd_proposal(new_sample, args): - base_grid, _, model_args = args + base_grid, _, full_model_args = args + step_model_args = (full_model_args[0], full_model_args[2], full_model_args[3]) log_weights = jax.vmap( lambda p, hd: step_model.assess( C["p"].set(p + new_sample.get_retval().p) | C["hd"].set(hd + new_sample.get_retval().hd), - model_args + step_model_args )[0] )(*base_grid) @@ -2250,7 +2254,7 @@ def grid_bwd_proposal(new_sample, args): # %% -def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observations): +def localization_sis_plus_grid_rejuv(motion_settings, s_noise, M_grid, N_grid, observations): base_grid = make_poses_grid_array( jnp.array([M_grid / 2.0, M_grid / 2.0]).T, N_grid @@ -2262,14 +2266,14 @@ def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observatio importance=lambda key, pose, control, observation: full_model_kernel.importance( key, C["sensor", "distance"].set(observation), - (motion_settings, pose, control), + (motion_settings, s_noise, pose, control), ), rejuvenate=lambda key, sample, pose, control, observation: run_SMCP3_step( grid_fwd_proposal, grid_bwd_proposal, key, sample, - (base_grid, observation, (motion_settings, pose, control)) + (base_grid, observation, (motion_settings, s_noise, pose, control)) ), ) @@ -2284,10 +2288,10 @@ def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observatio key, sub_key = jax.random.split(key) smc_result = localization_sis_plus_grid_rejuv( - motion_settings_high_deviation, M_grid, N_grid, observations_high_deviation + motion_settings_high_deviation, None, M_grid, N_grid, observations_high_deviation ).run(sub_key, N_particles) imp_result = localization_sis( - motion_settings_high_deviation, observations_high_deviation + motion_settings_high_deviation, None, observations_high_deviation ).run(sub_key, N_particles) plot_sis_result(path_high_deviation, smc_result) | plot_sis_result(path_high_deviation, imp_result) From 2d20a5418003399e74ca4fbd154519c2d1f37006 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 17 Feb 2025 11:29:41 -0500 Subject: [PATCH 4/9] Make graph sensor labels optional --- spring2025-course/localization-tutorial.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index b42134f..fcb4ebf 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -410,7 +410,7 @@ def ideal_sensor(sensor_angles, pose): # %% # Plot sensor data. -def plot_sensors(pose, readings, sensor_angles): +def plot_sensors(pose, readings, sensor_angles, show_legend=False): return Plot.Import("""export const projections = (pose, readings, angles) => Array.from({length: readings.length}, (_, i) => { const angle = angles[i] + pose.hd const reading = readings[i] @@ -419,14 +419,14 @@ def plot_sensors(pose, readings, sensor_angles): refer=["projections"]) | ( Plot.line( js("projections(%1, %2, %3).flatMap((projection, i) => [%1.p, projection, i])", pose, readings, sensor_angles), - stroke=Plot.constantly("sensor rays"), + opacity=0.1, ) + Plot.dot( js("projections(%1, %2, %3)", pose, readings, sensor_angles), r=2.75, - fill=Plot.constantly("sensor readings"), + fill="#f80" ) + - Plot.colorMap({"sensor rays": "rgba(0,0,0,0.1)", "sensor readings": "#f80"}) + Plot.cond(show_legend, Plot.colorMap({"sensor rays": "rgb(0,0,0,0.1)", "sensor readings": "#f80"}) | Plot.colorLegend()) ) @@ -443,7 +443,7 @@ def update_ideal_sensors(widget, _, label="pose"): ( ( world_plot - + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles) + + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles, show_legend=True) + pose_widget("pose", some_pose, color="blue") ) | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) @@ -462,7 +462,7 @@ def update_ideal_sensors(widget, _, label="pose"): Plot.Frames([ ( world_plot - + plot_sensors(pose, some_readings[i], sensor_angles) + + plot_sensors(pose, some_readings[i], sensor_angles, show_legend=True) + pose_plots(pose) ) for i, pose in enumerate(some_poses) @@ -1620,7 +1620,7 @@ def frame(pose, readings1, readings2): def plt(readings): return Plot.new( plot_base or Plot.domain([0, 20]), - plot_sensors(pose, readings, sensor_angles), + plot_sensors(pose, readings, sensor_angles, show_legend=True), {"width": 400, "height": 400}, ) From 4b732a79f4379253d9b0be5fd0125440194c9ad4 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 17 Feb 2025 11:30:05 -0500 Subject: [PATCH 5/9] Abstract out slider --- spring2025-course/localization-tutorial.py | 24 +++++++++++++--------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index fcb4ebf..af0e9d0 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -543,11 +543,22 @@ def sensor_model_one(pose, angle, s_noise): def noisy_sensor(key, pose, s_noise): return sensor_model.propose(key, (pose, sensor_angles, s_noise))[2] +def noise_slider(s_noise): + if s_noise is None: + s_noise = sensor_settings["s_noise"] + return Plot.Slider( + key="noise_slider", + label="Sensor noise:", + showValue=True, + range=[0.01, 2.5], + step=0.01, + ) | Plot.initialState({"noise_slider": s_noise}, sync={"noise_slider"}) + # %% def update_noisy_sensors(widget, _, label="pose"): k1, k2 = jax.random.split(jax.random.wrap_key_data(widget.state.k)) - readings = noisy_sensor(k1, pose_at(widget.state, label), widget.state.noise_slider) + readings = noisy_sensor(k1, pose_at(widget.state, label), float(widget.state.noise_slider)) widget.state.update({ "k": jax.random.key_data(k2), (label + "_readings"): readings @@ -561,19 +572,12 @@ def update_noisy_sensors(widget, _, label="pose"): + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles) + pose_widget("pose", some_pose, color="blue") ) - | Plot.Slider( - key="noise_slider", - label="Sensor noise:", - showValue=True, - range=[0.01, 2.5], - step=0.01, - init=0.1, - ) + | noise_slider(None) | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) | Plot.initialState({ "k": jax.random.key_data(k1), "pose_readings": noisy_sensor(k2, some_pose, None) - }, sync={"k", "noise_slider"}) + }, sync={"k"}) | Plot.onChange({"pose": update_noisy_sensors, "noise_slider": update_noisy_sensors}) ) From 373bdc36745ae2bde97ead58653e72200c39963d Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 17 Feb 2025 12:19:44 -0500 Subject: [PATCH 6/9] Both data and model noise sliders for single pose --- spring2025-course/localization-tutorial.py | 46 +++++++++++++--------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index af0e9d0..bfad74f 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -543,28 +543,28 @@ def sensor_model_one(pose, angle, s_noise): def noisy_sensor(key, pose, s_noise): return sensor_model.propose(key, (pose, sensor_angles, s_noise))[2] -def noise_slider(s_noise): - if s_noise is None: - s_noise = sensor_settings["s_noise"] +def noise_slider(key="noise_slider", label="Sensor noise =", init=None): + if init is None: + init = sensor_settings["s_noise"] return Plot.Slider( - key="noise_slider", - label="Sensor noise:", + key=key, + label=label, showValue=True, range=[0.01, 2.5], step=0.01, - ) | Plot.initialState({"noise_slider": s_noise}, sync={"noise_slider"}) - + ) | Plot.initialState({key: init}, sync={key}) -# %% -def update_noisy_sensors(widget, _, label="pose"): +def update_noisy_sensors(widget, _, pose_key="pose", slider_key="noise_slider"): k1, k2 = jax.random.split(jax.random.wrap_key_data(widget.state.k)) - readings = noisy_sensor(k1, pose_at(widget.state, label), float(widget.state.noise_slider)) + readings = noisy_sensor(k1, pose_at(widget.state, pose_key), float(getattr(widget.state, slider_key))) widget.state.update({ "k": jax.random.key_data(k2), - (label + "_readings"): readings + (pose_key + "_readings"): readings }) return readings + +# %% key, k1, k2 = jax.random.split(key, 3) ( ( @@ -572,7 +572,7 @@ def update_noisy_sensors(widget, _, label="pose"): + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles) + pose_widget("pose", some_pose, color="blue") ) - | noise_slider(None) + | noise_slider() | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) | Plot.initialState({ "k": jax.random.key_data(k1), @@ -640,7 +640,7 @@ def on_guess_pose_chage(widget, _): }) def on_target_pose_chage(widget, _): - update_noisy_sensors(widget, None, label="target") + update_noisy_sensors(widget, None, pose_key="target") widget.state.update({"likelihood": likelihood_function( C["distance"].set(widget.state.target_readings), @@ -697,6 +697,7 @@ def on_target_pose_chage(widget, _): ), cols=2 ) + | noise_slider() | ( Plot.html([ "div", @@ -728,7 +729,8 @@ def on_target_pose_chage(widget, _): }, sync={"k", "target_readings"}) | Plot.onChange({ "guess": on_guess_pose_chage, - "target": on_target_pose_chage + "target": on_target_pose_chage, + "noise_slider": on_target_pose_chage, }) ) @@ -746,7 +748,7 @@ def handler(widget, _): "k": jax.random.key_data(k1), "target": widget.state.camera, }) - readings = update_noisy_sensors(widget, None, label="target") + readings = update_noisy_sensors(widget, None, pose_key="target", slider_key="world_noise") button_handler(widget, k2, readings) widget.state.update({ "target_exists": True, @@ -770,6 +772,8 @@ def camera_widget( ) + pose_widget("camera", camera_pose, color="blue") ) + | noise_slider(key="world_noise", label="World/data noise = ") + | noise_slider(key="model_noise", label="Model/inference noise = ") | ( Plot.html([ "div", @@ -834,8 +838,9 @@ def make_poses_grid(bounds, ns): grid_poses = make_poses_grid(world["bounding_box"], N_grid) def grid_search_handler(widget, k, readings): + model_noise = float(getattr(widget.state, "model_noise")) jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose) + lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) ) likelihoods = jax.vmap(jitted_likelihood)(grid_poses) best = jnp.argsort(likelihoods, descending=True)[0:N_keep] @@ -891,8 +896,9 @@ def grid_search_handler(widget, k, readings): grid_poses = make_poses_grid(world["bounding_box"], N_grid) def grid_approximation_handler(widget, k, readings): + model_noise = float(getattr(widget.state, "model_noise")) jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose) + lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) ) likelihoods = jax.vmap(jitted_likelihood)(grid_poses) @@ -932,8 +938,9 @@ def grid_sample_one(k): camera_pose = Pose(jnp.array([15.13, 14.16]), jnp.array(1.5)) def importance_resampling_handler(widget, k, readings): + model_noise = float(getattr(widget.state, "model_noise")) jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose) + lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) ) def importance_resample_one(k): @@ -970,8 +977,9 @@ def importance_resample_one(k): camera_pose = Pose(jnp.array([15.13, 14.16]), jnp.array(1.5)) def MCMC_handler(widget, k, readings): + model_noise = float(getattr(widget.state, "model_noise")) jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose) + lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) ) def do_MH_step(pose_likelihood, k): From 7e49d2ead4713685eea5b51d2bd41725cf973e7b Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 17 Feb 2025 12:25:06 -0500 Subject: [PATCH 7/9] Final tweaks --- spring2025-course/localization-tutorial.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index bfad74f..4622d3a 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -550,7 +550,7 @@ def noise_slider(key="noise_slider", label="Sensor noise =", init=None): key=key, label=label, showValue=True, - range=[0.01, 2.5], + range=[0.01, 5.0], step=0.01, ) | Plot.initialState({key: init}, sync={key}) @@ -1590,6 +1590,9 @@ def animate_full_trace(trace, frame_key=None): # # Let us generate some fixed synthetic motion data that, for pedagogical purposes, we will work with as if it were the actual path of the robot. We will generate two versions, one each with low or high motion deviation. # %% +# HERE is a great place to update `sensor_settings["s_noise"]` if you wish, +# about to construct some "data" using it. + motion_settings_low_deviation = { "p_noise": 0.05, "hd_noise": (1 / 10.0) * 2 * jnp.pi / 360, @@ -2154,6 +2157,8 @@ def localization_sis(motion_settings, s_noise, observations): # This cell is included for convenience: # Rerun it to try out the SMC examples on a fresh instance of the problem. +# Set `sensor_settings["s_noise"]` here, if you wish. + key, k_low, k_high = jax.random.split(key, 3) trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation, None)) trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation, None)) From b3b70c3a85dd86e9f1628ffb1a2597f4b21be393 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 17 Feb 2025 16:11:00 -0500 Subject: [PATCH 8/9] Undo ersatz optional argument logic --- spring2025-course/localization-tutorial.py | 57 ++++++++++------------ 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index 4622d3a..95c3730 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -487,10 +487,6 @@ def update_ideal_sensors(widget, _, label="pose"): @genjax.gen def sensor_model_one(pose, angle, s_noise): - # Handle optional/default values in this admittedly hacky way, - # so as not to confuse JAX's JIT. - if s_noise is None: - s_noise = sensor_settings["s_noise"] return ( genjax.normal( sensor_distance(pose.rotate(angle), world["walls"], sensor_settings["box_size"]), @@ -507,7 +503,7 @@ def sensor_model_one(pose, angle, s_noise): # %% key, sub_key = jax.random.split(key) -cm, score, retval = sensor_model_one.propose(sub_key, (some_pose, sensor_angles[0], None)) +cm, score, retval = sensor_model_one.propose(sub_key, (some_pose, sensor_angles[0], sensor_settings["s_noise"])) retval # %% [markdown] @@ -524,7 +520,7 @@ def sensor_model_one(pose, angle, s_noise): # %% key, sub_key = jax.random.split(key) -cm, score, retval = sensor_model.propose(sub_key, (some_pose, sensor_angles, None)) +cm, score, retval = sensor_model.propose(sub_key, (some_pose, sensor_angles, sensor_settings["s_noise"])) retval # %% [markdown] @@ -576,7 +572,7 @@ def update_noisy_sensors(widget, _, pose_key="pose", slider_key="noise_slider"): | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) | Plot.initialState({ "k": jax.random.key_data(k1), - "pose_readings": noisy_sensor(k2, some_pose, None) + "pose_readings": noisy_sensor(k2, some_pose, sensor_settings["s_noise"]) }, sync={"k"}) | Plot.onChange({"pose": update_noisy_sensors, "noise_slider": update_noisy_sensors}) ) @@ -601,7 +597,7 @@ def update_noisy_sensors(widget, _, pose_key="pose", slider_key="noise_slider"): # The construction of a log density function is automated by the `assess` semantics for generative functions. This method is passed a choice map and a tuple of arguments, and it returns the log score plus the return value. # %% -score, retval = sensor_model.assess(cm, (some_pose, sensor_angles, None)) +score, retval = sensor_model.assess(cm, (some_pose, sensor_angles, sensor_settings["s_noise"])) jnp.exp(score) @@ -635,7 +631,7 @@ def on_guess_pose_chage(widget, _): likelihood_function( C["distance"].set(widget.state.target_readings), pose_at(widget.state, "guess"), - None + sensor_settings["s_noise"] ) }) @@ -645,7 +641,7 @@ def on_target_pose_chage(widget, _): likelihood_function( C["distance"].set(widget.state.target_readings), pose_at(widget.state, "guess"), - None + sensor_settings["s_noise"] ) }) @@ -723,8 +719,8 @@ def on_target_pose_chage(widget, _): { "k": jax.random.key_data(k1), "guess_readings": ideal_sensor(sensor_angles, guess_pose), - "target_readings": (initial_target_readings := noisy_sensor(k3, target_pose, None)), - "likelihood": likelihood_function(C["distance"].set(initial_target_readings), guess_pose, None), + "target_readings": (initial_target_readings := noisy_sensor(k3, target_pose, sensor_settings["s_noise"])), + "likelihood": likelihood_function(C["distance"].set(initial_target_readings), guess_pose, sensor_settings["s_noise"]), "show_target_pose": False, }, sync={"k", "target_readings"}) | Plot.onChange({ @@ -1394,7 +1390,6 @@ def full_model_kernel(motion_settings, s_noise, state, control): sensor_model(pose, sensor_angles, s_noise) @ "sensor" return pose -# Recall that supplying `None` for `s_noise` ultimately leads to the value `sensor_settings["s_niose"]` being used. @genjax.gen def full_model(motion_settings, s_noise): return ( @@ -1419,7 +1414,7 @@ def full_model(motion_settings, s_noise): # %% key, sub_key = jax.random.split(key) -cm, score, retval = full_model.propose(sub_key, (default_motion_settings, None)) +cm, score, retval = full_model.propose(sub_key, (default_motion_settings, sensor_settings["s_noise"])) cm @@ -1581,7 +1576,7 @@ def animate_full_trace(trace, frame_key=None): ) key, sub_key = jax.random.split(key) -tr = full_model.simulate(sub_key, (default_motion_settings, None)) +tr = full_model.simulate(sub_key, (default_motion_settings, sensor_settings["s_noise"])) animate_full_trace(tr) @@ -1603,8 +1598,8 @@ def animate_full_trace(trace, frame_key=None): } key, k_low, k_high = jax.random.split(key, 3) -trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation, None)) -trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation, None)) +trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation, sensor_settings["s_noise"])) +trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation, sensor_settings["s_noise"])) animate_full_trace(trace_low_deviation) # %% @@ -1674,13 +1669,13 @@ def plt(readings): key, sub_key = jax.random.split(key) sample, log_weight = model_importance( - sub_key, constraints_low_deviation, (motion_settings_low_deviation, None) + sub_key, constraints_low_deviation, (motion_settings_low_deviation, sensor_settings["s_noise"]) ) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% key, sub_key = jax.random.split(key) sample, log_weight = model_importance( - sub_key, constraints_high_deviation, (motion_settings_high_deviation, None) + sub_key, constraints_high_deviation, (motion_settings_high_deviation, sensor_settings["s_noise"]) ) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% [markdown] @@ -1744,13 +1739,13 @@ def constraint_from_path(path): trace_path_integrated_observations_low_deviation, w_low = model_importance( sub_key, constraints_path_integrated_observations_low_deviation, - (motion_settings_low_deviation, None), + (motion_settings_low_deviation, sensor_settings["s_noise"]), ) key, sub_key = jax.random.split(key) trace_path_integrated_observations_high_deviation, w_high = model_importance( sub_key, constraints_path_integrated_observations_high_deviation, - (motion_settings_high_deviation, None), + (motion_settings_high_deviation, sensor_settings["s_noise"]), ) Plot.Row(*[ @@ -1788,7 +1783,7 @@ def constraint_from_path(path): )( jax.random.split(sub_key, N_samples), constraints_low_deviation, - (motion_settings_low_deviation, None), + (motion_settings_low_deviation, sensor_settings["s_noise"]), ) traces_generated_high_deviation, high_weights = jax.vmap( @@ -1796,7 +1791,7 @@ def constraint_from_path(path): )( jax.random.split(sub_key, N_samples), constraints_high_deviation, - (motion_settings_high_deviation, None), + (motion_settings_high_deviation, sensor_settings["s_noise"]), ) low_deviation_paths = jax.vmap(get_path)(traces_generated_low_deviation) @@ -1921,11 +1916,11 @@ def importance_sample( key, sub_key = jax.random.split(key) low_posterior = jit_resample( - sub_key, constraints_low_deviation, motion_settings_low_deviation, None, N_presamples, N_samples + sub_key, constraints_low_deviation, motion_settings_low_deviation, sensor_settings["s_noise"], N_presamples, N_samples ) key, sub_key = jax.random.split(key) high_posterior = jit_resample( - sub_key, constraints_high_deviation, motion_settings_high_deviation, None, N_presamples, N_samples + sub_key, constraints_high_deviation, motion_settings_high_deviation, sensor_settings["s_noise"], N_presamples, N_samples ) @@ -2160,8 +2155,8 @@ def localization_sis(motion_settings, s_noise, observations): # Set `sensor_settings["s_noise"]` here, if you wish. key, k_low, k_high = jax.random.split(key, 3) -trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation, None)) -trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation, None)) +trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation, sensor_settings["s_noise"])) +trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation, sensor_settings["s_noise"])) path_low_deviation = get_path(trace_low_deviation) path_high_deviation = get_path(trace_high_deviation) observations_low_deviation = get_sensors(trace_low_deviation) @@ -2178,7 +2173,7 @@ def pose_list_to_plural_pose(pl: list[Pose]) -> Pose: key, sub_key = jax.random.split(key) smc_result = localization_sis( - motion_settings_high_deviation, None, observations_high_deviation + motion_settings_high_deviation, sensor_settings["s_noise"], observations_high_deviation ).run(sub_key, N_particles) def plot_sis_result(ground_truth, smc_result): @@ -2197,7 +2192,7 @@ def plot_sis_result(ground_truth, smc_result): key, sub_key = jax.random.split(key) low_smc_result = localization_sis( - motion_settings_low_deviation, None, observations_low_deviation + motion_settings_low_deviation, sensor_settings["s_noise"], observations_low_deviation ).run(sub_key, N_particles) plot_sis_result(path_low_deviation, low_smc_result) @@ -2305,10 +2300,10 @@ def localization_sis_plus_grid_rejuv(motion_settings, s_noise, M_grid, N_grid, o key, sub_key = jax.random.split(key) smc_result = localization_sis_plus_grid_rejuv( - motion_settings_high_deviation, None, M_grid, N_grid, observations_high_deviation + motion_settings_high_deviation, sensor_settings["s_noise"], M_grid, N_grid, observations_high_deviation ).run(sub_key, N_particles) imp_result = localization_sis( - motion_settings_high_deviation, None, observations_high_deviation + motion_settings_high_deviation, sensor_settings["s_noise"], observations_high_deviation ).run(sub_key, N_particles) plot_sis_result(path_high_deviation, smc_result) | plot_sis_result(path_high_deviation, imp_result) From a6b95014c9e0e740afbaad23972f02d7deb0b9a0 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 17 Feb 2025 16:30:02 -0500 Subject: [PATCH 9/9] Remove unhelpful optional arguments --- spring2025-course/localization-tutorial.py | 34 +++++++++++++--------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index 95c3730..88d0347 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -435,11 +435,14 @@ def pose_at(state, label): pose_dict = getattr(state, label) return Pose(jnp.array(pose_dict["p"]), jnp.array(pose_dict["hd"])) -def update_ideal_sensors(widget, _, label="pose"): +def update_ideal_sensors(widget, label): widget.state.update({ (label + "_readings"): ideal_sensor(sensor_angles, pose_at(widget.state, label)) }) +def on_pose_change(widget, _): + update_ideal_sensors(widget, "pose") + ( ( world_plot @@ -450,7 +453,7 @@ def update_ideal_sensors(widget, _, label="pose"): | Plot.initialState({ "pose_readings": ideal_sensor(sensor_angles, some_pose) }) - | Plot.onChange({"pose": update_ideal_sensors}) + | Plot.onChange({"pose": on_pose_change}) ) # %% [markdown] @@ -539,9 +542,9 @@ def sensor_model_one(pose, angle, s_noise): def noisy_sensor(key, pose, s_noise): return sensor_model.propose(key, (pose, sensor_angles, s_noise))[2] -def noise_slider(key="noise_slider", label="Sensor noise =", init=None): - if init is None: - init = sensor_settings["s_noise"] + +# %% +def noise_slider(key, label, init): return Plot.Slider( key=key, label=label, @@ -550,7 +553,7 @@ def noise_slider(key="noise_slider", label="Sensor noise =", init=None): step=0.01, ) | Plot.initialState({key: init}, sync={key}) -def update_noisy_sensors(widget, _, pose_key="pose", slider_key="noise_slider"): +def update_noisy_sensors(widget, pose_key, slider_key): k1, k2 = jax.random.split(jax.random.wrap_key_data(widget.state.k)) readings = noisy_sensor(k1, pose_at(widget.state, pose_key), float(getattr(widget.state, slider_key))) widget.state.update({ @@ -561,6 +564,9 @@ def update_noisy_sensors(widget, _, pose_key="pose", slider_key="noise_slider"): # %% +def on_slider_change(widget, _): + update_noisy_sensors(widget, "pose", "noise_slider") + key, k1, k2 = jax.random.split(key, 3) ( ( @@ -568,13 +574,13 @@ def update_noisy_sensors(widget, _, pose_key="pose", slider_key="noise_slider"): + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles) + pose_widget("pose", some_pose, color="blue") ) - | noise_slider() + | noise_slider("noise_slider", "Sensor noise =", sensor_settings["s_noise"]) | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) | Plot.initialState({ "k": jax.random.key_data(k1), "pose_readings": noisy_sensor(k2, some_pose, sensor_settings["s_noise"]) }, sync={"k"}) - | Plot.onChange({"pose": update_noisy_sensors, "noise_slider": update_noisy_sensors}) + | Plot.onChange({"pose": on_slider_change, "noise_slider": on_slider_change}) ) # %% [markdown] @@ -626,7 +632,7 @@ def likelihood_function(cm, pose, s_noise): return sensor_model.assess(cm, (pose, sensor_angles, s_noise))[0] def on_guess_pose_chage(widget, _): - update_ideal_sensors(widget, None, label="guess") + update_ideal_sensors(widget, "guess") widget.state.update({"likelihood": likelihood_function( C["distance"].set(widget.state.target_readings), @@ -636,7 +642,7 @@ def on_guess_pose_chage(widget, _): }) def on_target_pose_chage(widget, _): - update_noisy_sensors(widget, None, pose_key="target") + update_noisy_sensors(widget, "target", "noise_slider") widget.state.update({"likelihood": likelihood_function( C["distance"].set(widget.state.target_readings), @@ -693,7 +699,7 @@ def on_target_pose_chage(widget, _): ), cols=2 ) - | noise_slider() + | noise_slider("noise_slider", "Sensor noise =", sensor_settings["s_noise"]) | ( Plot.html([ "div", @@ -744,7 +750,7 @@ def handler(widget, _): "k": jax.random.key_data(k1), "target": widget.state.camera, }) - readings = update_noisy_sensors(widget, None, pose_key="target", slider_key="world_noise") + readings = update_noisy_sensors(widget, "target", "world_noise") button_handler(widget, k2, readings) widget.state.update({ "target_exists": True, @@ -768,8 +774,8 @@ def camera_widget( ) + pose_widget("camera", camera_pose, color="blue") ) - | noise_slider(key="world_noise", label="World/data noise = ") - | noise_slider(key="model_noise", label="Model/inference noise = ") + | noise_slider("world_noise", "World/data noise = ", sensor_settings["s_noise"]) + | noise_slider("model_noise", "Model/inference noise = ", sensor_settings["s_noise"]) | ( Plot.html([ "div",