Skip to content

Flow matching smc#236

Merged
kazewong merged 27 commits intomainfrom
flowMatchingSMC
Aug 18, 2025
Merged

Flow matching smc#236
kazewong merged 27 commits intomainfrom
flowMatchingSMC

Conversation

@kazewong
Copy link
Owner

@kazewong kazewong commented Aug 18, 2025

The following summary is generated by copilot

This pull request introduces a new flow matching model implementation, improves dependency management, and refactors the resource and model directory structure for better organization and maintainability. The most significant changes include the addition of a flow matching model under resource/model/flowmatching, migration of normalizing flow model base classes, and various updates to imports and dependencies to reflect the new structure.

Major Feature Addition

Flow Matching Model Implementation

  • Added a new FlowMatchingModel and supporting classes (Solver, Scheduler, Path, etc.) in src/flowMC/resource/model/flowmatching/base.py to enable flow matching-based generative modeling. This includes training, sampling, and log-probability computation methods leveraging ODE solvers from Diffrax.

Refactoring and Directory Structure

Normalizing Flow Model Refactor

  • Moved abstract base classes Bijection and Distribution from src/flowMC/resource/nf_model/base.py to src/flowMC/resource/model/common.py, and updated all relevant imports to use the new location. This change centralizes shared abstractions and reduces duplication. [1] [2] [3] [4]

Kernel and Model Directory Reorganization

  • Renamed and moved local_kernel modules (Gaussian_random_walk.py, HMC.py, MALA.py, base.py) to resource/kernel/ and updated all references throughout the codebase and tutorials. [1] [2] [3] [4] [5] [6]
  • Moved nf_model modules under resource/model/nf_model/, updating all imports accordingly. [1] [2] [3] [4] [5]

Dependency and Configuration Updates

Dependency Additions

  • Added diffrax as a required dependency in pyproject.toml and as an additional dependency for the pyright pre-commit hook, enabling ODE-based flow matching methods. [1] [2]

Development and Code QA Dependencies

  • Consolidated the codeqa dependency group into the dev group in pyproject.toml, and added ipython for improved development experience.

Most important changes, grouped by theme:

1. New Feature: Flow Matching Model

  • Added FlowMatchingModel and supporting classes in src/flowMC/resource/model/flowmatching/base.py for ODE-based generative modeling using Diffrax.

2. Refactoring & Structure

  • Moved Bijection and Distribution base classes to src/flowMC/resource/model/common.py and updated all related imports. [1] [2] [3] [4]
  • Relocated local_kernel modules to resource/kernel/ and updated all usages and tutorial references. [1] [2] [3] [4] [5] [6]
  • Moved nf_model modules under resource/model/nf_model/ and updated all imports. [1] [2] [3] [4] [5]

3. Dependency Management

  • Added diffrax as a required dependency and to the pre-commit pyright hook for ODE support. [1] [2]
  • Merged codeqa dependencies into the dev group in pyproject.toml and added ipython.

Summary by CodeRabbit

  • New Features
    • Introduced a flow-matching model with sampling, log-probability, saving/loading, and training utilities.
    • Added a Sequential Monte Carlo strategy interface.
  • Documentation
    • Corrected import paths across tutorials to match the new module layout.
  • Refactor
    • Reorganized modules: kernels under resource.kernel; models under resource.model.*; consolidated common base classes.
  • Tests
    • Added extensive flow-matching tests; updated imports in existing tests.
  • Chores
    • Added diffrax as a runtime dependency; expanded dev dependencies; updated pre-commit configuration; removed deprecated optional dependency group.

@coderabbitai
Copy link

coderabbitai bot commented Aug 18, 2025

Walkthrough

Project-wide module reorganization and dependency updates. Kernel imports moved from resource.local_kernel to resource.kernel; NF model imports consolidated under resource.model.*. Bijection/Distribution moved to model.common. Added a new flow-matching model module and an abstract SequentialMonteCarlo strategy. Updated docs/tests accordingly. Added diffrax as a runtime dependency and updated dev tooling.

Changes

Cohort / File(s) Summary
Packaging & Tooling
pyproject.toml, .pre-commit-config.yaml
Add runtime dep diffrax>=0.7.0; refresh dev deps (coveralls, pre-commit, pyright, pytest, ruff, ipython); remove optional group codeqa; include diffrax in Pyright pre-commit hook.
Kernel path reorg
src/flowMC/resource/kernel/Gaussian_random_walk.py, .../kernel/HMC.py, .../kernel/MALA.py, .../kernel/NF_proposal.py
Switch imports to new module locations: ProposalBase from resource.kernel.base; NFModel from resource.model.nf_model.base. No logic changes.
Strategy updates (imports)
src/flowMC/strategy/parallel_tempering.py, .../strategy/take_steps.py, .../strategy/train_model.py
Update imports to new kernel/model paths; no behavioral changes.
Strategy bundles (imports)
src/flowMC/resource_strategy_bundle/RQSpline_MALA.py, .../RQSpline_MALA_PT.py
Update MALA, NFProposal, and MaskedCouplingRQSpline import paths to kernel/model.*.
Docs (tutorials)
docs/tutorials/custom_strategy.ipynb, .../parallel_tempering.ipynb, .../train_normalizing_flow.ipynb
Update import paths for GaussianRandomWalk, MALA, RealNVP, MaskedCouplingRQSpline to new modules.
NF model reorg
src/flowMC/resource/model/nf_model/realNVP.py, .../nf_model/rqSpline.py, .../kernel/NF_proposal.py
Point NFModel and related components to resource.model.*; adjust Bijection/Distribution imports to model.common.
Common interfaces moved
src/flowMC/resource/model/common.py, .../model/nf_model/base.py
Add Bijection and Distribution abstract base classes to model.common; remove them from nf_model/base.
New flow-matching module
src/flowMC/resource/model/flowmatching/base.py
Add Solver, Scheduler/CondOTScheduler, Path, and FlowMatchingModel with sampling, log_prob, save/load, and training utilities (JAX/Equinox/Diffrax/Optax).
New SMC strategy interface
src/flowMC/strategy/sequential_monte_carlo.py
Add abstract SequentialMonteCarlo(Resource) with NotImplemented init/call.
Tests (kernel/model imports)
test/integration/test_HMC.py, .../test_MALA.py, .../test_RWMCMC.py, .../test_normalizingFlow.py, test/unit/test_kernels.py, .../test_nf.py, .../test_resources.py, .../test_strategies.py
Update imports to new kernel/model paths; no test logic changes.
New tests for flow-matching
test/unit/test_flowmatching.py
Add unit tests for Solver, Path/Scheduler, and FlowMatchingModel, including sampling, log_prob, save/load, and basic training steps.

Sequence Diagram(s)

sequenceDiagram
  actor User
  participant FM as FlowMatchingModel
  participant Solver
  participant Diffrax as ODE Solver
  participant Path
  participant Sched as Scheduler

  User->>FM: sample(rng, n, dt)
  FM->>Solver: sample(rng, n, dt)
  Solver->>Diffrax: solve ODE (t:0→1) on z0~N(0,I)
  Diffrax-->>Solver: x1 samples
  Solver-->>FM: x1
  FM->>FM: unwhiten (cov, mean)
  FM-->>User: samples

  User->>FM: log_prob(x)
  FM->>FM: whiten x
  FM->>Solver: log_prob(x_whitened, dt)
  Solver->>Diffrax: reverse-time ODE + jacobian trace
  Diffrax-->>Solver: log_prob term
  Solver-->>FM: log_prob(z1)
  FM->>FM: adjust by log|det(cov)|
  FM-->>User: log_prob(x)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Poem

A hop through kernels, models too,
We burrowed paths from old to new.
With Diffrax stars to guide our night,
We flow-match dawn in Dopri light.
New trails mapped, tests nibble true—
Thump, thump! The codebase grew. 🐇✨

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch flowMatchingSMC

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🔭 Outside diff range comments (11)
src/flowMC/strategy/take_steps.py (1)

134-145: Fix buffer cursor advancement with thinning (off-by-one bug).

You thin the arrays before writing to buffers, but then advance current_position by self.n_steps // self.thinning. When self.n_steps is not divisible by self.thinning, this undercounts (e.g., 5 steps with thinning=2 yields 3 writes but only +2 advancement), causing overlapping writes or misalignment across buffers.

Advance by the actual number of thinned steps written.

Apply:

         position_buffer.update_buffer(positions, self.current_position)
         log_prob_buffer.update_buffer(log_probs, self.current_position)
         acceptance_buffer.update_buffer(do_accepts, self.current_position)
-        self.current_position += self.n_steps // self.thinning
+        steps_written = int(positions.shape[1])  # number of thinned steps written
+        self.current_position += steps_written
src/flowMC/strategy/train_model.py (1)

