Conversation
WalkthroughProject-wide module reorganization and dependency updates. Kernel imports moved from resource.local_kernel to resource.kernel; NF model imports consolidated under resource.model.*. Bijection/Distribution moved to model.common. Added a new flow-matching model module and an abstract SequentialMonteCarlo strategy. Updated docs/tests accordingly. Added diffrax as a runtime dependency and updated dev tooling. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant FM as FlowMatchingModel
participant Solver
participant Diffrax as ODE Solver
participant Path
participant Sched as Scheduler
User->>FM: sample(rng, n, dt)
FM->>Solver: sample(rng, n, dt)
Solver->>Diffrax: solve ODE (t:0→1) on z0~N(0,I)
Diffrax-->>Solver: x1 samples
Solver-->>FM: x1
FM->>FM: unwhiten (cov, mean)
FM-->>User: samples
User->>FM: log_prob(x)
FM->>FM: whiten x
FM->>Solver: log_prob(x_whitened, dt)
Solver->>Diffrax: reverse-time ODE + jacobian trace
Diffrax-->>Solver: log_prob term
Solver-->>FM: log_prob(z1)
FM->>FM: adjust by log|det(cov)|
FM-->>User: log_prob(x)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 10
🔭 Outside diff range comments (11)
src/flowMC/strategy/take_steps.py (1)
134-145: Fix buffer cursor advancement with thinning (off-by-one bug).You thin the arrays before writing to buffers, but then advance
current_positionbyself.n_steps // self.thinning. Whenself.n_stepsis not divisible byself.thinning, this undercounts (e.g., 5 steps with thinning=2 yields 3 writes but only +2 advancement), causing overlapping writes or misalignment across buffers.Advance by the actual number of thinned steps written.
Apply:
position_buffer.update_buffer(positions, self.current_position) log_prob_buffer.update_buffer(log_probs, self.current_position) acceptance_buffer.update_buffer(do_accepts, self.current_position) - self.current_position += self.n_steps // self.thinning + steps_written = int(positions.shape[1]) # number of thinned steps written + self.current_position += steps_writtensrc/flowMC/strategy/train_model.py (1)
66-72: Boolean indexing + reshape can corrupt chain alignment or fail; select valid timesteps across all chains instead.Indexing with a 2D boolean mask flattens across (chain, time) pairs; reshaping back to
(n_chains, -1, n_dims)assumes identical counts of valid timesteps per chain, which is not guaranteed and may error or silently misalign samples.Filter by timesteps that are valid across all chains to preserve shape.
Apply:
- n_chains = data_resource.data.shape[0] - n_dims = data_resource.data.shape[-1] - training_data = data_resource.data[ - jnp.isfinite(data_resource.data).all(axis=-1) - ].reshape(n_chains, -1, n_dims) - training_data = training_data[:, -self.history_window :].reshape(-1, n_dims) + data_array = data_resource.data # (n_chains, T, n_dims) + n_chains = data_array.shape[0] + n_dims = data_array.shape[-1] + # Keep only timesteps where all chains are finite across all dims to preserve structure + valid_t = jnp.all(jnp.isfinite(data_array).all(axis=-1), axis=0) # (T,) + training_data = data_array[:, valid_t, :] # (n_chains, T_valid, n_dims) + training_data = training_data[:, -self.history_window :].reshape(-1, n_dims)pyproject.toml (1)
4-4: Fix typo in project description (“exhanced” → “enhanced”).This is user-facing metadata on PyPI.
-description = "Normalizing flow exhanced sampler in jax" +description = "Normalizing-flow-enhanced sampler in JAX"src/flowMC/resource/model/common.py (3)
285-293: AlignGaussian.sampleannotation with base and actual shape.Currently annotated as
" n_samples n_features"but returns(n_samples, n_dim).Apply this diff:
def sample( self, rng_key: PRNGKeyArray, n_samples: int - ) -> Float[Array, " n_samples n_features"]: + ) -> Float[Array, " n_samples n_dim"]: return jax.random.multivariate_normal( rng_key, self.mean, self.cov, (n_samples,) )
114-121: Avoid relying on eqx.nn.Linear in_features/out_features (may not exist).Equinox
Linearreliably exposesweight(shape: out_features x in_features), but not all versions exposein_features/out_features. Derive fromweight.shapeto be robust.Apply this diff:
@property def n_input(self) -> int: - return self.layers[0].in_features + return int(self.layers[0].weight.shape[1]) @@ @property def n_output(self) -> int: - return self.layers[-1].out_features + return int(self.layers[-1].weight.shape[0])
310-317: Composable.sample should return a single concatenated array to honor the base interface.Returning a dict violates the
Distribution.samplecontract and forces callers to be special-cased. Concatenating the parts along feature axis keeps a clean API.Apply this diff:
def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> Float[Array, " n_samples n_features"]: - samples = {} - for dist, (key, _) in zip(self.distributions, self.partitions.items()): - rng_key, sub_key = jax.random.split(rng_key) - samples[key] = dist.sample(sub_key, n_samples=n_samples) - return samples # type: ignore + self, rng_key: PRNGKeyArray, n_samples: int + ) -> Float[Array, " n_samples n_dim"]: + parts: list[Array] = [] + for dist, (key, (start, end)) in zip(self.distributions, self.partitions.items()): + rng_key, sub_key = jax.random.split(rng_key) + part = dist.sample(sub_key, n_samples=n_samples) # (n_samples, end-start) + parts.append(part) + return jnp.concatenate(parts, axis=1)Note: This still inherits the ordering caveat from
zip(...)—see prior comment.src/flowMC/resource/kernel/HMC.py (2)
118-121: Kinetic energy uses incorrect metric shape; fix for full/diag metrics.
kineticcurrently multipliesp**2 * metricassuming a diagonal metric but elsewhere a full matrix is passed (e.g.,self.condition_matrix). This yields wrong energies and gradients under a full metric and also mismatches the type annotations.Apply this diff to support both scalar/diagonal and full metrics:
- def kinetic( - p: Float[Array, " n_dim"], metric: Float[Array, " n_dim"] - ) -> Float[Array, "1"]: - return 0.5 * (p**2 * metric).sum() + def kinetic( + p: Float[Array, " n_dim"], + metric: Float[Array, " n_dim"] | Float[Array, " n_dim n_dim"], + ) -> Float[Array, "1"]: + # Support scalar/diagonal (vector) or full metric + if jnp.ndim(metric) == 0 or (hasattr(metric, "ndim") and metric.ndim == 1): + return 0.5 * (p**2 * metric).sum() + else: + return 0.5 * jnp.dot(p, metric @ p)
130-136: Duplicate momentum sampling and reusing the same PRNG key; sample once and respect metric.Momentum is sampled twice with the same key and the first sample is immediately overwritten. Also, using the same key twice for different draws is an anti-pattern in JAX. Sample once and transform by the (inverse) metric properly.
Apply this diff to sample momentum correctly and deterministically:
- momentum: Float[Array, " n_dim"] = ( - jax.random.normal(key1, shape=position.shape) * self.condition_matrix**-0.5 - ) - momentum = jnp.dot( - jax.random.normal(key1, shape=position.shape), - jnp.linalg.cholesky(jnp.linalg.inv(self.condition_matrix)).T, - ) + key_mom, _ = jax.random.split(key1) + # Sample p ~ N(0, Metric^{-1}), supporting scalar/diag/full metrics. + if jnp.ndim(self.condition_matrix) == 0 or ( + hasattr(self.condition_matrix, "ndim") and self.condition_matrix.ndim == 1 + ): + std = jnp.sqrt(1.0 / self.condition_matrix) + momentum: Float[Array, " n_dim"] = jax.random.normal( + key_mom, shape=position.shape + ) * std + else: + z = jax.random.normal(key_mom, shape=position.shape) + L = jnp.linalg.cholesky(jnp.linalg.inv(self.condition_matrix)) + momentum = jnp.dot(z, L.T)src/flowMC/resource/kernel/MALA.py (3)
71-73: Same PRNG key used twice in scan; split distinct keys to avoid correlationkeys fed into lax.scan are constructed as [key1, key1], causing identical noise across both iterations. This reduces stochasticity and is likely unintended.
Apply this diff to use distinct keys:
- _, (proposal, logprob, d_logprob) = jax.lax.scan( - body, (position, dt, data), jnp.array([key1, key1]) - ) + keys = jax.random.split(key1, 2) + _, (proposal, logprob, d_logprob) = jax.lax.scan( + body, (position, dt, data), keys + )
76-81: Covariance passed as a scalar; use an isotropic covariance matrixmultivariate_normal.logpdf expects a covariance array. Passing dt2 (scalar) relies on undocumented broadcasting and can break with shape/type changes. Use dt2 * I.
Apply this diff:
- ratio -= multivariate_normal.logpdf( - proposal[0], position + jnp.dot(dt2, d_logprob[0]) / 2, dt2 - ) + cov = dt2 * jnp.eye(position.shape[-1], dtype=position.dtype) + ratio -= multivariate_normal.logpdf( + proposal[0], position + jnp.dot(dt2, d_logprob[0]) / 2, cov + ) @@ - ratio += multivariate_normal.logpdf( - position, proposal[0] + jnp.dot(dt2, d_logprob[1]) / 2, dt2 - ) + ratio += multivariate_normal.logpdf( + position, proposal[0] + jnp.dot(dt2, d_logprob[1]) / 2, cov + )
71-81: Algorithmic inconsistency: acceptance ratio mixes states from different sub-stepsratio uses logprob[1] - logprob[0] (log p at proposal_1 minus at x) but proposal[0] and gradients d_logprob[0]/[1] are used inconsistently. If the intent is a single-step MALA, the current two-step scan is unnecessary and error-prone.
As a robust fix, simplify to a single-step MALA kernel (no scan), computing gradients at x and at the proposal for the corrected MH ratio:
def kernel( self, rng_key: PRNGKeyArray, position: Float[Array, " n_dim"], log_prob: Float[Array, "1"], logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], data: PyTree, ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]: dt: Float = self.step_size dt2 = dt * dt key1, key2 = jax.random.split(rng_key) logp_x, grad_x = jax.value_and_grad(logpdf)(position, data) noise = jax.random.normal(key1, shape=position.shape) proposal = position + 0.5 * dt2 * grad_x + dt * noise logp_prop, grad_prop = jax.value_and_grad(logpdf)(proposal, data) cov = dt2 * jnp.eye(position.shape[-1], dtype=position.dtype) log_q_x_to_prop = multivariate_normal.logpdf( proposal, position + 0.5 * dt2 * grad_x, cov ) log_q_prop_to_x = multivariate_normal.logpdf( position, proposal + 0.5 * dt2 * grad_prop, cov ) ratio = (logp_prop - logp_x) + (log_q_prop_to_x - log_q_x_to_prop) do_accept = jnp.log(jax.random.uniform(key2)) < ratio new_position = jnp.where(do_accept, proposal, position) new_log_prob = jnp.where(do_accept, logp_prop, logp_x) return new_position, new_log_prob, do_acceptIf multi-step integration is desired, rework the scan to propagate a single proposal and use consistent logp/grad pairs for x and x'.
♻️ Duplicate comments (1)
test/integration/test_HMC.py (1)
19-19: Same note about noisy print as in the MALA test.The
print("compile count")indual_moon_peis likely to clutter logs.
🧹 Nitpick comments (37)
src/flowMC/strategy/train_model.py (1)
73-80: Guard against empty training data and sample a bounded number of examples.If
training_data.shape[0] == 0,jax.random.choicewill error. Also, requestingself.n_max_examplesregardless of available data is wasteful.Apply:
- rng_key, subkey = jax.random.split(rng_key) - training_data = training_data[ - jax.random.choice( - subkey, - jnp.arange(training_data.shape[0]), - shape=(self.n_max_examples,), - replace=True, - ) - ] + rng_key, subkey = jax.random.split(rng_key) + num = min(self.n_max_examples, int(training_data.shape[0])) + if num == 0: + raise ValueError("No finite training samples available for NF training.") + idx = jax.random.choice( + subkey, + jnp.arange(training_data.shape[0]), + shape=(num,), + replace=training_data.shape[0] < num, + ) + training_data = training_data[idx]src/flowMC/strategy/parallel_tempering.py (3)
53-57: Minor: fix return type annotation to match input dim name.Return annotation uses
n_dimwhile the input usesn_dims. Aligning helps jaxtyping/static checks.- Float[Array, "n_chains n_dim"], + Float[Array, "n_chains n_dims"],
205-217: Typo in method name: _individal_step → _individual_step.Renaming improves readability and avoids future search/grep confusion.
-def _individal_step( +def _individual_step( @@ - positions, log_probs, do_accept = jax.vmap( - self._individal_step, in_axes=(None, 0, 0, None, 0, None) + positions, log_probs, do_accept = jax.vmap( + self._individual_step, in_axes=(None, 0, 0, None, 0, None)Also applies to: 285-287
213-217: Align do_accept shape with its annotation (or update annotation).
_individual_stepcurrently returnsdo_acceptfromlax.scan, which has shape(n_steps, 1), but the type annotation saysInt[Array, "1"]. This mismatches annotations here and in_ensemble_step(which declaresInt[Array, " n_temps"]).Two options:
- Return only the last acceptance per temperature (matches many PT diagnostics).
- Or update annotations to reflect full
(n_steps, 1)per temperature.Option A (return last acceptance flag):
- return position, log_prob, do_accept + return position, log_prob, do_accept[-1].squeeze()Option B (update
_individual_stepreturn annotation):- Int[Array, "1"], + Int[Array, " n_steps 1"],If you pick Option B, also update
_ensemble_step’s return annotation toInt[Array, " n_temps n_steps 1"]or compute an aggregate (e.g., mean) before returning.Also applies to: 239-249, 285-290
src/flowMC/strategy/sequential_monte_carlo.py (1)
12-18: Prefer Mapping to express read-only inputs; document mutation semantics.If
resourcesis not mutated in place, acceptMapping[str, Resource]and return a newdict[str, Resource]. If it is mutated, preferMutableMapping[str, Resource]and drop returning a new dict to avoid ambiguity. Add a short docstring to specify expectations (e.g., particle resampling, adaptation state updates).Would you like me to update the signature and add a docstring template for implementers?
src/flowMC/resource/model/common.py (4)
20-26: Consider makingconditionoptional in the base interface.Some bijectors won’t need a conditioner; others (e.g., coupling layers) will. Using
condition: Optional[Float[Array, " n_condition"]] = NoneinBijection.__call__/forward/inverseaccommodates both, reducing boilerplate no-ops in simple bijectors.
150-159:MaskedCouplingLayer.forward:conditionargument is ignored.You always condition the inner bijector on
x * self.maskand drop the providedcondition. Either:
- Document that
conditionis intentionally unused, or- Thread it through to the bijector (e.g., concatenate with masked inputs if the bijector expects it).
If you want to pass it through, a safe pattern is:
cond = x * self.mask if condition is None else jnp.concatenate([x * self.mask, condition]) y, log_det = self.bijector(x, cond)Pick one approach and keep it consistent in
inversetoo.Do you intend this layer to support external conditioning (e.g., context variables)? If yes, I can propose a concrete signature and conditioner wiring pattern.
191-197: Nit: fix comment wording for clarity.Change “this note output” to “this returns”.
Apply this diff:
- # Note that this note output log_det as an array instead of a number. + # Note: this returns log_det as an array (per-dimension) instead of a scalar.
304-309: Composable.log_prob relies on dict iteration order matching distributions list.
zip(self.distributions, self.partitions.items())assumes the insertion order ofpartitionsaligns withdistributions. That is brittle and can silently misalign slices. Prefer an explicit mapping from key -> Distribution or store an ordered list of keys.Minimal change: make
distributionsa dict keyed by the same keys aspartitionsand iterate keys.I can refactor
Composabletodistributions: dict[str, Distribution]and update bothlog_probandsampleaccordingly. Want me to send a patch?docs/tutorials/parallel_tempering.ipynb (1)
22-22: Optional: shorten the import if the package re-exports MALAIf
flowMC/resource/kernel/__init__.pyre-exportsMALA, consider importing from the package to reduce churn if filenames change. Otherwise ignore.-from flowMC.resource.kernel.MALA import MALA +from flowMC.resource.kernel import MALAsrc/flowMC/resource/kernel/HMC.py (2)
90-96: Avoid print during compute-intensive paths; prefer logging or remove.
print("Compiling leapfrog step")will spam stdout and interfere under JIT/vmap. Replace with a logger at debug level or remove.- print("Compiling leapfrog step") + # Consider using a logger at debug level or remove in production.
153-157: Avoid noisy prints in library code; use logging.These parameter prints are useful for debugging but should not run on every call in library code.
- def print_parameters(self): - print("HMC parameters:") - print(f"step_size: {self.step_size}") - print(f"n_leapfrog: {self.n_leapfrog}") - print(f"condition_matrix: {self.condition_matrix}") + def print_parameters(self): + # Prefer a logger.debug/info here instead of print to avoid noisy stdout. + passtest/unit/test_nf.py (1)
56-64: Minor typo in local variable name (hidden_layes).Harmless, but consider renaming to
hidden_layersfor clarity.- hidden_layes = [16, 16] + hidden_layers = [16, 16] @@ - model = MaskedCouplingRQSpline( - n_features, n_layers, hidden_layes, n_bins, jax.random.PRNGKey(10) - ) + model = MaskedCouplingRQSpline( + n_features, n_layers, hidden_layers, n_bins, jax.random.PRNGKey(10) + )test/integration/test_MALA.py (1)
19-19: Avoid noisy prints in tests’ target logpdf (use debug-print or guard).
print("compile count")will spam test output on JIT compile/eval and can lead to flaky golden-output checks. Consider removing it or wrapping it viajax.debug.printunder a debug flag.test/unit/test_kernels.py (1)
137-141: The “close_gaussian” tests are very long-running; consider marking as slow or reducing steps.
n_local_steps = 30000/50000will slow CI. Either mark these tests with@pytest.mark.slowor reduce steps and relax tolerances.Example:
import pytest @pytest.mark.slow def test_HMC_close_gaussian(): # possibly reduce n_local_steps to 5_000 and relax atol ...Also applies to: 245-246, 346-347
src/flowMC/resource/kernel/Gaussian_random_walk.py (1)
53-61: Type annotations don’t match actual values: proposal_log_prob is scalar and do_accept is boolean.
proposal_log_probshould be a scalar (shape "1"), not" n_dim".do_acceptis a boolean but the return annotation declaresInt.These mismatches can confuse type checkers and readers.
Suggested adjustments:
# In the return annotation of `kernel`: # -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Bool[Array, "1"]] # For proposal_log_prob: proposal_log_prob: Float[Array, "1"] = logpdf(proposal, data) # Optionally, annotate do_accept explicitly: do_accept: Bool = log_uniform < (proposal_log_prob - log_prob)Also consider normalizing jaxtyping shapes to remove leading spaces, e.g.,
"n_dim"instead of" n_dim", for consistency across the codebase.test/integration/test_normalizingFlow.py (1)
20-20: Clarify optimizer hyperparameters: “momentum” isn’t an Adam parameter.
optax.adamdoesn’t have a “momentum” parameter; the second positional arg maps tob1(default 0.9). To avoid confusion, pass it by name or rename the variable.For clarity:
optim = optax.adam(learning_rate, b1=0.9) # or rename the variable: beta1 = 0.9 optim = optax.adam(learning_rate, b1=beta1)Also applies to: 53-53
src/flowMC/resource/kernel/MALA.py (2)
84-84: Acceptance mask type annotation is incorrectdo_accept is a scalar boolean but annotated as Bool[Array, " n_dim"]. This is misleading and may trip static/type checks.
Apply this diff:
- do_accept: Bool[Array, " n_dim"] = log_uniform < ratio + do_accept: Bool = log_uniform < ratio
56-56: Avoid print inside JIT/scan bodyprint("Compiling MALA body") inside a JAX scan creates noisy logs and can interact poorly with JIT. Gate behind a verbosity flag or remove.
src/flowMC/resource/kernel/NF_proposal.py (5)
27-36: Return type hints don’t match actual shapes; and they diverge from other kernelskernel returns (positions, log_prob, do_accept) with shapes (n_steps, n_dim), (n_steps,), (n_steps,), but the annotations say "n_step 1" for the last two. This causes inconsistency with ProposalBase and other kernels (which typically return scalars per step).
Apply this diff to correct local annotations (or update calling code to expect 2D trailing dims consistently):
- ) -> tuple[ - Float[Array, "n_step n_dim"], Float[Array, "n_step 1"], Int[Array, "n_step 1"] - ]: + ) -> tuple[ + Float[Array, "n_step n_dim"], Float[Array, "n_step"], Int[Array, "n_step"] + ]:Longer-term, consider distinguishing single-step and multi-step proposals with separate base interfaces (e.g., LocalProposalBase vs GroupProposalBase) to avoid shape drift across kernels.
38-41: Remove or gate debug prints inside JIT code pathsprint("Compiling NF proposal kernel") will fire during JIT tracing and can be noisy. Gate it behind a verbosity flag or remove.
43-50: JIT inside hot path can cause repeated recompileseqx.filter_jit(self.model.log_prob) and eqx.filter_jit(self.sample_flow) are rebuilt at call sites. Prefer pre-binding jitted callables in init (or module-level) to reduce compilation overhead.
136-136: No-op assignmentrng_key = rng_key is a no-op and can be removed.
Apply this diff:
- rng_key = rng_key
141-153: Avoid print during scan/JITprint("Compiling sample_flow") inside lax.scan is noisy. Remove or guard with a flag.
src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (2)
127-128: Prefer jnp alias consistentlyUse the jnp alias already imported for np ops to keep style consistent.
Apply this diff:
- temperatures.update_buffer( - jax.numpy.linspace(1.0, max_temperature, n_temperatures) - ) + temperatures.update_buffer( + jnp.linspace(1.0, max_temperature, n_temperatures) + )
271-274: Avoid printing in library code; consider logger or verbose flagPrinting when n_tempered_steps <= 0 is surprising for library consumers. Either raise a ValueError or guard with a verbose flag.
src/flowMC/resource/model/nf_model/realNVP.py (1)
18-100: Unused AffineCoupling implementationAffineCoupling appears unused in this module (RealNVP uses MaskedCouplingLayer + MLPAffine). Consider removing or moving it to a dedicated bijection module to avoid duplication and dead code.
test/unit/test_flowmatching.py (3)
44-49: Potential flakiness: log_prob scalar shape vs vmap expectationYou assert
logp.shape == (n_samples, 1)aftereqx.filter_vmap(model.log_prob)(samples). TodayFlowMatchingModel.log_probreturns a scalar; vmap over a scalar typically produces shape(n_samples,), not(n_samples, 1). If you intend a column vector, ensurelog_probreturns shape(1,). See my suggested change inFlowMatchingModel.log_probto wrap the scalar into a length-1 array.
84-91: Scheduler tuple contents: rely only on array-like outputsGood coverage to assert the 4-tuple contract. Casting each element to a Python
floatforces materialization on device-host boundary and will fail for shaped arrays (e.g., iftis batched). If you plan to extend this to batchedt, consider checking JAX array-like viahasattr(x, "dtype")orjnp.shape(x), rather thanfloat(x).
174-191: train_epoch uses integer batch_size; keep annotation/types consistent
batch_sizeis used for indexing and reshaping; tests pass anint. Ensure the production signature treats it asint(notFloat). I’ve suggested the fix in the implementation.src/flowMC/resource/model/flowmatching/base.py (7)
44-52: Optional: avoid storing full ODE trajectories insampleFor sampling you only need the final state. Save only
t1to reduce memory/time, and drop the[-1]indexing.sol = diffeqsolve( term, self.method, t0=0.0, t1=1.0, dt0=dt, - y0=y0, + y0=y0, + saveat=SaveAt(t1=True), ) - return sol.ys[-1] # type: ignore + return sol.ys # type: ignore
92-98: Nit: simplify base Gaussian logpdf mean construction
self.model.n_output * jnp.zeros(self.model.n_output)is redundant. Usejnp.zeros(self.model.n_output)for clarity.- logpdf( - x1, - mean=self.model.n_output * jnp.zeros(self.model.n_output), - cov=jnp.eye(self.model.n_output), - ) + logpdf(x1, mean=jnp.zeros(self.model.n_output), cov=jnp.eye(self.model.n_output))
204-218: Don’t print from a jitted training step
print("Compiling training step")inside a JIT-ted function will spam stdout on recompilations and can interfere with benchmarks. Remove or guard behind a verbosity flag.- print("Compiling training step")
219-231: Type of batch_size should be intIt’s used for slicing/reshaping. Annotate as
intto avoid confusion and potential issues with floor division producing floats.- batch_size: Float, + batch_size: int,
289-319: Minor: progress bar UX and description text
- Description says “Training NF” while this is a Flow Matching model; rename for clarity.
epoch == num_epochsis never true; last epoch isnum_epochs - 1.- If
num_epochs <= 10, the description will never update; consider always updating on the last iteration.- pbar = trange(num_epochs, desc="Training NF", miniters=int(num_epochs / 10)) + pbar = trange(num_epochs, desc="Training FlowMatching", miniters=max(1, int(num_epochs / 10))) @@ - if verbose: - assert isinstance(pbar, tqdm) + if verbose: + assert isinstance(pbar, tqdm) if num_epochs > 10: if epoch % int(num_epochs / 10) == 0: pbar.set_description(f"Training NF, current loss: {value:.3f}") else: - if epoch == num_epochs: - pbar.set_description(f"Training NF, current loss: {value:.3f}") + if epoch == num_epochs - 1: + pbar.set_description(f"Training FlowMatching, current loss: {value:.3f}")
297-299: Consistency: covariance update vs diagonal-only code pathsYou set
_data_cov = jnp.cov(data[1].T)(full cov), but pre-change sampling/log_prob used only the diagonal. With the suggested Cholesky refactor, this inconsistency is resolved. If you decide to keep diagonal-only stats, then compute and store only the diagonal explicitly to avoid misleading consumers.
139-150: Return types for propertiesAnnotate property return types for better tooling support:
def n_features(self) -> int:def data_mean(self) -> Float[Array, " n_dim"]:def data_cov(self) -> Float[Array, " n_dim n_dim"]:No behavioral change; improves readability and static checks.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (29)
.pre-commit-config.yaml(1 hunks)docs/tutorials/custom_strategy.ipynb(1 hunks)docs/tutorials/parallel_tempering.ipynb(1 hunks)docs/tutorials/train_normalizing_flow.ipynb(1 hunks)pyproject.toml(2 hunks)src/flowMC/resource/kernel/Gaussian_random_walk.py(1 hunks)src/flowMC/resource/kernel/HMC.py(1 hunks)src/flowMC/resource/kernel/MALA.py(1 hunks)src/flowMC/resource/kernel/NF_proposal.py(1 hunks)src/flowMC/resource/model/common.py(1 hunks)src/flowMC/resource/model/flowmatching/base.py(1 hunks)src/flowMC/resource/model/nf_model/base.py(0 hunks)src/flowMC/resource/model/nf_model/realNVP.py(1 hunks)src/flowMC/resource/model/nf_model/rqSpline.py(1 hunks)src/flowMC/resource_strategy_bundle/RQSpline_MALA.py(1 hunks)src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py(1 hunks)src/flowMC/strategy/parallel_tempering.py(1 hunks)src/flowMC/strategy/sequential_monte_carlo.py(1 hunks)src/flowMC/strategy/take_steps.py(1 hunks)src/flowMC/strategy/train_model.py(1 hunks)test/integration/test_HMC.py(1 hunks)test/integration/test_MALA.py(1 hunks)test/integration/test_RWMCMC.py(1 hunks)test/integration/test_normalizingFlow.py(1 hunks)test/unit/test_flowmatching.py(1 hunks)test/unit/test_kernels.py(1 hunks)test/unit/test_nf.py(1 hunks)test/unit/test_resources.py(1 hunks)test/unit/test_strategies.py(1 hunks)
💤 Files with no reviewable changes (1)
- src/flowMC/resource/model/nf_model/base.py
👮 Files not reviewed due to content moderation or server errors (1)
- docs/tutorials/train_normalizing_flow.ipynb
🧰 Additional context used
🧬 Code Graph Analysis (20)
test/unit/test_kernels.py (4)
src/flowMC/resource/kernel/Gaussian_random_walk.py (2)
kernel(25-61)GaussianRandomWalk(10-71)src/flowMC/resource/kernel/MALA.py (2)
kernel(26-89)MALA(11-99)src/flowMC/resource/kernel/HMC.py (2)
kernel(98-151)HMC(11-163)src/flowMC/resource/kernel/base.py (1)
kernel(17-27)
test/integration/test_MALA.py (2)
src/flowMC/resource/kernel/MALA.py (2)
kernel(26-89)MALA(11-99)src/flowMC/resource/kernel/base.py (1)
kernel(17-27)
src/flowMC/resource_strategy_bundle/RQSpline_MALA.py (4)
src/flowMC/resource/kernel/NF_proposal.py (2)
kernel(27-128)NFProposal(15-184)src/flowMC/resource/kernel/MALA.py (2)
kernel(26-89)MALA(11-99)src/flowMC/resource/kernel/base.py (1)
kernel(17-27)src/flowMC/resource/model/nf_model/rqSpline.py (1)
MaskedCouplingRQSpline(364-507)
test/unit/test_nf.py (2)
src/flowMC/resource/model/nf_model/realNVP.py (2)
AffineCoupling(18-99)RealNVP(102-228)src/flowMC/resource/model/nf_model/rqSpline.py (1)
MaskedCouplingRQSpline(364-507)
src/flowMC/resource/kernel/HMC.py (1)
src/flowMC/resource/kernel/base.py (2)
kernel(17-27)ProposalBase(9-27)
src/flowMC/strategy/parallel_tempering.py (1)
src/flowMC/resource/kernel/base.py (2)
kernel(17-27)ProposalBase(9-27)
src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (4)
src/flowMC/resource/kernel/NF_proposal.py (2)
kernel(27-128)NFProposal(15-184)src/flowMC/resource/kernel/MALA.py (2)
kernel(26-89)MALA(11-99)src/flowMC/resource/kernel/base.py (1)
kernel(17-27)src/flowMC/resource/model/nf_model/rqSpline.py (1)
MaskedCouplingRQSpline(364-507)
src/flowMC/strategy/take_steps.py (1)
src/flowMC/resource/kernel/base.py (2)
kernel(17-27)ProposalBase(9-27)
test/unit/test_resources.py (2)
src/flowMC/resource/kernel/MALA.py (2)
kernel(26-89)MALA(11-99)src/flowMC/resource/kernel/base.py (1)
kernel(17-27)
test/integration/test_RWMCMC.py (1)
src/flowMC/resource/kernel/Gaussian_random_walk.py (2)
kernel(25-61)GaussianRandomWalk(10-71)
src/flowMC/resource/kernel/NF_proposal.py (2)
src/flowMC/resource/model/nf_model/base.py (1)
NFModel(14-245)src/flowMC/resource/kernel/base.py (1)
ProposalBase(9-27)
src/flowMC/strategy/sequential_monte_carlo.py (1)
src/flowMC/resource/base.py (1)
Resource(6-38)
src/flowMC/resource/kernel/Gaussian_random_walk.py (1)
src/flowMC/resource/kernel/base.py (2)
kernel(17-27)ProposalBase(9-27)
test/integration/test_normalizingFlow.py (2)
src/flowMC/resource/model/nf_model/realNVP.py (1)
RealNVP(102-228)src/flowMC/resource/model/nf_model/rqSpline.py (1)
MaskedCouplingRQSpline(364-507)
test/unit/test_flowmatching.py (2)
src/flowMC/resource/model/flowmatching/base.py (17)
FlowMatchingModel(132-328)Solver(15-98)Path(117-129)CondOTScheduler(108-114)sample(24-56)sample(124-129)sample(171-178)log_prob(58-98)log_prob(180-184)data_mean(144-145)data_cov(148-149)save_model(186-187)load_model(189-190)n_features(140-141)print_parameters(325-328)train_step(205-217)train_epoch(219-257)src/flowMC/resource/model/common.py (5)
MLP(68-124)n_input(115-116)n_output(119-120)mean(262-266)cov(269-273)
test/integration/test_HMC.py (2)
src/flowMC/resource/kernel/HMC.py (2)
kernel(98-151)HMC(11-163)src/flowMC/resource/kernel/base.py (1)
kernel(17-27)
src/flowMC/resource/kernel/MALA.py (3)
src/flowMC/resource/kernel/Gaussian_random_walk.py (1)
kernel(25-61)src/flowMC/resource/kernel/HMC.py (1)
kernel(98-151)src/flowMC/resource/kernel/base.py (2)
kernel(17-27)ProposalBase(9-27)
src/flowMC/resource/model/common.py (4)
src/flowMC/resource/model/nf_model/base.py (4)
forward(63-75)inverse(78-90)log_prob(55-56)sample(59-60)src/flowMC/resource/model/nf_model/realNVP.py (6)
forward(63-82)forward(174-190)inverse(84-99)inverse(192-208)log_prob(216-223)sample(210-214)src/flowMC/resource/model/nf_model/rqSpline.py (6)
forward(347-353)forward(450-468)inverse(355-361)inverse(470-488)log_prob(498-504)sample(490-496)src/flowMC/resource/model/flowmatching/base.py (5)
log_prob(58-98)log_prob(180-184)sample(24-56)sample(124-129)sample(171-178)
src/flowMC/resource/model/nf_model/rqSpline.py (2)
src/flowMC/resource/model/nf_model/base.py (1)
NFModel(14-245)src/flowMC/resource/model/common.py (2)
Distribution(44-65)Bijection(10-41)
src/flowMC/resource/model/flowmatching/base.py (4)
src/flowMC/resource/base.py (1)
Resource(6-38)test/unit/test_flowmatching.py (2)
model(100-116)solver(34-41)src/flowMC/resource/model/common.py (8)
MLP(68-124)sample(62-65)sample(288-293)sample(310-317)n_input(115-116)mean(262-266)n_output(119-120)cov(269-273)test/unit/test_bundle.py (1)
logpdf(6-7)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: pre-commit (3.12)
- GitHub Check: build (3.12)
- GitHub Check: build (3.11)
🔇 Additional comments (25)
src/flowMC/strategy/take_steps.py (1)
2-2: Verification complete: no remainingflowMC.resource.local_kernelimportsRan the provided ripgrep command across the repository and found zero occurrences of the old
local_kernelpath. The import insrc/flowMC/strategy/take_steps.pyand elsewhere correctly points toflowMC.resource.kernel.base.ProposalBase. No further updates are required.src/flowMC/strategy/train_model.py (1)
4-4: NFModel import path update is consistent with the new module layout.src/flowMC/strategy/parallel_tempering.py (1)
2-2: Import path migration to kernel.base looks correct.Matches the refactor; ProposalBase is sourced from the new canonical location.
.pre-commit-config.yaml (1)
15-15: Adding diffrax to Pyright’s additional_dependencies is appropriate.This ensures type stubs are available for the new Diffrax usage during pre-commit checks.
pyproject.toml (1)
27-27: Dependency compatibility confirmed for diffrax >= 0.7.0Verified that Diffrax 0.7.0 is fully compatible with:
- JAX 0.4.x–0.5.x and 0.6.1 (note: avoid JAX 0.6.0 due to a known upstream bug)
- Equinox 0.11.x
- Jaxtyping 0.2.x
No further changes to
pyproject.tomlare needed. Proceed with the addition of"diffrax>=0.7.0".docs/tutorials/parallel_tempering.ipynb (1)
22-22: Import path update LGTM and aligns with kernel reorgSwitching to
from flowMC.resource.kernel.MALA import MALAmatches the new package layout and keeps the tutorial consistent with the refactor. No logic changes implied.test/integration/test_RWMCMC.py (1)
6-6: Import path update toresource.kernellooks correct.This aligns with the re-org. No issues spotted in this file related to the change.
src/flowMC/resource/kernel/HMC.py (1)
7-7: Import path change tokernel.basematches the re-org.No functional impact from this import relocation. Good to go.
test/unit/test_resources.py (1)
5-5: Import path updated toresource.kernelis consistent with the refactor.No additional changes needed here.
docs/tutorials/custom_strategy.ipynb (1)
32-32: Notebook import path correction looks good.This matches the new module layout. Consider re-running the notebook to refresh outputs after the refactor.
Would you like me to open a follow-up to re-execute the tutorial notebooks in CI to ensure they still run end-to-end after the import path changes?
test/unit/test_nf.py (1)
4-5: NF model import paths updated correctly.RealNVP and RQSpline imports now target the reorganized
resource.model.nf_modelmodules. Tests should continue to pass unchanged.test/integration/test_MALA.py (1)
6-6: Import Path Verification Complete: No Deprecated References FoundAll occurrences of the old import paths (
flowMC.resource.local_kernel,flowMC.resource.nf_model) have been removed. The update toflowMC.resource.kernel.MALAis correct—approving these changes.test/integration/test_HMC.py (1)
6-6: Import path update matches the new kernel namespace (LGTM).Consistent with the project-wide reorganization.
test/unit/test_kernels.py (1)
5-7: Kernel import path updates are correct and consistent (LGTM).All three kernels now come from
flowMC.resource.kernel.*, matching the refactor.src/flowMC/resource/kernel/Gaussian_random_walk.py (1)
6-6: Import path fix to ProposalBase is correct (LGTM).This aligns the kernel with the new base location.
test/integration/test_normalizingFlow.py (1)
6-7: NF model import path updates are correct (LGTM).Matches the new
flowMC.resource.model.nf_model.*layout.src/flowMC/resource/kernel/MALA.py (1)
8-8: Import path update looks correctSwitching to flowMC.resource.kernel.base aligns with the new module layout. No issues spotted here.
src/flowMC/resource/kernel/NF_proposal.py (2)
10-11: Import path updates look correctNFModel and ProposalBase now reference the reorganized modules under resource.model.nf_model and resource.kernel respectively.
27-36: No downstream code relies on a trailing singleton dimension for log_prob/do_acceptI’ve searched for any indexing or buffer updates that assume a final dimension of size 1 (e.g.
[..., :, 1]or buffer shapes of(…, n_steps, 1)) and found none. AllTakeSerialSteps/TakeGroupStepspipelines andBuffer.update_buffercalls consume 1D or 2D arrays (e.g.(n_steps,)or(n_chains, n_steps)) as intended. The kernel annotations still mention"n_step 1", but at runtimelog_probanddo_acceptare produced as 1D arrays and fed into buffers with matching shapes.No changes are needed to the strategy or buffer code.
src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (1)
12-14: Import path updates look correctMALA, NFProposal, and MaskedCouplingRQSpline now point to the reorganized modules. Matches the PR’s restructuring.
src/flowMC/resource/model/nf_model/realNVP.py (1)
8-15: Import path consolidation looks goodNFModel and common building blocks (Distribution, MLP, Gaussian, MaskedCouplingLayer) now come from the reorganized modules. No issues.
test/unit/test_strategies.py (1)
6-11: Verified correct test imports — no legacy paths detectedFile: test/unit/test_strategies.py
Lines: 6–11from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline from flowMC.resource.optimizer import Optimizer from flowMC.resource.kernel.NF_proposal import NFProposal from flowMC.resource.kernel.MALA import MALA from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk from flowMC.resource.kernel.HMC import HMC• Ran RG searches across all
.pyfiles for old patterns (flowMC.resource.nf_model,flowMC.resource.local_kernel); found no matches.
• Test imports fully align with the reorganized module structure, confirming the strategy layer remains intact.src/flowMC/resource_strategy_bundle/RQSpline_MALA.py (1)
11-13: LGTM: New imports and package structure verified
- Confirmed
class MALAinsrc/flowMC/resource/kernel/MALA.py- Confirmed
class NFProposalinsrc/flowMC/resource/kernel/NF_proposal.py- Confirmed
class MaskedCouplingRQSplineinsrc/flowMC/resource/model/nf_model/rqSpline.py- All necessary
__init__.pyfiles are present for package importabilitysrc/flowMC/resource/model/nf_model/rqSpline.py (1)
9-12: Imports consolidated correctly—common.py exports all symbolsConfirmed that
src/flowMC/resource/model/common.pydefines or re-exports all of:
- Distribution
- Bijection
- MLP
- Gaussian
- MaskedCouplingLayer
- ScalarAffine
The consolidated imports in
src/flowMC/resource/model/nf_model/rqSpline.pyare valid.src/flowMC/resource/model/flowmatching/base.py (1)
64-75: Ignore the sign‐flip suggestion – the current implementation is correct
Because you integrate backward (t0=1 → t1=0 with dt0 = –dt), returning+trace(div)yields
d log p = trace(div)·dt = trace(div)·(–|dt|) = –trace(div)·|dt|,
which matches the standard CNF ODE d(log p)/dt = –div f.
No change needed here.Likely an incorrect or invalid review comment.
| "from typing import Any\n", | ||
| "\n", | ||
| "from flowMC.resource.local_kernel.MALA import MALA\n", | ||
| "from flowMC.resource.kernel.MALA import MALA\n", |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Verify no stale local_kernel imports remain across the repo
To avoid runtime/import errors in docs or examples, verify all references were updated from resource.local_kernel to resource.kernel, and that MALA.py exists at the new location.
Run:
🏁 Script executed:
#!/bin/bash
# Verify the new MALA module exists under the reorganized path.
fd -HI 'MALA.py' | rg -nP 'flowMC/.*/resource/.*/kernel/.*/MALA\.py|flowMC/.*/resource/kernel/MALA\.py' || true
# Find any lingering old imports anywhere (expect: no matches).
rg -nP -C2 '\bflowMC\.resource\.local_kernel\b' || true
# Confirm new-style imports are present (expect: matches in updated notebooks/tests).
rg -nP -C2 '\bfrom\s+flowMC\.resource\.kernel\.MALA\s+import\s+MALA\b' || trueLength of output: 3343
Missing MALA.py in flowMC/resource/kernel
The new imports reference flowMC.resource.kernel.MALA, but there is no MALA.py file under src/flowMC/resource/kernel/. Please add or move MALA.py (containing the MALA class) into that directory so all imports resolve.
All stale flowMC.resource.local_kernel references have been removed, and new-style imports are present in:
- docs/tutorials/parallel_tempering.ipynb (line 22)
- test/unit/test_strategies.py (line 9)
- src/flowMC/resource_strategy_bundle/RQSpline_MALA.py (line 11)
- src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py (line 12)
- test/unit/test_resources.py (line 5)
- test/unit/test_kernels.py (line 7)
- test/integration/test_MALA.py (line 6)
🤖 Prompt for AI Agents
In docs/tutorials/parallel_tempering.ipynb around line 22, the import references
flowMC.resource.kernel.MALA but there is no src/flowMC/resource/kernel/MALA.py;
add (or move) a MALA.py file into src/flowMC/resource/kernel/ that defines the
MALA class with the expected public API used by the listed files, and ensure the
module is importable (add/update __init__.py if necessary) so all imports in the
notebook and tests resolve.
| from abc import abstractmethod | ||
|
|
||
| from flowMC.resource.nf_model.base import Bijection, Distribution | ||
|
|
||
| class Bijection(eqx.Module): | ||
| """Base class for bijective transformations. | ||
|
|
||
| This is an abstract template that should not be directly used. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def __init__(self): | ||
| raise NotImplementedError | ||
|
|
There was a problem hiding this comment.
abstractmethod without ABC won’t enforce abstractness.
Bijection and Distribution don’t inherit from ABC, so @abstractmethod has no effect at class level. This allows accidental instantiation at runtime.
Apply this diff to enforce the contract:
-from abc import abstractmethod
+from abc import ABC, abstractmethod
@@
-class Bijection(eqx.Module):
+class Bijection(eqx.Module, ABC):
@@
-class Distribution(eqx.Module):
+class Distribution(eqx.Module, ABC):🤖 Prompt for AI Agents
In src/flowMC/resource/model/common.py around lines 7 to 19, the @abstractmethod
has no effect because Bijection doesn't inherit from ABC; update the class to
inherit from abc.ABC (import ABC from abc) in addition to eqx.Module (e.g.,
class Bijection(ABC, eqx.Module)): this will enforce the abstract contract so
__init__ cannot be instantiated, and apply the same pattern for Distribution if
present elsewhere.
| self, rng_key: PRNGKeyArray, n_samples: int | ||
| ) -> Float[Array, " n_samples n_features"]: | ||
| raise NotImplementedError | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Unify sample return shape variable: use n_dim consistently.
The base Distribution.sample returns shape " n_samples n_features" while the rest of the module uses " n_dim". This inconsistency will trip pyright/jaxtyping checks and downstream type expectations.
Apply this diff:
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
- ) -> Float[Array, " n_samples n_features"]:
+ ) -> Float[Array, " n_samples n_dim"]:
raise NotImplementedErrorYou’ll also want to align concrete implementations—see the follow-up snippet for Gaussian.sample.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def sample( | |
| self, rng_key: PRNGKeyArray, n_samples: int | |
| ) -> Float[Array, " n_samples n_features"]: | |
| raise NotImplementedError | |
| def sample( | |
| self, rng_key: PRNGKeyArray, n_samples: int | |
| ) -> Float[Array, " n_samples n_dim"]: | |
| raise NotImplementedError |
🤖 Prompt for AI Agents
In src/flowMC/resource/model/common.py around lines 62-65, the sample return
shape annotation uses " n_samples n_features" but the project standard is
"n_dim"; change the type string to " n_samples n_dim" and update any matching
type alias usage. Also update all concrete implementations (e.g.,
Gaussian.sample) to use the same " n_samples n_dim" return annotation so
pyright/jaxtyping types match end-to-end.
| import jax.numpy as jnp | ||
| import jax | ||
| from jax.scipy.stats.multivariate_normal import logpdf | ||
| from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver |
There was a problem hiding this comment.
Missing import: SaveAt
You use SaveAt in diffeqsolve for efficient final-state saving. Import it.
-from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver
+from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver, SaveAt📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver | |
| from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver, SaveAt |
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around line 11, the diffrax
import list is missing SaveAt which is used with diffeqsolve; update the import
statement to also import SaveAt from diffrax (i.e., add SaveAt to the existing
import list) so SaveAt can be used when calling diffeqsolve.
| def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> PyTree: | ||
| """Solve the ODE with initial condition y0.""" | ||
| term = ODETerm(model_wrapper) | ||
| y_init = jax.tree.map(jnp.asarray, [y0, 0.0]) | ||
| sol = diffeqsolve( | ||
| term, | ||
| self.method, | ||
| t0=1.0, | ||
| t1=0.0, | ||
| dt0=-dt, | ||
| y0=y_init, | ||
| ) | ||
| return sol.ys | ||
|
|
There was a problem hiding this comment.
Bug: building augmented initial state with jax.tree.map and returning full trajectory
Two issues:
jax.tree.mapis not a valid JAX API. Usejax.tree_map(orjax.tree_util.tree_map).solve_odereturnssol.ys(the whole trajectory), but later you unpackx0, log_p = solve_ode(...)and then immediately uselog_p, which should be the final scalar. Return only the final state to avoid shape/broadcast bugs and excess memory.
Apply:
- y_init = jax.tree.map(jnp.asarray, [y0, 0.0])
+ y_init = jax.tree_map(jnp.asarray, [y0, 0.0])
- sol = diffeqsolve(
+ sol = diffeqsolve(
term,
self.method,
t0=1.0,
t1=0.0,
dt0=-dt,
y0=y_init,
- )
- return sol.ys
+ saveat=SaveAt(t1=True),
+ )
+ return sol.ys # final state onlyAlso add the missing import (see separate comment on Line 11).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> PyTree: | |
| """Solve the ODE with initial condition y0.""" | |
| term = ODETerm(model_wrapper) | |
| y_init = jax.tree.map(jnp.asarray, [y0, 0.0]) | |
| sol = diffeqsolve( | |
| term, | |
| self.method, | |
| t0=1.0, | |
| t1=0.0, | |
| dt0=-dt, | |
| y0=y_init, | |
| ) | |
| return sol.ys | |
| def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> PyTree: | |
| """Solve the ODE with initial condition y0.""" | |
| term = ODETerm(model_wrapper) | |
| y_init = jax.tree_map(jnp.asarray, [y0, 0.0]) | |
| sol = diffeqsolve( | |
| term, | |
| self.method, | |
| t0=1.0, | |
| t1=0.0, | |
| dt0=-dt, | |
| y0=y_init, | |
| saveat=SaveAt(t1=True), | |
| ) | |
| return sol.ys # final state only |
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around lines 76–89, replace
the invalid jax.tree.map usage with the correct jax.tree_map (or
jax.tree_util.tree_map) when building the augmented initial state, and ensure
the module import for jax (or tree_util) is added as noted on line 11; also
change the function to return only the final ODE state (not sol.ys/the full
trajectory) so callers like x0, log_p = solve_ode(...) receive the final pair
(avoid returning the whole trajectory which causes shape/broadcast and memory
issues).
| def sample(self, x0: Float, x1: Float, t: Float) -> Float: | ||
| """Sample a point along the path between x0 and x1 at time t.""" | ||
| alpha_t, d_alpha_t, sigma_t, d_sigma_t = self.scheduler(t) | ||
| x_t = sigma_t * x0 + alpha_t * x1 | ||
| dx_t = d_sigma_t * x0 + d_alpha_t * x1 | ||
| return x_t, dx_t | ||
|
|
There was a problem hiding this comment.
Incorrect return type annotation for Path.sample
The method returns a tuple (x_t, dx_t) but is annotated as returning a single Float. Fix the signature:
- def sample(self, x0: Float, x1: Float, t: Float) -> Float:
+ def sample(self, x0: Float, x1: Float, t: Float) -> tuple[Float, Float]:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def sample(self, x0: Float, x1: Float, t: Float) -> Float: | |
| """Sample a point along the path between x0 and x1 at time t.""" | |
| alpha_t, d_alpha_t, sigma_t, d_sigma_t = self.scheduler(t) | |
| x_t = sigma_t * x0 + alpha_t * x1 | |
| dx_t = d_sigma_t * x0 + d_alpha_t * x1 | |
| return x_t, dx_t | |
| def sample(self, x0: Float, x1: Float, t: Float) -> tuple[Float, Float]: | |
| """Sample a point along the path between x0 and x1 at time t.""" | |
| alpha_t, d_alpha_t, sigma_t, d_sigma_t = self.scheduler(t) | |
| x_t = sigma_t * x0 + alpha_t * x1 | |
| dx_t = d_sigma_t * x0 + d_alpha_t * x1 | |
| return x_t, dx_t |
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around lines 124 to 130, the
sample method is annotated to return a single Float but actually returns a tuple
(x_t, dx_t); update the signature to return a tuple type (e.g., -> Tuple[Float,
Float]) and ensure Tuple (and Float if necessary) is imported from typing or
your project's type aliases; keep the implementation unchanged.
| def sample( | ||
| self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1 | ||
| ) -> Float[Array, " n_dim"]: | ||
| rng_key, subkey = jax.random.split(rng_key) | ||
| samples = self.solver.sample(subkey, num_samples, dt=dt) | ||
| std = jnp.sqrt(jnp.diag(self.data_cov)) | ||
| samples = samples * std + self.data_mean | ||
| return samples | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Essential: sampling with full covariance (currently only uses diagonal)
data_cov is treated as diagonal (via sqrt(diag(...))), but train sets it to jnp.cov(data[1].T), which is full. Sampling should apply the full scale using the Cholesky factor.
- std = jnp.sqrt(jnp.diag(self.data_cov))
- samples = samples * std + self.data_mean
+ L = jnp.linalg.cholesky(self.data_cov)
+ samples = samples @ L.T + self.data_meanThis fixes geometry when the data have correlations.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def sample( | |
| self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1 | |
| ) -> Float[Array, " n_dim"]: | |
| rng_key, subkey = jax.random.split(rng_key) | |
| samples = self.solver.sample(subkey, num_samples, dt=dt) | |
| std = jnp.sqrt(jnp.diag(self.data_cov)) | |
| samples = samples * std + self.data_mean | |
| return samples | |
| def sample( | |
| self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1 | |
| ) -> Float[Array, " n_dim"]: | |
| rng_key, subkey = jax.random.split(rng_key) | |
| samples = self.solver.sample(subkey, num_samples, dt=dt) | |
| L = jnp.linalg.cholesky(self.data_cov) | |
| samples = samples @ L.T + self.data_mean | |
| return samples |
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around lines 171-179, sampling
currently rescales using only the diagonal (sqrt(diag(self.data_cov))) even
though train sets self.data_cov to the full covariance matrix; replace that
diagonal-only scaling with a full-covariance transform by computing the Cholesky
factor L of self.data_cov and applying the linear map defined by L to the
standard samples before adding self.data_mean (i.e., transform samples with L so
Cov(samples) = self.data_cov then add mean); include a safe fallback (e.g., use
diagonal sqrt) if Cholesky fails or if self.data_cov is already diagonal to
avoid runtime errors.
| def log_prob(self, x: Float[Array, " n_dim"]) -> Float: | ||
| std = jnp.sqrt(jnp.diag(self.data_cov)) | ||
| x_whitened = (x - self.data_mean) / std | ||
| log_det = -jnp.sum(jnp.log(std)) | ||
| return self.solver.log_prob(x_whitened) + log_det | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Align whitening and log-det with full covariance and return shape (1,)
- Whitening currently uses only the diagonal and returns a scalar
(). Use Cholesky-based whitening consistent with the previous comment and return a(1,)vector to satisfy the tests’ vmap shape.
- def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
- std = jnp.sqrt(jnp.diag(self.data_cov))
- x_whitened = (x - self.data_mean) / std
- log_det = -jnp.sum(jnp.log(std))
- return self.solver.log_prob(x_whitened) + log_det
+ def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
+ L = jnp.linalg.cholesky(self.data_cov)
+ # Solve L y = (x - mean) for y (lower-triangular solve)
+ y = jax.scipy.linalg.solve_triangular(L, x - self.data_mean, lower=True)
+ log_det = -jnp.log(jnp.diag(L)).sum()
+ return jnp.atleast_1d(self.solver.log_prob(y) + log_det)Note: add from jax.scipy.linalg import solve_triangular if you prefer a direct import instead of jax.scipy.linalg.solve_triangular.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def log_prob(self, x: Float[Array, " n_dim"]) -> Float: | |
| std = jnp.sqrt(jnp.diag(self.data_cov)) | |
| x_whitened = (x - self.data_mean) / std | |
| log_det = -jnp.sum(jnp.log(std)) | |
| return self.solver.log_prob(x_whitened) + log_det | |
| def log_prob(self, x: Float[Array, " n_dim"]) -> Float: | |
| L = jnp.linalg.cholesky(self.data_cov) | |
| # Solve L y = (x - mean) for y (lower-triangular solve) | |
| y = jax.scipy.linalg.solve_triangular(L, x - self.data_mean, lower=True) | |
| log_det = -jnp.log(jnp.diag(L)).sum() | |
| return jnp.atleast_1d(self.solver.log_prob(y) + log_det) |
🤖 Prompt for AI Agents
In src/flowMC/resource/model/flowmatching/base.py around lines 180 to 185, the
whitening and log-det computation currently use only the diagonal of the
covariance and return a scalar; replace this with Cholesky-based whitening using
the full covariance: compute L = jnp.linalg.cholesky(self.data_cov), compute
whitened x by solving L y = (x - self.data_mean) with solve_triangular (or
jax.scipy.linalg.solve_triangular) twice as needed for forward/back
substitution, compute log_det as -jnp.sum(jnp.log(jnp.diag(L))) (accounting for
the determinant via the Cholesky), pass the whitened x to self.solver.log_prob,
and ensure the function returns a 1D array of shape (1,) (e.g., wrap the scalar
log-prob + log_det in jnp.array([...]) or jnp.reshape(..., (1,))). Also add the
import for solve_triangular at the top of the file.
| class SequentialMonteCarlo(Resource): | ||
| def __init__(self): | ||
| raise NotImplementedError | ||
|
|
||
| def __call__( | ||
| self, | ||
| rng_key: PRNGKeyArray, | ||
| resources: dict[str, Resource], | ||
| initial_position: Float[Array, "n_chains n_dim"], | ||
| data: dict, | ||
| ) -> tuple[ | ||
| PRNGKeyArray, | ||
| dict[str, Resource], | ||
| Float[Array, "n_chains n_dim"], | ||
| ]: | ||
| raise NotImplementedError |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Mark interface methods as abstract (don’t rely on raising at runtime).
Declare __init__ and __call__ as @abstractmethod to make the contract explicit and let static/type checkers catch instantiation of incomplete subclasses at import time. Since Resource already uses ABCMeta, the decorators will be honored.
Apply this diff:
from flowMC.resource.base import Resource
from jaxtyping import Array, Float, PRNGKeyArray
+from abc import abstractmethod
class SequentialMonteCarlo(Resource):
- def __init__(self):
- raise NotImplementedError
+ @abstractmethod
+ def __init__(self):
+ raise NotImplementedError
- def __call__(
+ @abstractmethod
+ def __call__(
self,
rng_key: PRNGKeyArray,
resources: dict[str, Resource],
initial_position: Float[Array, "n_chains n_dim"],
data: dict,
) -> tuple[
PRNGKeyArray,
dict[str, Resource],
Float[Array, "n_chains n_dim"],
]:
raise NotImplementedError📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class SequentialMonteCarlo(Resource): | |
| def __init__(self): | |
| raise NotImplementedError | |
| def __call__( | |
| self, | |
| rng_key: PRNGKeyArray, | |
| resources: dict[str, Resource], | |
| initial_position: Float[Array, "n_chains n_dim"], | |
| data: dict, | |
| ) -> tuple[ | |
| PRNGKeyArray, | |
| dict[str, Resource], | |
| Float[Array, "n_chains n_dim"], | |
| ]: | |
| raise NotImplementedError | |
| from flowMC.resource.base import Resource | |
| from jaxtyping import Array, Float, PRNGKeyArray | |
| from abc import abstractmethod | |
| class SequentialMonteCarlo(Resource): | |
| @abstractmethod | |
| def __init__(self): | |
| raise NotImplementedError | |
| @abstractmethod | |
| def __call__( | |
| self, | |
| rng_key: PRNGKeyArray, | |
| resources: dict[str, Resource], | |
| initial_position: Float[Array, "n_chains n_dim"], | |
| data: dict, | |
| ) -> tuple[ | |
| PRNGKeyArray, | |
| dict[str, Resource], | |
| Float[Array, "n_chains n_dim"], | |
| ]: | |
| raise NotImplementedError |
🤖 Prompt for AI Agents
In src/flowMC/strategy/sequential_monte_carlo.py around lines 5 to 20, the
interface methods __init__ and __call__ currently raise NotImplementedError at
runtime; mark them as abstract by importing and applying @abstractmethod to both
methods so the class is treated as an abstract base class and incomplete
subclasses cannot be instantiated at import time; remove the raise-only bodies
(you may keep a pass or leave the docstring) and ensure the module imports
abstractmethod from abc if not already present.
| from jaxtyping import Array, Float | ||
|
|
||
| from flowMC.resource.local_kernel.Gaussian_random_walk import GaussianRandomWalk | ||
| from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Scan the repo for any stale local_kernel or old NF-model import paths.
To avoid stragglers after the module move, run this scan.
🏁 Script executed:
#!/bin/bash
# Find any remaining old import paths after the refactor.
set -euo pipefail
echo "Searching for stale local_kernel imports..."
rg -nP '\bflowMC\.resource\.local_kernel\b' -C2 || true
echo
echo "Searching for old nf_model root path imports..."
rg -nP '\bflowMC\.resource\.nf_model\b' -C2 || true
echo
echo "Searching for old model base locations that moved to model.common..."
rg -nP '\bflowMC\.resource\.model\.nf_model\.base\b' -C2 || trueLength of output: 2204
Stale NFModel Imports Need Updating
The scan uncovered leftover imports from the old flowMC.resource.model.nf_model.base path. Please update these to the new location in flowMC.resource.model.common:
• src/flowMC/strategy/train_model.py (line 4)
- Change
to
from flowMC.resource.model.nf_model.base import NFModel
from flowMC.resource.model.common import NFModel
• src/flowMC/resource/model/nf_model/realNVP.py (line 8)
- Change
to
from flowMC.resource.model.nf_model.base import NFModel
from flowMC.resource.model.common import NFModel
• src/flowMC/resource/model/nf_model/rqSpline.py (line 9)
- Change
to
from flowMC.resource.model.nf_model.base import NFModel
from flowMC.resource.model.common import NFModel
• src/flowMC/resource/kernel/NF_proposal.py (line 10)
- Change
to
from flowMC.resource.model.nf_model.base import NFModel
from flowMC.resource.model.common import NFModel
After adjusting these, rerun the import-scan to confirm no stale references remain.
🤖 Prompt for AI Agents
In test/integration/test_RWMCMC.py around lines 6 to 6 and the referenced source
files, there are stale imports of NFModel from
flowMC.resource.model.nf_model.base; update each occurrence to import NFModel
from flowMC.resource.model.common instead (specifically fix
src/flowMC/strategy/train_model.py line 4,
src/flowMC/resource/model/nf_model/realNVP.py line 8,
src/flowMC/resource/model/nf_model/rqSpline.py line 9, and
src/flowMC/resource/kernel/NF_proposal.py line 10), then run the import-scan to
verify no remaining stale references.
The following summary is generated by copilot
This pull request introduces a new flow matching model implementation, improves dependency management, and refactors the resource and model directory structure for better organization and maintainability. The most significant changes include the addition of a flow matching model under
resource/model/flowmatching, migration of normalizing flow model base classes, and various updates to imports and dependencies to reflect the new structure.Major Feature Addition
Flow Matching Model Implementation
FlowMatchingModeland supporting classes (Solver,Scheduler,Path, etc.) insrc/flowMC/resource/model/flowmatching/base.pyto enable flow matching-based generative modeling. This includes training, sampling, and log-probability computation methods leveraging ODE solvers from Diffrax.Refactoring and Directory Structure
Normalizing Flow Model Refactor
BijectionandDistributionfromsrc/flowMC/resource/nf_model/base.pytosrc/flowMC/resource/model/common.py, and updated all relevant imports to use the new location. This change centralizes shared abstractions and reduces duplication. [1] [2] [3] [4]Kernel and Model Directory Reorganization
local_kernelmodules (Gaussian_random_walk.py,HMC.py,MALA.py,base.py) toresource/kernel/and updated all references throughout the codebase and tutorials. [1] [2] [3] [4] [5] [6]nf_modelmodules underresource/model/nf_model/, updating all imports accordingly. [1] [2] [3] [4] [5]Dependency and Configuration Updates
Dependency Additions
diffraxas a required dependency inpyproject.tomland as an additional dependency for thepyrightpre-commit hook, enabling ODE-based flow matching methods. [1] [2]Development and Code QA Dependencies
codeqadependency group into thedevgroup inpyproject.toml, and addedipythonfor improved development experience.Most important changes, grouped by theme:
1. New Feature: Flow Matching Model
FlowMatchingModeland supporting classes insrc/flowMC/resource/model/flowmatching/base.pyfor ODE-based generative modeling using Diffrax.2. Refactoring & Structure
BijectionandDistributionbase classes tosrc/flowMC/resource/model/common.pyand updated all related imports. [1] [2] [3] [4]local_kernelmodules toresource/kernel/and updated all usages and tutorial references. [1] [2] [3] [4] [5] [6]nf_modelmodules underresource/model/nf_model/and updated all imports. [1] [2] [3] [4] [5]3. Dependency Management
diffraxas a required dependency and to the pre-commitpyrighthook for ODE support. [1] [2]codeqadependencies into thedevgroup inpyproject.tomland addedipython.Summary by CodeRabbit