diff --git a/pixi.lock b/pixi.lock index 610d259..3525588 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2038,7 +2038,7 @@ packages: - pypi: . name: localization-tutorial version: 0.1.0 - sha256: 33e938ca55162ae9ce3a35f7ce2bb1cf5f91aabf0dad8d020256eb42b1fdcb28 + sha256: 079b995864eef3e1bb9b86ec0c6b7ff059bda536c7d85fed88207b7c1ca7cac5 requires_dist: - genstudio>=2025.2.2,<2026 - genjax==0.9.1 diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index 8f01c8a..88d0347 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -10,7 +10,7 @@ # format_version: '1.3' # jupytext_version: 1.11.2 # kernelspec: -# display_name: .venv +# display_name: default # language: python # name: python3 # --- @@ -410,7 +410,7 @@ def ideal_sensor(sensor_angles, pose): # %% # Plot sensor data. -def plot_sensors(pose, readings, sensor_angles): +def plot_sensors(pose, readings, sensor_angles, show_legend=False): return Plot.Import("""export const projections = (pose, readings, angles) => Array.from({length: readings.length}, (_, i) => { const angle = angles[i] + pose.hd const reading = readings[i] @@ -419,14 +419,14 @@ def plot_sensors(pose, readings, sensor_angles): refer=["projections"]) | ( Plot.line( js("projections(%1, %2, %3).flatMap((projection, i) => [%1.p, projection, i])", pose, readings, sensor_angles), - stroke=Plot.constantly("sensor rays"), + opacity=0.1, ) + Plot.dot( js("projections(%1, %2, %3)", pose, readings, sensor_angles), r=2.75, - fill=Plot.constantly("sensor readings"), + fill="#f80" ) + - Plot.colorMap({"sensor rays": "rgba(0,0,0,0.1)", "sensor readings": "#f80"}) + Plot.cond(show_legend, Plot.colorMap({"sensor rays": "rgb(0,0,0,0.1)", "sensor readings": "#f80"}) | Plot.colorLegend()) ) @@ -435,22 +435,25 @@ def pose_at(state, label): pose_dict = getattr(state, label) return Pose(jnp.array(pose_dict["p"]), jnp.array(pose_dict["hd"])) -def update_ideal_sensors(widget, _, label="pose"): +def update_ideal_sensors(widget, label): widget.state.update({ (label + "_readings"): ideal_sensor(sensor_angles, pose_at(widget.state, label)) }) +def on_pose_change(widget, _): + update_ideal_sensors(widget, "pose") + ( ( world_plot - + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles) + + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles, show_legend=True) + pose_widget("pose", some_pose, color="blue") ) | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) | Plot.initialState({ "pose_readings": ideal_sensor(sensor_angles, some_pose) }) - | Plot.onChange({"pose": update_ideal_sensors}) + | Plot.onChange({"pose": on_pose_change}) ) # %% [markdown] @@ -462,7 +465,7 @@ def update_ideal_sensors(widget, _, label="pose"): Plot.Frames([ ( world_plot - + plot_sensors(pose, some_readings[i], sensor_angles) + + plot_sensors(pose, some_readings[i], sensor_angles, show_legend=True) + pose_plots(pose) ) for i, pose in enumerate(some_poses) @@ -483,12 +486,14 @@ def update_ideal_sensors(widget, _, label="pose"): # Its declarative model in `Gen` starts with the case of just one sensor reading: # %% +sensor_settings["s_noise"] = 0.10 + @genjax.gen -def sensor_model_one(pose, angle): +def sensor_model_one(pose, angle, s_noise): return ( genjax.normal( sensor_distance(pose.rotate(angle), world["walls"], sensor_settings["box_size"]), - sensor_settings["s_noise"], + s_noise, ) @ "distance" ) @@ -500,10 +505,8 @@ def sensor_model_one(pose, angle): # We draw samples from `sensor_model_one` with `propose` semantics. Since this operation is stochastic, the method is called with a PRNG key in addition to a tuple of model arguments. The code is then run, performing the required draws from the sampling operations. The random draws get organized according to their addresses, forming a *choice map* data structure. This choice map, a score (to be discussed below), and the return value are all returned by `propose`. # %% -sensor_settings["s_noise"] = 0.10 - key, sub_key = jax.random.split(key) -cm, score, retval = sensor_model_one.propose(sub_key, (some_pose, sensor_angles[1])) +cm, score, retval = sensor_model_one.propose(sub_key, (some_pose, sensor_angles[0], sensor_settings["s_noise"])) retval # %% [markdown] @@ -516,11 +519,11 @@ def sensor_model_one(pose, angle): # We are interested in the related model whose *single* draw consists of a *vector* of the sensor distances computed across the vector of sensor angles. This is exactly what we get using the GenJAX `vmap` combinator on GFs. # %% -sensor_model = sensor_model_one.vmap(in_axes=(None, 0)) +sensor_model = sensor_model_one.vmap(in_axes=(None, 0, None)) # %% key, sub_key = jax.random.split(key) -cm, score, retval = sensor_model.propose(sub_key, (some_pose, sensor_angles)) +cm, score, retval = sensor_model.propose(sub_key, (some_pose, sensor_angles, sensor_settings["s_noise"])) retval # %% [markdown] @@ -536,20 +539,34 @@ def sensor_model_one(pose, angle): # With a little wrapping, one gets a function of the same type as `ideal_sensor`, ignoring the PRNG key. # %% -def noisy_sensor(key, pose): - return sensor_model.propose(key, (pose, sensor_angles))[2] +def noisy_sensor(key, pose, s_noise): + return sensor_model.propose(key, (pose, sensor_angles, s_noise))[2] # %% -def update_noisy_sensors(widget, _, label="pose"): +def noise_slider(key, label, init): + return Plot.Slider( + key=key, + label=label, + showValue=True, + range=[0.01, 5.0], + step=0.01, + ) | Plot.initialState({key: init}, sync={key}) + +def update_noisy_sensors(widget, pose_key, slider_key): k1, k2 = jax.random.split(jax.random.wrap_key_data(widget.state.k)) - readings = noisy_sensor(k1, pose_at(widget.state, label)) + readings = noisy_sensor(k1, pose_at(widget.state, pose_key), float(getattr(widget.state, slider_key))) widget.state.update({ "k": jax.random.key_data(k2), - (label + "_readings"): readings + (pose_key + "_readings"): readings }) return readings + +# %% +def on_slider_change(widget, _): + update_noisy_sensors(widget, "pose", "noise_slider") + key, k1, k2 = jax.random.split(key, 3) ( ( @@ -557,12 +574,13 @@ def update_noisy_sensors(widget, _, label="pose"): + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles) + pose_widget("pose", some_pose, color="blue") ) + | noise_slider("noise_slider", "Sensor noise =", sensor_settings["s_noise"]) | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) | Plot.initialState({ "k": jax.random.key_data(k1), - "pose_readings": noisy_sensor(k2, some_pose) + "pose_readings": noisy_sensor(k2, some_pose, sensor_settings["s_noise"]) }, sync={"k"}) - | Plot.onChange({"pose": update_noisy_sensors}) + | Plot.onChange({"pose": on_slider_change, "noise_slider": on_slider_change}) ) # %% [markdown] @@ -585,7 +603,7 @@ def update_noisy_sensors(widget, _, label="pose"): # The construction of a log density function is automated by the `assess` semantics for generative functions. This method is passed a choice map and a tuple of arguments, and it returns the log score plus the return value. # %% -score, retval = sensor_model.assess(cm, (some_pose, sensor_angles)) +score, retval = sensor_model.assess(cm, (some_pose, sensor_angles, sensor_settings["s_noise"])) jnp.exp(score) @@ -610,24 +628,26 @@ def update_noisy_sensors(widget, _, label="pose"): guess_pose = Pose(jnp.array([2.0, 16]), jnp.array(0.0)) target_pose = Pose(jnp.array([15.0, 4.0]), jnp.array(-1.6)) -def likelihood_function(cm, pose): - return sensor_model.assess(cm, (pose, sensor_angles))[0] +def likelihood_function(cm, pose, s_noise): + return sensor_model.assess(cm, (pose, sensor_angles, s_noise))[0] def on_guess_pose_chage(widget, _): - update_ideal_sensors(widget, None, label="guess") + update_ideal_sensors(widget, "guess") widget.state.update({"likelihood": likelihood_function( C["distance"].set(widget.state.target_readings), - pose_at(widget.state, "guess") + pose_at(widget.state, "guess"), + sensor_settings["s_noise"] ) }) def on_target_pose_chage(widget, _): - update_noisy_sensors(widget, None, label="target") + update_noisy_sensors(widget, "target", "noise_slider") widget.state.update({"likelihood": likelihood_function( C["distance"].set(widget.state.target_readings), - pose_at(widget.state, "guess") + pose_at(widget.state, "guess"), + sensor_settings["s_noise"] ) }) @@ -679,6 +699,7 @@ def on_target_pose_chage(widget, _): ), cols=2 ) + | noise_slider("noise_slider", "Sensor noise =", sensor_settings["s_noise"]) | ( Plot.html([ "div", @@ -704,13 +725,14 @@ def on_target_pose_chage(widget, _): { "k": jax.random.key_data(k1), "guess_readings": ideal_sensor(sensor_angles, guess_pose), - "target_readings": (initial_target_readings := noisy_sensor(k3, target_pose)), - "likelihood": likelihood_function(C["distance"].set(initial_target_readings), guess_pose), + "target_readings": (initial_target_readings := noisy_sensor(k3, target_pose, sensor_settings["s_noise"])), + "likelihood": likelihood_function(C["distance"].set(initial_target_readings), guess_pose, sensor_settings["s_noise"]), "show_target_pose": False, }, sync={"k", "target_readings"}) | Plot.onChange({ "guess": on_guess_pose_chage, - "target": on_target_pose_chage + "target": on_target_pose_chage, + "noise_slider": on_target_pose_chage, }) ) @@ -728,7 +750,7 @@ def handler(widget, _): "k": jax.random.key_data(k1), "target": widget.state.camera, }) - readings = update_noisy_sensors(widget, None, label="target") + readings = update_noisy_sensors(widget, "target", "world_noise") button_handler(widget, k2, readings) widget.state.update({ "target_exists": True, @@ -752,6 +774,8 @@ def camera_widget( ) + pose_widget("camera", camera_pose, color="blue") ) + | noise_slider("world_noise", "World/data noise = ", sensor_settings["s_noise"]) + | noise_slider("model_noise", "Model/inference noise = ", sensor_settings["s_noise"]) | ( Plot.html([ "div", @@ -816,8 +840,9 @@ def make_poses_grid(bounds, ns): grid_poses = make_poses_grid(world["bounding_box"], N_grid) def grid_search_handler(widget, k, readings): + model_noise = float(getattr(widget.state, "model_noise")) jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose) + lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) ) likelihoods = jax.vmap(jitted_likelihood)(grid_poses) best = jnp.argsort(likelihoods, descending=True)[0:N_keep] @@ -873,8 +898,9 @@ def grid_search_handler(widget, k, readings): grid_poses = make_poses_grid(world["bounding_box"], N_grid) def grid_approximation_handler(widget, k, readings): + model_noise = float(getattr(widget.state, "model_noise")) jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose) + lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) ) likelihoods = jax.vmap(jitted_likelihood)(grid_poses) @@ -914,8 +940,9 @@ def grid_sample_one(k): camera_pose = Pose(jnp.array([15.13, 14.16]), jnp.array(1.5)) def importance_resampling_handler(widget, k, readings): + model_noise = float(getattr(widget.state, "model_noise")) jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose) + lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) ) def importance_resample_one(k): @@ -952,8 +979,9 @@ def importance_resample_one(k): camera_pose = Pose(jnp.array([15.13, 14.16]), jnp.array(1.5)) def MCMC_handler(widget, k, readings): + model_noise = float(getattr(widget.state, "model_noise")) jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose) + lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) ) def do_MH_step(pose_likelihood, k): @@ -1363,15 +1391,17 @@ def plot_path_with_confidence(path, step): # %% @genjax.gen -def full_model_kernel(motion_settings, state, control): +def full_model_kernel(motion_settings, s_noise, state, control): pose = step_model(motion_settings, state, control) @ "pose" - sensor_model(pose, sensor_angles) @ "sensor" + sensor_model(pose, sensor_angles, s_noise) @ "sensor" return pose @genjax.gen -def full_model(motion_settings): +def full_model(motion_settings, s_noise): return ( - full_model_kernel.partial_apply(motion_settings) + full_model_kernel + .partial_apply(motion_settings) + .partial_apply(s_noise) .map(diag) .scan()(robot_inputs["start"], robot_inputs["controls"]) @ "steps" @@ -1390,7 +1420,7 @@ def full_model(motion_settings): # %% key, sub_key = jax.random.split(key) -cm, score, retval = full_model.propose(sub_key, (default_motion_settings,)) +cm, score, retval = full_model.propose(sub_key, (default_motion_settings, sensor_settings["s_noise"])) cm @@ -1552,7 +1582,7 @@ def animate_full_trace(trace, frame_key=None): ) key, sub_key = jax.random.split(key) -tr = full_model.simulate(sub_key, (default_motion_settings,)) +tr = full_model.simulate(sub_key, (default_motion_settings, sensor_settings["s_noise"])) animate_full_trace(tr) @@ -1561,6 +1591,9 @@ def animate_full_trace(trace, frame_key=None): # # Let us generate some fixed synthetic motion data that, for pedagogical purposes, we will work with as if it were the actual path of the robot. We will generate two versions, one each with low or high motion deviation. # %% +# HERE is a great place to update `sensor_settings["s_noise"]` if you wish, +# about to construct some "data" using it. + motion_settings_low_deviation = { "p_noise": 0.05, "hd_noise": (1 / 10.0) * 2 * jnp.pi / 360, @@ -1571,8 +1604,8 @@ def animate_full_trace(trace, frame_key=None): } key, k_low, k_high = jax.random.split(key, 3) -trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation,)) -trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation,)) +trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation, sensor_settings["s_noise"])) +trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation, sensor_settings["s_noise"])) animate_full_trace(trace_low_deviation) # %% @@ -1603,7 +1636,7 @@ def frame(pose, readings1, readings2): def plt(readings): return Plot.new( plot_base or Plot.domain([0, 20]), - plot_sensors(pose, readings, sensor_angles), + plot_sensors(pose, readings, sensor_angles, show_legend=True), {"width": 400, "height": 400}, ) @@ -1642,13 +1675,13 @@ def plt(readings): key, sub_key = jax.random.split(key) sample, log_weight = model_importance( - sub_key, constraints_low_deviation, (motion_settings_low_deviation,) + sub_key, constraints_low_deviation, (motion_settings_low_deviation, sensor_settings["s_noise"]) ) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% key, sub_key = jax.random.split(key) sample, log_weight = model_importance( - sub_key, constraints_high_deviation, (motion_settings_high_deviation,) + sub_key, constraints_high_deviation, (motion_settings_high_deviation, sensor_settings["s_noise"]) ) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% [markdown] @@ -1712,13 +1745,13 @@ def constraint_from_path(path): trace_path_integrated_observations_low_deviation, w_low = model_importance( sub_key, constraints_path_integrated_observations_low_deviation, - (motion_settings_low_deviation,), + (motion_settings_low_deviation, sensor_settings["s_noise"]), ) key, sub_key = jax.random.split(key) trace_path_integrated_observations_high_deviation, w_high = model_importance( sub_key, constraints_path_integrated_observations_high_deviation, - (motion_settings_high_deviation,), + (motion_settings_high_deviation, sensor_settings["s_noise"]), ) Plot.Row(*[ @@ -1756,7 +1789,7 @@ def constraint_from_path(path): )( jax.random.split(sub_key, N_samples), constraints_low_deviation, - (motion_settings_low_deviation,), + (motion_settings_low_deviation, sensor_settings["s_noise"]), ) traces_generated_high_deviation, high_weights = jax.vmap( @@ -1764,7 +1797,7 @@ def constraint_from_path(path): )( jax.random.split(sub_key, N_samples), constraints_high_deviation, - (motion_settings_high_deviation,), + (motion_settings_high_deviation, sensor_settings["s_noise"]), ) low_deviation_paths = jax.vmap(get_path)(traces_generated_low_deviation) @@ -1866,14 +1899,14 @@ def constraint_from_path(path): N_samples = 20 def importance_sample( - key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int + key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, s_noise, N: int, K: int ): """Produce K importance samples of depth N from the model. That is, K times, we generate N importance samples conditioned by the constraints, and categorically select one of them.""" key1, key2 = jax.random.split(key) samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))( - jax.random.split(key1, N * K), constraints, (motion_settings,) + jax.random.split(key1, N * K), constraints, (motion_settings, s_noise) ) winners = jax.vmap(genjax.categorical.sampler)( jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) @@ -1885,15 +1918,15 @@ def importance_sample( return selected -jit_resample = jax.jit(importance_sample, static_argnums=(3, 4)) +jit_resample = jax.jit(importance_sample, static_argnums=(4, 5)) key, sub_key = jax.random.split(key) low_posterior = jit_resample( - sub_key, constraints_low_deviation, motion_settings_low_deviation, N_presamples, N_samples + sub_key, constraints_low_deviation, motion_settings_low_deviation, sensor_settings["s_noise"], N_presamples, N_samples ) key, sub_key = jax.random.split(key) high_posterior = jit_resample( - sub_key, constraints_high_deviation, motion_settings_high_deviation, N_presamples, N_samples + sub_key, constraints_high_deviation, motion_settings_high_deviation, sensor_settings["s_noise"], N_presamples, N_samples ) @@ -2108,7 +2141,7 @@ def step(state, update): return SISwithRejuvenation.Result(N, end, samples, indices) # %% -def localization_sis(motion_settings, observations): +def localization_sis(motion_settings, s_noise, observations): return SISwithRejuvenation( robot_inputs["start"], robot_inputs["controls"], @@ -2116,7 +2149,7 @@ def localization_sis(motion_settings, observations): lambda key, pose, control, observation: full_model_kernel.importance( key, C["sensor", "distance"].set(observation), - (motion_settings, pose, control), + (motion_settings, s_noise, pose, control), ), ) @@ -2125,9 +2158,11 @@ def localization_sis(motion_settings, observations): # This cell is included for convenience: # Rerun it to try out the SMC examples on a fresh instance of the problem. +# Set `sensor_settings["s_noise"]` here, if you wish. + key, k_low, k_high = jax.random.split(key, 3) -trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation,)) -trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation,)) +trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation, sensor_settings["s_noise"])) +trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation, sensor_settings["s_noise"])) path_low_deviation = get_path(trace_low_deviation) path_high_deviation = get_path(trace_high_deviation) observations_low_deviation = get_sensors(trace_low_deviation) @@ -2144,7 +2179,7 @@ def pose_list_to_plural_pose(pl: list[Pose]) -> Pose: key, sub_key = jax.random.split(key) smc_result = localization_sis( - motion_settings_high_deviation, observations_high_deviation + motion_settings_high_deviation, sensor_settings["s_noise"], observations_high_deviation ).run(sub_key, N_particles) def plot_sis_result(ground_truth, smc_result): @@ -2163,7 +2198,7 @@ def plot_sis_result(ground_truth, smc_result): key, sub_key = jax.random.split(key) low_smc_result = localization_sis( - motion_settings_low_deviation, observations_low_deviation + motion_settings_low_deviation, sensor_settings["s_noise"], observations_low_deviation ).run(sub_key, N_particles) plot_sis_result(path_low_deviation, low_smc_result) @@ -2194,7 +2229,7 @@ def run_SMCP3_step(fwd_proposal, bwd_proposal, key, sample, proposal_args): # the unnormalized posterior density over steps. @genjax.gen def grid_fwd_proposal(sample, args): - base_grid, observation, model_args = args + base_grid, observation, full_model_args = args observation_cm = C["sensor", "distance"].set(observation) log_weights = jax.vmap( @@ -2203,7 +2238,7 @@ def grid_fwd_proposal(sample, args): observation_cm | C["pose", "p"].set(p + sample.get_retval().p) | C["pose", "hd"].set(hd + sample.get_retval().hd), - model_args + full_model_args )[0] )(*base_grid) fwd_index = genjax.categorical(log_weights) @ "fwd_index" @@ -2219,14 +2254,15 @@ def grid_fwd_proposal(sample, args): # Backwards proposal simply guesses according to the prior over steps, nothing fancier. @genjax.gen def grid_bwd_proposal(new_sample, args): - base_grid, _, model_args = args + base_grid, _, full_model_args = args + step_model_args = (full_model_args[0], full_model_args[2], full_model_args[3]) log_weights = jax.vmap( lambda p, hd: step_model.assess( C["p"].set(p + new_sample.get_retval().p) | C["hd"].set(hd + new_sample.get_retval().hd), - model_args + step_model_args )[0] )(*base_grid) @@ -2236,7 +2272,7 @@ def grid_bwd_proposal(new_sample, args): # %% -def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observations): +def localization_sis_plus_grid_rejuv(motion_settings, s_noise, M_grid, N_grid, observations): base_grid = make_poses_grid_array( jnp.array([M_grid / 2.0, M_grid / 2.0]).T, N_grid @@ -2248,14 +2284,14 @@ def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observatio importance=lambda key, pose, control, observation: full_model_kernel.importance( key, C["sensor", "distance"].set(observation), - (motion_settings, pose, control), + (motion_settings, s_noise, pose, control), ), rejuvenate=lambda key, sample, pose, control, observation: run_SMCP3_step( grid_fwd_proposal, grid_bwd_proposal, key, sample, - (base_grid, observation, (motion_settings, pose, control)) + (base_grid, observation, (motion_settings, s_noise, pose, control)) ), ) @@ -2270,10 +2306,10 @@ def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observatio key, sub_key = jax.random.split(key) smc_result = localization_sis_plus_grid_rejuv( - motion_settings_high_deviation, M_grid, N_grid, observations_high_deviation + motion_settings_high_deviation, sensor_settings["s_noise"], M_grid, N_grid, observations_high_deviation ).run(sub_key, N_particles) imp_result = localization_sis( - motion_settings_high_deviation, observations_high_deviation + motion_settings_high_deviation, sensor_settings["s_noise"], observations_high_deviation ).run(sub_key, N_particles) plot_sis_result(path_high_deviation, smc_result) | plot_sis_result(path_high_deviation, imp_result)