66-72: Boolean indexing + reshape can corrupt chain alignment or fail; select valid timesteps across all chains instead.

Indexing with a 2D boolean mask flattens across (chain, time) pairs; reshaping back to (n_chains, -1, n_dims) assumes identical counts of valid timesteps per chain, which is not guaranteed and may error or silently misalign samples.

Filter by timesteps that are valid across all chains to preserve shape.

Apply:

-        n_chains = data_resource.data.shape[0]
-        n_dims = data_resource.data.shape[-1]
-        training_data = data_resource.data[
-            jnp.isfinite(data_resource.data).all(axis=-1)
-        ].reshape(n_chains, -1, n_dims)
-        training_data = training_data[:, -self.history_window :].reshape(-1, n_dims)
+        data_array = data_resource.data  # (n_chains, T, n_dims)
+        n_chains = data_array.shape[0]
+        n_dims = data_array.shape[-1]
+        # Keep only timesteps where all chains are finite across all dims to preserve structure
+        valid_t = jnp.all(jnp.isfinite(data_array).all(axis=-1), axis=0)  # (T,)
+        training_data = data_array[:, valid_t, :]  # (n_chains, T_valid, n_dims)
+        training_data = training_data[:, -self.history_window :].reshape(-1, n_dims)
pyproject.toml (1)

4-4: Fix typo in project description (“exhanced” → “enhanced”).

This is user-facing metadata on PyPI.

-description = "Normalizing flow exhanced sampler in jax"
+description = "Normalizing-flow-enhanced sampler in JAX"
src/flowMC/resource/model/common.py (3)

285-293: Align Gaussian.sample annotation with base and actual shape.

Currently annotated as " n_samples n_features" but returns (n_samples, n_dim).

Apply this diff:

     def sample(
         self, rng_key: PRNGKeyArray, n_samples: int
-    ) -> Float[Array, " n_samples n_features"]:
+    ) -> Float[Array, " n_samples n_dim"]:
         return jax.random.multivariate_normal(
             rng_key, self.mean, self.cov, (n_samples,)
         )

114-121: Avoid relying on eqx.nn.Linear in_features/out_features (may not exist).

Equinox Linear reliably exposes weight (shape: out_features x in_features), but not all versions expose in_features/out_features. Derive from weight.shape to be robust.

Apply this diff:

     @property
     def n_input(self) -> int:
-        return self.layers[0].in_features
+        return int(self.layers[0].weight.shape[1])
@@
     @property
     def n_output(self) -> int:
-        return self.layers[-1].out_features
+        return int(self.layers[-1].weight.shape[0])

310-317: Composable.sample should return a single concatenated array to honor the base interface.

Returning a dict violates the Distribution.sample contract and forces callers to be special-cased. Concatenating the parts along feature axis keeps a clean API.

Apply this diff:

     def sample(
-        self, rng_key: PRNGKeyArray, n_samples: int
-    ) -> Float[Array, " n_samples n_features"]:
-        samples = {}
-        for dist, (key, _) in zip(self.distributions, self.partitions.items()):
-            rng_key, sub_key = jax.random.split(rng_key)
-            samples[key] = dist.sample(sub_key, n_samples=n_samples)
-        return samples  # type: ignore
+        self, rng_key: PRNGKeyArray, n_samples: int
+    ) -> Float[Array, " n_samples n_dim"]:
+        parts: list[Array] = []
+        for dist, (key, (start, end)) in zip(self.distributions, self.partitions.items()):
+            rng_key, sub_key = jax.random.split(rng_key)
+            part = dist.sample(sub_key, n_samples=n_samples)  # (n_samples, end-start)
+            parts.append(part)
+        return jnp.concatenate(parts, axis=1)

Note: This still inherits the ordering caveat from zip(...)—see prior comment.

src/flowMC/resource/kernel/HMC.py (2)

118-121: Kinetic energy uses incorrect metric shape; fix for full/diag metrics.

kinetic currently multiplies p**2 * metric assuming a diagonal metric but elsewhere a full matrix is passed (e.g., self.condition_matrix). This yields wrong energies and gradients under a full metric and also mismatches the type annotations.

Apply this diff to support both scalar/diagonal and full metrics:

-        def kinetic(
-            p: Float[Array, " n_dim"], metric: Float[Array, " n_dim"]
-        ) -> Float[Array, "1"]:
-            return 0.5 * (p**2 * metric).sum()
+        def kinetic(
+            p: Float[Array, " n_dim"],
+            metric: Float[Array, " n_dim"] | Float[Array, " n_dim n_dim"],
+        ) -> Float[Array, "1"]:
+            # Support scalar/diagonal (vector) or full metric
+            if jnp.ndim(metric) == 0 or (hasattr(metric, "ndim") and metric.ndim == 1):
+                return 0.5 * (p**2 * metric).sum()
+            else:
+                return 0.5 * jnp.dot(p, metric @ p)

130-136: Duplicate momentum sampling and reusing the same PRNG key; sample once and respect metric.

Momentum is sampled twice with the same key and the first sample is immediately overwritten. Also, using the same key twice for different draws is an anti-pattern in JAX. Sample once and transform by the (inverse) metric properly.

Apply this diff to sample momentum correctly and deterministically:

-        momentum: Float[Array, " n_dim"] = (
-            jax.random.normal(key1, shape=position.shape) * self.condition_matrix**-0.5
-        )
-        momentum = jnp.dot(
-            jax.random.normal(key1, shape=position.shape),
-            jnp.linalg.cholesky(jnp.linalg.inv(self.condition_matrix)).T,
-        )
+        key_mom, _ = jax.random.split(key1)
+        # Sample p ~ N(0, Metric^{-1}), supporting scalar/diag/full metrics.
+        if jnp.ndim(self.condition_matrix) == 0 or (
+            hasattr(self.condition_matrix, "ndim") and self.condition_matrix.ndim == 1
+        ):
+            std = jnp.sqrt(1.0 / self.condition_matrix)
+            momentum: Float[Array, " n_dim"] = jax.random.normal(
+                key_mom, shape=position.shape
+            ) * std
+        else:
+            z = jax.random.normal(key_mom, shape=position.shape)
+            L = jnp.linalg.cholesky(jnp.linalg.inv(self.condition_matrix))
+            momentum = jnp.dot(z, L.T)
src/flowMC/resource/kernel/MALA.py (3)

71-73: Same PRNG key used twice in scan; split distinct keys to avoid correlation

keys fed into lax.scan are constructed as [key1, key1], causing identical noise across both iterations. This reduces stochasticity and is likely unintended.

Apply this diff to use distinct keys:

-        _, (proposal, logprob, d_logprob) = jax.lax.scan(
-            body, (position, dt, data), jnp.array([key1, key1])
-        )
+        keys = jax.random.split(key1, 2)
+        _, (proposal, logprob, d_logprob) = jax.lax.scan(
+            body, (position, dt, data), keys
+        )

76-81: Covariance passed as a scalar; use an isotropic covariance matrix

multivariate_normal.logpdf expects a covariance array. Passing dt2 (scalar) relies on undocumented broadcasting and can break with shape/type changes. Use dt2 * I.

Apply this diff:

-        ratio -= multivariate_normal.logpdf(
-            proposal[0], position + jnp.dot(dt2, d_logprob[0]) / 2, dt2
-        )
+        cov = dt2 * jnp.eye(position.shape[-1], dtype=position.dtype)
+        ratio -= multivariate_normal.logpdf(
+            proposal[0], position + jnp.dot(dt2, d_logprob[0]) / 2, cov
+        )
@@
-        ratio += multivariate_normal.logpdf(
-            position, proposal[0] + jnp.dot(dt2, d_logprob[1]) / 2, dt2
-        )
+        ratio += multivariate_normal.logpdf(
+            position, proposal[0] + jnp.dot(dt2, d_logprob[1]) / 2, cov
+        )

71-81: Algorithmic inconsistency: acceptance ratio mixes states from different sub-steps

ratio uses logprob[1] - logprob[0] (log p at proposal_1 minus at x) but proposal[0] and gradients d_logprob[0]/[1] are used inconsistently. If the intent is a single-step MALA, the current two-step scan is unnecessary and error-prone.

As a robust fix, simplify to a single-step MALA kernel (no scan), computing gradients at x and at the proposal for the corrected MH ratio:

def kernel(
    self,
    rng_key: PRNGKeyArray,
    position: Float[Array, " n_dim"],
    log_prob: Float[Array, "1"],
    logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
    data: PyTree,
) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]:
    dt: Float = self.step_size
    dt2 = dt * dt
    key1, key2 = jax.random.split(rng_key)

    logp_x, grad_x = jax.value_and_grad(logpdf)(position, data)
    noise = jax.random.normal(key1, shape=position.shape)
    proposal = position + 0.5 * dt2 * grad_x + dt * noise

    logp_prop, grad_prop = jax.value_and_grad(logpdf)(proposal, data)
    cov = dt2 * jnp.eye(position.shape[-1], dtype=position.dtype)

    log_q_x_to_prop = multivariate_normal.logpdf(
        proposal, position + 0.5 * dt2 * grad_x, cov
    )
    log_q_prop_to_x = multivariate_normal.logpdf(
        position, proposal + 0.5 * dt2 * grad_prop, cov
    )

    ratio = (logp_prop - logp_x) + (log_q_prop_to_x - log_q_x_to_prop)
    do_accept = jnp.log(jax.random.uniform(key2)) < ratio

    new_position = jnp.where(do_accept, proposal, position)
    new_log_prob = jnp.where(do_accept, logp_prop, logp_x)
    return new_position, new_log_prob, do_accept

If multi-step integration is desired, rework the scan to propagate a single proposal and use consistent logp/grad pairs for x and x'.

♻️ Duplicate comments (1)
test/integration/test_HMC.py (1)

19-19: Same note about noisy print as in the MALA test.

The print("compile count") in dual_moon_pe is likely to clutter logs.

🧹 Nitpick comments (37)
src/flowMC/strategy/train_model.py (1)

73-80: Guard against empty training data and sample a bounded number of examples.

If training_data.shape[0] == 0, jax.random.choice will error. Also, requesting self.n_max_examples regardless of available data is wasteful.

Apply:

-        rng_key, subkey = jax.random.split(rng_key)
-        training_data = training_data[
-            jax.random.choice(
-                subkey,
-                jnp.arange(training_data.shape[0]),
-                shape=(self.n_max_examples,),
-                replace=True,
-            )
-        ]
+        rng_key, subkey = jax.random.split(rng_key)
+        num = min(self.n_max_examples, int(training_data.shape[0]))
+        if num == 0:
+            raise ValueError("No finite training samples available for NF training.")
+        idx = jax.random.choice(
+            subkey,
+            jnp.arange(training_data.shape[0]),
+            shape=(num,),
+            replace=training_data.shape[0] < num,
+        )
+        training_data = training_data[idx]
src/flowMC/strategy/parallel_tempering.py (3)

53-57: Minor: fix return type annotation to match input dim name.

Return annotation uses n_dim while the input uses n_dims. Aligning helps jaxtyping/static checks.

-        Float[Array, "n_chains n_dim"],
+        Float[Array, "n_chains n_dims"],

205-217: Typo in method name: _individal_step → _individual_step.

Renaming improves readability and avoids future search/grep confusion.

-def _individal_step(
+def _individual_step(
@@
-        positions, log_probs, do_accept = jax.vmap(
-            self._individal_step, in_axes=(None, 0, 0, None, 0, None)
+        positions, log_probs, do_accept = jax.vmap(
+            self._individual_step, in_axes=(None, 0, 0, None, 0, None)

Also applies to: 285-287


213-217: Align do_accept shape with its annotation (or update annotation).

_individual_step currently returns do_accept from lax.scan, which has shape (n_steps, 1), but the type annotation says Int[Array, "1"]. This mismatches annotations here and in _ensemble_step (which declares Int[Array, " n_temps"]).

Two options:

  • Return only the last acceptance per temperature (matches many PT diagnostics).
  • Or update annotations to reflect full (n_steps, 1) per temperature.

Option A (return last acceptance flag):

-        return position, log_prob, do_accept
+        return position, log_prob, do_accept[-1].squeeze()

Option B (update _individual_step return annotation):

-        Int[Array, "1"],
+        Int[Array, " n_steps 1"],

If you pick Option B, also update _ensemble_step’s return annotation to Int[Array, " n_temps n_steps 1"] or compute an aggregate (e.g., mean) before returning.

Also applies to: 239-249, 285-290

src/flowMC/strategy/sequential_monte_carlo.py (1)

12-18: Prefer Mapping to express read-only inputs; document mutation semantics.

If resources is not mutated in place, accept Mapping[str, Resource] and return a new dict[str, Resource]. If it is mutated, prefer MutableMapping[str, Resource] and drop returning a new dict to avoid ambiguity. Add a short docstring to specify expectations (e.g., particle resampling, adaptation state updates).

Would you like me to update the signature and add a docstring template for implementers?

src/flowMC/resource/model/common.py (4)

20-26: Consider making condition optional in the base interface.

Some bijectors won’t need a conditioner; others (e.g., coupling layers) will. Using condition: Optional[Float[Array, " n_condition"]] = None in Bijection.__call__/forward/inverse accommodates both, reducing boilerplate no-ops in simple bijectors.


150-159: MaskedCouplingLayer.forward: condition argument is ignored.

You always condition the inner bijector on x * self.mask and drop the provided condition. Either:

  • Document that condition is intentionally unused, or
  • Thread it through to the bijector (e.g., concatenate with masked inputs if the bijector expects it).

If you want to pass it through, a safe pattern is:

cond = x * self.mask if condition is None else jnp.concatenate([x * self.mask, condition])
y, log_det = self.bijector(x, cond)

Pick one approach and keep it consistent in inverse too.

Do you intend this layer to support external conditioning (e.g., context variables)? If yes, I can propose a concrete signature and conditioner wiring pattern.


191-197: Nit: fix comment wording for clarity.

Change “this note output” to “this returns”.

Apply this diff:

-        # Note that this note output log_det as an array instead of a number.
+        # Note: this returns log_det as an array (per-dimension) instead of a scalar.

304-309: Composable.log_prob relies on dict iteration order matching distributions list.

zip(self.distributions, self.partitions.items()) assumes the insertion order of partitions aligns with distributions. That is brittle and can silently misalign slices. Prefer an explicit mapping from key -> Distribution or store an ordered list of keys.

Minimal change: make distributions a dict keyed by the same keys as partitions and iterate keys.

I can refactor Composable to distributions: dict[str, Distribution] and update both log_prob and sample accordingly. Want me to send a patch?

docs/tutorials/parallel_tempering.ipynb (1)

22-22: Optional: shorten the import if the package re-exports MALA

If flowMC/resource/kernel/__init__.py re-exports MALA, consider importing from the package to reduce churn if filenames change. Otherwise ignore.

-from flowMC.resource.kernel.MALA import MALA
+from flowMC.resource.kernel import MALA
src/flowMC/resource/kernel/HMC.py (2)

90-96: Avoid print during compute-intensive paths; prefer logging or remove.

print("Compiling leapfrog step") will spam stdout and interfere under JIT/vmap. Replace with a logger at debug level or remove.

-        print("Compiling leapfrog step")
+        # Consider using a logger at debug level or remove in production.

153-157: Avoid noisy prints in library code; use logging.

These parameter prints are useful for debugging but should not run on every call in library code.

-    def print_parameters(self):
-        print("HMC parameters:")
-        print(f"step_size: {self.step_size}")
-        print(f"n_leapfrog: {self.n_leapfrog}")
-        print(f"condition_matrix: {self.condition_matrix}")
+    def print_parameters(self):
+        # Prefer a logger.debug/info here instead of print to avoid noisy stdout.
+        pass
test/unit/test_nf.py (1)

56-64: Minor typo in local variable name (hidden_layes).

Harmless, but consider renaming to hidden_layers for clarity.

-    hidden_layes = [16, 16]
+    hidden_layers = [16, 16]
@@
-    model = MaskedCouplingRQSpline(
-        n_features, n_layers, hidden_layes, n_bins, jax.random.PRNGKey(10)
-    )
+    model = MaskedCouplingRQSpline(
+        n_features, n_layers, hidden_layers, n_bins, jax.random.PRNGKey(10)
+    )
test/integration/test_MALA.py (1)

19-19: Avoid noisy prints in tests’ target logpdf (use debug-print or guard).

print("compile count") will spam test output on JIT compile/eval and can lead to flaky golden-output checks. Consider removing it or wrapping it via jax.debug.print under a debug flag.

test/unit/test_kernels.py (1)

137-141: The “close_gaussian” tests are very long-running; consider marking as slow or reducing steps.

n_local_steps = 30000/50000 will slow CI. Either mark these tests with @pytest.mark.slow or reduce steps and relax tolerances.

Example:

import pytest

@pytest.mark.slow
def test_HMC_close_gaussian():
    # possibly reduce n_local_steps to 5_000 and relax atol
    ...

Also applies to: 245-246, 346-347

src/flowMC/resource/kernel/Gaussian_random_walk.py (1)

53-61: Type annotations don’t match actual values: proposal_log_prob is scalar and do_accept is boolean.

  • proposal_log_prob should be a scalar (shape "1"), not " n_dim".
  • do_accept is a boolean but the return annotation declares Int.

These mismatches can confuse type checkers and readers.

Suggested adjustments:

# In the return annotation of `kernel`:
# -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Bool[Array, "1"]]

# For proposal_log_prob:
proposal_log_prob: Float[Array, "1"] = logpdf(proposal, data)

# Optionally, annotate do_accept explicitly:
do_accept: Bool = log_uniform < (proposal_log_prob - log_prob)

Also consider normalizing jaxtyping shapes to remove leading spaces, e.g., "n_dim" instead of " n_dim", for consistency across the codebase.

test/integration/test_normalizingFlow.py (1)

20-20: Clarify optimizer hyperparameters: “momentum” isn’t an Adam parameter.

optax.adam doesn’t have a “momentum” parameter; the second positional arg maps to b1 (default 0.9). To avoid confusion, pass it by name or rename the variable.

For clarity:

optim = optax.adam(learning_rate, b1=0.9)
# or rename the variable:
beta1 = 0.9
optim = optax.adam(learning_rate, b1=beta1)

Also applies to: 53-53

src/flowMC/resource/kernel/MALA.py (2)

84-84: Acceptance mask type annotation is incorrect

do_accept is a scalar boolean but annotated as Bool[Array, " n_dim"]. This is misleading and may trip static/type checks.

Apply this diff:

-        do_accept: Bool[Array, " n_dim"] = log_uniform < ratio
+        do_accept: Bool = log_uniform < ratio

56-56: Avoid print inside JIT/scan body

print("Compiling MALA body") inside a JAX scan creates noisy logs and can interact poorly with JIT. Gate behind a verbosity flag or remove.

src/flowMC/resource/kernel/NF_proposal.py (5)

27-36: Return type hints don’t match actual shapes; and they diverge from other kernels

kernel returns (positions, log_prob, do_accept) with shapes (n_steps, n_dim), (n_steps,), (n_steps,), but the annotations say "n_step 1" for the last two. This causes inconsistency with ProposalBase and other kernels (which typically return scalars per step).

Apply this diff to correct local annotations (or update calling code to expect 2D trailing dims consistently):

-    ) -> tuple[
-        Float[Array, "n_step n_dim"], Float[Array, "n_step 1"], Int[Array, "n_step 1"]
-    ]:
+    ) -> tuple[
+        Float[Array, "n_step n_dim"], Float[Array, "n_step"], Int[Array, "n_step"]
+    ]:

Longer-term, consider distinguishing single-step and multi-step proposals with separate base interfaces (e.g., LocalProposalBase vs GroupProposalBase) to avoid shape drift across kernels.


38-41: Remove or gate debug prints inside JIT code paths

print("Compiling NF proposal kernel") will fire during JIT tracing and can be noisy. Gate it behind a verbosity flag or remove.


43-50: JIT inside hot path can cause repeated recompiles

eqx.filter_jit(self.model.log_prob) and eqx.filter_jit(self.sample_flow) are rebuilt at call sites. Prefer pre-binding jitted callables in init (or module-level) to reduce compilation overhead.


136-136: No-op assignment

rng_key = rng_key is a no-op and can be removed.

Apply this diff:

-            rng_key = rng_key

141-153: Avoid print during scan/JIT

print("Compiling sample_flow") inside lax.scan is noisy. Remove or guard with a flag.

src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (2)

127-128: Prefer jnp alias consistently

Use the jnp alias already imported for np ops to keep style consistent.

Apply this diff:

-        temperatures.update_buffer(
-            jax.numpy.linspace(1.0, max_temperature, n_temperatures)
-        )
+        temperatures.update_buffer(
+            jnp.linspace(1.0, max_temperature, n_temperatures)
+        )

271-274: Avoid printing in library code; consider logger or verbose flag

Printing when n_tempered_steps <= 0 is surprising for library consumers. Either raise a ValueError or guard with a verbose flag.

src/flowMC/resource/model/nf_model/realNVP.py (1)

18-100: Unused AffineCoupling implementation

AffineCoupling appears unused in this module (RealNVP uses MaskedCouplingLayer + MLPAffine). Consider removing or moving it to a dedicated bijection module to avoid duplication and dead code.

test/unit/test_flowmatching.py (3)

44-49: Potential flakiness: log_prob scalar shape vs vmap expectation

You assert logp.shape == (n_samples, 1) after eqx.filter_vmap(model.log_prob)(samples). Today FlowMatchingModel.log_prob returns a scalar; vmap over a scalar typically produces shape (n_samples,), not (n_samples, 1). If you intend a column vector, ensure log_prob returns shape (1,). See my suggested change in FlowMatchingModel.log_prob to wrap the scalar into a length-1 array.


84-91: Scheduler tuple contents: rely only on array-like outputs

Good coverage to assert the 4-tuple contract. Casting each element to a Python float forces materialization on device-host boundary and will fail for shaped arrays (e.g., if t is batched). If you plan to extend this to batched t, consider checking JAX array-like via hasattr(x, "dtype") or jnp.shape(x), rather than float(x).


174-191: train_epoch uses integer batch_size; keep annotation/types consistent

batch_size is used for indexing and reshaping; tests pass an int. Ensure the production signature treats it as int (not Float). I’ve suggested the fix in the implementation.

src/flowMC/resource/model/flowmatching/base.py (7)

44-52: Optional: avoid storing full ODE trajectories in sample

For sampling you only need the final state. Save only t1 to reduce memory/time, and drop the [-1] indexing.

             sol = diffeqsolve(
                 term,
                 self.method,
                 t0=0.0,
                 t1=1.0,
                 dt0=dt,
-                y0=y0,
+                y0=y0,
+                saveat=SaveAt(t1=True),
             )
-            return sol.ys[-1]  # type: ignore
+            return sol.ys  # type: ignore

92-98: Nit: simplify base Gaussian logpdf mean construction

self.model.n_output * jnp.zeros(self.model.n_output) is redundant. Use jnp.zeros(self.model.n_output) for clarity.

-            logpdf(
-                x1,
-                mean=self.model.n_output * jnp.zeros(self.model.n_output),
-                cov=jnp.eye(self.model.n_output),
-            )
+            logpdf(x1, mean=jnp.zeros(self.model.n_output), cov=jnp.eye(self.model.n_output))

204-218: Don’t print from a jitted training step

print("Compiling training step") inside a JIT-ted function will spam stdout on recompilations and can interfere with benchmarks. Remove or guard behind a verbosity flag.

-        print("Compiling training step")

219-231: Type of batch_size should be int

It’s used for slicing/reshaping. Annotate as int to avoid confusion and potential issues with floor division producing floats.

-        batch_size: Float,
+        batch_size: int,

289-319: Minor: progress bar UX and description text

  • Description says “Training NF” while this is a Flow Matching model; rename for clarity.
  • epoch == num_epochs is never true; last epoch is num_epochs - 1.
  • If num_epochs <= 10, the description will never update; consider always updating on the last iteration.
-            pbar = trange(num_epochs, desc="Training NF", miniters=int(num_epochs / 10))
+            pbar = trange(num_epochs, desc="Training FlowMatching", miniters=max(1, int(num_epochs / 10)))
@@
-            if verbose:
-                assert isinstance(pbar, tqdm)
+            if verbose:
+                assert isinstance(pbar, tqdm)
                 if num_epochs > 10:
                     if epoch % int(num_epochs / 10) == 0:
                         pbar.set_description(f"Training NF, current loss: {value:.3f}")
                 else:
-                    if epoch == num_epochs:
-                        pbar.set_description(f"Training NF, current loss: {value:.3f}")
+                    if epoch == num_epochs - 1:
+                        pbar.set_description(f"Training FlowMatching, current loss: {value:.3f}")

297-299: Consistency: covariance update vs diagonal-only code paths

You set _data_cov = jnp.cov(data[1].T) (full cov), but pre-change sampling/log_prob used only the diagonal. With the suggested Cholesky refactor, this inconsistency is resolved. If you decide to keep diagonal-only stats, then compute and store only the diagonal explicitly to avoid misleading consumers.


139-150: Return types for properties

Annotate property return types for better tooling support:

  • def n_features(self) -> int:
  • def data_mean(self) -> Float[Array, " n_dim"]:
  • def data_cov(self) -> Float[Array, " n_dim n_dim"]:

No behavioral change; improves readability and static checks.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 046c78c and 922efe4.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (29)
  • .pre-commit-config.yaml (1 hunks)
  • docs/tutorials/custom_strategy.ipynb (1 hunks)
  • docs/tutorials/parallel_tempering.ipynb (1 hunks)
  • docs/tutorials/train_normalizing_flow.ipynb (1 hunks)
  • pyproject.toml (2 hunks)
  • src/flowMC/resource/kernel/Gaussian_random_walk.py (1 hunks)
  • src/flowMC/resource/kernel/HMC.py (1 hunks)
  • src/flowMC/resource/kernel/MALA.py (1 hunks)
  • src/flowMC/resource/kernel/NF_proposal.py (1 hunks)
  • src/flowMC/resource/model/common.py (1 hunks)
  • src/flowMC/resource/model/flowmatching/base.py (1 hunks)
  • src/flowMC/resource/model/nf_model/base.py (0 hunks)
  • src/flowMC/resource/model/nf_model/realNVP.py (1 hunks)
  • src/flowMC/resource/model/nf_model/rqSpline.py (1 hunks)
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA.py (1 hunks)
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (1 hunks)
  • src/flowMC/strategy/parallel_tempering.py (1 hunks)
  • src/flowMC/strategy/sequential_monte_carlo.py (1 hunks)
  • src/flowMC/strategy/take_steps.py (1 hunks)
  • src/flowMC/strategy/train_model.py (1 hunks)
  • test/integration/test_HMC.py (1 hunks)
  • test/integration/test_MALA.py (1 hunks)
  • test/integration/test_RWMCMC.py (1 hunks)
  • test/integration/test_normalizingFlow.py (1 hunks)
  • test/unit/test_flowmatching.py (1 hunks)
  • test/unit/test_kernels.py (1 hunks)
  • test/unit/test_nf.py (1 hunks)
  • test/unit/test_resources.py (1 hunks)
  • test/unit/test_strategies.py (1 hunks)
💤 Files with no reviewable changes (1)
  • src/flowMC/resource/model/nf_model/base.py
👮 Files not reviewed due to content moderation or server errors (1)
  • docs/tutorials/train_normalizing_flow.ipynb
🧰 Additional context used
🧬 Code Graph Analysis (20)
test/unit/test_kernels.py (4)
src/flowMC/resource/kernel/Gaussian_random_walk.py (2)
  • kernel (25-61)
  • GaussianRandomWalk (10-71)
src/flowMC/resource/kernel/MALA.py (2)
  • kernel (26-89)
  • MALA (11-99)
src/flowMC/resource/kernel/HMC.py (2)
  • kernel (98-151)
  • HMC (11-163)
src/flowMC/resource/kernel/base.py (1)
  • kernel (17-27)
test/integration/test_MALA.py (2)
src/flowMC/resource/kernel/MALA.py (2)
  • kernel (26-89)
  • MALA (11-99)
src/flowMC/resource/kernel/base.py (1)
  • kernel (17-27)
src/flowMC/resource_strategy_bundle/RQSpline_MALA.py (4)
src/flowMC/resource/kernel/NF_proposal.py (2)
  • kernel (27-128)
  • NFProposal (15-184)
src/flowMC/resource/kernel/MALA.py (2)
  • kernel (26-89)
  • MALA (11-99)
src/flowMC/resource/kernel/base.py (1)
  • kernel (17-27)
src/flowMC/resource/model/nf_model/rqSpline.py (1)
  • MaskedCouplingRQSpline (364-507)
test/unit/test_nf.py (2)
src/flowMC/resource/model/nf_model/realNVP.py (2)
  • AffineCoupling (18-99)
  • RealNVP (102-228)
src/flowMC/resource/model/nf_model/rqSpline.py (1)
  • MaskedCouplingRQSpline (364-507)
src/flowMC/resource/kernel/HMC.py (1)
src/flowMC/resource/kernel/base.py (2)
  • kernel (17-27)
  • ProposalBase (9-27)
src/flowMC/strategy/parallel_tempering.py (1)
src/flowMC/resource/kernel/base.py (2)
  • kernel (17-27)
  • ProposalBase (9-27)
src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (4)
src/flowMC/resource/kernel/NF_proposal.py (2)
  • kernel (27-128)
  • NFProposal (15-184)
src/flowMC/resource/kernel/MALA.py (2)
  • kernel (26-89)
  • MALA (11-99)
src/flowMC/resource/kernel/base.py (1)
  • kernel (17-27)
src/flowMC/resource/model/nf_model/rqSpline.py (1)
  • MaskedCouplingRQSpline (364-507)
src/flowMC/strategy/take_steps.py (1)
src/flowMC/resource/kernel/base.py (2)
  • kernel (17-27)
  • ProposalBase (9-27)
test/unit/test_resources.py (2)
src/flowMC/resource/kernel/MALA.py (2)
  • kernel (26-89)
  • MALA (11-99)
src/flowMC/resource/kernel/base.py (1)
  • kernel (17-27)
test/integration/test_RWMCMC.py (1)
src/flowMC/resource/kernel/Gaussian_random_walk.py (2)
  • kernel (25-61)
  • GaussianRandomWalk (10-71)
src/flowMC/resource/kernel/NF_proposal.py (2)
src/flowMC/resource/model/nf_model/base.py (1)
  • NFModel (14-245)
src/flowMC/resource/kernel/base.py (1)
  • ProposalBase (9-27)
src/flowMC/strategy/sequential_monte_carlo.py (1)
src/flowMC/resource/base.py (1)
  • Resource (6-38)
src/flowMC/resource/kernel/Gaussian_random_walk.py (1)
src/flowMC/resource/kernel/base.py (2)
  • kernel (17-27)
  • ProposalBase (9-27)
test/integration/test_normalizingFlow.py (2)
src/flowMC/resource/model/nf_model/realNVP.py (1)
  • RealNVP (102-228)
src/flowMC/resource/model/nf_model/rqSpline.py (1)
  • MaskedCouplingRQSpline (364-507)
test/unit/test_flowmatching.py (2)
src/flowMC/resource/model/flowmatching/base.py (17)
  • FlowMatchingModel (132-328)
  • Solver (15-98)
  • Path (117-129)
  • CondOTScheduler (108-114)
  • sample (24-56)
  • sample (124-129)
  • sample (171-178)
  • log_prob (58-98)
  • log_prob (180-184)
  • data_mean (144-145)
  • data_cov (148-149)
  • save_model (186-187)
  • load_model (189-190)
  • n_features (140-141)
  • print_parameters (325-328)
  • train_step (205-217)
  • train_epoch (219-257)
src/flowMC/resource/model/common.py (5)
  • MLP (68-124)
  • n_input (115-116)
  • n_output (119-120)
  • mean (262-266)
  • cov (269-273)
test/integration/test_HMC.py (2)
src/flowMC/resource/kernel/HMC.py (2)
  • kernel (98-151)
  • HMC (11-163)
src/flowMC/resource/kernel/base.py (1)
  • kernel (17-27)
src/flowMC/resource/kernel/MALA.py (3)
src/flowMC/resource/kernel/Gaussian_random_walk.py (1)
  • kernel (25-61)
src/flowMC/resource/kernel/HMC.py (1)
  • kernel (98-151)
src/flowMC/resource/kernel/base.py (2)
  • kernel (17-27)
  • ProposalBase (9-27)
src/flowMC/resource/model/common.py (4)
src/flowMC/resource/model/nf_model/base.py (4)
  • forward (63-75)
  • inverse (78-90)
  • log_prob (55-56)
  • sample (59-60)
src/flowMC/resource/model/nf_model/realNVP.py (6)
  • forward (63-82)
  • forward (174-190)
  • inverse (84-99)
  • inverse (192-208)
  • log_prob (216-223)
  • sample (210-214)
src/flowMC/resource/model/nf_model/rqSpline.py (6)
  • forward (347-353)
  • forward (450-468)
  • inverse (355-361)
  • inverse (470-488)
  • log_prob (498-504)
  • sample (490-496)
src/flowMC/resource/model/flowmatching/base.py (5)
  • log_prob (58-98)
  • log_prob (180-184)
  • sample (24-56)
  • sample (124-129)
  • sample (171-178)
src/flowMC/resource/model/nf_model/rqSpline.py (2)
src/flowMC/resource/model/nf_model/base.py (1)
  • NFModel (14-245)
src/flowMC/resource/model/common.py (2)
  • Distribution (44-65)
  • Bijection (10-41)
src/flowMC/resource/model/flowmatching/base.py (4)
src/flowMC/resource/base.py (1)
  • Resource (6-38)
test/unit/test_flowmatching.py (2)
  • model (100-116)
  • solver (34-41)
src/flowMC/resource/model/common.py (8)
  • MLP (68-124)
  • sample (62-65)
  • sample (288-293)
  • sample (310-317)
  • n_input (115-116)
  • mean (262-266)
  • n_output (119-120)
  • cov (269-273)
test/unit/test_bundle.py (1)
  • logpdf (6-7)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: pre-commit (3.12)
  • GitHub Check: build (3.12)
  • GitHub Check: build (3.11)
🔇 Additional comments (25)
src/flowMC/strategy/take_steps.py (1)

2-2: Verification complete: no remaining flowMC.resource.local_kernel imports

Ran the provided ripgrep command across the repository and found zero occurrences of the old local_kernel path. The import in src/flowMC/strategy/take_steps.py and elsewhere correctly points to flowMC.resource.kernel.base.ProposalBase. No further updates are required.

src/flowMC/strategy/train_model.py (1)

4-4: NFModel import path update is consistent with the new module layout.

src/flowMC/strategy/parallel_tempering.py (1)

2-2: Import path migration to kernel.base looks correct.

Matches the refactor; ProposalBase is sourced from the new canonical location.

.pre-commit-config.yaml (1)

15-15: Adding diffrax to Pyright’s additional_dependencies is appropriate.

This ensures type stubs are available for the new Diffrax usage during pre-commit checks.

pyproject.toml (1)

27-27: Dependency compatibility confirmed for diffrax >= 0.7.0

Verified that Diffrax 0.7.0 is fully compatible with:

  • JAX 0.4.x–0.5.x and 0.6.1 (note: avoid JAX 0.6.0 due to a known upstream bug)
  • Equinox 0.11.x
  • Jaxtyping 0.2.x

No further changes to pyproject.toml are needed. Proceed with the addition of "diffrax>=0.7.0".

docs/tutorials/parallel_tempering.ipynb (1)

22-22: Import path update LGTM and aligns with kernel reorg

Switching to from flowMC.resource.kernel.MALA import MALA matches the new package layout and keeps the tutorial consistent with the refactor. No logic changes implied.

test/integration/test_RWMCMC.py (1)

6-6: Import path update to resource.kernel looks correct.

This aligns with the re-org. No issues spotted in this file related to the change.

src/flowMC/resource/kernel/HMC.py (1)

7-7: Import path change to kernel.base matches the re-org.

No functional impact from this import relocation. Good to go.

test/unit/test_resources.py (1)

5-5: Import path updated to resource.kernel is consistent with the refactor.

No additional changes needed here.

docs/tutorials/custom_strategy.ipynb (1)

32-32: Notebook import path correction looks good.

This matches the new module layout. Consider re-running the notebook to refresh outputs after the refactor.

Would you like me to open a follow-up to re-execute the tutorial notebooks in CI to ensure they still run end-to-end after the import path changes?

test/unit/test_nf.py (1)

4-5: NF model import paths updated correctly.

RealNVP and RQSpline imports now target the reorganized resource.model.nf_model modules. Tests should continue to pass unchanged.

test/integration/test_MALA.py (1)

6-6: Import Path Verification Complete: No Deprecated References Found

All occurrences of the old import paths (flowMC.resource.local_kernel, flowMC.resource.nf_model) have been removed. The update to flowMC.resource.kernel.MALA is correct—approving these changes.

test/integration/test_HMC.py (1)

6-6: Import path update matches the new kernel namespace (LGTM).

Consistent with the project-wide reorganization.

test/unit/test_kernels.py (1)

5-7: Kernel import path updates are correct and consistent (LGTM).

All three kernels now come from flowMC.resource.kernel.*, matching the refactor.

src/flowMC/resource/kernel/Gaussian_random_walk.py (1)

6-6: Import path fix to ProposalBase is correct (LGTM).

This aligns the kernel with the new base location.

test/integration/test_normalizingFlow.py (1)

6-7: NF model import path updates are correct (LGTM).

Matches the new flowMC.resource.model.nf_model.* layout.

src/flowMC/resource/kernel/MALA.py (1)

8-8: Import path update looks correct

Switching to flowMC.resource.kernel.base aligns with the new module layout. No issues spotted here.

src/flowMC/resource/kernel/NF_proposal.py (2)

10-11: Import path updates look correct

NFModel and ProposalBase now reference the reorganized modules under resource.model.nf_model and resource.kernel respectively.


27-36: No downstream code relies on a trailing singleton dimension for log_prob/do_accept

I’ve searched for any indexing or buffer updates that assume a final dimension of size 1 (e.g. [..., :, 1] or buffer shapes of (…, n_steps, 1)) and found none. All TakeSerialSteps/TakeGroupSteps pipelines and Buffer.update_buffer calls consume 1D or 2D arrays (e.g. (n_steps,) or (n_chains, n_steps)) as intended. The kernel annotations still mention "n_step 1", but at runtime log_prob and do_accept are produced as 1D arrays and fed into buffers with matching shapes.

No changes are needed to the strategy or buffer code.

src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (1)

12-14: Import path updates look correct

MALA, NFProposal, and MaskedCouplingRQSpline now point to the reorganized modules. Matches the PR’s restructuring.

src/flowMC/resource/model/nf_model/realNVP.py (1)

8-15: Import path consolidation looks good

NFModel and common building blocks (Distribution, MLP, Gaussian, MaskedCouplingLayer) now come from the reorganized modules. No issues.

test/unit/test_strategies.py (1)

6-11: Verified correct test imports — no legacy paths detected

File: test/unit/test_strategies.py
Lines: 6–11

from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline
from flowMC.resource.optimizer import Optimizer
from flowMC.resource.kernel.NF_proposal import NFProposal
from flowMC.resource.kernel.MALA import MALA
from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk
from flowMC.resource.kernel.HMC import HMC

• Ran RG searches across all .py files for old patterns (flowMC.resource.nf_model, flowMC.resource.local_kernel); found no matches.
• Test imports fully align with the reorganized module structure, confirming the strategy layer remains intact.

src/flowMC/resource_strategy_bundle/RQSpline_MALA.py (1)

11-13: LGTM: New imports and package structure verified

  • Confirmed class MALA in src/flowMC/resource/kernel/MALA.py
  • Confirmed class NFProposal in src/flowMC/resource/kernel/NF_proposal.py
  • Confirmed class MaskedCouplingRQSpline in src/flowMC/resource/model/nf_model/rqSpline.py
  • All necessary __init__.py files are present for package importability
src/flowMC/resource/model/nf_model/rqSpline.py (1)

9-12: Imports consolidated correctly—common.py exports all symbols

Confirmed that src/flowMC/resource/model/common.py defines or re-exports all of:

  • Distribution
  • Bijection
  • MLP
  • Gaussian
  • MaskedCouplingLayer
  • ScalarAffine

The consolidated imports in src/flowMC/resource/model/nf_model/rqSpline.py are valid.

src/flowMC/resource/model/flowmatching/base.py (1)

64-75: Ignore the sign‐flip suggestion – the current implementation is correct
Because you integrate backward (t0=1 → t1=0 with dt0 = –dt), returning +trace(div) yields
d log p = trace(div)·dt = trace(div)·(–|dt|) = –trace(div)·|dt|,
which matches the standard CNF ODE d(log p)/dt = –div f.
No change needed here.

Likely an incorrect or invalid review comment.

"from typing import Any\n",
"\n",
"from flowMC.resource.local_kernel.MALA import MALA\n",
"from flowMC.resource.kernel.MALA import MALA\n",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify no stale local_kernel imports remain across the repo

To avoid runtime/import errors in docs or examples, verify all references were updated from resource.local_kernel to resource.kernel, and that MALA.py exists at the new location.

Run:


🏁 Script executed:

#!/bin/bash
# Verify the new MALA module exists under the reorganized path.
fd -HI 'MALA.py' | rg -nP 'flowMC/.*/resource/.*/kernel/.*/MALA\.py|flowMC/.*/resource/kernel/MALA\.py' || true

# Find any lingering old imports anywhere (expect: no matches).
rg -nP -C2 '\bflowMC\.resource\.local_kernel\b' || true

# Confirm new-style imports are present (expect: matches in updated notebooks/tests).
rg -nP -C2 '\bfrom\s+flowMC\.resource\.kernel\.MALA\s+import\s+MALA\b' || true

Length of output: 3343


Missing MALA.py in flowMC/resource/kernel

The new imports reference flowMC.resource.kernel.MALA, but there is no MALA.py file under src/flowMC/resource/kernel/. Please add or move MALA.py (containing the MALA class) into that directory so all imports resolve.

All stale flowMC.resource.local_kernel references have been removed, and new-style imports are present in:

  • docs/tutorials/parallel_tempering.ipynb (line 22)
  • test/unit/test_strategies.py (line 9)
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA.py (line 11)
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (line 12)
  • test/unit/test_resources.py (line 5)
  • test/unit/test_kernels.py (line 7)
  • test/integration/test_MALA.py (line 6)
🤖 Prompt for AI Agents
In docs/tutorials/parallel_tempering.ipynb around line 22, the import references
flowMC.resource.kernel.MALA but there is no src/flowMC/resource/kernel/MALA.py;
add (or move) a MALA.py file into src/flowMC/resource/kernel/ that defines the
MALA class with the expected public API used by the listed files, and ensure the
module is importable (add/update __init__.py if necessary) so all imports in the
notebook and tests resolve.

Comment on lines +7 to +19
from abc import abstractmethod

from flowMC.resource.nf_model.base import Bijection, Distribution

class Bijection(eqx.Module):
"""Base class for bijective transformations.

This is an abstract template that should not be directly used.
"""

@abstractmethod
def __init__(self):
raise NotImplementedError

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

abstractmethod without ABC won’t enforce abstractness.

Bijection and Distribution don’t inherit from ABC, so @abstractmethod has no effect at class level. This allows accidental instantiation at runtime.

Apply this diff to enforce the contract:

-from abc import abstractmethod
+from abc import ABC, abstractmethod
@@
-class Bijection(eqx.Module):
+class Bijection(eqx.Module, ABC):
@@
-class Distribution(eqx.Module):
+class Distribution(eqx.Module, ABC):
🤖 Prompt for AI Agents
In src/flowMC/resource/model/common.py around lines 7 to 19, the @abstractmethod
has no effect because Bijection doesn't inherit from ABC; update the class to
inherit from abc.ABC (import ABC from abc) in addition to eqx.Module (e.g.,
class Bijection(ABC, eqx.Module)): this will enforce the abstract contract so
__init__ cannot be instantiated, and apply the same pattern for Distribution if
present elsewhere.

Comment on lines +62 to 65
self, rng_key: PRNGKeyArray, n_samples: int
) -> Float[Array, " n_samples n_features"]:
raise NotImplementedError

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Unify sample return shape variable: use n_dim consistently.

The base Distribution.sample returns shape " n_samples n_features" while the rest of the module uses " n_dim". This inconsistency will trip pyright/jaxtyping checks and downstream type expectations.

Apply this diff:

     def sample(
         self, rng_key: PRNGKeyArray, n_samples: int
-    ) -> Float[Array, " n_samples n_features"]:
+    ) -> Float[Array, " n_samples n_dim"]:
         raise NotImplementedError

You’ll also want to align concrete implementations—see the follow-up snippet for Gaussian.sample.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> Float[Array, " n_samples n_features"]:
raise NotImplementedError
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> Float[Array, " n_samples n_dim"]:
raise NotImplementedError
🤖 Prompt for AI Agents
In src/flowMC/resource/model/common.py around lines 62-65, the sample return
shape annotation uses " n_samples n_features" but the project standard is
"n_dim"; change the type string to " n_samples n_dim" and update any matching
type alias usage. Also update all concrete implementations (e.g.,
Gaussian.sample) to use the same " n_samples n_dim" return annotation so
pyright/jaxtyping types match end-to-end.

import jax.numpy as jnp
import jax
from jax.scipy.stats.multivariate_normal import logpdf
from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing import: SaveAt

You use SaveAt in diffeqsolve for efficient final-state saving. Import it.

-from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver
+from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver, SaveAt
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver
from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver, SaveAt
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around line 11, the diffrax
import list is missing SaveAt which is used with diffeqsolve; update the import
statement to also import SaveAt from diffrax (i.e., add SaveAt to the existing
import list) so SaveAt can be used when calling diffeqsolve.

Comment on lines +76 to +89
def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> PyTree:
"""Solve the ODE with initial condition y0."""
term = ODETerm(model_wrapper)
y_init = jax.tree.map(jnp.asarray, [y0, 0.0])
sol = diffeqsolve(
term,
self.method,
t0=1.0,
t1=0.0,
dt0=-dt,
y0=y_init,
)
return sol.ys

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: building augmented initial state with jax.tree.map and returning full trajectory

Two issues:

  • jax.tree.map is not a valid JAX API. Use jax.tree_map (or jax.tree_util.tree_map).
  • solve_ode returns sol.ys (the whole trajectory), but later you unpack x0, log_p = solve_ode(...) and then immediately use log_p, which should be the final scalar. Return only the final state to avoid shape/broadcast bugs and excess memory.

Apply:

-            y_init = jax.tree.map(jnp.asarray, [y0, 0.0])
+            y_init = jax.tree_map(jnp.asarray, [y0, 0.0])
-            sol = diffeqsolve(
+            sol = diffeqsolve(
                 term,
                 self.method,
                 t0=1.0,
                 t1=0.0,
                 dt0=-dt,
                 y0=y_init,
-            )
-            return sol.ys
+                saveat=SaveAt(t1=True),
+            )
+            return sol.ys  # final state only

Also add the missing import (see separate comment on Line 11).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> PyTree:
"""Solve the ODE with initial condition y0."""
term = ODETerm(model_wrapper)
y_init = jax.tree.map(jnp.asarray, [y0, 0.0])
sol = diffeqsolve(
term,
self.method,
t0=1.0,
t1=0.0,
dt0=-dt,
y0=y_init,
)
return sol.ys
def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> PyTree:
"""Solve the ODE with initial condition y0."""
term = ODETerm(model_wrapper)
y_init = jax.tree_map(jnp.asarray, [y0, 0.0])
sol = diffeqsolve(
term,
self.method,
t0=1.0,
t1=0.0,
dt0=-dt,
y0=y_init,
saveat=SaveAt(t1=True),
)
return sol.ys # final state only
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around lines 76–89, replace
the invalid jax.tree.map usage with the correct jax.tree_map (or
jax.tree_util.tree_map) when building the augmented initial state, and ensure
the module import for jax (or tree_util) is added as noted on line 11; also
change the function to return only the final ODE state (not sol.ys/the full
trajectory) so callers like x0, log_p = solve_ode(...) receive the final pair
(avoid returning the whole trajectory which causes shape/broadcast and memory
issues).

Comment on lines +124 to +130
def sample(self, x0: Float, x1: Float, t: Float) -> Float:
"""Sample a point along the path between x0 and x1 at time t."""
alpha_t, d_alpha_t, sigma_t, d_sigma_t = self.scheduler(t)
x_t = sigma_t * x0 + alpha_t * x1
dx_t = d_sigma_t * x0 + d_alpha_t * x1
return x_t, dx_t

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Incorrect return type annotation for Path.sample

The method returns a tuple (x_t, dx_t) but is annotated as returning a single Float. Fix the signature:

-    def sample(self, x0: Float, x1: Float, t: Float) -> Float:
+    def sample(self, x0: Float, x1: Float, t: Float) -> tuple[Float, Float]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sample(self, x0: Float, x1: Float, t: Float) -> Float:
"""Sample a point along the path between x0 and x1 at time t."""
alpha_t, d_alpha_t, sigma_t, d_sigma_t = self.scheduler(t)
x_t = sigma_t * x0 + alpha_t * x1
dx_t = d_sigma_t * x0 + d_alpha_t * x1
return x_t, dx_t
def sample(self, x0: Float, x1: Float, t: Float) -> tuple[Float, Float]:
"""Sample a point along the path between x0 and x1 at time t."""
alpha_t, d_alpha_t, sigma_t, d_sigma_t = self.scheduler(t)
x_t = sigma_t * x0 + alpha_t * x1
dx_t = d_sigma_t * x0 + d_alpha_t * x1
return x_t, dx_t
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around lines 124 to 130, the
sample method is annotated to return a single Float but actually returns a tuple
(x_t, dx_t); update the signature to return a tuple type (e.g., -> Tuple[Float,
Float]) and ensure Tuple (and Float if necessary) is imported from typing or
your project's type aliases; keep the implementation unchanged.

Comment on lines +171 to +179
def sample(
self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1
) -> Float[Array, " n_dim"]:
rng_key, subkey = jax.random.split(rng_key)
samples = self.solver.sample(subkey, num_samples, dt=dt)
std = jnp.sqrt(jnp.diag(self.data_cov))
samples = samples * std + self.data_mean
return samples

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Essential: sampling with full covariance (currently only uses diagonal)

data_cov is treated as diagonal (via sqrt(diag(...))), but train sets it to jnp.cov(data[1].T), which is full. Sampling should apply the full scale using the Cholesky factor.

-        std = jnp.sqrt(jnp.diag(self.data_cov))
-        samples = samples * std + self.data_mean
+        L = jnp.linalg.cholesky(self.data_cov)
+        samples = samples @ L.T + self.data_mean

This fixes geometry when the data have correlations.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sample(
self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1
) -> Float[Array, " n_dim"]:
rng_key, subkey = jax.random.split(rng_key)
samples = self.solver.sample(subkey, num_samples, dt=dt)
std = jnp.sqrt(jnp.diag(self.data_cov))
samples = samples * std + self.data_mean
return samples
def sample(
self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1
) -> Float[Array, " n_dim"]:
rng_key, subkey = jax.random.split(rng_key)
samples = self.solver.sample(subkey, num_samples, dt=dt)
L = jnp.linalg.cholesky(self.data_cov)
samples = samples @ L.T + self.data_mean
return samples
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around lines 171-179, sampling
currently rescales using only the diagonal (sqrt(diag(self.data_cov))) even
though train sets self.data_cov to the full covariance matrix; replace that
diagonal-only scaling with a full-covariance transform by computing the Cholesky
factor L of self.data_cov and applying the linear map defined by L to the
standard samples before adding self.data_mean (i.e., transform samples with L so
Cov(samples) = self.data_cov then add mean); include a safe fallback (e.g., use
diagonal sqrt) if Cholesky fails or if self.data_cov is already diagonal to
avoid runtime errors.

Comment on lines +180 to +185
def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
std = jnp.sqrt(jnp.diag(self.data_cov))
x_whitened = (x - self.data_mean) / std
log_det = -jnp.sum(jnp.log(std))
return self.solver.log_prob(x_whitened) + log_det

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Align whitening and log-det with full covariance and return shape (1,)

  • Whitening currently uses only the diagonal and returns a scalar (). Use Cholesky-based whitening consistent with the previous comment and return a (1,) vector to satisfy the tests’ vmap shape.
-    def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
-        std = jnp.sqrt(jnp.diag(self.data_cov))
-        x_whitened = (x - self.data_mean) / std
-        log_det = -jnp.sum(jnp.log(std))
-        return self.solver.log_prob(x_whitened) + log_det
+    def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
+        L = jnp.linalg.cholesky(self.data_cov)
+        # Solve L y = (x - mean) for y (lower-triangular solve)
+        y = jax.scipy.linalg.solve_triangular(L, x - self.data_mean, lower=True)
+        log_det = -jnp.log(jnp.diag(L)).sum()
+        return jnp.atleast_1d(self.solver.log_prob(y) + log_det)

Note: add from jax.scipy.linalg import solve_triangular if you prefer a direct import instead of jax.scipy.linalg.solve_triangular.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
std = jnp.sqrt(jnp.diag(self.data_cov))
x_whitened = (x - self.data_mean) / std
log_det = -jnp.sum(jnp.log(std))
return self.solver.log_prob(x_whitened) + log_det
def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
L = jnp.linalg.cholesky(self.data_cov)
# Solve L y = (x - mean) for y (lower-triangular solve)
y = jax.scipy.linalg.solve_triangular(L, x - self.data_mean, lower=True)
log_det = -jnp.log(jnp.diag(L)).sum()
return jnp.atleast_1d(self.solver.log_prob(y) + log_det)
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around lines 180 to 185, the
whitening and log-det computation currently use only the diagonal of the
covariance and return a scalar; replace this with Cholesky-based whitening using
the full covariance: compute L = jnp.linalg.cholesky(self.data_cov), compute
whitened x by solving L y = (x - self.data_mean) with solve_triangular (or
jax.scipy.linalg.solve_triangular) twice as needed for forward/back
substitution, compute log_det as -jnp.sum(jnp.log(jnp.diag(L))) (accounting for
the determinant via the Cholesky), pass the whitened x to self.solver.log_prob,
and ensure the function returns a 1D array of shape (1,) (e.g., wrap the scalar
log-prob + log_det in jnp.array([...]) or jnp.reshape(..., (1,))). Also add the
import for solve_triangular at the top of the file.

Comment on lines +5 to +20
class SequentialMonteCarlo(Resource):
def __init__(self):
raise NotImplementedError

def __call__(
self,
rng_key: PRNGKeyArray,
resources: dict[str, Resource],
initial_position: Float[Array, "n_chains n_dim"],
data: dict,
) -> tuple[
PRNGKeyArray,
dict[str, Resource],
Float[Array, "n_chains n_dim"],
]:
raise NotImplementedError
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Mark interface methods as abstract (don’t rely on raising at runtime).

Declare __init__ and __call__ as @abstractmethod to make the contract explicit and let static/type checkers catch instantiation of incomplete subclasses at import time. Since Resource already uses ABCMeta, the decorators will be honored.

Apply this diff:

 from flowMC.resource.base import Resource
 from jaxtyping import Array, Float, PRNGKeyArray
+from abc import abstractmethod
 
 
 class SequentialMonteCarlo(Resource):
-    def __init__(self):
-        raise NotImplementedError
+    @abstractmethod
+    def __init__(self):
+        raise NotImplementedError
 
-    def __call__(
+    @abstractmethod
+    def __call__(
         self,
         rng_key: PRNGKeyArray,
         resources: dict[str, Resource],
         initial_position: Float[Array, "n_chains n_dim"],
         data: dict,
     ) -> tuple[
         PRNGKeyArray,
         dict[str, Resource],
         Float[Array, "n_chains n_dim"],
     ]:
         raise NotImplementedError
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class SequentialMonteCarlo(Resource):
def __init__(self):
raise NotImplementedError
def __call__(
self,
rng_key: PRNGKeyArray,
resources: dict[str, Resource],
initial_position: Float[Array, "n_chains n_dim"],
data: dict,
) -> tuple[
PRNGKeyArray,
dict[str, Resource],
Float[Array, "n_chains n_dim"],
]:
raise NotImplementedError
from flowMC.resource.base import Resource
from jaxtyping import Array, Float, PRNGKeyArray
from abc import abstractmethod
class SequentialMonteCarlo(Resource):
@abstractmethod
def __init__(self):
raise NotImplementedError
@abstractmethod
def __call__(
self,
rng_key: PRNGKeyArray,
resources: dict[str, Resource],
initial_position: Float[Array, "n_chains n_dim"],
data: dict,
) -> tuple[
PRNGKeyArray,
dict[str, Resource],
Float[Array, "n_chains n_dim"],
]:
raise NotImplementedError
🤖 Prompt for AI Agents
In src/flowMC/strategy/sequential_monte_carlo.py around lines 5 to 20, the
interface methods __init__ and __call__ currently raise NotImplementedError at
runtime; mark them as abstract by importing and applying @abstractmethod to both
methods so the class is treated as an abstract base class and incomplete
subclasses cannot be instantiated at import time; remove the raise-only bodies
(you may keep a pass or leave the docstring) and ensure the module imports
abstractmethod from abc if not already present.

from jaxtyping import Array, Float

from flowMC.resource.local_kernel.Gaussian_random_walk import GaussianRandomWalk
from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Scan the repo for any stale local_kernel or old NF-model import paths.

To avoid stragglers after the module move, run this scan.


🏁 Script executed:

#!/bin/bash
# Find any remaining old import paths after the refactor.

set -euo pipefail

echo "Searching for stale local_kernel imports..."
rg -nP '\bflowMC\.resource\.local_kernel\b' -C2 || true

echo
echo "Searching for old nf_model root path imports..."
rg -nP '\bflowMC\.resource\.nf_model\b' -C2 || true

echo
echo "Searching for old model base locations that moved to model.common..."
rg -nP '\bflowMC\.resource\.model\.nf_model\.base\b' -C2 || true

Length of output: 2204


Stale NFModel Imports Need Updating

The scan uncovered leftover imports from the old flowMC.resource.model.nf_model.base path. Please update these to the new location in flowMC.resource.model.common:

• src/flowMC/strategy/train_model.py (line 4)

  • Change
    from flowMC.resource.model.nf_model.base import NFModel
    to
    from flowMC.resource.model.common import NFModel

• src/flowMC/resource/model/nf_model/realNVP.py (line 8)

  • Change
    from flowMC.resource.model.nf_model.base import NFModel
    to
    from flowMC.resource.model.common import NFModel

• src/flowMC/resource/model/nf_model/rqSpline.py (line 9)

  • Change
    from flowMC.resource.model.nf_model.base import NFModel
    to
    from flowMC.resource.model.common import NFModel

• src/flowMC/resource/kernel/NF_proposal.py (line 10)

  • Change
    from flowMC.resource.model.nf_model.base import NFModel
    to
    from flowMC.resource.model.common import NFModel

After adjusting these, rerun the import-scan to confirm no stale references remain.

🤖 Prompt for AI Agents
In test/integration/test_RWMCMC.py around lines 6 to 6 and the referenced source
files, there are stale imports of NFModel from
flowMC.resource.model.nf_model.base; update each occurrence to import NFModel
from flowMC.resource.model.common instead (specifically fix
src/flowMC/strategy/train_model.py line 4,
src/flowMC/resource/model/nf_model/realNVP.py line 8,
src/flowMC/resource/model/nf_model/rqSpline.py line 9, and
src/flowMC/resource/kernel/NF_proposal.py line 10), then run the import-scan to
verify no remaining stale references.

@kazewong kazewong merged commit a0e8795 into main Aug 18, 2025
6 of 9 checks passed
@kazewong kazewong deleted the flowMatchingSMC branch August 18, 2025 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant