diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 3f223b3de..dd8ce04f5 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -29,6 +29,7 @@ jobs: group: - InterfaceI - InterfaceII + - ThreadSafety - QA exclude: - version: "pre" diff --git a/.github/workflows/ThreadSafety.yml b/.github/workflows/ThreadSafety.yml new file mode 100644 index 000000000..b82eda2db --- /dev/null +++ b/.github/workflows/ThreadSafety.yml @@ -0,0 +1,65 @@ +name: "Thread Safety Tests" + +on: + push: + branches: + - master + paths-ignore: + - 'docs/**' + pull_request: + branches: + - master + paths-ignore: + - 'docs/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref_name != github.event.repository.default_branch || github.ref != 'refs/tags/v*' }} + +jobs: + thread-safety: + name: "Thread Safety" + strategy: + fail-fast: false + matrix: + version: + - "1" + - "lts" + - "pre" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: "Setup Julia ${{ matrix.version }}" + uses: julia-actions/setup-julia@v2 + with: + version: "${{ matrix.version }}" + + - name: "Cache Julia packages" + uses: julia-actions/cache@v2 + with: + token: "${{ secrets.GITHUB_TOKEN }}" + + - name: "Build package" + uses: julia-actions/julia-buildpkg@v1 + + - name: "Run thread safety tests (4 threads)" + run: | + julia --threads=4 --code-coverage=user --check-bounds=yes --compiled-modules=yes \ + --project=@. --color=yes -e ' + using Pkg + Pkg.test() + ' + env: + GROUP: ThreadSafety + + - name: "Process Coverage" + uses: julia-actions/julia-processcoverage@v1 + + - name: "Report Coverage" + uses: codecov/codecov-action@v5 + if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} + with: + files: lcov.info + token: "${{ secrets.CODECOV_TOKEN }}" + fail_ci_if_error: false diff --git a/HISTORY.md b/HISTORY.md index 1851dbd63..d010778af 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,26 @@ # Breaking updates and feature summaries across releases -## JumpProcesses unreleased (master branch) +## 10.0 (Breaking) + + - **Breaking**: The `rng` keyword argument has been removed from + `JumpProblem`. Pass `rng` to `solve` or `init` instead: + ```julia + # Before (no longer works): + jprob = JumpProblem(dprob, Direct(), jump; rng = Xoshiro(1234)) + sol = solve(jprob, SSAStepper()) + + # After: + jprob = JumpProblem(dprob, Direct(), jump) + sol = solve(jprob, SSAStepper(); rng = Xoshiro(1234)) + ``` + - RNG state is now owned by the integrator, not the aggregator. This + eliminates data races when sharing a `JumpProblem` across threads and + ensures a single, consistent RNG priority across all solver pathways: + `rng` > `seed` > `Random.default_rng()`. + - `rng` and `seed` kwargs are fully supported on `solve`/`init` for all + solver pathways (SSAStepper, ODE, SDE, tau-leaping). + - `SSAIntegrator` now supports the `SciMLBase` RNG interface (`has_rng`, + `get_rng`, `set_rng!`). ## 9.14 diff --git a/Project.toml b/Project.toml index 4f65c5c16..242da5aee 100644 --- a/Project.toml +++ b/Project.toml @@ -45,22 +45,25 @@ KernelAbstractions = "0.9" LinearAlgebra = "1" LinearSolve = "3" OrdinaryDiffEq = "6" -OrdinaryDiffEqCore = "1.32.0" +OrdinaryDiffEqCore = "3.11" Pkg = "1" PoissonRandom = "0.4" Random = "1" RecursiveArrayTools = "3.35" Reexport = "1.2" SafeTestsets = "0.1" -SciMLBase = "2.115" +SciMLBase = "2.147" StableRNGs = "1" StaticArrays = "1.9.8" Statistics = "1" -StochasticDiffEq = "6.82" +StochasticDiffEq = "6.95" SymbolicIndexingInterface = "0.3.36" Test = "1" julia = "1.10" +[sources] +SciMLBase = {url = "https://github.com/isaacsas/SciMLBase.jl.git", rev = "ensemble_rng_redesign"} + [extras] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/docs/src/api.md b/docs/src/api.md index d8fd944b2..4a0db690b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -40,9 +40,91 @@ RSSACR SortingDirect ``` -# Private API Functions +## Random Number Generator Control + +JumpProcesses supports controlling the random number generator (RNG) used for +jump sampling via the `rng` and `seed` keyword arguments to `solve` or `init`. + +### `rng` keyword argument + +Pass any `AbstractRNG` to `solve` or `init`: + +```julia +using Random, StableRNGs + +# Using a StableRNG for cross-version reproducibility +sol = solve(jprob, SSAStepper(); rng = StableRNG(1234)) + +# Using Julia's built-in Xoshiro +sol = solve(jprob, Tsit5(); rng = Xoshiro(42)) +``` + +### `seed` keyword argument + +As a shorthand, pass an integer `seed` to create a `Xoshiro` generator: + +```julia +sol = solve(jprob, SSAStepper(); seed = 1234) +# equivalent to: solve(jprob, SSAStepper(); rng = Xoshiro(1234)) +``` + +### Resolution priority + +When both `rng` and `seed` are passed to the same `solve`/`init` call, `rng` +takes priority: + +| User provides | Result | +|---|---| +| `rng` via `solve`/`init` | Uses that `rng` | +| `seed` via `solve`/`init` | Creates `Xoshiro(seed)` | +| Nothing | Uses `Random.default_rng()` (SSAStepper, ODE, tau-leaping) or a randomly-seeded `Xoshiro` (SDE) | + +### Behavior by solver pathway + +| Solver | Default RNG (nothing passed) | `rng` / `seed` support | +|---|---|---| +| `SSAStepper` | `Random.default_rng()` | Full support via `solve`/`init` kwargs | +| ODE solvers (e.g., `Tsit5`) | `Random.default_rng()` | Full support via `solve`/`init` kwargs | +| SDE solvers (e.g., `SRIW1`) | Randomly-seeded `Xoshiro` | Full support; `TaskLocalRNG` is auto-converted to `Xoshiro` | +| `SimpleTauLeaping` | `Random.default_rng()` | Full support via `solve` kwargs | + +!!! note + For reproducible simulations, always pass an explicit `rng` or `seed`. + The default RNG is shared global state and may produce different results + depending on prior usage. + +# Private / Developer API ```@docs ExtendedJumpArray SSAIntegrator ``` + +## Internal Dispatch Pathways + +The following table documents which code handles `solve`/`init` for each solver +type. This is relevant for developers working on JumpProcesses or its solver +backends. + +| Solver type | `__solve` handled by | `__init` handled by | Uses `__jump_init`? | +|---|---|---|---| +| `SSAStepper` | JumpProcesses (`solve.jl`) | JumpProcesses (`SSA_stepper.jl`) | No | +| ODE (e.g., `Tsit5`) | JumpProcesses (`solve.jl`) | JumpProcesses (`solve.jl`) → OrdinaryDiffEq | Yes | +| SDE (e.g., `SRIW1`) | StochasticDiffEq | StochasticDiffEq | No | +| `SimpleTauLeaping` | JumpProcesses (`simple_regular_solve.jl`, custom `DiffEqBase.solve`) | N/A | No | + +For **SSAStepper**, `rng` is resolved via `resolve_rng` in `SSA_stepper.jl`'s +`__init` and stored on the [`SSAIntegrator`](@ref). + +For **ODE solvers**, `rng` is resolved via `resolve_rng` in `__jump_init` +(`solve.jl`) and forwarded to OrdinaryDiffEq's `init`, which stores it on the +`ODEIntegrator`. + +For **SDE solvers**, StochasticDiffEq handles the full solve/init pathway +directly (JumpProcesses' ambiguity-fix `__solve` method is never dispatched to). +StochasticDiffEq has its own `_resolve_rng` that additionally handles +`TaskLocalRNG` conversion and the problem's stored seed. + +For **tau-leaping**, JumpProcesses defines a custom `DiffEqBase.solve` that +bypasses the standard `__solve`/`__init` pathway. It calls `resolve_rng` +directly with the `rng` and `seed` kwargs from the `solve` call. diff --git a/docs/src/applications/advanced_point_process.md b/docs/src/applications/advanced_point_process.md index 976770017..87e397d86 100644 --- a/docs/src/applications/advanced_point_process.md +++ b/docs/src/applications/advanced_point_process.md @@ -387,10 +387,10 @@ function Base.rand(rng::AbstractRNG, out = Array{History, 1}(undef, n) p = params(pp) dprob = DiscreteProblem([0], tspan, p) - jprob = JumpProblem(dprob, Coevolve(), jumps...; dep_graph = pp.g, save_positions, rng) + jprob = JumpProblem(dprob, Coevolve(), jumps...; dep_graph = pp.g, save_positions) for i in 1:n params!(pp, p) - solve(jprob, SSAStepper()) + solve(jprob, SSAStepper(); rng) out[i] = deepcopy(p.h) end return out diff --git a/docs/src/faq.md b/docs/src/faq.md index d2d0be3be..47d77c3cb 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -55,21 +55,27 @@ jset = JumpSet(; constant_jumps = cjvec, variable_jumps = vjtuple, ## How can I set the random number generator used in the jump process sampling algorithms (SSAs)? -Random number generators can be passed to `JumpProblem` via the `rng` keyword +Random number generators can be passed to `solve` or `init` via the `rng` keyword argument. Continuing the previous example: ```julia -#] add RandomNumbers -using RandomNumbers -jprob = JumpProblem(dprob, Direct(), maj, - rng = Xorshifts.Xoroshiro128Star(rand(UInt64))) +using Random +jprob = JumpProblem(dprob, Direct(), maj) +sol = solve(jprob, SSAStepper(); rng = Xoshiro(1234)) +``` + +Any `AbstractRNG` can be used. For example, to use a generator from +[StableRNGs.jl](https://github.com/JuliaRandom/StableRNGs.jl): + +```julia +using StableRNGs +sol = solve(jprob, SSAStepper(); rng = StableRNG(1234)) ``` -uses the `Xoroshiro128Star` generator from -[RandomNumbers.jl](https://github.com/JuliaRandom/RandomNumbers.jl). +A `seed` keyword argument is also supported as a shorthand for creating a `Xoshiro` +generator: `solve(jprob, SSAStepper(); seed = 1234)`. -On version 1.7 and up, JumpProcesses uses Julia's built-in random number generator by -default. On versions below 1.7 it uses `Xoroshiro128Star`. +By default, JumpProcesses uses Julia's built-in `Random.default_rng()`. ## What are these aggregators and aggregations in JumpProcesses? diff --git a/docs/src/tutorials/simple_poisson_process.md b/docs/src/tutorials/simple_poisson_process.md index ae8e0955e..11f5c5f0e 100644 --- a/docs/src/tutorials/simple_poisson_process.md +++ b/docs/src/tutorials/simple_poisson_process.md @@ -371,11 +371,10 @@ with ``N(t)`` a Poisson counting process with constant transition rate ``\lambda``, and the ``C_i`` independent and identical samples from a uniform distribution over ``\{-1,1\}``. We can simulate such a process as follows. -We first ensure that we use the same random number generator as JumpProcesses. We -can either pass one as an input to [`JumpProblem`](@ref) via the `rng` keyword -argument, and make sure it is the same one we use in our `affect!` function, or -we can just use the default generator chosen by JumpProcesses if one is not -specified, `JumpProcesses.DEFAULT_RNG`. Let's do the latter +We first ensure that we use the same random number generator as JumpProcesses. +Custom RNGs can be passed to `solve` or `init` via the `rng` keyword argument. +If no RNG is specified, JumpProcesses uses `Random.default_rng()`, which is also +available as `JumpProcesses.DEFAULT_RNG`. Let's use the default ```@example tut1 rng = JumpProcesses.DEFAULT_RNG diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 87c8c2c26..e32c4fb8e 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -40,7 +40,7 @@ using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction, ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!, deleteat!, isinplace, remake, savevalues!, step!, u_modified! -using SciMLBase: SciMLBase, DEIntegrator +using SciMLBase: SciMLBase, DEIntegrator, has_rng, get_rng, set_rng! abstract type AbstractJump end abstract type AbstractMassActionJump <: AbstractJump end diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index 3403b0602..a3e4096ef 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -13,6 +13,9 @@ Highly efficient integrator for pure jump problems that involve only `ConstantRa `SSAStepper`. - Only supports a limited subset of the output controls from the common solver interface, specifically `save_start`, `save_end`, and `saveat`. + - Supports `rng` and `seed` keyword arguments in `solve`/`init` to control the random + number generator used for jump sampling. `rng` accepts any `AbstractRNG`, while `seed` + creates a `Xoshiro` generator. `rng` takes priority over `seed`. - As when using jumps with ODEs and SDEs, saving controls for whether to save each time a jump occurs are via the `save_positions` keyword argument to `JumpProblem`. Note that when choosing `SSAStepper` as the timestepper, `save_positions = (true,true)`, `(true,false)`, @@ -58,17 +61,18 @@ for details. """ struct SSAStepper <: DiffEqBase.DEAlgorithm end SciMLBase.allows_late_binding_tstops(::SSAStepper) = true +SciMLBase.supports_solve_rng(::JumpProblem, ::SSAStepper) = true """ $(TYPEDEF) -Solution objects for pure jump problems solved via `SSAStepper`. +Integrator for pure jump problems solved via `SSAStepper`. ## Fields $(FIELDS) """ -mutable struct SSAIntegrator{F, uType, tType, tdirType, P, S, CB, SA, OPT, TS} <: +mutable struct SSAIntegrator{F, uType, tType, tdirType, P, S, CB, SA, OPT, TS, R} <: AbstractSSAIntegrator{SSAStepper, Nothing, uType, tType} """The underlying `prob.f` function. Not currently used.""" f::F @@ -108,6 +112,23 @@ mutable struct SSAIntegrator{F, uType, tType, tdirType, P, S, CB, SA, OPT, TS} < alias_tstops::Bool """If true indicates we have already allocated the tstops array""" copied_tstops::Bool + """The random number generator.""" + rng::R +end + +SciMLBase.has_rng(::SSAIntegrator) = true +SciMLBase.get_rng(integrator::SSAIntegrator) = integrator.rng +function SciMLBase.set_rng!(integrator::SSAIntegrator, rng) + R = typeof(integrator.rng) + if !isa(rng, R) + throw(ArgumentError( + "Cannot set RNG of type $(typeof(rng)) on an integrator " * + "whose RNG type parameter is $R. " * + "Construct a new integrator via `init(prob, alg; rng = your_rng)` instead." + )) + end + integrator.rng = rng + nothing end (integrator::SSAIntegrator)(t) = copy(integrator.u) @@ -198,6 +219,7 @@ function DiffEqBase.__init(jump_prob::JumpProblem, save_start = true, save_end = true, seed = nothing, + rng = nothing, alias_jump = Threads.threadid() == 1, saveat = nothing, callback = nothing, @@ -219,19 +241,13 @@ function DiffEqBase.__init(jump_prob::JumpProblem, # Check for continuous callbacks passed via kwargs (from JumpProblem constructor or solve) check_continuous_callback_error(callback) + + _rng = resolve_rng(rng, seed) + if alias_jump cb = jump_prob.jump_callback.discrete_callbacks[end] - if seed !== nothing - Random.seed!(cb.condition.rng, seed) - end else cb = deepcopy(jump_prob.jump_callback.discrete_callbacks[end]) - # Only reseed if an explicit seed is provided. This respects the user's RNG choice - # and enables reproducibility. For EnsembleProblems, use prob_func to set unique seeds - # for each trajectory if different results are needed. - if seed !== nothing - Random.seed!(cb.condition.rng, seed) - end end opts = (callback = CallbackSet(callback),) @@ -286,7 +302,7 @@ function DiffEqBase.__init(jump_prob::JumpProblem, integrator = SSAIntegrator(prob.f, copy(prob.u0), prob.tspan[1], prob.tspan[1], tdir, prob.p, sol, 1, prob.tspan[1], cb, _saveat, save_everystep, - save_end, cur_saveat, opts, _tstops, 1, false, true, alias_tstops, false) + save_end, cur_saveat, opts, _tstops, 1, false, true, alias_tstops, false, _rng) cb.initialize(cb, integrator.u, prob.tspan[1], integrator) DiffEqBase.initialize!(opts.callback, integrator.u, prob.tspan[1], integrator) if save_start diff --git a/src/aggregators/ccnrm.jl b/src/aggregators/ccnrm.jl index 379ad642a..720868c50 100644 --- a/src/aggregators/ccnrm.jl +++ b/src/aggregators/ccnrm.jl @@ -3,8 +3,8 @@ # algorithm with optimal binning, Journal of Chemical Physics 143, 074108 # (2015). doi: 10.1063/1.4928635. -mutable struct CCNRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PT} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct CCNRMJumpAggregation{T, S, F1, F2, DEPGR, PT} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -15,15 +15,14 @@ mutable struct CCNRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PT} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR ptt::PT end function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; num_specs, dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + num_specs, dep_graph = nothing, + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -45,30 +44,31 @@ function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, binwidthconst = binwidthconst, numbinsconst = numbinsconst) # We will re-initialize this in initialize!() affecttype = F2 <: Tuple ? F2 : Any - CCNRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(ptt)}( + CCNRMJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(ptt)}( nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, - rng, dg, ptt) + dg, ptt) end -+############################# Required Functions ############################## +############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::CCNRM, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(CCNRMJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end # set up a new simulation and calculate the first jump / jump time function initialize!(p::CCNRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] - initialize_rates_and_times!(p, u, params, t) + rng = get_rng(integrator) + initialize_rates_and_times!(p, u, params, t, rng) generate_jumps!(p, integrator, u, params, t) nothing end @@ -79,7 +79,8 @@ function execute_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t, affec u = update_state!(p, integrator, u, affects!) # update current jump rates and times - update_dependent_rates!(p, u, params, t) + rng = get_rng(integrator) + update_dependent_rates!(p, u, params, t, rng) nothing end @@ -88,7 +89,7 @@ end function generate_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t) p.next_jump, p.next_jump_time = getfirst(p.ptt) - # Rebuild the table if no next jump is found. + # Rebuild the table if no next jump is found. if p.next_jump == 0 timestep = 1 / sum(p.cur_rates) min_time = minimum(p.ptt.times) @@ -102,7 +103,7 @@ end ######################## SSA specific helper routines ######################## # Recalculate jump rates for jumps that depend on the just executed jump (p.next_jump) -function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t) +function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t, rng) @inbounds dep_rxs = p.dep_gr[p.next_jump] (; ptt, cur_rates, rates, ma_jumps, end_time) = p num_majumps = get_num_majumps(ma_jumps) @@ -125,7 +126,7 @@ function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t) end else if cur_rates[rx] > zero(eltype(cur_rates)) - update!(ptt, rx, oldtime, t + randexp(p.rng) / cur_rates[rx]) + update!(ptt, rx, oldtime, t + randexp(rng) / cur_rates[rx]) else update!(ptt, rx, oldtime, floatmax(typeof(t))) end @@ -134,15 +135,15 @@ function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t) nothing end -# Evaluate all the rates and initialize the times in the priority table. -function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t) +# Evaluate all the rates and initialize the times in the priority table. +function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t, rng) # Initialize next-reaction times for the mass action jumps majumps = p.ma_jumps cur_rates = p.cur_rates pttdata = Vector{typeof(t)}(undef, length(cur_rates)) @inbounds for i in 1:get_num_majumps(majumps) cur_rates[i] = evalrxrate(u, i, majumps) - pttdata[i] = t + randexp(p.rng) / cur_rates[i] + pttdata[i] = t + randexp(rng) / cur_rates[i] end # Initialize next-reaction times for the constant rates @@ -150,11 +151,11 @@ function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t) idx = get_num_majumps(majumps) + 1 @inbounds for rate in rates cur_rates[idx] = rate(u, params, t) - pttdata[idx] = t + randexp(p.rng) / cur_rates[idx] + pttdata[idx] = t + randexp(rng) / cur_rates[idx] idx += 1 end - # Build the priority time table with the times and bin width. + # Build the priority time table with the times and bin width. timestep = 1 / sum(cur_rates) p.ptt.times = pttdata rebuild!(p.ptt, t, timestep) diff --git a/src/aggregators/coevolve.jl b/src/aggregators/coevolve.jl index dc48a5e21..ab5dc5f4f 100644 --- a/src/aggregators/coevolve.jl +++ b/src/aggregators/coevolve.jl @@ -1,8 +1,8 @@ """ Queue method. This method handles variable intensity rates. """ -mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct CoevolveJumpAggregation{T, S, F1, F2, GR, PQ} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int # the next jump to execute prev_jump::Int # the previous jump that was executed next_jump_time::T # the time of the next jump @@ -13,7 +13,6 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: rates::F1 # vector of rate functions affects!::F2 # vector of affect functions for VariableRateJumps save_positions::Tuple{Bool, Bool} # tuple for whether to save the jumps before and/or after event - rng::RNG # random number generator dep_gr::GR # map from jumps to jumps depending on it pq::PQ # priority queue of next time lrates::F1 # vector of rate lower bound functions @@ -24,10 +23,10 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: end function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Nothing, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; u::U, dep_graph = nothing, lrates, urates, + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + u::U, dep_graph = nothing, lrates, urates, rateintervals, haslratevec, - cur_lrates::Vector{T}) where {T, S, F1, F2, RNG, U} + cur_lrates::Vector{T}) where {T, S, F1, F2, U} if dep_graph === nothing if (get_num_majumps(maj) == 0) || !isempty(urates) error("To use Coevolve a dependency graph between jumps must be supplied.") @@ -49,9 +48,9 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not pq = MutableBinaryMinHeap{T}() affecttype = F2 <: Tuple ? F2 : Any - CoevolveJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), + CoevolveJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(pq)}(nj, nj, njt, et, crs, sr, maj, - rs, affs!, sps, rng, dg, pq, + rs, affs!, sps, dg, pq, lrates, urates, rateintervals, haslratevec, cur_lrates) end @@ -98,7 +97,7 @@ end # creating the JumpAggregation structure (tuple-based variable jumps) function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; dep_graph = nothing, + ma_jumps, save_positions; dep_graph = nothing, variable_jumps = nothing, kwargs...) RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), Tuple{typeof(u), typeof(p), typeof(t)}} @@ -141,7 +140,7 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, next_jump = 0 next_jump_time = typemax(t) CoevolveJumpAggregation(next_jump, next_jump_time, end_time, cur_rates, sum_rate, - ma_jumps, rates, affects!, save_positions, rng; + ma_jumps, rates, affects!, save_positions; u, dep_graph, lrates, urates, rateintervals, haslratevec, cur_lrates) end @@ -149,7 +148,8 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::CoevolveJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] - fill_rates_and_get_times!(p, u, params, t) + rng = get_rng(integrator) + fill_rates_and_get_times!(p, u, params, t, rng) generate_jumps!(p, integrator, u, params, t) nothing end @@ -160,7 +160,8 @@ function execute_jumps!(p::CoevolveJumpAggregation, integrator, u, params, t, # execute jump update_state!(p, integrator, u, affects!) # update current jump rates and times - update_dependent_rates!(p, integrator.u, integrator.p, t) + rng = get_rng(integrator) + update_dependent_rates!(p, integrator.u, integrator.p, t, rng) nothing end @@ -178,7 +179,8 @@ function accept_next_jump!(p::CoevolveJumpAggregation, integrator, u, params, t) (next_jump <= num_majumps) && return true - (; cur_rates, rates, rng, urates, cur_lrates) = p + (; cur_rates, rates, urates, cur_lrates) = p + rng = get_rng(integrator) num_cjumps = length(urates) - length(rates) uidx = next_jump - num_majumps lidx = uidx - num_cjumps @@ -225,11 +227,11 @@ function accept_next_jump!(p::CoevolveJumpAggregation, integrator, u, params, t) return false end -function update_dependent_rates!(p::CoevolveJumpAggregation, u, params, t) +function update_dependent_rates!(p::CoevolveJumpAggregation, u, params, t, rng) @inbounds deps = p.dep_gr[p.next_jump] (; cur_rates, pq) = p for (ix, i) in enumerate(deps) - ti, urate_i = next_time(p, u, params, t, i) + ti, urate_i = next_time(p, u, params, t, i, rng) update!(pq, i, ti) @inbounds cur_rates[i] = urate_i end @@ -256,8 +258,8 @@ end @inbounds return p.rates[lidx](u, params, t) end -function next_time(p::CoevolveJumpAggregation, u, params, t, i) - (; next_jump, cur_rates, ma_jumps, rates, rng, pq, urates) = p +function next_time(p::CoevolveJumpAggregation, u, params, t, i, rng) + (; next_jump, cur_rates, ma_jumps, rates, pq, urates) = p num_majumps = get_num_majumps(ma_jumps) num_cjumps = length(urates) - length(rates) uidx = i - num_majumps @@ -300,12 +302,12 @@ function next_candidate_time!(p::CoevolveJumpAggregation, u, params, t, s, lidx) end # re-evaluates all rates, recalculate all jump times, and reinit the priority queue -function fill_rates_and_get_times!(p::CoevolveJumpAggregation, u, params, t) +function fill_rates_and_get_times!(p::CoevolveJumpAggregation, u, params, t, rng) num_jumps = get_num_majumps(p.ma_jumps) + length(p.urates) p.cur_rates = zeros(typeof(t), num_jumps) jump_times = Vector{typeof(t)}(undef, num_jumps) @inbounds for i in 1:num_jumps - jump_times[i], p.cur_rates[i] = next_time(p, u, params, t, i) + jump_times[i], p.cur_rates[i] = next_time(p, u, params, t, i, rng) end p.pq = MutableBinaryMinHeap(jump_times) nothing diff --git a/src/aggregators/direct.jl b/src/aggregators/direct.jl index ab33dd842..b51e18f7c 100644 --- a/src/aggregators/direct.jl +++ b/src/aggregators/direct.jl @@ -1,5 +1,5 @@ -mutable struct DirectJumpAggregation{T, S, F1, F2, RNG} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct DirectJumpAggregation{T, S, F1, F2} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -10,38 +10,37 @@ mutable struct DirectJumpAggregation{T, S, F1, F2, RNG} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG end function DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, - rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; - kwargs...) where {T, S, F1, F2, RNG} + rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + kwargs...) where {T, S, F1, F2} affecttype = F2 <: Tuple ? F2 : Any - DirectJumpAggregation{T, S, F1, affecttype, RNG}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng) + DirectJumpAggregation{T, S, F1, affecttype}(nj, nj, njt, et, crs, sr, maj, rs, + affs!, sps) end ############################# Required Functions ############################# # creating the JumpAggregation structure (tuple-based constant jumps) function aggregate(aggregator::Direct, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using tuples rates, affects! = get_jump_info_tuples(constant_jumps) build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; kwargs...) + rates, affects!, save_positions; kwargs...) end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::DirectFW, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; kwargs...) + rates, affects!, save_positions; kwargs...) end # set up a new simulation and calculate the first jump / jump time @@ -60,9 +59,10 @@ end # calculate the next jump / jump time function generate_jumps!(p::DirectJumpAggregation, integrator, u, params, t) - p.sum_rate, ttnj = time_to_next_jump(p, u, params, t) + rng = get_rng(integrator) + p.sum_rate, ttnj = time_to_next_jump(p, u, params, t, rng) p.next_jump_time = add_fast(t, ttnj) - @inbounds p.next_jump = searchsortedfirst(p.cur_rates, rand(p.rng) * p.sum_rate) + @inbounds p.next_jump = searchsortedfirst(p.cur_rates, rand(rng) * p.sum_rate) nothing end @@ -70,7 +70,7 @@ end # tuple-based constant jumps function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, - t) where {T, S, F1 <: Tuple} + t, rng) where {T, S, F1 <: Tuple} prev_rate = zero(t) new_rate = zero(t) cur_rates = p.cur_rates @@ -96,7 +96,7 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, end @inbounds sum_rate = cur_rates[end] - sum_rate, randexp(p.rng) / sum_rate + sum_rate, randexp(rng) / sum_rate end @inline function fill_cur_rates(u, p, t, cur_rates, idx, rate, rates...) @@ -112,7 +112,7 @@ end # function wrapper-based constant jumps function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, - t) where {T, S, F1 <: AbstractArray} + t, rng) where {T, S, F1 <: AbstractArray} prev_rate = zero(t) new_rate = zero(t) cur_rates = p.cur_rates @@ -137,5 +137,5 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, end @inbounds sum_rate = cur_rates[end] - sum_rate, randexp(p.rng) / sum_rate + sum_rate, randexp(rng) / sum_rate end diff --git a/src/aggregators/directcr.jl b/src/aggregators/directcr.jl index 41d079fe8..636300bb3 100644 --- a/src/aggregators/directcr.jl +++ b/src/aggregators/directcr.jl @@ -10,9 +10,9 @@ by S. Mauch and M. Stalzer, ACM Trans. Comp. Biol. and Bioinf., 8, No. 1, 27-35 const MINJUMPRATE = 2.0^exponent(1e-12) -mutable struct DirectCRJumpAggregation{T, S, F1, F2, RNG, DEPGR, U <: PriorityTable, +mutable struct DirectCRJumpAggregation{T, S, F1, F2, DEPGR, U <: PriorityTable, W <: Function} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -23,7 +23,6 @@ mutable struct DirectCRJumpAggregation{T, S, F1, F2, RNG, DEPGR, U <: PriorityTa rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR minrate::T maxrate::T # initial maxrate only, table can increase beyond it! @@ -32,11 +31,11 @@ mutable struct DirectCRJumpAggregation{T, S, F1, F2, RNG, DEPGR, U <: PriorityTa end function DirectCRJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; num_specs, dep_graph = nothing, + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + num_specs, dep_graph = nothing, minrate = convert(T, MINJUMPRATE), maxrate = convert(T, Inf), - kwargs...) where {T, S, F1, F2, RNG} + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -63,9 +62,9 @@ function DirectCRJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) affecttype = F2 <: Tuple ? F2 : Any - DirectCRJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), + DirectCRJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, crs, sr, maj, - rs, affs!, sps, rng, dg, + rs, affs!, sps, dg, minrate, maxrate, rt, ratetogroup) end @@ -74,13 +73,13 @@ end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::DirectCR, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(DirectCRJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end @@ -113,10 +112,11 @@ end # calculate the next jump / jump time function generate_jumps!(p::DirectCRJumpAggregation, integrator, u, params, t) - p.next_jump_time = t + randexp(p.rng) / p.sum_rate + rng = get_rng(integrator) + p.next_jump_time = t + randexp(rng) / p.sum_rate if p.next_jump_time < p.end_time - p.next_jump = sample(p.rt, p.cur_rates, p.rng) + p.next_jump = sample(p.rt, p.cur_rates, rng) end nothing end diff --git a/src/aggregators/frm.jl b/src/aggregators/frm.jl index 94baed375..0db30113b 100644 --- a/src/aggregators/frm.jl +++ b/src/aggregators/frm.jl @@ -1,5 +1,5 @@ -mutable struct FRMJumpAggregation{T, S, F1, F2, RNG} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct FRMJumpAggregation{T, S, F1, F2} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -10,40 +10,39 @@ mutable struct FRMJumpAggregation{T, S, F1, F2, RNG} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG end function FRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, - affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; - kwargs...) where {T, S, F1, F2, RNG} + affs!::F2, sps::Tuple{Bool, Bool}; + kwargs...) where {T, S, F1, F2} affecttype = F2 <: Tuple ? F2 : Any - FRMJumpAggregation{T, S, F1, affecttype, RNG}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng) + FRMJumpAggregation{T, S, F1, affecttype}(nj, nj, njt, et, crs, sr, maj, rs, + affs!, sps) end ############################# Required Functions ############################# # creating the JumpAggregation structure (tuple-based constant jumps) function aggregate(aggregator::FRM, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using tuples rates, affects! = get_jump_info_tuples(constant_jumps) build_jump_aggregation( FRMJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, - save_positions, rng; kwargs...) + save_positions; kwargs...) end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::FRMFW, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation( FRMJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, - save_positions, rng; kwargs...) + save_positions; kwargs...) end # set up a new simulation and calculate the first jump / jump time @@ -62,8 +61,9 @@ end # calculate the next jump / jump time function generate_jumps!(p::FRMJumpAggregation, integrator, u, params, t) - nextmaj, ttnmaj = next_ma_jump(p, u, params, t) - nextcrj, ttncrj = next_constant_rate_jump(p, u, params, t) + rng = get_rng(integrator) + nextmaj, ttnmaj = next_ma_jump(p, u, params, t, rng) + nextcrj, ttncrj = next_constant_rate_jump(p, u, params, t, rng) # execute reaction with minimal time if ttnmaj < ttncrj @@ -79,13 +79,13 @@ end ######################## SSA specific helper routines ######################## # mass action jumps -function next_ma_jump(p::FRMJumpAggregation, u, params, t) +function next_ma_jump(p::FRMJumpAggregation, u, params, t, rng) ttnj = typemax(typeof(t)) nextrx = zero(Int) majumps = p.ma_jumps @inbounds for i in 1:get_num_majumps(majumps) p.cur_rates[i] = evalrxrate(u, i, majumps) - dt = randexp(p.rng) / p.cur_rates[i] + dt = randexp(rng) / p.cur_rates[i] if dt < ttnj ttnj = dt nextrx = i @@ -95,15 +95,15 @@ function next_ma_jump(p::FRMJumpAggregation, u, params, t) end # tuple-based constant jumps -function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u, params, - t) where {T, S, F1 <: Tuple, F2 <: Tuple, RNG} +function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2}, u, params, + t, rng) where {T, S, F1 <: Tuple, F2 <: Tuple} ttnj = typemax(typeof(t)) nextrx = zero(Int) if !isempty(p.rates) idx = get_num_majumps(p.ma_jumps) + 1 fill_cur_rates(u, params, t, p.cur_rates, idx, p.rates...) @inbounds for i in idx:length(p.cur_rates) - dt = randexp(p.rng) / p.cur_rates[i] + dt = randexp(rng) / p.cur_rates[i] if dt < ttnj ttnj = dt nextrx = i @@ -115,14 +115,14 @@ end # function wrapper-based constant jumps function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1}, u, params, - t) where {T, S, F1 <: AbstractArray} + t, rng) where {T, S, F1 <: AbstractArray} ttnj = typemax(typeof(t)) nextrx = zero(Int) if !isempty(p.rates) idx = get_num_majumps(p.ma_jumps) + 1 @inbounds for i in 1:length(p.rates) p.cur_rates[idx] = p.rates[i](u, params, t) - dt = randexp(p.rng) / p.cur_rates[idx] + dt = randexp(rng) / p.cur_rates[idx] if dt < ttnj ttnj = dt nextrx = idx diff --git a/src/aggregators/nrm.jl b/src/aggregators/nrm.jl index 7fbcd5964..934540448 100644 --- a/src/aggregators/nrm.jl +++ b/src/aggregators/nrm.jl @@ -1,8 +1,8 @@ # Implementation the original Next Reaction Method # Gibson and Bruck, J. Phys. Chem. A, 104 (9), (2000) -mutable struct NRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PQ} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct NRMJumpAggregation{T, S, F1, F2, DEPGR, PQ} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -13,15 +13,14 @@ mutable struct NRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PQ} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR pq::PQ end function NRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; num_specs, dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + num_specs, dep_graph = nothing, + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -40,29 +39,30 @@ function NRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, pq = MutableBinaryMinHeap{T}() affecttype = F2 <: Tuple ? F2 : Any - NRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(pq)}(nj, nj, njt, et, + NRMJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(pq)}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, - rng, dg, pq) + dg, pq) end -+############################# Required Functions ############################## +############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::NRM, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(NRMJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end # set up a new simulation and calculate the first jump / jump time function initialize!(p::NRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] - fill_rates_and_get_times!(p, u, params, t) + rng = get_rng(integrator) + fill_rates_and_get_times!(p, u, params, t, rng) generate_jumps!(p, integrator, u, params, t) nothing end @@ -73,7 +73,8 @@ function execute_jumps!(p::NRMJumpAggregation, integrator, u, params, t, affects u = update_state!(p, integrator, u, affects!) # update current jump rates and times - update_dependent_rates!(p, u, params, t) + rng = get_rng(integrator) + update_dependent_rates!(p, u, params, t, rng) nothing end @@ -87,7 +88,7 @@ end ######################## SSA specific helper routines ######################## # recalculate jump rates for jumps that depend on the just executed jump (p.next_jump) -function update_dependent_rates!(p::NRMJumpAggregation, u, params, t) +function update_dependent_rates!(p::NRMJumpAggregation, u, params, t, rng) @inbounds dep_rxs = p.dep_gr[p.next_jump] (; cur_rates, rates, ma_jumps) = p num_majumps = get_num_majumps(ma_jumps) @@ -108,7 +109,7 @@ function update_dependent_rates!(p::NRMJumpAggregation, u, params, t) end else if cur_rates[rx] > zero(eltype(cur_rates)) - update!(p.pq, rx, t + randexp(p.rng) / cur_rates[rx]) + update!(p.pq, rx, t + randexp(rng) / cur_rates[rx]) else update!(p.pq, rx, typemax(t)) end @@ -118,7 +119,7 @@ function update_dependent_rates!(p::NRMJumpAggregation, u, params, t) end # reevaluate all rates, recalculate all jump times, and reinit the priority queue -function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t) +function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t, rng) # mass action jumps majumps = p.ma_jumps @@ -126,7 +127,7 @@ function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t) pqdata = Vector{typeof(t)}(undef, length(cur_rates)) @inbounds for i in 1:get_num_majumps(majumps) cur_rates[i] = evalrxrate(u, i, majumps) - pqdata[i] = t + randexp(p.rng) / cur_rates[i] + pqdata[i] = t + randexp(rng) / cur_rates[i] end # constant rates @@ -134,7 +135,7 @@ function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t) idx = get_num_majumps(majumps) + 1 @inbounds for rate in rates cur_rates[idx] = rate(u, params, t) - pqdata[idx] = t + randexp(p.rng) / cur_rates[idx] + pqdata[idx] = t + randexp(rng) / cur_rates[idx] idx += 1 end diff --git a/src/aggregators/rdirect.jl b/src/aggregators/rdirect.jl index 8376b71b9..eea283160 100644 --- a/src/aggregators/rdirect.jl +++ b/src/aggregators/rdirect.jl @@ -2,8 +2,8 @@ Direct with rejection sampling """ -mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct RDirectJumpAggregation{T, S, F1, F2, DEPGR} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -14,7 +14,6 @@ mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR max_rate::T counter::Int @@ -22,10 +21,10 @@ mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: end function RDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, - rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; + rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; num_specs, counter_threshold = length(crs), dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing if (get_num_majumps(maj) == 0) || !isempty(rs) @@ -42,9 +41,9 @@ function RDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, m max_rate = maximum(crs) affecttype = F2 <: Tuple ? F2 : Any - return RDirectJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg)}(nj, nj, njt, et, + return RDirectJumpAggregation{T, S, F1, affecttype, typeof(dg)}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng, + affs!, sps, dg, max_rate, 0, counter_threshold) end @@ -53,13 +52,13 @@ end # creating the JumpAggregation structure (tuple-based constant jumps) function aggregate(aggregator::RDirect, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(RDirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end @@ -93,7 +92,8 @@ function generate_jumps!(p::RDirectJumpAggregation, integrator, u, params, t) if nomorejumps!(p, sum_rate) return nothing end - (; rng, cur_rates, max_rate) = p + rng = get_rng(integrator) + (; cur_rates, max_rate) = p num_rxs = length(cur_rates) counter = 0 @@ -105,7 +105,7 @@ function generate_jumps!(p::RDirectJumpAggregation, integrator, u, params, t) p.counter = counter p.next_jump = rx - p.next_jump_time = t + randexp(p.rng) / sum_rate + p.next_jump_time = t + randexp(rng) / sum_rate nothing end diff --git a/src/aggregators/rssa.jl b/src/aggregators/rssa.jl index 2943bf88b..f86e0b444 100644 --- a/src/aggregators/rssa.jl +++ b/src/aggregators/rssa.jl @@ -4,8 +4,8 @@ # functions of the current population sizes (i.e. u) # requires vartojumps_map and fluct_rates as JumpProblem keywords -mutable struct RSSAJumpAggregation{T, S, F1, F2, RNG, VJMAP, JVMAP, BD, U} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct RSSAJumpAggregation{T, S, F1, F2, VJMAP, JVMAP, BD, U} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -17,7 +17,6 @@ mutable struct RSSAJumpAggregation{T, S, F1, F2, RNG, VJMAP, JVMAP, BD, U} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG vartojumps_map::VJMAP jumptovars_map::JVMAP bracket_data::BD @@ -26,10 +25,10 @@ mutable struct RSSAJumpAggregation{T, S, F1, F2, RNG, VJMAP, JVMAP, BD, U} <: end function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; u::U, vartojumps_map = nothing, + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + u::U, vartojumps_map = nothing, jumptovars_map = nothing, - bracket_data = nothing, kwargs...) where {T, S, F1, F2, RNG, U} + bracket_data = nothing, kwargs...) where {T, S, F1, F2, U} # a dependency graph is needed and must be provided if there are constant rate jumps if vartojumps_map === nothing if (get_num_majumps(maj) == 0) || !isempty(rs) @@ -63,10 +62,10 @@ function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, uhigh = similar(u) affecttype = F2 <: Tuple ? F2 : Any - RSSAJumpAggregation{T, S, F1, affecttype, RNG, typeof(vtoj_map), + RSSAJumpAggregation{T, S, F1, affecttype, typeof(vtoj_map), typeof(jtov_map), typeof(bd), U}(nj, nj, njt, et, crl_bnds, crh_bnds, sr, maj, rs, affs!, sps, - rng, vtoj_map, jtov_map, bd, ulow, + vtoj_map, jtov_map, bd, ulow, uhigh) end @@ -74,13 +73,13 @@ end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::RSSA, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(RSSAJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; u = u, + rates, affects!, save_positions; u, kwargs...) end @@ -108,7 +107,8 @@ function generate_jumps!(p::RSSAJumpAggregation, integrator, u, params, t) return nothing end # next jump type - (; ma_jumps, rates, cur_rate_high, cur_rate_low, rng) = p + (; ma_jumps, rates, cur_rate_high, cur_rate_low) = p + rng = get_rng(integrator) num_majumps = get_num_majumps(ma_jumps) rerl = zero(sum_rate) diff --git a/src/aggregators/rssacr.jl b/src/aggregators/rssacr.jl index 1caf3be53..9cc47c3b8 100644 --- a/src/aggregators/rssacr.jl +++ b/src/aggregators/rssacr.jl @@ -4,9 +4,9 @@ Composition-Rejection with Rejection sampling method (RSSA-CR) const MINJUMPRATE = 2.0^exponent(1e-12) -mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, +mutable struct RSSACRJumpAggregation{F, S, F1, F2, U, VJMAP, JVMAP, BD, P <: PriorityTable, W <: Function} <: - AbstractSSAJumpAggregator{F, S, F1, F2, RNG} + AbstractSSAJumpAggregator{F, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::F @@ -18,7 +18,6 @@ mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG vartojumps_map::VJMAP jumptovars_map::JVMAP bracket_data::BD @@ -31,11 +30,11 @@ mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, end function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate::F, maj::S, - rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; u::U, + rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; u::U, vartojumps_map = nothing, jumptovars_map = nothing, bracket_data = nothing, minrate = convert(F, MINJUMPRATE), maxrate = convert(F, Inf), - kwargs...) where {F, S, F1, F2, RNG, U} + kwargs...) where {F, S, F1, F2, U} # a dependency graph is needed and must be provided if there are constant rate jumps if vartojumps_map === nothing if (get_num_majumps(maj) == 0) || !isempty(rs) @@ -80,10 +79,10 @@ function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate: rt = PriorityTable(ratetogroup, zeros(F, 1), minrate, 2 * minrate) affecttype = F2 <: Tuple ? F2 : Any - RSSACRJumpAggregation{typeof(njt), S, F1, affecttype, RNG, U, typeof(vtoj_map), + RSSACRJumpAggregation{typeof(njt), S, F1, affecttype, U, typeof(vtoj_map), typeof(jtov_map), typeof(bd), typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, crl_bnds, crh_bnds, - sum_rate, maj, rs, affs!, sps, rng, vtoj_map, + sum_rate, maj, rs, affs!, sps, vtoj_map, jtov_map, bd, ulow, uhigh, minrate, maxrate, rt, ratetogroup) end @@ -92,13 +91,13 @@ end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::RSSACR, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(RSSACRJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; u = u, kwargs...) + rates, affects!, save_positions; u, kwargs...) end # set up a new simulation and calculate the first jump / jump time @@ -134,7 +133,8 @@ function generate_jumps!(p::RSSACRJumpAggregation, integrator, u, params, t) return nothing end - (; rt, ma_jumps, rates, cur_rate_high, cur_rate_low, rng) = p + (; rt, ma_jumps, rates, cur_rate_high, cur_rate_low) = p + rng = get_rng(integrator) num_majumps = get_num_majumps(ma_jumps) rerl = zero(sum_rate) diff --git a/src/aggregators/sortingdirect.jl b/src/aggregators/sortingdirect.jl index f9048f039..c20ea1ae4 100644 --- a/src/aggregators/sortingdirect.jl +++ b/src/aggregators/sortingdirect.jl @@ -2,8 +2,8 @@ # "The sorting direct method for stochastic simulation of biochemical systems with varying reaction execution behavior" # Comp. Bio. and Chem., 30, pg. 39-49 (2006). -mutable struct SortingDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct SortingDirectJumpAggregation{T, S, F1, F2, DEPGR} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -14,16 +14,15 @@ mutable struct SortingDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR jump_search_order::Vector{Int} jump_search_idx::Int end function SortingDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; num_specs, dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + num_specs, dep_graph = nothing, + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -42,9 +41,9 @@ function SortingDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr # map jump idx to idx in cur_rates jtoidx = collect(1:length(crs)) affecttype = F2 <: Tuple ? F2 : Any - SortingDirectJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg)}(nj, nj, njt, et, + SortingDirectJumpAggregation{T, S, F1, affecttype, typeof(dg)}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng, + affs!, sps, dg, jtoidx, zero(Int)) end @@ -53,13 +52,13 @@ end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::SortingDirect, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(SortingDirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end @@ -92,14 +91,15 @@ end # calculate the next jump / jump time function generate_jumps!(p::SortingDirectJumpAggregation, integrator, u, params, t) - p.next_jump_time = t + randexp(p.rng) / p.sum_rate + rng = get_rng(integrator) + p.next_jump_time = t + randexp(rng) / p.sum_rate # search for next jump if p.next_jump_time < p.end_time cur_rates = p.cur_rates numjumps = length(cur_rates) jso = p.jump_search_order - rn = p.sum_rate * rand(p.rng) + rn = p.sum_rate * rand(rng) @inbounds for idx in 1:numjumps rn -= cur_rates[jso[idx]] if rn < zero(rn) diff --git a/src/aggregators/ssajump.jl b/src/aggregators/ssajump.jl index 90c260c97..fec06c185 100644 --- a/src/aggregators/ssajump.jl +++ b/src/aggregators/ssajump.jl @@ -13,13 +13,11 @@ An aggregator interface for SSA-like algorithms. - `rates` # vector of rate functions for ConstantRateJumps - `affects!` # vector of affect functions for ConstantRateJumps - `save_positions` # tuple for whether to save the jumps before and/or after event - - `rng` # random number generator - ### Optional fields: - `dep_gr` # dependency graph, dep_gr[i] = indices of reactions that should be updated when rx i occurs. """ -abstract type AbstractSSAJumpAggregator{T, S, F1, F2, RNG} <: AbstractJumpAggregator end +abstract type AbstractSSAJumpAggregator{T, S, F1, F2} <: AbstractJumpAggregator end function DiscreteCallback(c::AbstractSSAJumpAggregator) DiscreteCallback(c, c, initialize = c, save_positions = c.save_positions) @@ -112,12 +110,12 @@ end """ build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rates, - affects!, save_positions, rng; kwargs...) + affects!, save_positions; kwargs...) Helper routine for setting up standard fields of SSA jump aggregations. """ function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rates, - affects!, save_positions, rng; kwargs...) + affects!, save_positions; kwargs...) # mass action jumps majumps = ma_jumps @@ -134,7 +132,7 @@ function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rate next_jump = 0 next_jump_time = typemax(typeof(t)) jump_agg_type(next_jump, next_jump_time, end_time, cur_rates, sum_rate, - majumps, rates, affects!, save_positions, rng; kwargs...) + majumps, rates, affects!, save_positions; kwargs...) end """ diff --git a/src/problem.jl b/src/problem.jl index 157c97a30..56d58da62 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -51,7 +51,6 @@ $(FIELDS) ## Keyword Arguments - - `rng`, the random number generator to use. Defaults to Julia's built-in generator. - `save_positions=(true,true)` when including variable rates and `(false,true)` for constant rates, specifies whether to save the system's state (before, after) the jump occurs. - `spatial_system`, for spatial problems the underlying spatial structure. @@ -61,14 +60,25 @@ $(FIELDS) integration interface, and treated like general `VariableRateJump`s. - `vr_aggregator`, indicates the aggregator to use for sampling variable rate jumps. Current default is `VR_FRM`. + - `tstops`, time stops to pass through to the solver. Can be an `AbstractVector` of times + or a callable `(p, tspan) -> times`. Please see the [tutorial page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in the DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage examples and commonly asked questions. + +!!! warning "Thread Safety" + `JumpProblem` contains mutable state (aggregator data, callbacks) and is **not + thread-safe**. A single `JumpProblem` instance must not be solved concurrently from + multiple threads or tasks without first creating independent copies via `deepcopy`. + When running ensemble simulations via `EnsembleProblem`, this is handled automatically + — the `SciMLBase` ensemble layer provides per-task isolation and per-trajectory RNG + seeding. This warning only applies to manually parallelized `solve` calls outside the + ensemble interface. """ -mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J1, - J2, J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J} +mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J1, + J2, J3, J4, K} <: DiffEqBase.AbstractJumpProblem{P, J} """The type of problem to couple the jumps to. For a pure jump process use `DiscreteProblem`, to couple to ODEs, `ODEProblem`, etc.""" prob::P """The aggregator algorithm that determines the next jump times and types for `ConstantRateJump`s and `MassActionJump`s. Examples include `Direct`.""" @@ -85,26 +95,24 @@ mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggrega regular_jump::J3 """The `MassActionJump`s.""" massaction_jump::J4 - """The random number generator to use.""" - rng::R """kwargs to pass on to solve call.""" kwargs::K end function JumpProblem(p::P, a::A, dj::J, jc::C, cj::J1, vj::J2, rj::J3, mj::J4, - rng::R, kwargs::K) where {P, A, J, C, J1, J2, J3, J4, R, K} + kwargs::K) where {P, A, J, C, J1, J2, J3, J4, K} iip = isinplace_jump(p, rj) - JumpProblem{iip, P, A, C, J, J1, J2, J3, J4, R, K}(p, a, dj, jc, cj, vj, rj, mj, - rng, kwargs) + JumpProblem{iip, P, A, C, J, J1, J2, J3, J4, K}(p, a, dj, jc, cj, vj, rj, mj, + kwargs) end ######## remaking ###### # for a problem where prob.u0 is an ExtendedJumpArray, create an ExtendedJumpArray that # aliases and resets prob.u0.jump_u while having newu0 as the new u component. -function remake_extended_u0(prob, newu0, rng) +function remake_extended_u0(prob, newu0) jump_u = prob.u0.jump_u ttype = eltype(prob.tspan) - @. jump_u = -randexp(rng, ttype) + @. jump_u = zero(ttype) ExtendedJumpArray(newu0, jump_u) end @@ -142,7 +150,7 @@ function DiffEqBase.remake(jprob::JumpProblem; u0 = missing, p = missing, error("Passed in u0 is incompatible with current u0 which has type: $(typeof(prob.u0.u)).") end - final_u0 = remake_extended_u0(prob, state_vals, jprob.rng) + final_u0 = remake_extended_u0(prob, state_vals) end newprob = DiffEqBase.remake(prob; u0 = final_u0, p, interpret_symbolicmap, use_defaults, kwargs...) else @@ -169,8 +177,8 @@ function DiffEqBase.remake(jprob::JumpProblem; u0 = missing, p = missing, end T(newprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback, - jprob.constant_jumps, jprob.variable_jumps, jprob.regular_jump, - jprob.massaction_jump, jprob.rng, jprob.kwargs) + jprob.constant_jumps, jprob.variable_jumps, jprob.regular_jump, + jprob.massaction_jump, jprob.kwargs) end # for updating parameters in JumpProblems to update MassActionJumps @@ -239,10 +247,14 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS vr_aggregator::VariableRateAggregator = VR_FRM(), save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true), - rng = DEFAULT_RNG, scale_rates = true, useiszero = true, + scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, callback = nothing, tstops = nothing, use_vrj_bounds = true, kwargs...) + if haskey(kwargs, :rng) + throw(ArgumentError("`rng` is no longer a keyword argument for `JumpProblem`. Pass `rng` to `solve` or `init` instead, e.g. `solve(jprob, SSAStepper(); rng = my_rng)`.")) + end + # initialize the MassActionJump rate constants with the user parameters if using_params(jumps.massaction_jump) rates = jumps.massaction_jump.param_mapper(prob.p) @@ -289,15 +301,15 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS constant_jump_callback = CallbackSet() else disc_agg = aggregate(aggregator, u, prob.p, t, end_time, crjs, maj, - save_positions, rng; kwargs...) + save_positions; kwargs...) constant_jump_callback = DiscreteCallback(disc_agg) end # handle any remaining vrjs if length(cvrjs) > 0 # Handle variable rate jumps based on vr_aggregator - new_prob, variable_jump_callback = configure_jump_problem(prob, vr_aggregator, - jumps, cvrjs; rng) + new_prob, variable_jump_callback = configure_jump_problem(prob, vr_aggregator, + jumps, cvrjs) else new_prob = prob variable_jump_callback = CallbackSet() @@ -310,18 +322,22 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS JumpProblem{iip, typeof(new_prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(cvrjs), typeof(jumps.regular_jump), - typeof(maj), typeof(rng), typeof(solkwargs)}(new_prob, aggregator, disc_agg, - jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, rng, solkwargs) + typeof(maj), typeof(solkwargs)}(new_prob, aggregator, disc_agg, + jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, solkwargs) end # Special dispatch for PureLeaping aggregator - bypasses all aggregation function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true), - rng = DEFAULT_RNG, scale_rates = true, useiszero = true, + scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, callback = nothing, tstops = nothing, kwargs...) + if haskey(kwargs, :rng) + throw(ArgumentError("`rng` is no longer a keyword argument for `JumpProblem`. Pass `rng` to `solve` or `init` instead, e.g. `solve(jprob, SSAStepper(); rng = my_rng)`.")) + end + # Validate no spatial systems (not currently supported) (spatial_system !== nothing || hopping_constants !== nothing) && error("PureLeaping does not currently support spatial problems.") @@ -342,18 +358,18 @@ function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; # No discrete jump aggregation or variable rate callbacks are created disc_agg = nothing jump_cbs = CallbackSet() - + # Store all jump types for access by tau-leaping solver crjs = jumps.constant_jumps vrjs = jumps.variable_jumps - + iip = isinplace_jump(prob, jumps.regular_jump) solkwargs = tstops === nothing ? make_kwarg(; callback) : make_kwarg(; callback, tstops) JumpProblem{iip, typeof(prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(vrjs), typeof(jumps.regular_jump), - typeof(maj), typeof(rng), typeof(solkwargs)}(prob, aggregator, disc_agg, - jump_cbs, crjs, vrjs, jumps.regular_jump, maj, rng, solkwargs) + typeof(maj), typeof(solkwargs)}(prob, aggregator, disc_agg, + jump_cbs, crjs, vrjs, jumps.regular_jump, maj, solkwargs) end aggregator(jp::JumpProblem{iip, P, A}) where {iip, P, A} = A diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index da98ba1c4..f71351efc 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -6,6 +6,9 @@ end SimpleExplicitTauLeaping(; epsilon = 0.05) = SimpleExplicitTauLeaping(epsilon) +SciMLBase.supports_solve_rng(::JumpProblem, ::SimpleTauLeaping) = true +SciMLBase.supports_solve_rng(::JumpProblem, ::SimpleExplicitTauLeaping) = true + function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ @@ -70,13 +73,14 @@ function _process_saveat(saveat, tspan, save_start, save_end) end function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; - seed = nothing, dt = error("dt is required for SimpleTauLeaping."), + seed = nothing, rng = nothing, + dt = error("dt is required for SimpleTauLeaping."), saveat = nothing, save_start = nothing, save_end = nothing) validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only RegularJumps.") - (; prob, rng) = jump_prob - (seed !== nothing) && seed!(rng, seed) + prob = jump_prob.prob + _rng = resolve_rng(rng, seed) rj = jump_prob.regular_jump rate = rj.rate # rate function rate(out,u,p,t) @@ -117,7 +121,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; t_new = tprev + dt rate(rate_cache, uprev, p, tprev) rate_cache .*= dt - counts .= pois_rand.((rng,), rate_cache) + counts .= pois_rand.((_rng,), rate_cache) c(du, uprev, p, tprev, counts, mark) u_new .= du .+ uprev @@ -335,22 +339,20 @@ function simple_explicit_tau_leaping_loop!( end function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; - seed = nothing, + seed = nothing, rng = nothing, dtmin = nothing, saveat = nothing, save_start = nothing, save_end = nothing) validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleExplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") prob = jump_prob.prob - rng = jump_prob.rng + _rng = resolve_rng(rng, seed) tspan = prob.tspan if dtmin === nothing dtmin = 1e-10 * one(typeof(tspan[2])) end - (seed !== nothing) && seed!(rng, seed) - maj = jump_prob.massaction_jump numjumps = get_num_majumps(maj) rj = jump_prob.regular_jump @@ -394,7 +396,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; reactant_stoch, hor, length(u0), numjumps) simple_explicit_tau_leaping_loop!( - prob, alg, u_current, u_new, t_current, t_end, p, rng, + prob, alg, u_current, u_new, t_current, t_end, p, _rng, rate, c, nu, hor, max_hor, max_stoich, numjumps, epsilon, dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj, save_end) diff --git a/src/solve.jl b/src/solve.jl index 11bc16bb6..ffe3177dc 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,3 +1,23 @@ +""" + resolve_rng(rng, seed) + +Resolve which RNG to use for a jump simulation. + +Priority: `rng` > `seed` (creates `Xoshiro`) > `Random.default_rng()`. +""" +function resolve_rng(rng, seed) + if rng !== nothing + rng + elseif seed !== nothing + Random.Xoshiro(seed) + else + Random.default_rng() + end +end + +SciMLBase.supports_solve_rng(jprob::JumpProblem, alg::DiffEqBase.DEAlgorithm) = + SciMLBase.supports_solve_rng(jprob.prob, alg) + function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg::DiffEqBase.DEAlgorithm; merge_callbacks = true, kwargs...) where {P} @@ -21,6 +41,9 @@ function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}, integrator.sol end +SciMLBase.supports_solve_rng(jprob::JumpProblem, ::Nothing) = + jprob.prob isa DiffEqBase.DiscreteProblem + # if passed a JumpProblem over a DiscreteProblem, and no aggregator is selected use # SSAStepper function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}; @@ -38,53 +61,33 @@ function DiffEqBase.__init(_jump_prob::DiffEqBase.AbstractJumpProblem{P}, kwargs = DiffEqBase.merge_problem_kwargs(_jump_prob; merge_callbacks, kwargs...) __jump_init(_jump_prob, alg; kwargs...) -end +end function __jump_init(_jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg; - callback = nothing, seed = nothing, + callback = nothing, seed = nothing, rng = nothing, alias_jump = Threads.threadid() == 1, kwargs...) where {P} + + _rng = resolve_rng(rng, seed) + if alias_jump jump_prob = _jump_prob - reset_jump_problem!(jump_prob, seed) else - jump_prob = resetted_jump_problem(_jump_prob, seed) + jump_prob = resetted_jump_problem(_jump_prob) end - # DDEProblems do not have a recompile_flag argument - if jump_prob.prob isa DiffEqBase.AbstractDDEProblem - # callback comes after jump consistent with SSAStepper - integrator = init(jump_prob.prob, alg; - callback = CallbackSet(jump_prob.jump_callback, callback), - kwargs...) - else - # callback comes after jump consistent with SSAStepper - integrator = init(jump_prob.prob, alg; - callback = CallbackSet(jump_prob.jump_callback, callback), - kwargs...) - end + init(jump_prob.prob, alg; + callback = CallbackSet(jump_prob.jump_callback, callback), + rng = _rng, kwargs...) end -# Derive an independent seed from the caller's seed. When a caller (e.g. StochasticDiffEq) -# passes the same seed used for its noise process, we must produce a distinct seed for the -# jump aggregator's RNG. We cannot assume the JumpProblem's stored RNG is any particular -# type, so we pass the seed through `hash` (to decorrelate from the input) and then through -# a Xoshiro draw (to ensure strong mixing regardless of the target RNG's seeding quality). -const _JUMP_SEED_SALT = 0x4a756d7050726f63 # "JumPProc" in ASCII -_derive_jump_seed(seed) = rand(Random.Xoshiro(hash(seed, _JUMP_SEED_SALT)), UInt64) - -function resetted_jump_problem(_jump_prob, seed) - jump_prob = deepcopy(_jump_prob) - if seed !== nothing && !isempty(jump_prob.jump_callback.discrete_callbacks) - rng = jump_prob.jump_callback.discrete_callbacks[1].condition.rng - Random.seed!(rng, _derive_jump_seed(seed)) - end - jump_prob +# Keep function signatures for StochasticDiffEq backward compatibility. +# The seed argument is accepted but no longer used to reseed aggregator RNGs +# (RNG state is now managed by the integrator). +function resetted_jump_problem(_jump_prob, seed = nothing) + deepcopy(_jump_prob) end -function reset_jump_problem!(jump_prob, seed) - if seed !== nothing && !isempty(jump_prob.jump_callback.discrete_callbacks) - Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, - _derive_jump_seed(seed)) - end +function reset_jump_problem!(jump_prob, seed = nothing) + nothing end diff --git a/src/spatial/directcrdirect.jl b/src/spatial/directcrdirect.jl index f282bd1ae..0961ec25b 100644 --- a/src/spatial/directcrdirect.jl +++ b/src/spatial/directcrdirect.jl @@ -5,10 +5,10 @@ const MINJUMPRATE = 2.0^exponent(1e-12) #NOTE state vector u is a matrix. u[i,j] is species i, site j #NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j -mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, +mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, J, RX, HOP, DEPGR, VJMAP, JVMAP, SS, U <: PriorityTable, W <: Function} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::SpatialJump{J} #some structure to identify the next event: reaction or hop prev_jump::SpatialJump{J} #some structure to identify the previous event: reaction or hop next_jump_time::T @@ -19,7 +19,6 @@ mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPG rates::F1 # legacy, not used affects!::F2 # legacy, not used save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR #dep graph is same for each site vartojumps_map::VJMAP #vartojumps_map is same for each site jumptovars_map::JVMAP #jumptovars_map is same for each site @@ -31,11 +30,11 @@ end function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, hop_rates::HOP, site_rates::Vector{T}, - sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; + sps::Tuple{Bool, Bool}, spatial_system::SS; num_specs, minrate = convert(T, MINJUMPRATE), vartojumps_map = nothing, jumptovars_map = nothing, dep_graph = nothing, - kwargs...) where {J, T, RX, HOP, RNG, SS} + kwargs...) where {J, T, RX, HOP, SS} # a dependency graph is needed if dep_graph === nothing @@ -69,12 +68,12 @@ function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rat # construct an empty initial priority table -- we'll reset this in init rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) - DirectCRDirectJumpAggregation{T, Nothing, Nothing, Nothing, RNG, J, RX, HOP, + DirectCRDirectJumpAggregation{T, Nothing, Nothing, Nothing, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, rx_rates, hop_rates, site_rates, nothing, nothing, sps, - rng, dg, vtoj_map, + dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, ratetogroup) end @@ -82,7 +81,7 @@ end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::DirectCRDirect, starting_state, p, t, end_time, - constant_jumps, ma_jumps, save_positions, rng; hopping_constants, + constant_jumps, ma_jumps, save_positions; hopping_constants, spatial_system, kwargs...) num_species = size(starting_state, 1) majumps = ma_jumps @@ -99,7 +98,7 @@ function aggregate(aggregator::DirectCRDirect, starting_state, p, t, end_time, site_rates = zeros(typeof(end_time), num_sites(spatial_system)) DirectCRDirectJumpAggregation(next_jump, next_jump_time, end_time, rx_rates, hop_rates, - site_rates, save_positions, rng, spatial_system; + site_rates, save_positions, spatial_system; num_specs = num_species, kwargs...) end @@ -113,10 +112,11 @@ end # calculate the next jump / jump time function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, params, u, t) - p.next_jump_time = t + randexp(p.rng) / p.rt.gsum + rng = get_rng(integrator) + p.next_jump_time = t + randexp(rng) / p.rt.gsum p.next_jump_time >= p.end_time && return nothing - site = sample(p.rt, p.site_rates, p.rng) - p.next_jump = sample_jump_direct(p, site) + site = sample(p.rt, p.site_rates, rng) + p.next_jump = sample_jump_direct(p, site, rng) nothing end diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index 07c48da6e..acf6b56eb 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -3,9 +3,9 @@ ############################ NSM ################################### #NOTE state vector u is a matrix. u[i,j] is species i, site j #NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j -mutable struct NSMJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, +mutable struct NSMJumpAggregation{T, S, F1, F2, J, RX, HOP, DEPGR, VJMAP, JVMAP, PQ, SS} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::SpatialJump{J} #some structure to identify the next event: reaction or hop prev_jump::SpatialJump{J} #some structure to identify the previous event: reaction or hop next_jump_time::T @@ -15,7 +15,6 @@ mutable struct NSMJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, J rates::F1 # legacy, not used affects!::F2 # legacy, not used save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR #dep graph is same for each site vartojumps_map::VJMAP #vartojumps_map is same for each site jumptovars_map::JVMAP #jumptovars_map is same for each site @@ -27,9 +26,9 @@ end function NSMJumpAggregation( nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, hop_rates::HOP, sps::Tuple{Bool, Bool}, - rng::RNG, spatial_system::SS; num_specs, + spatial_system::SS; num_specs, vartojumps_map = nothing, jumptovars_map = nothing, - dep_graph = nothing, kwargs...) where {J, T, RX, HOP, RNG, SS} + dep_graph = nothing, kwargs...) where {J, T, RX, HOP, SS} # a dependency graph is needed if dep_graph === nothing @@ -55,13 +54,13 @@ function NSMJumpAggregation( pq = MutableBinaryMinHeap{T}() - NSMJumpAggregation{T, Nothing, Nothing, Nothing, RNG, J, RX, HOP, typeof(dg), + NSMJumpAggregation{T, Nothing, Nothing, Nothing, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), typeof(pq), SS}(nj, nj, njt, et, rx_rates, hop_rates, nothing, nothing, - sps, rng, dg, + sps, dg, vtoj_map, jtov_map, pq, @@ -72,7 +71,7 @@ end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::NSM, starting_state, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; hopping_constants, spatial_system, + ma_jumps, save_positions; hopping_constants, spatial_system, kwargs...) num_species = size(starting_state, 1) majumps = ma_jumps @@ -88,7 +87,7 @@ function aggregate(aggregator::NSM, starting_state, p, t, end_time, constant_jum hop_rates = HopRates(hopping_constants, spatial_system) NSMJumpAggregation(next_jump, next_jump_time, end_time, rx_rates, hop_rates, - save_positions, rng, spatial_system; num_specs = num_species, + save_positions, spatial_system; num_specs = num_species, kwargs...) end @@ -104,7 +103,8 @@ end function generate_jumps!(p::NSMJumpAggregation, integrator, params, u, t) p.next_jump_time, site = top_with_handle(p.pq) p.next_jump_time >= p.end_time && return nothing - p.next_jump = sample_jump_direct(p, site) + rng = get_rng(integrator) + p.next_jump = sample_jump_direct(p, site, rng) nothing end @@ -125,7 +125,8 @@ end reset all structs, reevaluate all rates, recalculate tentative site firing times, and reinit the priority queue """ function fill_rates_and_get_times!(aggregation::NSMJumpAggregation, integrator, t) - (; spatial_system, rx_rates, hop_rates, rng) = aggregation + (; spatial_system, rx_rates, hop_rates) = aggregation + rng = get_rng(integrator) u = integrator.u reset!(rx_rates) @@ -153,30 +154,31 @@ recalculate jump rates for jumps that depend on the just executed jump (p.prev_j """ function update_dependent_rates_and_firing_times!(p::NSMJumpAggregation, integrator, t) u = integrator.u + rng = get_rng(integrator) jump = p.prev_jump if is_hop(p, jump) source_site = jump.src target_site = jump.dst update_rates_after_hop!(p, integrator, source_site, target_site, jump.jidx) - update_site_time!(p, source_site, t) - update_site_time!(p, target_site, t) + update_site_time!(p, source_site, t, rng) + update_site_time!(p, target_site, t, rng) else site = jump.src update_rates_after_reaction!(p, integrator, site, reaction_id_from_jump(p, jump)) - update_site_time!(p, site, t) + update_site_time!(p, site, t, rng) end nothing end """ - update_site_time!(p::NSMJumpAggregation, site, t) + update_site_time!(p::NSMJumpAggregation, site, t, rng) update the time of site in the priority queue """ -function update_site_time!(p::NSMJumpAggregation, site, t) +function update_site_time!(p::NSMJumpAggregation, site, t, rng) site_rate = (total_site_rate(p.rx_rates, p.hop_rates, site)) if site_rate > zero(typeof(site_rate)) - update!(p.pq, site, t + randexp(p.rng) / site_rate) + update!(p.pq, site, t + randexp(rng) / site_rate) else update!(p.pq, site, typemax(t)) end diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index 370e86015..3be2b7378 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -23,18 +23,18 @@ end ######################## helper routines for all spatial SSAs ######################## """ - sample_jump_direct(p, site) + sample_jump_direct(p, site, rng) sample jump at site with direct method """ -function sample_jump_direct(p, site) - if rand(p.rng) * (total_site_rate(p.rx_rates, p.hop_rates, site)) < +function sample_jump_direct(p, site, rng) + if rand(rng) * (total_site_rate(p.rx_rates, p.hop_rates, site)) < total_site_rx_rate(p.rx_rates, site) - rx = sample_rx_at_site(p.rx_rates, site, p.rng) + rx = sample_rx_at_site(p.rx_rates, site, rng) return SpatialJump(site, rx + p.numspecies, site) else species_to_diffuse, - target_site = sample_hop_at_site(p.hop_rates, site, p.rng, + target_site = sample_hop_at_site(p.hop_rates, site, rng, p.spatial_system) return SpatialJump(site, species_to_diffuse, target_site) end diff --git a/src/variable_rate.jl b/src/variable_rate.jl index e78684fb6..01b6d2115 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -28,21 +28,21 @@ Simulating a birth-death process with `VR_FRM`: ```julia using JumpProcesses, OrdinaryDiffEq -u0 = [1.0] # Initial population -p = [10.0, 0.5] # [birth rate, death rate] +u0 = [1.0] # Initial population +p = [10.0, 0.5] # [birth rate, death rate] tspan = (0.0, 10.0) -# Birth jump: ∅ → X +# Birth jump: ∅ → X birth_rate(u, p, t) = p[1] birth_affect!(integrator) = (integrator.u[1] += 1; nothing) birth_jump = VariableRateJump(birth_rate, birth_affect!) -# Death jump: X → ∅ +# Death jump: X → ∅ death_rate(u, p, t) = p[2] * u[1] death_affect!(integrator) = (integrator.u[1] -= 1; nothing) death_jump = VariableRateJump(death_rate, death_affect!) -# Problem setup +# Problem setup oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p) jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_FRM()) sol = solve(jprob, Tsit5()) @@ -58,26 +58,25 @@ sol = solve(jprob, Tsit5()) """ struct VR_FRM <: VariableRateAggregator end -function configure_jump_problem(prob, vr_aggregator::VR_FRM, jumps, cvrjs; - rng = DEFAULT_RNG) - new_prob = extend_problem(prob, cvrjs; rng) - variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng) +function configure_jump_problem(prob, vr_aggregator::VR_FRM, jumps, cvrjs) + new_prob = extend_problem(prob, cvrjs) + variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...) return new_prob, variable_jump_callback end # extends prob.u0 to an ExtendedJumpArray with Njumps integrated intensity values, # of type prob.tspan -function extend_u0(prob, Njumps, rng) +function extend_u0(prob, Njumps) ttype = eltype(prob.tspan) - u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:Njumps]) + u0 = ExtendedJumpArray(prob.u0, zeros(ttype, Njumps)) return u0 end -function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps) error("General `VariableRateJump`s require a continuous problem, like an ODE/SDE/DDE/DAE problem. To use a `DiscreteProblem` bounded `VariableRateJump`s must be used. See the JumpProcesses docs.") end -function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps) _f = SciMLBase.unwrapped_f(prob.f) if isinplace(prob) @@ -97,13 +96,13 @@ function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAUL end end - u0 = extend_u0(prob, length(jumps), rng) + u0 = extend_u0(prob, length(jumps)) f = ODEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, observed = prob.f.observed) remake(prob; f, u0) end -function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps) _f = SciMLBase.unwrapped_f(prob.f) if isinplace(prob) @@ -133,13 +132,13 @@ function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAUL end end - u0 = extend_u0(prob, length(jumps), rng) + u0 = extend_u0(prob, length(jumps)) f = SDEFunction{isinplace(prob)}(jump_f, jump_g; sys = prob.f.sys, observed = prob.f.observed) remake(prob; f, g = jump_g, u0) end -function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps) _f = SciMLBase.unwrapped_f(prob.f) if isinplace(prob) @@ -159,14 +158,14 @@ function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAUL end end - u0 = extend_u0(prob, length(jumps), rng) + u0 = extend_u0(prob, length(jumps)) f = DDEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, observed = prob.f.observed) remake(prob; f, u0) end # Not sure if the DAE one is correct: Should be a residual of sorts -function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps) _f = SciMLBase.unwrapped_f(prob.f) if isinplace(prob) @@ -186,16 +185,15 @@ function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAUL end end - u0 = extend_u0(prob, length(jumps), rng) + u0 = extend_u0(prob, length(jumps)) f = DAEFunction{isinplace(prob)}(jump_f, sys = prob.f.sys, observed = prob.f.observed) remake(prob; f, u0) end -struct VR_FRMEventCallback{F, RNG} +struct VR_FRMEventCallback{F} idx::Int affect!::F - rng::RNG end # condition: (u, t, integrator) @@ -204,19 +202,21 @@ end # affect: (integrator) function (c::VR_FRMEventCallback)(integrator) c.affect!(integrator) - integrator.u.jump_u[c.idx] = -randexp(c.rng, typeof(integrator.t)) + rng = get_rng(integrator) + integrator.u.jump_u[c.idx] = -randexp(rng, typeof(integrator.t)) nothing end # initialize: (cb, u, t, integrator) function (c::VR_FRMEventCallback)(cb, u, t, integrator) - integrator.u.jump_u[c.idx] = -randexp(c.rng, typeof(integrator.t)) + rng = get_rng(integrator) + integrator.u.jump_u[c.idx] = -randexp(rng, typeof(integrator.t)) u_modified!(integrator, true) nothing end -function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG) - cb_functor = VR_FRMEventCallback(idx, jump.affect!, rng) +function wrap_jump_in_callback(idx, jump) + cb_functor = VR_FRMEventCallback(idx, jump.affect!) ContinuousCallback(cb_functor, cb_functor; initialize = cb_functor, idxs = jump.idxs, @@ -227,15 +227,15 @@ function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG) reltol = jump.reltol) end -function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG) +function build_variable_callback(cb, idx, jump, jumps...) idx += 1 - new_cb = wrap_jump_in_callback(idx, jump; rng) - build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG) + new_cb = wrap_jump_in_callback(idx, jump) + build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...) end -function build_variable_callback(cb, idx, jump; rng = DEFAULT_RNG) +function build_variable_callback(cb, idx, jump) idx += 1 - CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng)) + CallbackSet(cb, wrap_jump_in_callback(idx, jump)) end @inline function update_jumps!(du, u, p, t, idx, jump) @@ -267,21 +267,21 @@ Simulating a birth-death process with `VR_Direct` (default) and VR_DirectFW: ```julia using JumpProcesses, OrdinaryDiffEq -u0 = [1.0] # Initial population -p = [10.0, 0.5] # [birth rate, death rate coefficient] +u0 = [1.0] # Initial population +p = [10.0, 0.5] # [birth rate, death rate coefficient] tspan = (0.0, 10.0) -# Birth jump: ∅ → X +# Birth jump: ∅ → X birth_rate(u, p, t) = p[1] birth_affect!(integrator) = (integrator.u[1] += 1; nothing) birth_jump = VariableRateJump(birth_rate, birth_affect!) -# Death jump: X → ∅ +# Death jump: X → ∅ death_rate(u, p, t) = p[2] * u[1] death_affect!(integrator) = (integrator.u[1] -= 1; nothing) death_jump = VariableRateJump(death_rate, death_affect!) -# Problem setup +# Problem setup oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p) jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_Direct()) sol = solve(jprob, Tsit5()) @@ -297,21 +297,19 @@ sol = solve(jprob, Tsit5()) struct VR_Direct <: VariableRateAggregator end struct VR_DirectFW <: VariableRateAggregator end -mutable struct VR_DirectEventCache{T, RNG, F1, F2} +mutable struct VR_DirectEventCache{T, F1, F2} prev_time::T prev_threshold::T current_time::T current_threshold::T total_rate::T - rng::RNG rate_funcs::F1 affect_funcs::F2 cum_rate_sum::Vector{T} end function VR_DirectEventCache( - jumps::JumpSet, ::VR_Direct, prob, ::Type{T}; rng = DEFAULT_RNG) where {T} - initial_threshold = randexp(rng, T) + jumps::JumpSet, ::VR_Direct, prob, ::Type{T}) where {T} vjumps = jumps.variable_jumps # handle vjumps using tuples @@ -319,14 +317,13 @@ function VR_DirectEventCache( cum_rate_sum = Vector{T}(undef, length(vjumps)) - VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T), - initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, + VR_DirectEventCache{T, typeof(rate_funcs), typeof(affect_funcs)}(zero(T), + zero(T), zero(T), zero(T), zero(T), rate_funcs, affect_funcs, cum_rate_sum) end function VR_DirectEventCache( - jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}; rng = DEFAULT_RNG) where {T} - initial_threshold = randexp(rng, T) + jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}) where {T} vjumps = jumps.variable_jumps t, u = prob.tspan[1], prob.u0 @@ -336,16 +333,17 @@ function VR_DirectEventCache( cum_rate_sum = Vector{T}(undef, length(vjumps)) - VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), Any}(zero(T), - initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, + VR_DirectEventCache{T, typeof(rate_funcs), Any}(zero(T), + zero(T), zero(T), zero(T), zero(T), rate_funcs, affect_funcs, cum_rate_sum) end # Initialization function for VR_DirectEventCache function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrator) + rng = get_rng(integrator) cache.prev_time = zero(integrator.t) cache.current_time = zero(integrator.t) - cache.prev_threshold = randexp(cache.rng, eltype(integrator.t)) + cache.prev_threshold = randexp(rng, eltype(integrator.t)) cache.current_threshold = cache.prev_threshold cache.total_rate = zero(integrator.t) cache.cum_rate_sum .= 0 @@ -363,8 +361,8 @@ end nothing end -@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2}, - ::I) where {T, RNG, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} +@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, F1, F2}, + ::I) where {T, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} nothing end @@ -392,16 +390,16 @@ function build_variable_integcallback(cache::VR_DirectEventCache, jumps) save_positions, abstol, reltol) end -function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs; rng = DEFAULT_RNG) +function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs) new_prob = prob - cache = VR_DirectEventCache(jumps, VR_Direct(), prob, eltype(prob.tspan); rng) + cache = VR_DirectEventCache(jumps, VR_Direct(), prob, eltype(prob.tspan)) variable_jump_callback = build_variable_integcallback(cache, cvrjs) return new_prob, variable_jump_callback end -function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs; rng = DEFAULT_RNG) +function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs) new_prob = prob - cache = VR_DirectEventCache(jumps, VR_DirectFW(), prob, eltype(prob.tspan); rng) + cache = VR_DirectEventCache(jumps, VR_DirectFW(), prob, eltype(prob.tspan)) variable_jump_callback = build_variable_integcallback(cache, cvrjs) return new_prob, variable_jump_callback end @@ -434,8 +432,7 @@ end end function total_variable_rate( - cache::VR_DirectEventCache{ - T, RNG, F1, F2}, u, p, t) where {T, RNG, F1, F2} + cache::VR_DirectEventCache{T, F1, F2}, u, p, t) where {T, F1, F2} (; cum_rate_sum, rate_funcs) = cache sum_rate = cumsum_rates!(cum_rate_sum, u, p, t, rate_funcs) return sum_rate @@ -480,8 +477,8 @@ function (cache::VR_DirectEventCache)(u, t, integrator) return cache.current_threshold end -@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, - integrator::I, idx) where {T, RNG, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} +@generated function execute_affect!(cache::VR_DirectEventCache{T, F1, F2}, + integrator::I, idx) where {T, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} quote (; affect_funcs) = cache Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds affect_funcs[i](integrator)) i -> (@inbounds affect_funcs[fieldcount(F2)](integrator)) @@ -508,7 +505,7 @@ function (cache::VR_DirectEventCache)(integrator) end cache.total_rate = total_variable_rate_sum - rng = cache.rng + rng = get_rng(integrator) r = rand(rng) * total_variable_rate_sum @inbounds jump_idx = searchsortedfirst(cache.cum_rate_sum, r) execute_affect!(cache, integrator, jump_idx) diff --git a/test/allocations.jl b/test/allocations.jl index 64efb7582..25bd02f1e 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -35,17 +35,17 @@ let u₀ = [999, 10, 0] tspan = (0.0, 250.0) dprob = DiscreteProblem(u₀, tspan, p) - jprob = JumpProblem(dprob, Direct(), maj, jump, jump2; save_positions, rng) - sol = solve(jprob, SSAStepper()) + jprob = JumpProblem(dprob, Direct(), maj, jump, jump2; save_positions) + sol = solve(jprob, SSAStepper(); rng) - al1 = @allocations solve(jprob, SSAStepper()) + al1 = @allocations solve(jprob, SSAStepper(); rng) tspan2 = (0.0, 2500.0) dprob2 = DiscreteProblem(u₀, tspan2, p) - jprob2 = JumpProblem(dprob2, Direct(), maj, jump, jump2; save_positions, rng) - sol2 = solve(jprob2, SSAStepper()) + jprob2 = JumpProblem(dprob2, Direct(), maj, jump, jump2; save_positions) + sol2 = solve(jprob2, SSAStepper(); rng) - al2 = @allocations solve(jprob2, SSAStepper()) + al2 = @allocations solve(jprob2, SSAStepper(); rng) @test al1 == al2 end @@ -56,7 +56,7 @@ let end function makeprob(; T = 100.0, alg = Direct(), save_positions = (false, false), - graphkwargs = (;), rng) + graphkwargs = (;)) r1(u, p, t) = rate(p[1], u[1], u[2], p[2]) * u[1] r2(u, p, t) = rate(p[1], u[2], u[1], p[2]) * u[2] r3(u, p, t) = p[3] * u[1] @@ -86,7 +86,7 @@ let ConstantRateJump(r3, aff3!), ConstantRateJump(r4, aff4!), ConstantRateJump(r5, aff5!), ConstantRateJump(r6, aff6!); - save_positions, rng, graphkwargs...) + save_positions, graphkwargs...) return jprob end @@ -99,15 +99,15 @@ let graphkwargs = (; dep_graph, vartojumps_map, jumptovars_map) @testset "Allocations for $agg" for agg in JumpProcesses.JUMP_AGGREGATORS - jprob1 = makeprob(; alg = agg, T = 10.0, graphkwargs, rng = StableRNG(1234)) + jprob1 = makeprob(; alg = agg, T = 10.0, graphkwargs) stepper = SSAStepper() - sol1 = solve(jprob1, stepper) - sol1 = solve(jprob1, stepper) - al1 = @allocated solve(jprob1, stepper) - jprob2 = makeprob(; alg = agg, T = 100.0, graphkwargs, rng = StableRNG(1234)) - sol2 = solve(jprob2, stepper) - sol2 = solve(jprob2, stepper) - al2 = @allocated solve(jprob2, stepper) + sol1 = solve(jprob1, stepper; rng = StableRNG(1234)) + sol1 = solve(jprob1, stepper; rng = StableRNG(1234)) + al1 = @allocated solve(jprob1, stepper; rng = StableRNG(1234)) + jprob2 = makeprob(; alg = agg, T = 100.0, graphkwargs) + sol2 = solve(jprob2, stepper; rng = StableRNG(1234)) + sol2 = solve(jprob2, stepper; rng = StableRNG(1234)) + al2 = @allocated solve(jprob2, stepper; rng = StableRNG(1234)) @test al1 == al2 end end diff --git a/test/bimolerx_test.jl b/test/bimolerx_test.jl index befeefab1..05917bf5a 100644 --- a/test/bimolerx_test.jl +++ b/test/bimolerx_test.jl @@ -55,10 +55,14 @@ jump_to_dep_specs = [[1, 2], [1, 2], [1, 2, 3], [1, 2, 3], [1, 3]] majumps = MassActionJump(rates, reactstoch, netstoch) # average number of proteins in a simulation -function runSSAs(jump_prob; use_stepper = true) +function runSSAs(jump_prob; use_stepper = true, rng = nothing) Psamp = zeros(Int, Nsims) for i in 1:Nsims - sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob) + sol = if use_stepper + solve(jump_prob, SSAStepper(); rng) + else + solve(jump_prob; rng) + end Psamp[i] = sol[1, end] end mean(Psamp) @@ -72,8 +76,8 @@ if doplot for alg in SSAalgs local jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - local sol = solve(jump_prob, SSAStepper()) + jumptovars_map = jump_to_dep_specs) + local sol = solve(jump_prob, SSAStepper(); rng) local plothand = plot(sol, seriestype = :steppost, reuse = false) display(plothand) end @@ -84,15 +88,15 @@ if dotestmean for (i, alg) in enumerate(SSAalgs) local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - means = runSSAs(jump_prob) + jumptovars_map = jump_to_dep_specs) + means = runSSAs(jump_prob; rng) relerr = abs(means - expected_avg) / expected_avg doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means, ", rel err = ", relerr) @test abs(means - expected_avg) < reltol * expected_avg # test not specifying SSAStepper - means = runSSAs(jump_prob; use_stepper = false) + means = runSSAs(jump_prob; use_stepper = false, rng) relerr = abs(means - expected_avg) / expected_avg @test abs(means - expected_avg) < reltol * expected_avg end @@ -107,8 +111,8 @@ if dotestmean jset = JumpSet((), (), nothing, majump_vec) jump_prob = JumpProblem(prob, Direct(), jset, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - meanval = runSSAs(jump_prob) + jumptovars_map = jump_to_dep_specs) + meanval = runSSAs(jump_prob; rng) relerr = abs(meanval - expected_avg) / expected_avg if doprintmeans println("Using individual MassActionJumps; Mean from method: ", typeof(Direct()), diff --git a/test/bracketing.jl b/test/bracketing.jl index 7a4776da4..182689432 100644 --- a/test/bracketing.jl +++ b/test/bracketing.jl @@ -49,7 +49,7 @@ t = 0.0 ### Aggregator ### mutable struct DummyAggregator{T, M, R, BD} <: - JP.AbstractSSAJumpAggregator{T, M, R, Nothing, Nothing} + JP.AbstractSSAJumpAggregator{T, M, R, Nothing} ulow::Vector{Int} uhigh::Vector{Int} cur_rate_low::Vector{T} diff --git a/test/callbacks.jl b/test/callbacks.jl index 6ac5f9547..ce45d0069 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -25,8 +25,8 @@ rng = StableRNG(12345) affect_cb!(integrator) = (cb_called[] = true) cb = ContinuousCallback(condition, affect_cb!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb) + sol = solve(jprob, Tsit5(); rng) @test cb_called[] @test sol.t[end] ≈ 10.0 @@ -37,8 +37,8 @@ rng = StableRNG(12345) affect_dcb!(integrator) = (dcb_called[] += 1) dcb = DiscreteCallback(condition_d, affect_dcb!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = dcb) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = dcb) + sol = solve(jprob, Tsit5(); rng) @test dcb_called[] > 0 # Should have fired multiple times @@ -47,8 +47,8 @@ rng = StableRNG(12345) affect_term!(integrator) = terminate!(integrator) cb_term = ContinuousCallback(condition_term, affect_term!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb_term) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb_term) + sol = solve(jprob, Tsit5(); rng) @test sol.t[end] ≈ 3.0 # Should terminate at t=3 @@ -57,8 +57,8 @@ rng = StableRNG(12345) affect_mod!(integrator) = (integrator.u[1] *= 2.0) cb_mod = ContinuousCallback(condition_mod, affect_mod!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb_mod) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb_mod) + sol = solve(jprob, Tsit5(); rng) # Check that state was modified at t=5 idx = findfirst(t -> t >= 5.0, sol.t) @@ -95,8 +95,8 @@ end # Test 1: Both callbacks should fire (default merge_callbacks = true) cb1_count[] = 0 cb2_count[] = 0 - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb1) - sol = solve(jprob, Tsit5(); callback = cb2) + jprob = JumpProblem(prob, Direct(), jump; callback = cb1) + sol = solve(jprob, Tsit5(); callback = cb2, rng) @test cb1_count[] > 0 @test cb2_count[] > 0 @@ -110,8 +110,8 @@ end # Test 2: Only solve callback should fire (merge_callbacks = false) cb1_count[] = 0 cb2_count[] = 0 - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb1) - sol = solve(jprob, Tsit5(); callback = cb2, merge_callbacks = false) + jprob = JumpProblem(prob, Direct(), jump; callback = cb1) + sol = solve(jprob, Tsit5(); callback = cb2, merge_callbacks = false, rng) @test cb1_count[] == 0 # Should not fire @test cb2_count[] == 1 # Should fire exactly once @@ -139,8 +139,8 @@ end cb = ContinuousCallback(condition, affect_cb!) # Callback in JumpProblem constructor - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb) - integrator = init(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb) + integrator = init(jprob, Tsit5(); rng) solve!(integrator) @test cb_called[] @@ -175,8 +175,8 @@ end # Create CallbackSet with both types cbset = CallbackSet(ccb, dcb) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cbset) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cbset) + sol = solve(jprob, Tsit5(); rng) @test ccb_called[] @test dcb_called[] > 0 @@ -202,8 +202,8 @@ end affect_term!(integrator) = (cb_called[] = true; terminate!(integrator)) dcb_term = DiscreteCallback(condition_term, affect_term!) - jprob = JumpProblem(dprob, Direct(), jump1; rng, callback = dcb_term) - sol = solve(jprob, SSAStepper()) + jprob = JumpProblem(dprob, Direct(), jump1; callback = dcb_term) + sol = solve(jprob, SSAStepper(); rng) @test cb_called[] # Should have fired @test sol.u[end][1] >= 5 # Should have reached threshold @@ -215,8 +215,8 @@ end affect_count!(integrator) = (dcb_counter[] += 1) dcb_count = DiscreteCallback(condition_count, affect_count!) - jprob2 = JumpProblem(dprob, Direct(), jump1; rng) - sol2 = solve(jprob2, SSAStepper(); callback = dcb_count) + jprob2 = JumpProblem(dprob, Direct(), jump1) + sol2 = solve(jprob2, SSAStepper(); callback = dcb_count, rng) @test dcb_counter[] > 0 # Should have fired at least once @@ -232,8 +232,8 @@ end affect_cb2!(integrator) = (cb2_count[] += 1) dcb2 = DiscreteCallback(condition2, affect_cb2!) - jprob3 = JumpProblem(dprob, Direct(), jump1; rng, callback = dcb1) - sol3 = solve(jprob3, SSAStepper(); callback = dcb2) + jprob3 = JumpProblem(dprob, Direct(), jump1; callback = dcb1) + sol3 = solve(jprob3, SSAStepper(); callback = dcb2, rng) @test cb1_count[] > 0 # First callback should fire @test cb2_count[] > 0 # Second callback should fire @@ -257,8 +257,8 @@ end affect_cb4!(integrator) = (cb4_called[] = true) dcb4 = DiscreteCallback(condition4, affect_cb4!) - jprob4 = JumpProblem(dprob, Direct(), jump1; rng, callback = dcb3) - sol4 = solve(jprob4, SSAStepper(); callback = dcb4, merge_callbacks = false) + jprob4 = JumpProblem(dprob, Direct(), jump1; callback = dcb3) + sol4 = solve(jprob4, SSAStepper(); callback = dcb4, merge_callbacks = false, rng) @test !cb3_called[] # First callback should NOT fire @test cb4_called[] # Second callback should fire @@ -288,8 +288,8 @@ end cb = ContinuousCallback(condition, affect_cb!) # This was broken in v9.17.0 - callback wouldn't fire - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb) + sol = solve(jprob, Tsit5(); rng) @test cb_called[] @test sol.t[end] ≈ 0.5 # Should terminate at 0.5, not run to 1.0 @@ -323,8 +323,8 @@ end affect_d!(integrator) = (dcb_called[] += 1) dcb = DiscreteCallback(condition_d, affect_d!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = ccb) - sol = solve(jprob, Tsit5(); callback = dcb) + jprob = JumpProblem(prob, Direct(), jump; callback = ccb) + sol = solve(jprob, Tsit5(); callback = dcb, rng) @test ccb_called[] # Continuous callback should fire @test dcb_called[] > 0 # Discrete callback should fire multiple times @@ -357,8 +357,8 @@ end affect_c!(integrator) = (ccb_called[] = true) ccb = ContinuousCallback(condition_c, affect_c!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = dcb) - sol = solve(jprob, Tsit5(); callback = ccb) + jprob = JumpProblem(prob, Direct(), jump; callback = dcb) + sol = solve(jprob, Tsit5(); callback = ccb, rng) @test dcb_called[] > 0 # Discrete callback should fire @test ccb_called[] # Continuous callback should fire @@ -398,8 +398,8 @@ end affect3!(integrator) = (cb3_called[] = true) ccb2 = ContinuousCallback(condition3, affect3!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cbset) - sol = solve(jprob, Tsit5(); callback = ccb2) + jprob = JumpProblem(prob, Direct(), jump; callback = cbset) + sol = solve(jprob, Tsit5(); callback = ccb2, rng) @test cb1_called[] # First continuous callback should fire @test cb2_called[] > 0 # Discrete callback should fire @@ -421,12 +421,12 @@ end affect_cb!(integrator) = nothing ccb = ContinuousCallback(condition, affect_cb!) - jprob_ccb = JumpProblem(dprob, Direct(), jump; rng, callback = ccb) - @test_throws ErrorException solve(jprob_ccb, SSAStepper()) + jprob_ccb = JumpProblem(dprob, Direct(), jump; callback = ccb) + @test_throws ErrorException solve(jprob_ccb, SSAStepper(); rng) # Test 2: ContinuousCallback passed to solve should error - jprob = JumpProblem(dprob, Direct(), jump; rng) - @test_throws ErrorException solve(jprob, SSAStepper(); callback = ccb) + jprob = JumpProblem(dprob, Direct(), jump) + @test_throws ErrorException solve(jprob, SSAStepper(); callback = ccb, rng) # Test 3: CallbackSet with continuous callbacks passed to JumpProblem should error on solve condition_d(u, t, integrator) = true @@ -434,19 +434,19 @@ end dcb = DiscreteCallback(condition_d, affect_dcb!) cbset_with_continuous = CallbackSet(ccb, dcb) - jprob_cbset = JumpProblem(dprob, Direct(), jump; rng, callback = cbset_with_continuous) - @test_throws ErrorException solve(jprob_cbset, SSAStepper()) + jprob_cbset = JumpProblem(dprob, Direct(), jump; callback = cbset_with_continuous) + @test_throws ErrorException solve(jprob_cbset, SSAStepper(); rng) # Test 4: CallbackSet with continuous callbacks passed to solve should error - @test_throws ErrorException solve(jprob, SSAStepper(); callback = cbset_with_continuous) + @test_throws ErrorException solve(jprob, SSAStepper(); callback = cbset_with_continuous, rng) # Test 5: CallbackSet with multiple continuous callbacks should error with correct count ccb2 = ContinuousCallback(condition, affect_cb!) cbset_multi = CallbackSet(ccb, ccb2, dcb) - jprob_multi = JumpProblem(dprob, Direct(), jump; rng, callback = cbset_multi) + jprob_multi = JumpProblem(dprob, Direct(), jump; callback = cbset_multi) err = try - solve(jprob_multi, SSAStepper()) + solve(jprob_multi, SSAStepper(); rng) nothing catch e e @@ -457,18 +457,18 @@ end # Test 6: DiscreteCallbacks should work fine (no error) dcb_only = DiscreteCallback(condition_d, affect_dcb!) - jprob_dcb = JumpProblem(dprob, Direct(), jump; rng, callback = dcb_only) - sol = solve(jprob_dcb, SSAStepper()) + jprob_dcb = JumpProblem(dprob, Direct(), jump; callback = dcb_only) + sol = solve(jprob_dcb, SSAStepper(); rng) @test sol.retcode == ReturnCode.Success # Test 7: CallbackSet with only discrete callbacks should work dcb2 = DiscreteCallback(condition_d, affect_dcb!) cbset_discrete = CallbackSet(dcb_only, dcb2) - jprob_dcb2 = JumpProblem(dprob, Direct(), jump; rng, callback = cbset_discrete) - sol2 = solve(jprob_dcb2, SSAStepper()) + jprob_dcb2 = JumpProblem(dprob, Direct(), jump; callback = cbset_discrete) + sol2 = solve(jprob_dcb2, SSAStepper(); rng) @test sol2.retcode == ReturnCode.Success # Test 8: Error should also be thrown with init - @test_throws ErrorException init(jprob_ccb, SSAStepper()) - @test_throws ErrorException init(jprob, SSAStepper(); callback = ccb) + @test_throws ErrorException init(jprob_ccb, SSAStepper(); rng) + @test_throws ErrorException init(jprob, SSAStepper(); callback = ccb, rng) end diff --git a/test/constant_rate.jl b/test/constant_rate.jl index 86c237c06..635be26f8 100644 --- a/test/constant_rate.jl +++ b/test/constant_rate.jl @@ -16,37 +16,37 @@ end jump2 = ConstantRateJump(rate, affect!) prob = DiscreteProblem(1.0, (0.0, 3.0)) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump) -sol = solve(jump_prob, FunctionMap()) +sol = solve(jump_prob, FunctionMap(); rng) # using Plots; plot(sol) prob = DiscreteProblem(10.0, (0.0, 3.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump, jump2) -sol = solve(jump_prob, FunctionMap()) +sol = solve(jump_prob, FunctionMap(); rng) # plot(sol) nums = Int[] @time for i in 1:10000 - local jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) - local sol = solve(jump_prob, FunctionMap()) + local jump_prob = JumpProblem(prob, Direct(), jump, jump2) + local sol = solve(jump_prob, FunctionMap(); rng) push!(nums, sol.u[end]) end @test mean(nums) - 45 < 1 prob = DiscreteProblem(1.0, (0.0, 3.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump, jump2) -sol = solve(jump_prob, FunctionMap()) +sol = solve(jump_prob, FunctionMap(); rng) nums = Int[] @time for i in 1:10000 - local jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) - local sol = solve(jump_prob, FunctionMap()) + local jump_prob = JumpProblem(prob, Direct(), jump, jump2) + local sol = solve(jump_prob, FunctionMap(); rng) push!(nums, sol.u[2]) end diff --git a/test/degenerate_rx_cases.jl b/test/degenerate_rx_cases.jl index 3b211decd..55b2985e7 100644 --- a/test/degenerate_rx_cases.jl +++ b/test/degenerate_rx_cases.jl @@ -20,8 +20,8 @@ rs = [[0 => 1]] ns = [[1 => 1]] jump = MassActionJump(rate, rs, ns) prob = DiscreteProblem([100], (0.0, 100.0)) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) -sol = solve(jump_prob, SSAStepper()) +jump_prob = JumpProblem(prob, Direct(), jump) +sol = solve(jump_prob, SSAStepper(); rng) if doprint println("mass act jump using vectors of data: last val = ", sol[end, end]) end @@ -35,8 +35,8 @@ rate = 2.0 rs = [0 => 3] # stoich power should be ignored ns = [1 => 1] jump = MassActionJump(rate, rs, ns) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) -sol = solve(jump_prob, SSAStepper()) +jump_prob = JumpProblem(prob, Direct(), jump) +sol = solve(jump_prob, SSAStepper(); rng) if doprint println("mass act jump using scalar data: last val = ", sol[end, end]) end @@ -51,8 +51,8 @@ rs = [Vector{Pair{Int, Int}}()] ns = [[1 => 1]] jump = MassActionJump(rate, rs, ns) prob = DiscreteProblem([100], (0.0, 100.0)) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) -sol = solve(jump_prob, SSAStepper()) +jump_prob = JumpProblem(prob, Direct(), jump) +sol = solve(jump_prob, SSAStepper(); rng) if doprint println("mass act jump using vector of Pair{Int,Int}: last val = ", sol[end, end]) end @@ -66,8 +66,8 @@ rate = 2.0 rs = Vector{Pair{Int, Int}}() ns = [1 => 1] jump = MassActionJump(rate, rs, ns) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) -sol = solve(jump_prob, SSAStepper()) +jump_prob = JumpProblem(prob, Direct(), jump) +sol = solve(jump_prob, SSAStepper(); rng) if doprint println("mass act jump using scalar Pair{Int,Int}: last val = ", sol[end, end]) end @@ -100,8 +100,8 @@ jump_to_dep_specs = [[1], [1]] namedpars = (dep_graph = dep_graph, vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs) for method in methods - local jump_prob = JumpProblem(prob, method, jump, jump2; rng = rng, namedpars...) - local sol = solve(jump_prob, SSAStepper()) + local jump_prob = JumpProblem(prob, method, jump, jump2; namedpars...) + local sol = solve(jump_prob, SSAStepper(); rng) if doplot plot!(plothand2, sol, label = ("A <-> 0, " * string(method))) diff --git a/test/ensemble_problems.jl b/test/ensemble_problems.jl index 75f82618c..da11b321c 100644 --- a/test/ensemble_problems.jl +++ b/test/ensemble_problems.jl @@ -6,30 +6,30 @@ using StableRNGs, Random # ========================================================================== # Constant-rate birth-death for SSAStepper / ODE-coupled tests -function make_ssa_jump_prob(; rng = StableRNG(12345)) +function make_ssa_jump_prob() j1 = ConstantRateJump((u, p, t) -> 10.0, integrator -> (integrator.u[1] += 1)) j2 = ConstantRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1)) dprob = DiscreteProblem([10], (0.0, 20.0)) - JumpProblem(dprob, Direct(), j1, j2; rng) + JumpProblem(dprob, Direct(), j1, j2) end # ODE + variable-rate jump -function make_vr_jump_prob(agg; rng = StableRNG(12345)) +function make_vr_jump_prob(agg) f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) oprob = ODEProblem(f!, [100.0], (0.0, 10.0)) vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1.0)) - JumpProblem(oprob, Direct(), vrj; vr_aggregator = agg, rng) + JumpProblem(oprob, Direct(), vrj; vr_aggregator = agg) end # SDE + variable-rate jump -function make_sde_vr_jump_prob(agg; rng = StableRNG(12345)) +function make_sde_vr_jump_prob(agg) f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) g!(du, u, p, t) = (du[1] = 0.1 * u[1]; nothing) sprob = SDEProblem(f!, g!, [100.0], (0.0, 10.0)) vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1.0)) - JumpProblem(sprob, Direct(), vrj; vr_aggregator = agg, rng) + JumpProblem(sprob, Direct(), vrj; vr_aggregator = agg) end # Helpers @@ -44,7 +44,7 @@ first_jump_time(traj) = traj.t[findfirst(>(traj.t[1]), traj.t)] @testset "SSAStepper" begin jprob = make_ssa_jump_prob() sol = solve(EnsembleProblem(jprob), SSAStepper(), EnsembleSerial(); - trajectories = 3) + trajectories = 3, rng = StableRNG(12345)) times = [first_jump_time(sol.u[i]) for i in 1:3] @test allunique(times) end @@ -52,11 +52,9 @@ first_jump_time(traj) = traj.t[findfirst(>(traj.t[1]), traj.t)] @testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jprob = make_vr_jump_prob(agg) sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleSerial(); - trajectories = 3) + trajectories = 3, rng = StableRNG(12345)) times = [first_jump_time(sol.u[i]) for i in 1:3] @test allunique(times) - finals = [sol.u[i].u[end][1] for i in 1:3] - @test allunique(finals) end # EM() uses a fixed time grid so jump event times aren't directly visible @@ -77,109 +75,77 @@ end @testset "Sequential solves: different RNG streams" begin @testset "SSAStepper" begin jprob = make_ssa_jump_prob() - times = [first_jump_time(solve(jprob, SSAStepper())) for _ in 1:3] + rng = StableRNG(12345) + times = [first_jump_time(solve(jprob, SSAStepper(); rng)) for _ in 1:3] @test allunique(times) end @testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jprob = make_vr_jump_prob(agg) - sols = [solve(jprob, Tsit5()) for _ in 1:3] + rng = StableRNG(12345) + sols = [solve(jprob, Tsit5(); rng) for _ in 1:3] times = [first_jump_time(s) for s in sols] @test allunique(times) - finals = [s.u[end][1] for s in sols] - @test allunique(finals) end end # ========================================================================== -# 3. Threaded ensemble: no data race on the shared JumpProblem -# -# The ODE/SSA path through __jump_init receives seed=nothing from -# SciMLBase, so deepcopy'd problems on non-main threads start with -# identical RNG states. We only assert completion here — uniqueness -# requires explicit seeding (tested in section 4 below). -# -# The SDE path goes through StochasticDiffEq's __init which generates -# per-trajectory seeds, so we can additionally verify uniqueness there. +# 3. rng kwarg reproducibility: same rng seed → identical trajectory, +# different rng seeds → different trajectories # ========================================================================== -@testset "EnsembleThreads: no data race" begin - @testset "SSAStepper" begin +@testset "rng kwarg reproducibility" begin + @testset "SSAStepper: same seed → same trajectory" begin jprob = make_ssa_jump_prob() - sol = solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); - trajectories = 4) - @test length(sol) == 4 + sol1 = solve(jprob, SSAStepper(); rng = StableRNG(42)) + sol2 = solve(jprob, SSAStepper(); rng = StableRNG(42)) + @test sol1.t == sol2.t + @test sol1.u == sol2.u end - @testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) + @testset "SSAStepper: different seeds → different trajectories" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); rng = StableRNG(100)) + sol2 = solve(jprob, SSAStepper(); rng = StableRNG(200)) + sol3 = solve(jprob, SSAStepper(); rng = StableRNG(300)) + times = [first_jump_time(sol1), first_jump_time(sol2), first_jump_time(sol3)] + @test allunique(times) + end + + @testset "ODE + VR ($agg): same seed → same trajectory" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jprob = make_vr_jump_prob(agg) - # This path previously had a data race: resetted_jump_problem called - # randexp!(_jump_prob.rng, ...) on the shared original problem. - sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleThreads(); - trajectories = 4, save_everystep = false) - @test length(sol) == 4 + sol1 = solve(jprob, Tsit5(); rng = StableRNG(42)) + sol2 = solve(jprob, Tsit5(); rng = StableRNG(42)) + @test sol1.t ≈ sol2.t + @test sol1.u[end] ≈ sol2.u[end] end - @testset "SDE + VR (VR_FRM): unique trajectories" begin - jprob = make_sde_vr_jump_prob(VR_FRM()) - # StochasticDiffEq generates per-trajectory seeds and passes them to - # resetted_jump_problem, so trajectories should be distinct. - sol = solve(EnsembleProblem(jprob), EM(), EnsembleThreads(); - trajectories = 4, dt = 0.01, save_everystep = false) - @test length(sol) == 4 - finals = [sol.u[i].u[end][1] for i in 1:4] - @test length(unique(finals)) > 1 + @testset "ODE + VR ($agg): different seeds → different trajectories" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) + jprob = make_vr_jump_prob(agg) + sols = [solve(jprob, Tsit5(); rng = StableRNG(s)) for s in (100, 200, 300)] + times = [first_jump_time(s) for s in sols] + @test allunique(times) end end # ========================================================================== -# 4. Seed-based stream independence: resetted_jump_problem and -# reset_jump_problem! produce distinct RNG streams for different seeds -# -# This tests the mechanism that EnsembleThreads relies on (when seeds are -# provided by the caller, e.g. StochasticDiffEq) to get independent streams -# on different threads. +# 4. has_rng / get_rng / set_rng! interface on SSAIntegrator # ========================================================================== -@testset "resetted_jump_problem: different seeds → different streams" begin +@testset "SSAIntegrator RNG interface" begin jprob = make_ssa_jump_prob() - seeds = UInt64[100, 200, 300] + integrator = init(jprob, SSAStepper(); rng = StableRNG(42)) - # Each seed should produce a distinct aggregator RNG state - rngs = map(seeds) do s - jp = JumpProcesses.resetted_jump_problem(jprob, s) - jp.jump_callback.discrete_callbacks[1].condition.rng - end - draws = [rand(rng) for rng in rngs] - @test allunique(draws) - - # Same seed should be deterministic - jp1 = JumpProcesses.resetted_jump_problem(jprob, UInt64(42)) - jp2 = JumpProcesses.resetted_jump_problem(jprob, UInt64(42)) - rng1 = jp1.jump_callback.discrete_callbacks[1].condition.rng - rng2 = jp2.jump_callback.discrete_callbacks[1].condition.rng - @test rand(rng1) == rand(rng2) -end + @test SciMLBase.has_rng(integrator) + rng = SciMLBase.get_rng(integrator) + @test rng isa StableRNG -@testset "reset_jump_problem!: different seeds → different streams" begin - seeds = UInt64[100, 200, 300] - draws = map(seeds) do s - jp = make_ssa_jump_prob() - JumpProcesses.reset_jump_problem!(jp, s) - rand(jp.jump_callback.discrete_callbacks[1].condition.rng) - end - @test allunique(draws) -end + new_rng = StableRNG(99) + SciMLBase.set_rng!(integrator, new_rng) + @test SciMLBase.get_rng(integrator) === new_rng -@testset "_derive_jump_seed: decorrelates from input seed" begin - seed = UInt64(12345) - derived = JumpProcesses._derive_jump_seed(seed) - # Derived seed should differ from input - @test derived != seed - # Should be deterministic - @test derived == JumpProcesses._derive_jump_seed(seed) - # Different inputs → different outputs - @test JumpProcesses._derive_jump_seed(UInt64(1)) != JumpProcesses._derive_jump_seed(UInt64(2)) + # mismatched RNG type should throw + @test_throws ArgumentError SciMLBase.set_rng!(integrator, Random.Xoshiro(123)) end # ========================================================================== @@ -187,13 +153,42 @@ end # # For VR_FRM, each trajectory's first jump time is determined by the initial # jump_u threshold (set to -randexp() by the VR_FRMEventCallback initialize). -# Distinct thresholds → distinct first event times. +# We verify both the thresholds (via init) and the resulting event times. # ========================================================================== @testset "VR_FRM: jump_u thresholds unique per trajectory (EnsembleSerial)" begin jprob = make_vr_jump_prob(VR_FRM()) + + # Check jump_u thresholds directly via init (callback sets them during initialization) + rng = StableRNG(12345) + thresholds = [begin + integrator = init(jprob, Tsit5(); rng) + integrator.u.jump_u[1] + end for _ in 1:3] + @test allunique(thresholds) + + # From a full ensemble solve, check both first event times and the + # initial jump_u thresholds (u[2] is the initialization save where + # jump_u has been set to -randexp() by the callback). sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleSerial(); - trajectories = 3) + trajectories = 3, rng = StableRNG(12345)) event_times = [first_jump_time(sol.u[i]) for i in 1:3] @test allunique(event_times) + init_thresholds = [sol.u[i].u[2].jump_u[1] for i in 1:3] + @test allunique(init_thresholds) +end + +# ========================================================================== +# 6. JumpProblem rng kwarg throws ArgumentError +# ========================================================================== + +@testset "JumpProblem rng kwarg throws ArgumentError" begin + j1 = ConstantRateJump((u, p, t) -> 1.0, integrator -> (integrator.u[1] += 1)) + dprob = DiscreteProblem([10], (0.0, 10.0)) + j1_local = ConstantRateJump((u, p, t) -> 1.0, integrator -> (integrator.u[1] += 1)) + dprob_local = DiscreteProblem([10], (0.0, 10.0)) + jprob = JumpProblem(dprob_local, Direct(), j1_local) + @test_throws ArgumentError JumpProblem(dprob_local, Direct(), j1_local; rng = StableRNG(1)) + sol = solve(jprob, SSAStepper(); rng = StableRNG(1)) + @test sol.retcode == ReturnCode.Success end diff --git a/test/ensemble_uniqueness.jl b/test/ensemble_uniqueness.jl index adb97df30..9861231cc 100644 --- a/test/ensemble_uniqueness.jl +++ b/test/ensemble_uniqueness.jl @@ -7,35 +7,19 @@ u0 = [0] dprob = DiscreteProblem(u0, (0.0, 100.0)) -# For EnsembleProblems, use prob_func to create a new JumpProblem with unique RNG per trajectory. -# This ensures different trajectories while maintaining reproducibility. -# Generate seeds from a seeded RNG for reproducibility of ensemble results. -function make_seeded_prob_func(dprob, aggregator, jumps, base_rng) - return function prob_func(prob, i, repeat) - seed = rand(base_rng, UInt64) - JumpProblem(dprob, aggregator, jumps...; rng = StableRNG(seed)) - end -end - -# Test with FunctionMap - use prob_func to create JumpProblems with unique RNGs -rng1 = StableRNG(12345) -jump_prob = JumpProblem(dprob, Direct(), j1, j2; rng = rng1) -ensemble_rng = StableRNG(99999) # separate RNG for generating trajectory seeds -ensemble_prob = EnsembleProblem(jump_prob; - prob_func = make_seeded_prob_func(dprob, Direct(), (j1, j2), ensemble_rng)) -sol = solve(ensemble_prob, FunctionMap(), trajectories = 3) +# Test with FunctionMap - pass rng to solve so trajectories get unique sequences +jump_prob = JumpProblem(dprob, Direct(), j1, j2) +ensemble_prob = EnsembleProblem(jump_prob) +sol = solve(ensemble_prob, FunctionMap(), trajectories = 3; rng = StableRNG(12345)) @test Array(sol.u[1]) !== Array(sol.u[2]) @test Array(sol.u[1]) !== Array(sol.u[3]) @test Array(sol.u[2]) !== Array(sol.u[3]) @test eltype(sol.u[1].u[1]) == Int -# Test with SSAStepper - use prob_func to create JumpProblems with unique RNGs -rng2 = StableRNG(12345) -jump_prob = JumpProblem(dprob, Direct(), j1, j2; rng = rng2) -ensemble_rng2 = StableRNG(99999) # separate RNG for generating trajectory seeds -ensemble_prob2 = EnsembleProblem(jump_prob; - prob_func = make_seeded_prob_func(dprob, Direct(), (j1, j2), ensemble_rng2)) -sol = solve(ensemble_prob2, SSAStepper(), trajectories = 3) +# Test with SSAStepper - pass rng to solve so trajectories get unique sequences +jump_prob = JumpProblem(dprob, Direct(), j1, j2) +ensemble_prob2 = EnsembleProblem(jump_prob) +sol = solve(ensemble_prob2, SSAStepper(), trajectories = 3; rng = StableRNG(12345)) @test Array(sol.u[1]) !== Array(sol.u[2]) @test Array(sol.u[1]) !== Array(sol.u[3]) @test Array(sol.u[2]) !== Array(sol.u[3]) diff --git a/test/extended_jump_array_remake.jl b/test/extended_jump_array_remake.jl index 7f2168d5e..7a32c64d7 100644 --- a/test/extended_jump_array_remake.jl +++ b/test/extended_jump_array_remake.jl @@ -5,8 +5,6 @@ using JumpProcesses, OrdinaryDiffEq, Test, SymbolicIndexingInterface using StableRNGs @testset "remake JumpProblem with VariableRateJumps (ExtendedJumpArray)" begin - rng = StableRNG(12345) - # Setup: Create an ODEProblem with SymbolCache for symbolic indexing f(du, u, p, t) = (du .= 0; nothing) g = ODEFunction(f; sys = SymbolCache([:X, :Y], [:k1, :k2], :t)) @@ -17,20 +15,25 @@ using StableRNGs vr_affect!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1) vrj = VariableRateJump(vr_rate, vr_affect!) - jprob = JumpProblem(oprob, vrj; rng) + jprob = JumpProblem(oprob, vrj) # Verify we have ExtendedJumpArray @test jprob.prob.u0 isa ExtendedJumpArray @test jprob.prob.u0.u == [10.0, 5.0] - @testset "remake with numeric Vector{Float64}" begin - original_jump_u = copy(jprob.prob.u0.jump_u) + # Solve original problem and capture jump_u after initialization + orig_integrator = init(jprob, Tsit5(); rng = StableRNG(42)) + orig_jump_u = copy(orig_integrator.u.jump_u) + @testset "remake with numeric Vector{Float64}" begin prob2 = remake(jprob; u0 = [20.0, 10.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u == [20.0, 10.0] - # jump_u should be resampled (different from original) - @test prob2.prob.u0.jump_u != original_jump_u + @test all(iszero, prob2.prob.u0.jump_u) + # After init, callback sets fresh jump_u thresholds (different RNG seed) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) + @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake with ExtendedJumpArray (no resample)" begin @@ -47,34 +50,34 @@ using StableRNGs end @testset "remake with Symbol pairs" begin - original_jump_u = copy(jprob.prob.u0.jump_u) - - # This was the FAILING case - should work after fix prob2 = remake(jprob; u0 = [:X => 25.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u[1] == 25.0 - # jump_u should be resampled - @test prob2.prob.u0.jump_u != original_jump_u + @test all(iszero, prob2.prob.u0.jump_u) + # After init, callback sets fresh jump_u thresholds (different RNG seed) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) + @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake with multiple Symbol pairs" begin - original_jump_u = copy(jprob.prob.u0.jump_u) - prob2 = remake(jprob; u0 = [:X => 35.0, :Y => 15.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u == [35.0, 15.0] - # jump_u should be resampled - @test prob2.prob.u0.jump_u != original_jump_u + @test all(iszero, prob2.prob.u0.jump_u) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) + @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake with Dict" begin - original_jump_u = copy(jprob.prob.u0.jump_u) - prob2 = remake(jprob; u0 = Dict(:X => 40.0)) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u[1] == 40.0 - # jump_u should be resampled - @test prob2.prob.u0.jump_u != original_jump_u + @test all(iszero, prob2.prob.u0.jump_u) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) + @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake with parameters only (u0 unchanged)" begin @@ -88,25 +91,31 @@ using StableRNGs end @testset "remake with both u0 and p" begin - original_jump_u = copy(jprob.prob.u0.jump_u) - prob2 = remake(jprob; u0 = [:X => 50.0], p = [:k1 => 3.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u[1] == 50.0 @test prob2.prob.p[1] == 3.0 - # jump_u should be resampled - @test prob2.prob.u0.jump_u != original_jump_u + @test all(iszero, prob2.prob.u0.jump_u) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) + @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake preserves problem solvability" begin - # Ensure remade problems can actually be solved + # Solve original, then remake and solve again — jump_u should differ + sol1 = solve(jprob, Tsit5(); rng = StableRNG(42)) + @test SciMLBase.successful_retcode(sol1) + prob2 = remake(jprob; u0 = [5.0, 2.0]) - sol = solve(prob2, Tsit5()) - @test SciMLBase.successful_retcode(sol) + sol2 = solve(prob2, Tsit5(); rng = StableRNG(99)) + @test SciMLBase.successful_retcode(sol2) + # Different RNG seeds → different jump_u thresholds after init + @test sol1.u[2].jump_u != sol2.u[2].jump_u - # With symbolic map (after fix) + # With symbolic map prob3 = remake(jprob; u0 = [:X => 8.0]) - sol3 = solve(prob3, Tsit5()) + sol3 = solve(prob3, Tsit5(); rng = StableRNG(77)) @test SciMLBase.successful_retcode(sol3) + @test sol1.u[2].jump_u != sol3.u[2].jump_u end end diff --git a/test/extinction_test.jl b/test/extinction_test.jl index 880254ba0..468c96743 100644 --- a/test/extinction_test.jl +++ b/test/extinction_test.jl @@ -21,9 +21,8 @@ algs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator()) for n in 1:Nsims for ssa in algs - local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false), - rng = rng) - local sol = solve(jprob, SSAStepper()) + local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false)) + local sol = solve(jprob, SSAStepper(); rng) @test sol[1, end] == 0 @test sol.t[end] < Inf end @@ -33,9 +32,8 @@ u0 = SA[10] dprob = DiscreteProblem(u0, (0.0, 100.0), rates) for ssa in algs - local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false), - rng = rng) - local sol = solve(jprob, SSAStepper(), saveat = 100.0) + local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false)) + local sol = solve(jprob, SSAStepper(); saveat = 100.0, rng) @test sol[1, end] == 0 @test sol.t[end] < Inf end @@ -57,8 +55,8 @@ end et = ExtinctionTest() cb = DiscreteCallback(et, et, save_positions = (false, false)) dprob = DiscreteProblem(u0, (0.0, 1000.0), rates) -jprob = JumpProblem(dprob, Direct(), majump; save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), callback = cb, save_end = false) +jprob = JumpProblem(dprob, Direct(), majump; save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); callback = cb, save_end = false, rng) @test sol.t[end] < 1000.0 # test terminate @@ -73,8 +71,8 @@ end cb = DiscreteCallback(extinction_condition2, extinction_affect!2, save_positions = (false, false)) dprob = DiscreteProblem(u0, (0.0, 1000.0), rates) -jprob = JumpProblem(dprob, majump; save_positions = (false, false), rng) -sol = solve(jprob; callback = cb, save_end = false) +jprob = JumpProblem(dprob, majump; save_positions = (false, false)) +sol = solve(jprob; callback = cb, save_end = false, rng) @test sol[1, end] == 1 @test sol.retcode == ReturnCode.Terminated @test sol.t[end] < 1000.0 diff --git a/test/fp_unknowns.jl b/test/fp_unknowns.jl index 6a8b5f6d6..8e1ecd9d7 100644 --- a/test/fp_unknowns.jl +++ b/test/fp_unknowns.jl @@ -34,11 +34,11 @@ function test(rng) Xmeans = zeros(length(SSAalgs)) Ymeans = zeros(length(SSAalgs)) for (j, agg) in enumerate(SSAalgs) - jprob = JumpProblem(dprob, agg, maj; save_positions = (false, false), rng, + jprob = JumpProblem(dprob, agg, maj; save_positions = (false, false), vartojumps_map = vtoj, jumptovars_map = jtov, dep_graph = dg, scale_rates = false) for i in 1:Nsims - sol = solve(jprob, SSAStepper()) + sol = solve(jprob, SSAStepper(); rng) Xmeans[j] += sol[1, end] Ymeans[j] += sol[2, end] end diff --git a/test/functionwrappers.jl b/test/functionwrappers.jl index 2f009ead4..0c1bc0a0b 100644 --- a/test/functionwrappers.jl +++ b/test/functionwrappers.jl @@ -12,18 +12,18 @@ let rateinterval = (u, p, t) -> 0.1) prob = DiscreteProblem([0.0], (0.0, 2.0), [1.0]) - jprob = JumpProblem(prob, Coevolve(), jump; dep_graph = [[1]], rng) + jprob = JumpProblem(prob, Coevolve(), jump; dep_graph = [[1]]) agg = jprob.discrete_jump_aggregation @test agg.affects! isa Vector{Any} - integ = init(jprob, SSAStepper()) + integ = init(jprob, SSAStepper(); rng) T = Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{typeof(integ)}}} @test agg.affects! isa T affs = agg.affects! sol_c = solve!(integ) # check the affects vector is unchanged from a second call - integ = init(jprob, SSAStepper()) + integ = init(jprob, SSAStepper(); rng) sol_c = solve!(integ) @test affs === agg.affects! @@ -31,7 +31,7 @@ let terminate_condition(u, t, integrator) = (return u[1] >= 1) terminate_affect!(integrator) = terminate!(integrator) terminate_cb = DiscreteCallback(terminate_condition, terminate_affect!) - integ2 = init(jprob, SSAStepper(); callback = terminate_cb) + integ2 = init(jprob, SSAStepper(); rng, callback = terminate_cb) T2 = Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{typeof(integ2)}}} @test T2 !== T @test agg.affects! isa T2 @@ -42,7 +42,7 @@ let solve!(integ2) # check affs2 is unchanged when solving again now - integ2 = init(jprob, SSAStepper(); callback = terminate_cb) + integ2 = init(jprob, SSAStepper(); rng, callback = terminate_cb) solve!(integ2) @test affs2 === agg.affects! end diff --git a/test/geneexpr_test.jl b/test/geneexpr_test.jl index ea55c82f9..45e9a02b8 100644 --- a/test/geneexpr_test.jl +++ b/test/geneexpr_test.jl @@ -22,19 +22,23 @@ expected_avg = 5.926553750000000e+02 reltol = 0.01 # average number of proteins in a simulation -function runSSAs(jump_prob; use_stepper = true) +function runSSAs(jump_prob; use_stepper = true, rng = nothing) Psamp = zeros(Int, Nsims) for i in 1:Nsims - sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob) + sol = if use_stepper + solve(jump_prob, SSAStepper(); rng) + else + solve(jump_prob; rng) + end Psamp[i] = sol[3, end] end mean(Psamp) end -function runSSAs_ode(vrjprob) +function runSSAs_ode(vrjprob; rng = nothing) Psamp = zeros(Float64, Nsims) tsave = vrjprob.prob.tspan[2] - integrator = init(vrjprob, Tsit5(); saveat = tsave) + integrator = init(vrjprob, Tsit5(); saveat = tsave, rng) solve!(integrator) Psamp[1] = integrator.sol[3, end] for i in 2:Nsims @@ -94,8 +98,8 @@ if doplot for alg in SSAalgs local jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - local sol = solve(jump_prob, SSAStepper()) + jumptovars_map = jump_to_dep_specs) + local sol = solve(jump_prob, SSAStepper(); rng) plot!(plothand, sol.t, sol[3, :], seriestype = :steppost) end display(plothand) @@ -106,8 +110,8 @@ if dotestmean for (i, alg) in enumerate(SSAalgs) local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - means = runSSAs(jump_prob) + jumptovars_map = jump_to_dep_specs) + means = runSSAs(jump_prob; rng) relerr = abs(means - expected_avg) / expected_avg doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means, ", rel err = ", relerr) @@ -118,8 +122,8 @@ if dotestmean let alg = Direct() jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - means = runSSAs(jump_prob; use_stepper = false) + jumptovars_map = jump_to_dep_specs) + means = runSSAs(jump_prob; use_stepper = false, rng) @test abs(means - expected_avg) < reltol * expected_avg end @@ -128,8 +132,8 @@ if dotestmean for alg in (Direct(), RSSA()) jump_probf = JumpProblem(probf, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - means = runSSAs(jump_probf) + jumptovars_map = jump_to_dep_specs) + means = runSSAs(jump_probf; rng) relerr = abs(means - expected_avg) / expected_avg doprintmeans && println("Mean from method (Float64 u0): ", typeof(alg), " is = ", means, ", rel err = ", relerr) @@ -139,11 +143,11 @@ end # no-aggregator tests jump_prob = JumpProblem(prob, majumps; save_positions = (false, false), - vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs, rng) -@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg + vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs) +@test abs(runSSAs(jump_prob; rng) - expected_avg) < reltol * expected_avg -jump_prob = JumpProblem(prob, majumps, save_positions = (false, false), rng = rng) -@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg +jump_prob = JumpProblem(prob, majumps, save_positions = (false, false)) +@test abs(runSSAs(jump_prob; rng) - expected_avg) < reltol * expected_avg # crj/vrj accuracy test # k1, DNA --> mRNA + DNA @@ -187,20 +191,20 @@ let VariableRateJump(r6, a6!, save_positions = (false, false))) prob = DiscreteProblem(u0, (0.0, tf), rates) - crjprob = JumpProblem(prob, crjs; save_positions = (false, false), rng) - @test abs(runSSAs(crjprob) - expected_avg) < reltol * expected_avg + crjprob = JumpProblem(prob, crjs; save_positions = (false, false)) + @test abs(runSSAs(crjprob; rng) - expected_avg) < reltol * expected_avg # vrjs are very slow so test on a shorter time span and compare to the crjs prob = DiscreteProblem(u0, (0.0, tf / 5), rates) - crjprob = JumpProblem(prob, crjs; save_positions = (false, false), rng) - crjmean = runSSAs(crjprob) + crjprob = JumpProblem(prob, crjs; save_positions = (false, false)) + crjmean = runSSAs(crjprob; rng) f(du, u, p, t) = (du .= 0; nothing) oprob = ODEProblem(f, u0f, (0.0, tf / 5), rates) for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) vrjprob = JumpProblem( - oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false), rng) - vrjmean = runSSAs_ode(vrjprob) + oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false)) + vrjmean = runSSAs_ode(vrjprob; rng) @test abs(vrjmean - crjmean) < reltol * crjmean end end diff --git a/test/gpu/regular_jumps.jl b/test/gpu/regular_jumps.jl index 5e60ee2d4..bb136dd18 100644 --- a/test/gpu/regular_jumps.jl +++ b/test/gpu/regular_jumps.jl @@ -1,9 +1,6 @@ using JumpProcesses, DiffEqBase using Test, LinearAlgebra, Statistics using KernelAbstractions, Adapt, CUDA -using StableRNGs -rng = StableRNG(12345) - Nsims = 100_000 # SIR model with influx @@ -72,7 +69,7 @@ let # Create JumpProblem prob_disc = DiscreteProblem(u0, tspan, p) rj = RegularJump(regular_rate, regular_c, 3) - jump_prob = JumpProblem(prob_disc, PureLeaping(), rj; rng = StableRNG(12345)) + jump_prob = JumpProblem(prob_disc, PureLeaping(), rj) sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0) diff --git a/test/hawkes_test.jl b/test/hawkes_test.jl index adf5a83dc..bae9ff2c5 100644 --- a/test/hawkes_test.jl +++ b/test/hawkes_test.jl @@ -65,7 +65,7 @@ function hawkes_problem(p, agg::Coevolve; u = [0.0], tspan = (0.0, 50.0), save_positions = (false, true), g = [[1]], h = [[]], uselrate = true, kwargs...) dprob = DiscreteProblem(u, tspan, p) jumps = hawkes_jump(u, g, h; uselrate) - jprob = JumpProblem(dprob, agg, jumps...; dep_graph = g, save_positions, rng) + jprob = JumpProblem(dprob, agg, jumps...; dep_graph = g, save_positions) return jprob end @@ -78,7 +78,7 @@ function hawkes_problem(p, agg; u = [0.0], tspan = (0.0, 50.0), save_positions = (false, true), g = [[1]], h = [[]], vr_aggregator = VR_FRM(), kwargs...) oprob = ODEProblem(f!, u, tspan, p) jumps = hawkes_jump(u, g, h) - jprob = JumpProblem(oprob, agg, jumps...; vr_aggregator, save_positions, rng) + jprob = JumpProblem(oprob, agg, jumps...; vr_aggregator, save_positions) return jprob end @@ -119,7 +119,7 @@ for (i, alg) in enumerate(algs) sols = Vector{ODESolution}(undef, Nsims) for n in 1:Nsims reset_history!(h) - sols[n] = solve(jump_prob, stepper) + sols[n] = solve(jump_prob, stepper; rng) end if alg isa Coevolve @@ -137,12 +137,12 @@ let alg = Coevolve() for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW()) oprob = ODEProblem(f!, u0, tspan, p) jumps = hawkes_jump(u0, g, h) - jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng) + jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g) @test ((jprob.variable_jumps === nothing) || isempty(jprob.variable_jumps)) sols = Vector{ODESolution}(undef, Nsims) for n in 1:Nsims reset_history!(h) - sols[n] = solve(jprob, Tsit5()) + sols[n] = solve(jprob, Tsit5(); rng) end λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols)) @test isapprox(mean(λs), Eλ; atol = 0.01) @@ -156,12 +156,12 @@ let alg = Coevolve() for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW()) oprob = ODEProblem(f!, u0, tspan, p) jumps = hawkes_jump(u0, g, h) - jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng, use_vrj_bounds = false) + jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, use_vrj_bounds = false) @test length(jprob.variable_jumps) == 1 sols = Vector{ODESolution}(undef, Nsims) for n in 1:Nsims reset_history!(h) - sols[n] = solve(jprob, Tsit5()) + sols[n] = solve(jprob, Tsit5(); rng) end cols = length(u0) diff --git a/test/linearreaction_test.jl b/test/linearreaction_test.jl index 657982a6b..7883883c3 100644 --- a/test/linearreaction_test.jl +++ b/test/linearreaction_test.jl @@ -3,7 +3,7 @@ using DiffEqBase, JumpProcesses, Statistics using Test using StableRNGs -rng = StableRNG(12345) +const rng = StableRNG(12345) # using BenchmarkTools # dobenchmark = true @@ -29,7 +29,7 @@ exactmeanval = exactmean(tf, rates) function runSSAs(jump_prob) Asamp = zeros(Int, Nsims) for i in 1:Nsims - sol = solve(jump_prob, SSAStepper()) + sol = solve(jump_prob, SSAStepper(); rng) Asamp[i] = sol[1, end] end mean(Asamp) @@ -52,7 +52,7 @@ function A_to_B_tuple(N, method) jumps = ((jump for jump in jumpvec)...,) jset = JumpSet((), jumps, nothing, nothing) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -74,7 +74,7 @@ function A_to_B_vec(N, method) # convert jumpvec to tuple to send to JumpProblem... jset = JumpSet((), jumps, nothing, nothing) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -92,7 +92,7 @@ function A_to_B_ma(N, method) majumps = MassActionJump(rates, reactstoch, netstoch) jset = JumpSet((), (), nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -126,7 +126,7 @@ function A_to_B_hybrid(N, method) majumps = MassActionJump(rates[1:switchidx], reactstoch, netstoch) jset = JumpSet((), jumps, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -161,7 +161,7 @@ function A_to_B_hybrid_nojset(N, method) jumps = (constjumps..., majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) jump_prob = JumpProblem(prob, method, jumps...; save_positions = (false, false), - rng, namedpars...) + namedpars...) jump_prob end @@ -190,7 +190,7 @@ function A_to_B_hybrid_vecs(N, method) end jset = JumpSet((), jumpvec, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -220,7 +220,7 @@ function A_to_B_hybrid_vecs_scalars(N, method) end jset = JumpSet((), jumpvec, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -252,7 +252,7 @@ function A_to_B_hybrid_tups_scalars(N, method) jumps = ((maj for maj in majumpsv)..., (jump for jump in jumpvec)...) prob = DiscreteProblem([A0, 0], (0.0, tf)) jump_prob = JumpProblem(prob, method, jumps...; save_positions = (false, false), - rng, namedpars...) + namedpars...) jump_prob end @@ -282,7 +282,7 @@ function A_to_B_hybrid_tups(N, method) jumps = ((jump for jump in jumpvec)...,) jset = JumpSet((), jumps, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob diff --git a/test/longtimes_test.jl b/test/longtimes_test.jl index 7c787b4f3..4a63422e1 100644 --- a/test/longtimes_test.jl +++ b/test/longtimes_test.jl @@ -11,6 +11,6 @@ u0 = [5] tspan = (0.0, 2e6) dt = tspan[2] / 1000 dprob = DiscreteProblem(u0, tspan, p) -jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), saveat = tspan[1]:dt:tspan[2]) +jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); saveat = tspan[1]:dt:tspan[2], rng) @test length(unique(sol.u[(end - 10):end][:])) > 1 diff --git a/test/monte_carlo_test.jl b/test/monte_carlo_test.jl index 42566b3ad..2141013c8 100644 --- a/test/monte_carlo_test.jl +++ b/test/monte_carlo_test.jl @@ -8,28 +8,28 @@ prob = SDEProblem(f, g, [1.0], (0.0, 1.0)) rate = (u, p, t) -> 200.0 affect! = integrator -> (integrator.u[1] = integrator.u[1] / 2) jump = VariableRateJump(rate, affect!, save_positions = (false, true)) -jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM(), rng) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM()) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(); trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false, rng) first_event(traj) = traj.t[findfirst(>(traj.t[1]), traj.t)] @test first_event(sol.u[1]) != first_event(sol.u[2]) != first_event(sol.u[3]) -jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_Direct(), rng) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_Direct()) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(); trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false, rng) @test allunique(sol.u[1].t) -jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_DirectFW(), rng) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_DirectFW()) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(); trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false, rng) @test allunique(sol.u[1].t) jump = ConstantRateJump(rate, affect!) -jump_prob = JumpProblem(prob, Direct(), jump; save_positions = (true, false), rng) +jump_prob = JumpProblem(prob, Direct(), jump; save_positions = (true, false)) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(); trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false, rng) @test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2] diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index f4d009eae..99c3375c7 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -36,11 +36,11 @@ end u0 = [999.0, 10.0, 0.0] # S, I, R tspan = (0.0, 250.0) prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng, save_positions = (false, false)) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; save_positions = (false, false)) # Solve with SSAStepper (save only at t_compare times) sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); - trajectories = Nsims, saveat = t_compare) + trajectories = Nsims, saveat = t_compare, rng) # RegularJump formulation for SimpleTauLeaping regular_rate = (out, u, p, t) -> begin @@ -55,22 +55,22 @@ end dc[3] = counts[2] end rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj) # Solve with SimpleTauLeaping (save only at t_compare times) sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); - trajectories = Nsims, dt = 0.1, saveat = t_compare) + trajectories = Nsims, dt = 0.1, saveat = t_compare, rng) # MassActionJump formulation for SimpleExplicitTauLeaping reactant_stoich = [[1 => 1, 2 => 1], [2 => 1], Pair{Int, Int}[]] net_stoich = [[1 => -1, 2 => 1], [2 => -1, 3 => 1], [1 => 1]] param_idxs = [1, 2, 3] maj = MassActionJump(reactant_stoich, net_stoich; param_idxs) - jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj) # Solve with SimpleExplicitTauLeaping (save only at t_compare times) sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); - trajectories = Nsims, saveat = t_compare) + trajectories = Nsims, saveat = t_compare, rng) # Compute mean I trajectories via direct indexing (I is index 2 in SIR) mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 2) @@ -101,11 +101,11 @@ end u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R tspan = (0.0, 250.0) prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng, save_positions = (false, false)) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; save_positions = (false, false)) # Solve with SSAStepper (save only at t_compare times) sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); - trajectories = Nsims, saveat = t_compare) + trajectories = Nsims, saveat = t_compare, rng) # RegularJump formulation for SimpleTauLeaping regular_rate = (out, u, p, t) -> begin @@ -121,22 +121,22 @@ end dc[4] = counts[3] end rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj) # Solve with SimpleTauLeaping (save only at t_compare times) sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); - trajectories = Nsims, dt = 0.1, saveat = t_compare) + trajectories = Nsims, dt = 0.1, saveat = t_compare, rng) # MassActionJump formulation for SimpleExplicitTauLeaping reactant_stoich = [[1 => 1, 3 => 1], [2 => 1], [3 => 1]] net_stoich = [[1 => -1, 2 => 1], [2 => -1, 3 => 1], [3 => -1, 4 => 1]] param_idxs = [1, 2, 3] maj = MassActionJump(reactant_stoich, net_stoich; param_idxs) - jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj) # Solve with SimpleExplicitTauLeaping (save only at t_compare times) sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); - trajectories = Nsims, saveat = t_compare) + trajectories = Nsims, saveat = t_compare, rng) # Compute mean I trajectories via direct indexing (I is index 3 in SEIR) mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 3) @@ -183,7 +183,7 @@ end maj = MassActionJump(rates, reactant_stoich, net_stoich) # Test PureLeaping JumpProblem creation - jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj); rng) + jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj)) @test jp_pure.aggregator isa PureLeaping @test jp_pure.discrete_jump_aggregation === nothing @test jp_pure.massaction_jump !== nothing @@ -194,7 +194,7 @@ end affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) crj = ConstantRateJump(rate, affect!) - jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj); rng) + jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj)) @test jp_pure_crj.aggregator isa PureLeaping @test jp_pure_crj.discrete_jump_aggregation === nothing @test length(jp_pure_crj.constant_jumps) == 1 @@ -204,7 +204,7 @@ end vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) vrj = VariableRateJump(vrate, vaffect!) - jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj); rng) + jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj)) @test jp_pure_vrj.aggregator isa PureLeaping @test jp_pure_vrj.discrete_jump_aggregation === nothing @test length(jp_pure_vrj.variable_jumps) == 1 @@ -224,7 +224,7 @@ end regj = RegularJump(rj_rate, rj_c, 1) - jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj); rng) + jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj)) @test jp_pure_regj.aggregator isa PureLeaping @test jp_pure_regj.discrete_jump_aggregation === nothing @test jp_pure_regj.regular_jump !== nothing @@ -232,7 +232,7 @@ end # Test mixed jump types mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), variable_jumps = (vrj,), regular_jumps = regj) - jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps; rng) + jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps) @test jp_pure_mixed.aggregator isa PureLeaping @test jp_pure_mixed.discrete_jump_aggregation === nothing @test jp_pure_mixed.massaction_jump !== nothing @@ -243,14 +243,14 @@ end # Test spatial system error spatial_sys = CartesianGrid((2, 2)) hopping_consts = [1.0] - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, + @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); spatial_system = spatial_sys) - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, + @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); hopping_constants = hopping_consts) # Test MassActionJump with parameter mapping maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) - jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng) + jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params)) scaled_rates = [p[1], p[2]/2] @test jp_params.massaction_jump.scaled_rates == scaled_rates end @@ -266,23 +266,23 @@ end prob = DiscreteProblem(u0, tspan) # SSAStepper with save_positions=(false,false) + saveat: only saveat times stored - jp = JumpProblem(prob, Direct(), crj; rng, save_positions = (false, false)) - sol = solve(jp, SSAStepper(); saveat = 1.0) + jp = JumpProblem(prob, Direct(), crj; save_positions = (false, false)) + sol = solve(jp, SSAStepper(); saveat = 1.0, rng) @test sol.t == collect(0.0:1.0:10.0) # SSAStepper with default save_positions + saveat: jump times stored too - jp2 = JumpProblem(prob, Direct(), crj; rng) - sol2 = solve(jp2, SSAStepper(); saveat = 1.0) + jp2 = JumpProblem(prob, Direct(), crj) + sol2 = solve(jp2, SSAStepper(); saveat = 1.0, rng) @test length(sol2.t) > length(sol.t) # --- SimpleTauLeaping save_start/save_end/saveat tests --- regular_rate = (out, u, p, t) -> (out[1] = 1.0) regular_c = (dc, u, p, t, counts, mark) -> (dc[1] = counts[1]) rj = RegularJump(regular_rate, regular_c, 1) - jp_tau = JumpProblem(prob, PureLeaping(), rj; rng) + jp_tau = JumpProblem(prob, PureLeaping(), rj) # No saveat: stores every dt step (save_start=true, save_end=true by default) - sol_tau = solve(jp_tau, SimpleTauLeaping(); dt = 1.0) + sol_tau = solve(jp_tau, SimpleTauLeaping(); dt = 1.0, rng) @test sol_tau.t == collect(0.0:1.0:10.0) # saveat as Number: defaults save_start=true, save_end=true @@ -334,7 +334,7 @@ end reactant_stoich = [[1 => 1]] net_stoich = [[1 => -1]] maj = MassActionJump([0.1], reactant_stoich, net_stoich) - jp_explicit = JumpProblem(prob_decay, PureLeaping(), maj; rng) + jp_explicit = JumpProblem(prob_decay, PureLeaping(), maj) # saveat as Number: defaults save_start=true, save_end=true sol = solve(jp_explicit, SimpleExplicitTauLeaping(); saveat = 2.0) diff --git a/test/remake_test.jl b/test/remake_test.jl index d622233ac..c01014b05 100644 --- a/test/remake_test.jl +++ b/test/remake_test.jl @@ -1,4 +1,4 @@ -using JumpProcesses, DiffEqBase, OrdinaryDiffEq +using JumpProcesses, DiffEqBase, OrdinaryDiffEq, Test using StableRNGs rng = StableRNG(12345) @@ -21,21 +21,20 @@ p = (0.1 / 1000, 0.01) tspan = (0.0, 2500.0) dprob = DiscreteProblem(u0, tspan, p) -jprob = JumpProblem(dprob, Direct(), jump, jump2, save_positions = (false, false), - rng = rng) -sol = solve(jprob, SSAStepper()) +jprob = JumpProblem(dprob, Direct(), jump, jump2, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng) @test sol[3, end] == 1000 u02 = [1000, 1, 0] p2 = (0.1 / 1000, 0.0) dprob2 = remake(dprob, u0 = u02, p = p2) jprob2 = remake(jprob, prob = dprob2) -sol2 = solve(jprob2, SSAStepper()) +sol2 = solve(jprob2, SSAStepper(); rng) @test sol2[2, end] == 1001 tspan2 = (0.0, 25000.0) jprob3 = remake(jprob, p = p2, tspan = tspan2) -sol3 = solve(jprob3, SSAStepper()) +sol3 = solve(jprob3, SSAStepper(); rng) @test sol3[2, end] == 1000 @test sol3.t[end] == 25000.0 @@ -46,19 +45,19 @@ ns = [[2 => -1, 3 => 1], [1 => -1, 2 => 1]] pidxs = [2, 1] maj = MassActionJump(rs, ns; param_idxs = pidxs) dprob = DiscreteProblem(u0, tspan, p) -jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper()) +jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng) @test sol[3, end] == 1000 # update the MassActionJump dprob2 = remake(dprob, u0 = u02, p = p2) jprob2 = remake(jprob, prob = dprob2) -sol2 = solve(jprob2, SSAStepper()) +sol2 = solve(jprob2, SSAStepper(); rng) @test sol2[2, end] == 1001 tspan2 = (0.0, 25000.0) jprob3 = remake(jprob, p = p2, tspan = tspan2) -sol3 = solve(jprob3, SSAStepper()) +sol3 = solve(jprob3, SSAStepper(); rng) @test sol3[2, end] == 1000 @test sol3.t[end] == 25000.0 @@ -75,20 +74,20 @@ let rrate(u, p, t) = u[1] aaffect!(integrator) = (integrator.u[1] += 1; nothing) vrj = VariableRateJump(rrate, aaffect!) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) u0 = [4.0] jprob2 = remake(jprob; u0) @test jprob2.prob.u0 isa ExtendedJumpArray @test jprob2.prob.u0.u === u0 - sol = solve(jprob2, Tsit5()) + sol = solve(jprob2, Tsit5(); rng) u = sol[1, :] t = sol.t first_nontstart = findfirst(>(t[1]), t) @@ -97,7 +96,7 @@ let u0 = deepcopy(jprob2.prob.u0) u0.u .= 0 jprob3 = remake(jprob2; u0) - sol = solve(jprob3, Tsit5()) + sol = solve(jprob3, Tsit5(); rng) @test all(==(0.0), sol[1, :]) @test_throws ErrorException jprob4=remake(jprob, u0 = 1) end @@ -109,24 +108,24 @@ let rrate(u, p, t) = u[1] aaffect!(integrator) = (integrator.u[1] += 1; nothing) vrj = VariableRateJump(rrate, aaffect!) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) u0 = [4.0] prob2 = remake(jprob.prob; u0) @test_throws ErrorException jprob2=remake(jprob; prob = prob2) - u0eja = JumpProcesses.remake_extended_u0(jprob.prob, u0, rng) + u0eja = JumpProcesses.remake_extended_u0(jprob.prob, u0) prob3 = remake(jprob.prob; u0 = u0eja) jprob3 = remake(jprob; prob = prob3) @test jprob3.prob.u0 isa ExtendedJumpArray @test jprob3.prob.u0 === u0eja - sol = solve(jprob3, Tsit5()) + sol = solve(jprob3, Tsit5(); rng) u = sol[1, :] t = sol.t first_nontstart = findfirst(>(t[1]), t) diff --git a/test/reversible_binding.jl b/test/reversible_binding.jl index f872d92a4..d0e9b9a6d 100644 --- a/test/reversible_binding.jl +++ b/test/reversible_binding.jl @@ -20,10 +20,10 @@ tspan = (0.0, 5.0) prob = DiscreteProblem(u0, tspan, rates) majumps = MassActionJump(rates, reactstoch, netstoch) -function getmean(jprob, Nsims) +function getmean(jprob, Nsims; rng = nothing) Amean = 0 for i in 1:Nsims - sol = solve(jprob, SSAStepper()) + sol = solve(jprob, SSAStepper(); rng) Amean += sol[1, end] end Amean /= Nsims @@ -48,8 +48,7 @@ mastereq_mean = mastereqmean(u0, rates) algs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator()) relative_tolerance = 0.01 for alg in algs - local jprob = JumpProblem(prob, alg, majumps, save_positions = (false, false), - rng = rng) - local Amean = getmean(jprob, Nsims) + local jprob = JumpProblem(prob, alg, majumps, save_positions = (false, false)) + local Amean = getmean(jprob, Nsims; rng) @test abs(Amean - mastereq_mean) / mastereq_mean < relative_tolerance end diff --git a/test/rng_kwarg_tests.jl b/test/rng_kwarg_tests.jl new file mode 100644 index 000000000..a237cecdc --- /dev/null +++ b/test/rng_kwarg_tests.jl @@ -0,0 +1,327 @@ +using JumpProcesses, OrdinaryDiffEq, StochasticDiffEq, Test +using StableRNGs, Random + +# ========================================================================== +# Test that rng/seed can be passed via solve/init kwargs for all pathways, +# and that JumpProblem(; rng=...) throws an error. +# ========================================================================== + +# -------------------------------------------------------------------------- +# Problem constructors +# -------------------------------------------------------------------------- +function make_ssa_jump_prob() + j1 = ConstantRateJump((u, p, t) -> 10.0, integrator -> (integrator.u[1] += 1)) + j2 = ConstantRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1)) + dprob = DiscreteProblem([10], (0.0, 20.0)) + JumpProblem(dprob, Direct(), j1, j2) +end + +function make_ode_vr_jump_prob() + f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) + oprob = ODEProblem(f!, [100.0], (0.0, 10.0)) + vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], + integrator -> (integrator.u[1] -= 1.0)) + JumpProblem(oprob, Direct(), vrj) +end + +function make_sde_vr_jump_prob() + f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) + g!(du, u, p, t) = (du[1] = 0.1 * u[1]; nothing) + sprob = SDEProblem(f!, g!, [100.0], (0.0, 10.0)) + vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], + integrator -> (integrator.u[1] -= 1.0)) + JumpProblem(sprob, Direct(), vrj) +end + +# Helpers +first_jump_time(traj) = traj.t[findfirst(>(traj.t[1]), traj.t)] + +# ========================================================================== +# 1. JumpProblem(; rng=...) throws ArgumentError +# ========================================================================== +@testset "JumpProblem(; rng=...) throws ArgumentError" begin + dprob = DiscreteProblem([10], (0.0, 10.0)) + j1 = ConstantRateJump((u, p, t) -> 1.0, integrator -> (integrator.u[1] += 1)) + @test_throws ArgumentError JumpProblem(dprob, Direct(), j1; rng = StableRNG(42)) +end + +# ========================================================================== +# 2. SSAStepper: rng via solve/init +# ========================================================================== +@testset "SSAStepper: rng via solve kwargs" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); rng = Xoshiro(42)) + @test SciMLBase.get_rng(integrator) isa Xoshiro + sol = solve(jprob, SSAStepper(); rng = Xoshiro(42)) + @test sol.retcode == ReturnCode.Success +end + +# ========================================================================== +# 3. SSAStepper: reproducibility via solve rng +# ========================================================================== +@testset "SSAStepper: solve rng reproducibility" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); rng = StableRNG(123)) + sol2 = solve(jprob, SSAStepper(); rng = StableRNG(123)) + @test sol1.t == sol2.t + @test sol1.u == sol2.u +end + +# ========================================================================== +# 4. SSAStepper: different seeds → different trajectories +# ========================================================================== +@testset "SSAStepper: different seeds → different trajectories" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); rng = StableRNG(100)) + sol2 = solve(jprob, SSAStepper(); rng = StableRNG(200)) + sol3 = solve(jprob, SSAStepper(); rng = StableRNG(300)) + times = [sol1.t[2], sol2.t[2], sol3.t[2]] + @test allunique(times) +end + +# ========================================================================== +# 5. ODE + VR: rng via solve/init +# ========================================================================== +@testset "ODE + VR: rng via solve kwargs" begin + jprob = make_ode_vr_jump_prob() + integrator = init(jprob, Tsit5(); rng = Xoshiro(42)) + @test SciMLBase.get_rng(integrator) isa Xoshiro +end + +# ========================================================================== +# 6. ODE + VR: reproducibility via solve rng +# ========================================================================== +@testset "ODE + VR: solve rng reproducibility" begin + jprob = make_ode_vr_jump_prob() + sol1 = solve(jprob, Tsit5(); rng = StableRNG(123)) + sol2 = solve(jprob, Tsit5(); rng = StableRNG(123)) + @test sol1.t ≈ sol2.t + @test sol1.u[end] ≈ sol2.u[end] +end + +# ========================================================================== +# 7. ODE + VR: different seeds → different trajectories +# ========================================================================== +@testset "ODE + VR: different seeds → different trajectories" begin + jprob = make_ode_vr_jump_prob() + sols = [solve(jprob, Tsit5(); rng = StableRNG(s)) for s in (100, 200, 300)] + times = [first_jump_time(s) for s in sols] + @test allunique(times) +end + +# ========================================================================== +# 8. SDE + VR: rng via solve/init +# ========================================================================== +@testset "SDE + VR: rng via solve kwargs" begin + jprob = make_sde_vr_jump_prob() + integrator = init(jprob, EM(); dt = 0.01, rng = Xoshiro(42)) + @test SciMLBase.get_rng(integrator) isa Xoshiro +end + +# ========================================================================== +# 9. SDE + VR: reproducibility via solve rng +# ========================================================================== +@testset "SDE + VR: solve rng reproducibility" begin + jprob = make_sde_vr_jump_prob() + sol1 = solve(jprob, EM(); dt = 0.01, save_everystep = false, rng = StableRNG(123)) + sol2 = solve(jprob, EM(); dt = 0.01, save_everystep = false, rng = StableRNG(123)) + @test sol1.u[end] ≈ sol2.u[end] +end + +# ========================================================================== +# 10. SDE + VR: different seeds → different trajectories +# ========================================================================== +@testset "SDE + VR: different seeds → different trajectories" begin + jprob = make_sde_vr_jump_prob() + sols = [solve(jprob, SRIW1(); save_everystep = false, + rng = StableRNG(s)) for s in (100, 200, 300)] + times = [first_jump_time(s) for s in sols] + @test allunique(times) +end + +# ========================================================================== +# 11. SimpleTauLeaping: rng via solve kwargs +# ========================================================================== +@testset "SimpleTauLeaping: rng via solve kwargs" begin + rate(out, u, p, t) = (out .= max.(u, 0); nothing) + c(du, u, p, t, counts, mark) = (du .= counts; nothing) + rj = RegularJump(rate, c, 2) + dprob = DiscreteProblem([100, 100], (0.0, 1.0)) + jprob = JumpProblem(dprob, PureLeaping(), rj) + sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(42)) + sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(42)) + @test sol1.u == sol2.u +end + +# ========================================================================== +# 12. SimpleTauLeaping: different seeds → different trajectories +# ========================================================================== +@testset "SimpleTauLeaping: different seeds → different trajectories" begin + rate(out, u, p, t) = (out .= max.(u, 0); nothing) + c(du, u, p, t, counts, mark) = (du .= counts; nothing) + rj = RegularJump(rate, c, 2) + dprob = DiscreteProblem([100, 100], (0.0, 1.0)) + jprob = JumpProblem(dprob, PureLeaping(), rj) + sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(42)) + sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(99)) + @test sol1.u != sol2.u +end + +# ========================================================================== +# 13. SimpleTauLeaping: seed kwarg reproducibility +# ========================================================================== +@testset "SimpleTauLeaping: seed kwarg reproducibility" begin + rate(out, u, p, t) = (out .= max.(u, 0); nothing) + c(du, u, p, t, counts, mark) = (du .= counts; nothing) + rj = RegularJump(rate, c, 2) + dprob = DiscreteProblem([100, 100], (0.0, 1.0)) + jprob = JumpProblem(dprob, PureLeaping(), rj) + sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01, seed = 42) + sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, seed = 42) + @test sol1.u == sol2.u +end + +# ========================================================================== +# 14. SimpleTauLeaping: seed different seeds → different trajectories +# ========================================================================== +@testset "SimpleTauLeaping: seed different seeds → different trajectories" begin + rate(out, u, p, t) = (out .= max.(u, 0); nothing) + c(du, u, p, t, counts, mark) = (du .= counts; nothing) + rj = RegularJump(rate, c, 2) + dprob = DiscreteProblem([100, 100], (0.0, 1.0)) + jprob = JumpProblem(dprob, PureLeaping(), rj) + sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01, seed = 42) + sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, seed = 99) + @test sol1.u != sol2.u +end + +# ========================================================================== +# 15. has_rng / get_rng / set_rng! interface on SSAIntegrator +# ========================================================================== +@testset "SSAIntegrator RNG interface" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); rng = StableRNG(42)) + + @test SciMLBase.has_rng(integrator) + rng = SciMLBase.get_rng(integrator) + @test rng isa StableRNG + + new_rng = StableRNG(99) + SciMLBase.set_rng!(integrator, new_rng) + @test SciMLBase.get_rng(integrator) === new_rng + + # mismatched RNG type should throw + @test_throws ArgumentError SciMLBase.set_rng!(integrator, Random.Xoshiro(123)) +end + +# ========================================================================== +# 16. No rng kwarg: uses default_rng (non-reproducible but functional) +# ========================================================================== +@testset "No rng kwarg: functional solve" begin + @testset "SSAStepper" begin + jprob = make_ssa_jump_prob() + sol = solve(jprob, SSAStepper()) + @test sol.retcode == ReturnCode.Success + end + + @testset "ODE + VR" begin + jprob = make_ode_vr_jump_prob() + sol = solve(jprob, Tsit5()) + @test sol.retcode == ReturnCode.Success + end + + @testset "SDE + VR" begin + jprob = make_sde_vr_jump_prob() + sol = solve(jprob, EM(); dt = 0.01) + @test sol.retcode == ReturnCode.Success + end +end + +# ========================================================================== +# 17. seed kwarg: creates Xoshiro from integer seed +# ========================================================================== +@testset "seed kwarg creates Xoshiro" begin + @testset "SSAStepper" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); seed = 42) + @test SciMLBase.get_rng(integrator) isa Xoshiro + end + + @testset "ODE + VR" begin + jprob = make_ode_vr_jump_prob() + integrator = init(jprob, Tsit5(); seed = 42) + @test SciMLBase.get_rng(integrator) isa Xoshiro + end + + @testset "SDE + VR" begin + jprob = make_sde_vr_jump_prob() + integrator = init(jprob, EM(); dt = 0.01, seed = 42) + @test SciMLBase.get_rng(integrator) isa Xoshiro + end +end + +# ========================================================================== +# 18. seed kwarg reproducibility: same seed → same trajectory +# ========================================================================== +@testset "seed kwarg reproducibility" begin + @testset "SSAStepper" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); seed = 42) + sol2 = solve(jprob, SSAStepper(); seed = 42) + @test sol1.t == sol2.t + @test sol1.u == sol2.u + end + + @testset "ODE + VR" begin + jprob = make_ode_vr_jump_prob() + sol1 = solve(jprob, Tsit5(); seed = 42) + sol2 = solve(jprob, Tsit5(); seed = 42) + @test sol1.t ≈ sol2.t + @test sol1.u[end] ≈ sol2.u[end] + end + + @testset "SDE + VR" begin + jprob = make_sde_vr_jump_prob() + sol1 = solve(jprob, EM(); dt = 0.01, save_everystep = false, seed = 42) + sol2 = solve(jprob, EM(); dt = 0.01, save_everystep = false, seed = 42) + @test sol1.u[end] ≈ sol2.u[end] + end +end + +# ========================================================================== +# 19. seed kwarg: different seeds → different trajectories +# ========================================================================== +@testset "seed kwarg: different seeds → different trajectories" begin + @testset "SSAStepper" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); seed = 100) + sol2 = solve(jprob, SSAStepper(); seed = 200) + sol3 = solve(jprob, SSAStepper(); seed = 300) + times = [sol1.t[2], sol2.t[2], sol3.t[2]] + @test allunique(times) + end + + @testset "ODE + VR" begin + jprob = make_ode_vr_jump_prob() + sols = [solve(jprob, Tsit5(); seed = s) for s in (100, 200, 300)] + times = [first_jump_time(s) for s in sols] + @test allunique(times) + end + + @testset "SDE + VR" begin + jprob = make_sde_vr_jump_prob() + sols = [solve(jprob, SRIW1(); save_everystep = false, + seed = s) for s in (100, 200, 300)] + times = [first_jump_time(s) for s in sols] + @test allunique(times) + end +end + +# ========================================================================== +# 20. rng takes priority over seed +# ========================================================================== +@testset "rng takes priority over seed" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); rng = StableRNG(42), seed = 99) + @test SciMLBase.get_rng(integrator) isa StableRNG +end diff --git a/test/runtests.jl b/test/runtests.jl index 9a2dfc246..6df569d28 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,15 +40,15 @@ end if GROUP == "All" || GROUP == "InterfaceII" @time @safetestset "Saveat Regression test" begin include("saveat_regression.jl") end @time @safetestset "Save_positions test" begin include("save_positions.jl") end + @time @safetestset "RNG kwarg tests" begin include("rng_kwarg_tests.jl") end @time @safetestset "Ensemble Uniqueness test" begin include("ensemble_uniqueness.jl") end - @time @safetestset "Thread Safety test" begin include("thread_safety.jl") end @time @safetestset "Ensemble Problem Tests" begin include("ensemble_problems.jl") end @time @safetestset "A + B <--> C" begin include("reversible_binding.jl") end @time @safetestset "Remake tests" begin include("remake_test.jl") end @time @safetestset "ExtendedJumpArray remake tests" begin include("extended_jump_array_remake.jl") end @time @safetestset "Symbol based problem indexing" begin include("jprob_symbol_indexing.jl") end @time @safetestset "Long time accuracy test" begin include("longtimes_test.jl") end - @time @safetestset "Hawkes process" begin include("hawkes_test.jl") end + @time @safetestset "Hawkes process" begin include("hawkes_test.jl") end @time @safetestset "Reaction rates" begin include("spatial/reaction_rates.jl") end @time @safetestset "Hop rates" begin include("spatial/hop_rates.jl") end @time @safetestset "Topology" begin include("spatial/topology.jl") end @@ -58,6 +58,10 @@ end @time @safetestset "Pure diffusion" begin include("spatial/diffusion.jl") end end + if GROUP == "All" || GROUP == "ThreadSafety" + @time @safetestset "Thread Safety test (threaded)" begin include("thread_safety.jl") end + end + if GROUP == "CUDA" activate_gpu_env() @time @safetestset "GPU Tau Leaping test" begin include("gpu/regular_jumps.jl") end diff --git a/test/save_positions.jl b/test/save_positions.jl index 13413b5b2..5d0a683f2 100644 --- a/test/save_positions.jl +++ b/test/save_positions.jl @@ -15,8 +15,8 @@ let jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1; urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0) jumpproblem = JumpProblem(dprob, alg, jump; dep_graph = [[1]], - save_positions = (false, true), rng) - sol = solve(jumpproblem, SSAStepper()) + save_positions = (false, true)) + sol = solve(jumpproblem, SSAStepper(); rng) @test sol.t == [0.0, 30.0] oprob = ODEProblem((du, u, p, t) -> 0, u0, tspan) @@ -26,8 +26,8 @@ let for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jumpproblem = JumpProblem( oprob, alg, jump; vr_aggregator = vr_agg, dep_graph = [[1]], - save_positions = (false, true), rng) - sol = solve(jumpproblem, Tsit5(); save_everystep = false) + save_positions = (false, true)) + sol = solve(jumpproblem, Tsit5(); rng, save_everystep = false) @test sol.t == [0.0, 30.0] end end @@ -46,20 +46,20 @@ let # for pure jump problems dense = save_everystep vals = (true, true, true, false) for (sp, val) in zip(sps, vals) - jprob = JumpProblem(dprob, Direct(), crj; save_positions = sp, rng) - sol = solve(jprob, SSAStepper()) + jprob = JumpProblem(dprob, Direct(), crj; save_positions = sp) + sol = solve(jprob, SSAStepper(); rng) @test SciMLBase.isdenseplot(sol) == val end # for mixed problems sol.dense currently ignores save_positions oprob = ODEProblem((du, u, p, t) -> du[1] = 0.1, u0, tspan) for sp in sps - jprob = JumpProblem(oprob, Direct(), crj; save_positions = sp, rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(oprob, Direct(), crj; save_positions = sp) + sol = solve(jprob, Tsit5(); rng) @test sol.dense == true @test SciMLBase.isdenseplot(sol) == true - sol = solve(jprob, Tsit5(); dense = false) + sol = solve(jprob, Tsit5(); rng, dense = false) @test sol.dense == false @test SciMLBase.isdenseplot(sol) == false end diff --git a/test/saveat_regression.jl b/test/saveat_regression.jl index 03665d76a..108fde97e 100644 --- a/test/saveat_regression.jl +++ b/test/saveat_regression.jl @@ -10,12 +10,12 @@ maj = MassActionJump(rate_consts, reactant_stoich, net_stoich) n0 = [1, 1, 0] tspan = (0, 0.2) dprob = DiscreteProblem(n0, tspan) -jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng) +jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false)) ts = collect(0:0.002:tspan[2]) NA = zeros(length(ts)) Nsims = 10_000 -sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(), saveat = ts, - trajectories = Nsims) +sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(); saveat = ts, + trajectories = Nsims, rng) for i in 1:length(sol) NA .+= sol.u[i][1, :] @@ -26,10 +26,10 @@ for i in 1:length(ts) end NA = zeros(length(ts)) -jprob = JumpProblem(dprob, Direct(), maj; rng = rng) +jprob = JumpProblem(dprob, Direct(), maj) sol = nothing; GC.gc(); -sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(), trajectories = Nsims) +sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(); trajectories = Nsims, rng) for i in 1:Nsims for n in 1:length(ts) diff --git a/test/sir_model.jl b/test/sir_model.jl index e8cea455e..8f4cc1ccc 100644 --- a/test/sir_model.jl +++ b/test/sir_model.jl @@ -18,8 +18,8 @@ end jump2 = ConstantRateJump(rate, affect!) prob = DiscreteProblem([999.0, 1.0, 0.0], (0.0, 250.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) -integrator = init(jump_prob, FunctionMap()) +jump_prob = JumpProblem(prob, Direct(), jump, jump2) +integrator = init(jump_prob, FunctionMap(); rng) condition(u, t, integrator) = t == 100 function purge_affect!(integrator) @@ -27,8 +27,8 @@ function purge_affect!(integrator) reset_aggregated_jumps!(integrator) end cb = DiscreteCallback(condition, purge_affect!, save_positions = (false, false)) -sol = solve(jump_prob, FunctionMap(), callback = cb, tstops = [100]) -sol = solve(jump_prob, SSAStepper(), callback = cb, tstops = [100]) +sol = solve(jump_prob, FunctionMap(); callback = cb, tstops = [100], rng) +sol = solve(jump_prob, SSAStepper(); callback = cb, tstops = [100], rng) # test README example using the auto-solver selection runs let diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 8d7230a74..3935ffca9 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -38,10 +38,10 @@ prob = DiscreteProblem(starting_state, tspan, rates) hopping_constants = [hopping_rate for i in starting_state] # algs = [NSM(), DirectCRDirect()] -function get_mean_end_state(jump_prob, Nsims) +function get_mean_end_state(jump_prob, Nsims; rng = nothing) end_state = zeros(size(jump_prob.prob.u0)) for i in 1:Nsims - sol = solve(jump_prob, SSAStepper()) + sol = solve(jump_prob, SSAStepper(); rng) end_state .+= sol.u[end] end end_state / Nsims @@ -52,19 +52,19 @@ grids = [CartesianGridRej(dims), Graphs.grid(dims)] jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, spatial_system = grid, - save_positions = (false, false), rng = rng) + save_positions = (false, false)) for grid in grids] push!(jump_problems, JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) # setup flattenned jump prob push!(jump_problems, JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) # test for spatial_jump_prob in jump_problems - solution = solve(spatial_jump_prob, SSAStepper()) - mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) + solution = solve(spatial_jump_prob, SSAStepper(); rng) + mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims; rng) mean_end_state = reshape(mean_end_state, num_species, num_nodes) diff = sum(mean_end_state, dims = 2) - non_spatial_mean for (i, d) in enumerate(diff) diff --git a/test/spatial/diffusion.jl b/test/spatial/diffusion.jl index c014b5c52..e501ab07f 100644 --- a/test/spatial/diffusion.jl +++ b/test/spatial/diffusion.jl @@ -5,10 +5,10 @@ using Graphs using StableRNGs rng = StableRNG(12345) -function get_mean_sol(jump_prob, Nsims, saveat) - sol = solve(jump_prob, SSAStepper(), saveat = saveat).u +function get_mean_sol(jump_prob, Nsims, saveat; rng = nothing) + sol = solve(jump_prob, SSAStepper(); saveat, rng).u for i in 1:(Nsims - 1) - sol += solve(jump_prob, SSAStepper(), saveat = saveat).u + sol += solve(jump_prob, SSAStepper(); saveat, rng).u end sol / Nsims end @@ -66,24 +66,24 @@ grids = [CartesianGridRej(dims), Graphs.grid(dims)] jump_problems = JumpProblem[JumpProblem(prob, algs[2], majumps, hopping_constants = hopping_constants, spatial_system = grid, - save_positions = (false, false), rng = rng) + save_positions = (false, false)) for grid in grids] sizehint!(jump_problems, 15 + length(jump_problems)) # flattenned jump prob push!(jump_problems, JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) # hop rates of form D_s hop_constants = [hopping_rate] for alg in algs push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false))) end # hop rates of form L_{s,i,j} @@ -96,10 +96,10 @@ end for alg in algs push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false))) end # hop rates of form D_s * L_{i,j} @@ -112,11 +112,11 @@ for alg in algs push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false))) end # hop rates of form D_{s,i} * L_{i,j} @@ -129,16 +129,16 @@ for alg in algs push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false))) end # testing for (j, spatial_jump_prob) in enumerate(jump_problems) - mean_sol = get_mean_sol(spatial_jump_prob, Nsims, tf / num_time_points) + mean_sol = get_mean_sol(spatial_jump_prob, Nsims, tf / num_time_points; rng) for (i, t) in enumerate(times) local diff = analytic_solution(t) - reshape(mean_sol[i], num_nodes, 1) @test abs(sum(diff[1:center_node]) / sum(analytic_solution(t)[1:center_node])) < @@ -165,7 +165,7 @@ tspan = (0.0, 10.0) prob = DiscreteProblem(starting_state, tspan) jp = JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, - spatial_system = grid, save_positions = (false, false), rng = rng) -sol = solve(jp, SSAStepper()) + spatial_system = grid, save_positions = (false, false)) +sol = solve(jp, SSAStepper(); rng) @test sol.u[end][1, 1] == sum(sol.u[end]) diff --git a/test/spatial/spatial_majump.jl b/test/spatial/spatial_majump.jl index 6afcd367c..6018b68fd 100644 --- a/test/spatial/spatial_majump.jl +++ b/test/spatial/spatial_majump.jl @@ -61,26 +61,26 @@ non_uniform_majumps = [non_uniform_majumps_1, non_uniform_majumps_2, non_uniform uniform_jump_problems = JumpProblem[JumpProblem(prob, NSM(), majump, hopping_constants = hopping_constants, spatial_system = grid, - save_positions = (false, false), rng = rng) + save_positions = (false, false)) for majump in uniform_majumps] # flattenned append!(uniform_jump_problems, JumpProblem[JumpProblem(prob, NRM(), majump, hopping_constants = hopping_constants, - spatial_system = grid, save_positions = (false, false), rng = rng) + spatial_system = grid, save_positions = (false, false)) for majump in uniform_majumps]) # non-uniform non_uniform_jump_problems = JumpProblem[JumpProblem(prob, NSM(), majump, hopping_constants = hopping_constants, spatial_system = grid, - save_positions = (false, false), rng = rng) + save_positions = (false, false)) for majump in non_uniform_majumps] # testing -function get_mean_end_state(jump_prob, Nsims) +function get_mean_end_state(jump_prob, Nsims; rng = nothing) end_state = zeros(size(jump_prob.prob.u0)) for i in 1:Nsims - sol = solve(jump_prob, SSAStepper()) + sol = solve(jump_prob, SSAStepper(); rng) end_state .+= sol.u[end] end end_state / Nsims @@ -106,8 +106,8 @@ ode_prob = ODEProblem(f, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) for spatial_jump_prob in uniform_jump_problems - solution = solve(spatial_jump_prob, SSAStepper()) - mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) + solution = solve(spatial_jump_prob, SSAStepper(); rng) + mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims; rng) mean_end_state = reshape(mean_end_state, num_nodes) diff = mean_end_state - sol.u[end] for (i, d) in enumerate(diff) @@ -122,8 +122,8 @@ end ode_prob = ODEProblem(f2, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) -solution = solve(non_uniform_jump_problems[1], SSAStepper()) -mean_end_state = get_mean_end_state(non_uniform_jump_problems[1], Nsims) +solution = solve(non_uniform_jump_problems[1], SSAStepper(); rng) +mean_end_state = get_mean_end_state(non_uniform_jump_problems[1], Nsims; rng) mean_end_state = reshape(mean_end_state, num_nodes) diff = mean_end_state - sol.u[end] for (i, d) in enumerate(diff) @@ -135,8 +135,8 @@ f3(u, p, t) = L * u - diagm([0.0, 0.0, death_rate, 0.0, 0.0]) * u + ones(num_nod ode_prob = ODEProblem(f3, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) -solution = solve(non_uniform_jump_problems[2], SSAStepper()) -mean_end_state = get_mean_end_state(non_uniform_jump_problems[2], Nsims) +solution = solve(non_uniform_jump_problems[2], SSAStepper(); rng) +mean_end_state = get_mean_end_state(non_uniform_jump_problems[2], Nsims; rng) mean_end_state = reshape(mean_end_state, num_nodes) diff = mean_end_state - sol.u[end] for (i, d) in enumerate(diff) @@ -150,8 +150,8 @@ end ode_prob = ODEProblem(f4, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) -solution = solve(non_uniform_jump_problems[3], SSAStepper()) -mean_end_state = get_mean_end_state(non_uniform_jump_problems[3], Nsims) +solution = solve(non_uniform_jump_problems[3], SSAStepper(); rng) +mean_end_state = get_mean_end_state(non_uniform_jump_problems[3], Nsims; rng) mean_end_state = reshape(mean_end_state, num_nodes) diff = mean_end_state - sol.u[end] for (i, d) in enumerate(diff) diff --git a/test/splitcoupled.jl b/test/splitcoupled.jl index e871e2c0f..1030ec88f 100644 --- a/test/splitcoupled.jl +++ b/test/splitcoupled.jl @@ -12,16 +12,15 @@ jump1 = ConstantRateJump(rate, affect!) prob = DiscreteProblem([10], (0.0, 50.0)) prob_control = DiscreteProblem([10], (0.0, 50.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump1) coupling_map = [(1, 1)] coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) + jump_prob, jump_prob_control, Direct(), coupling_map) -@time sol = solve(coupled_prob, FunctionMap()) -@time solve(jump_prob, FunctionMap()) +@time sol = solve(coupled_prob, FunctionMap(); rng) +@time solve(jump_prob, FunctionMap(); rng) @test [s[1] - s[2] for s in sol.u] == zeros(length(sol.t)) # coupling two copies of the same process should give zero rate = (u, p, t) -> 1.0 @@ -42,34 +41,31 @@ end # Jump ODE to jump ODE prob = ODEProblem(f, [1.0], (0.0, 1.0)) prob_control = ODEProblem(f, [1.0], (0.0, 1.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump2; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump2) coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) -sol = solve(coupled_prob, Tsit5()) + jump_prob, jump_prob_control, Direct(), coupling_map) +sol = solve(coupled_prob, Tsit5(); rng) @test mean([abs(s[1] - s[2]) for s in sol.u]) <= 5.0 # Jump SDE to Jump SDE prob = SDEProblem(f, g, [1.0], (0.0, 1.0)) prob_control = SDEProblem(f, g, [1.0], (0.0, 1.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump1) coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) -sol = solve(coupled_prob, SRIW1()) + jump_prob, jump_prob_control, Direct(), coupling_map) +sol = solve(coupled_prob, SRIW1(); rng) @test mean([abs(s[1] - s[2]) for s in sol.u]) <= 5.0 # Jump SDE to Jump ODE prob = ODEProblem(f, [1.0], (0.0, 1.0)) prob_control = SDEProblem(f, g, [1.0], (0.0, 1.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump1) coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) -sol = solve(coupled_prob, SRIW1()) + jump_prob, jump_prob_control, Direct(), coupling_map) +sol = solve(coupled_prob, SRIW1(); rng) @test mean([abs(s[1] - s[2]) for s in sol.u]) <= 5.0 # Jump SDE to Discrete @@ -79,12 +75,11 @@ affect! = function (integrator) end prob = DiscreteProblem([1.0], (0.0, 1.0)) prob_control = SDEProblem(f, g, [1.0], (0.0, 1.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump1) coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) -sol = solve(coupled_prob, SRIW1()) + jump_prob, jump_prob_control, Direct(), coupling_map) +sol = solve(coupled_prob, SRIW1(); rng) # test mass action jumps coupled to ODE # 0 -> A (stochasic) and A -> 0 (ODE) @@ -96,13 +91,12 @@ f = function (du, u, p, t) du[1] = -1.0 * u[1] end odeprob = ODEProblem(f, [10.0], (0.0, 10.0)) -jump_prob = JumpProblem(odeprob, Direct(), majumps, save_positions = (false, false); - rng = rng) +jump_prob = JumpProblem(odeprob, Direct(), majumps, save_positions = (false, false)) Nsims = 8000 Amean = 0.0 for i in 1:Nsims global Amean - local sol = solve(jump_prob, Tsit5(), saveat = 10.0) + local sol = solve(jump_prob, Tsit5(); saveat = 10.0, rng) Amean += sol[1, end] end Amean /= Nsims diff --git a/test/ssa_callback_test.jl b/test/ssa_callback_test.jl index dbe9c969d..1cbf678b3 100644 --- a/test/ssa_callback_test.jl +++ b/test/ssa_callback_test.jl @@ -11,9 +11,9 @@ end jump = ConstantRateJump(rate, affect!) prob = DiscreteProblem([0.0, 0.0], (0.0, 10.0)) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump) -sol = solve(jump_prob, SSAStepper()) +sol = solve(jump_prob, SSAStepper(); rng) @test sol.t == [0.0, 10.0] @test sol.u == [[0.0, 0.0], [0.0, 0.0]] @@ -25,13 +25,13 @@ function fuel_affect!(integrator) end cb = DiscreteCallback(condition, fuel_affect!, save_positions = (false, true)) -sol = solve(jump_prob, SSAStepper(); callback = cb, tstops = [5]) +sol = solve(jump_prob, SSAStepper(); rng, callback = cb, tstops = [5]) @test sol.t[1:2] == [0.0, 5.0] # no jumps between t=0 and t=5 @test sol(5 + 1e-10) == [100, 0] # state just after fueling before any decays can happen # test can pass callbacks via JumpProblem -jump_prob2 = JumpProblem(prob, Direct(), jump; rng = rng, callback = cb) -sol2 = solve(jump_prob2, SSAStepper(); tstops = [5]) +jump_prob2 = JumpProblem(prob, Direct(), jump; callback = cb) +sol2 = solve(jump_prob2, SSAStepper(); rng, tstops = [5]) @test sol2.t[1:2] == [0.0, 5.0] # no jumps between t=0 and t=5 @test sol2(5 + 1e-10) == [100, 0] # state just after fueling before any decays can happen @@ -48,7 +48,7 @@ finalizer_called = 0 fuel_finalize(cb, u, t, integrator) = global finalizer_called += 1 cb2 = DiscreteCallback(condition, fuel_affect!, initialize = fuel_init!, finalize = fuel_finalize) -sol = solve(jump_prob, SSAStepper(), callback = cb2) +sol = solve(jump_prob, SSAStepper(); rng, callback = cb2) for tstop in random_tstops @test tstop ∈ sol.t end @@ -62,37 +62,37 @@ maj = MassActionJump(rs, ns; param_idxs = [1, 2]) u₀ = [100, 0] tspan = (0.0, 2000.0) dprob = DiscreteProblem(u₀, tspan, p) -jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng) +jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false)) pcondit(u, t, integrator) = t == 1000.0 function paffect!(integrator) integrator.p[1] = 0.0 integrator.p[2] = 1.0 reset_aggregated_jumps!(integrator) end -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p .== [0.0, 1.0]) @test sol[1, end] == 100 p .= [1.0, 0.0] maj1 = MassActionJump([1 => 1], [1 => -1, 2 => 1]; param_idxs = 1) maj2 = MassActionJump([2 => 1], [1 => 1, 2 => -1]; param_idxs = 2) -jprob = JumpProblem(dprob, Direct(), maj1, maj2, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) +jprob = JumpProblem(dprob, Direct(), maj1, maj2, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p .== [0.0, 1.0]) @test sol[1, end] == 100 p2 = [1.0, 0.0, 0.0] maj3 = MassActionJump([1 => 1], [1 => -1, 2 => 1]; param_idxs = 3) dprob = DiscreteProblem(u₀, tspan, p2) -jprob = JumpProblem(dprob, Direct(), maj1, maj2, maj3, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) +jprob = JumpProblem(dprob, Direct(), maj1, maj2, maj3, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p2 .== [0.0, 1.0, 0.0]) @test sol[1, end] == 100 p2 .= [1.0, 0.0, 0.0] jprob = JumpProblem(dprob, Direct(), JumpSet(; massaction_jumps = [maj1, maj2, maj3]), - save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) + save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p2 .== [0.0, 1.0, 0.0]) @test sol[1, end] == 100 @@ -100,8 +100,8 @@ p .= [1.0, 0.0] dprob = DiscreteProblem(u₀, tspan, p) maj4 = MassActionJump([[1 => 1], [2 => 1]], [[1 => -1, 2 => 1], [1 => 1, 2 => -1]]; param_idxs = [1, 2]) -jprob = JumpProblem(dprob, Direct(), maj4, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) +jprob = JumpProblem(dprob, Direct(), maj4, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p .== [0.0, 1.0]) @test sol[1, end] == 100 @@ -109,9 +109,9 @@ sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback( p .= [1.0] dprob = DiscreteProblem(u₀, tspan, p) maj5 = MassActionJump([[1 => 2]], [[1 => -1, 2 => 1]]; param_idxs = [1]) -jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false), rng = rng) +jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false)) @test all(jprob.massaction_jump.scaled_rates .== [0.5]) -jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false), rng = rng, scale_rates = false) +jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false), scale_rates = false) @test all(jprob.massaction_jump.scaled_rates .== [1.0]) # test for https://github.com/SciML/JumpProcesses.jl/issues/239 @@ -119,11 +119,11 @@ maj6 = MassActionJump([[1 => 1], [2 => 1]], [[1 => -1, 2 => 1], [1 => 1, 2 => -1 param_idxs = [1, 2]) p = (0.1, 0.1) dprob = DiscreteProblem([10, 0], (0.0, 100.0), p) -jprob = JumpProblem(dprob, Direct(), maj6; save_positions = (false, false), rng = rng) +jprob = JumpProblem(dprob, Direct(), maj6; save_positions = (false, false)) cbtimes = [20.0, 30.0] affectpresets!(integrator) = integrator.u[1] += 10 cb = PresetTimeCallback(cbtimes, affectpresets!) -jsol = solve(jprob, SSAStepper(), saveat = 0.1, callback = cb) +jsol = solve(jprob, SSAStepper(); rng, saveat = 0.1, callback = cb) @test (jsol(20.00000000001) - jsol(19.9999999999))[1] == 10 # test periodic callbacks working, i.e. #417 @@ -134,18 +134,18 @@ let dprob = DiscreteProblem([0], (0.0, 10.0)) cbfun(integ) = (integ.u[1] += 1; nothing) cb = PeriodicCallback(cbfun, 1.0) - jprob = JumpProblem(dprob, crj; rng) - sol = solve(jprob; callback = cb) + jprob = JumpProblem(dprob, crj) + sol = solve(jprob; rng, callback = cb) @test sol[1, end] == 9 cb = PeriodicCallback(cbfun, 1.0; initial_affect = true) - jprob = JumpProblem(dprob, crj; rng) - sol = solve(jprob; callback = cb) + jprob = JumpProblem(dprob, crj) + sol = solve(jprob; rng, callback = cb) @test sol[1, end] == 10 cb = PeriodicCallback(cbfun, 1.0; initial_affect = true, final_affect = true) - jprob = JumpProblem(dprob, crj; rng) - sol = solve(jprob; callback = cb) + jprob = JumpProblem(dprob, crj) + sol = solve(jprob; rng, callback = cb) @test sol[1, end] == 11 end @@ -157,22 +157,22 @@ let dprob = DiscreteProblem([0], (0.0, 10.0)) cbfun(integ) = (integ.u[1] += 1; nothing) cb = PeriodicCallback(cbfun, 1.0) - jprob = JumpProblem(dprob, crj; rng) + jprob = JumpProblem(dprob, crj) tstops = Float64[] # tests for when aliasing system is in place - #sol = solve(jprob; callback = cb, tstops, alias_tstops = true) + #sol = solve(jprob; callback = cb, tstops, alias_tstops = true) # @test sol[1, end] == 9 - #@test tstops == 1.0:9.0 + #@test tstops == 1.0:9.0 # empty!(tstops) # sol = solve(jprob; callback = cb, tstops, alias_tstops = false) # @test sol[1, end] == 9 # @test isempty(tstops) - sol = solve(jprob; callback = cb, tstops) + sol = solve(jprob; rng, callback = cb, tstops) @test sol[1, end] == 9 @test isempty(tstops) empty!(tstops) - integ = init(jprob, SSAStepper(); callback = cb, tstops) + integ = init(jprob, SSAStepper(); rng, callback = cb, tstops) solve!(integ) @test integ.tstops !== tstops @test isempty(tstops) @@ -184,18 +184,18 @@ let affect!(integrator) = (integrator.u[1] += 1) crj = ConstantRateJump(rate, affect!) prob = DiscreteProblem([0], (0.0, 10.0), [10.0]) - jprob = JumpProblem(prob, Direct(), crj; rng) + jprob = JumpProblem(prob, Direct(), crj) # basic callable tstops my_tstops = (p, tspan) -> [3.0, 6.0] - sol = solve(jprob, SSAStepper(); tstops = my_tstops) + sol = solve(jprob, SSAStepper(); rng, tstops = my_tstops) @test sol.t[end] == 10.0 @test 3.0 ∈ sol.t @test 6.0 ∈ sol.t # parameter-dependent callable tstops param_tstops = (p, tspan) -> [p[1] / 5.0, p[1] / 2.0] - sol2 = solve(jprob, SSAStepper(); tstops = param_tstops) + sol2 = solve(jprob, SSAStepper(); rng, tstops = param_tstops) @test sol2.t[end] == 10.0 @test 2.0 ∈ sol2.t # 10.0 / 5.0 @test 5.0 ∈ sol2.t # 10.0 / 2.0 @@ -204,7 +204,7 @@ let condition(u, t, integrator) = t == 3.0 cb_affect!(integrator) = (integrator.u[1] += 1000) cb = DiscreteCallback(condition, cb_affect!) - sol3 = solve(jprob, SSAStepper(); tstops = my_tstops, callback = cb) + sol3 = solve(jprob, SSAStepper(); rng, tstops = my_tstops, callback = cb) @test sol3.t[end] == 10.0 @test 3.0 ∈ sol3.t # verify the callback fired: use findlast to get post-callback state at t=3.0 @@ -213,15 +213,15 @@ let # callable returning a tuple tuple_tstops = (p, tspan) -> (2.0, 7.0) - sol4 = solve(jprob, SSAStepper(); tstops = tuple_tstops) + sol4 = solve(jprob, SSAStepper(); rng, tstops = tuple_tstops) @test sol4.t[end] == 10.0 @test 2.0 ∈ sol4.t @test 7.0 ∈ sol4.t # callable tstops stored in JumpProblem via constructor kwarg - jprob2 = JumpProblem(prob, Direct(), crj; rng, tstops = my_tstops) + jprob2 = JumpProblem(prob, Direct(), crj; tstops = my_tstops) @test haskey(jprob2.kwargs, :tstops) - sol5 = solve(jprob2, SSAStepper()) + sol5 = solve(jprob2, SSAStepper(); rng) @test sol5.t[end] == 10.0 @test 3.0 ∈ sol5.t @test 6.0 ∈ sol5.t diff --git a/test/ssa_tests.jl b/test/ssa_tests.jl index e82c50959..79bdbe15f 100644 --- a/test/ssa_tests.jl +++ b/test/ssa_tests.jl @@ -16,31 +16,30 @@ end jump2 = ConstantRateJump(rate, affect!) prob = DiscreteProblem([10.0], (0.0, 3.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump, jump2) -integrator = init(jump_prob, SSAStepper()) +integrator = init(jump_prob, SSAStepper(); rng) step!(integrator) integrator.u[1] # test different saving behaviors -sol = solve(jump_prob, SSAStepper()) +sol = solve(jump_prob, SSAStepper(); rng) @test SciMLBase.successful_retcode(sol) @test sol.t[begin] == 0.0 @test sol.t[end] == 3.0 -sol = solve(jump_prob, SSAStepper(), save_end = false) +sol = solve(jump_prob, SSAStepper(); save_end = false, rng) @test sol.t[begin] == 0.0 @test sol.t[end] < 3.0 -sol = solve(jump_prob, SSAStepper(), save_start = false) +sol = solve(jump_prob, SSAStepper(); save_start = false, rng) @test sol.t[begin] > 0.0 @test sol.t[end] == 3.0 -jump_prob = JumpProblem(prob, Direct(), jump, jump2, save_positions = (false, false); - rng = rng) -sol = solve(jump_prob, SSAStepper(), save_start = false, save_end = false) +jump_prob = JumpProblem(prob, Direct(), jump, jump2, save_positions = (false, false)) +sol = solve(jump_prob, SSAStepper(); save_start = false, save_end = false, rng) @test isempty(sol.t) && isempty(sol.u) -sol = solve(jump_prob, SSAStepper(), saveat = 0.0:0.1:2.9) +sol = solve(jump_prob, SSAStepper(); saveat = 0.0:0.1:2.9, rng) @test sol.t == collect(0.0:0.1:3.0) diff --git a/test/thread_safety.jl b/test/thread_safety.jl index 7356ce2fc..9f7c44918 100644 --- a/test/thread_safety.jl +++ b/test/thread_safety.jl @@ -1,7 +1,5 @@ using DiffEqBase, Test -using JumpProcesses, OrdinaryDiffEq -using StableRNGs -rng = StableRNG(12345) +using JumpProcesses, OrdinaryDiffEq, StochasticDiffEq sr = [1.0, 2.0, 50.0] maj = MassActionJump(sr, [[1 => 1], [1 => 1], [0 => 1]], [[1 => 1], [1 => -1], [1 => 1]]) @@ -9,10 +7,25 @@ params = (1.0, 2.0, 50.0) tspan = (0.0, 4.0) u0 = [5] dprob = DiscreteProblem(u0, tspan, params) -jprob = JumpProblem(dprob, Direct(), maj; rng = rng) -solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); trajectories = 10) -solve(EnsembleProblem(jprob; safetycopy = true), SSAStepper(), EnsembleThreads(); - trajectories = 10) +jprob = JumpProblem(dprob, Direct(), maj) + +# Verify threaded solves complete and produce distinct trajectories. +# NOTE: We intentionally do NOT pass `rng` here. In threaded ensembles, passing a +# shared rng object via `solve(...; rng=...)` does not yet provide correct +# per-trajectory stream handling. Until SciMLBase's ensemble RNG updates land +# (master rng -> per-trajectory rng), correctness in threaded contexts relies on +# task-local `Random.default_rng()`. +sol = solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); + trajectories = 400) +@test length(sol) == 400 +firstrx_time = [sol.u[i].t[findfirst(>(sol.u[i].t[1]), sol.u[i].t)] for i in 1:length(sol)] +@test allunique(firstrx_time) + +sol2 = solve(EnsembleProblem(jprob; safetycopy = true), SSAStepper(), EnsembleThreads(); + trajectories = 400) +@test length(sol2) == 400 +firstrx_time2 = [sol2.u[i].t[findfirst(>(sol2.u[i].t[1]), sol2.u[i].t)] for i in 1:length(sol2)] +@test allunique(firstrx_time2) # test for https://github.com/SciML/JumpProcesses.jl/issues/472 let @@ -34,3 +47,22 @@ let @test allunique(firstrx_time) end end + +# SDE + variable-rate jumps with EnsembleThreads +let + f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) + g!(du, u, p, t) = (du[1] = 0.1 * u[1]; nothing) + sde_prob = SDEProblem(f!, g!, [100.0], (0.0, 10.0)) + vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], + integrator -> (integrator.u[1] -= 1.0)) + + for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) + jump_prob = JumpProblem(sde_prob, Direct(), vrj; vr_aggregator = agg) + prob_func(prob, i, repeat) = deepcopy(prob) + prob = EnsembleProblem(jump_prob; prob_func) + sol = solve(prob, SRIW1(), EnsembleThreads(); + trajectories = 400, save_everystep = false) + firstrx_time = [sol.u[i].t[findfirst(>(sol.u[i].t[1]), sol.u[i].t)] for i in 1:length(sol)] + @test allunique(firstrx_time) + end +end diff --git a/test/variable_rate.jl b/test/variable_rate.jl index ce08781db..195479ec1 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -30,15 +30,15 @@ f = function (du, u, p, t) end prob = ODEProblem(f, [0.2], (0.0, 10.0)) -jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) -integrator = init(jump_prob, Tsit5()) -sol = solve(jump_prob, Tsit5()) -sol = solve(jump_prob, Rosenbrock23(autodiff = AutoFiniteDiff())) -sol = solve(jump_prob, Rosenbrock23()) - -jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng) -integrator = init(jump_prob_gill, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM()) +integrator = init(jump_prob, Tsit5(); rng) +sol = solve(jump_prob, Tsit5(); rng) +sol = solve(jump_prob, Rosenbrock23(autodiff = AutoFiniteDiff()); rng) +sol = solve(jump_prob, Rosenbrock23(); rng) + +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct()) +integrator = init(jump_prob_gill, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) sol_gill = solve(jump_prob, Rosenbrock23(autodiff = AutoFiniteDiff())) sol_gill = solve(jump_prob, Rosenbrock23()) @test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12 @@ -48,10 +48,10 @@ g = function (du, u, p, t) du[1] = u[1] end prob = SDEProblem(f, g, [0.2], (0.0, 10.0)) -jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng = rng) -sol = solve(jump_prob, SRIW1()) -jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng = rng) -sol_gill = solve(jump_prob_gill, SRIW1()) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM()) +sol = solve(jump_prob, SRIW1(); rng) +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct()) +sol_gill = solve(jump_prob_gill, SRIW1(); rng) @test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12 @test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1e-12 @@ -74,10 +74,10 @@ function affect_switch!(integrator) end jump_switch = VariableRateJump(rate_switch, affect_switch!) prob = SDEProblem(ff, gg, ones(2), (0.0, 1.0), 0, noise_rate_prototype = zeros(2, 2)) -jump_prob = JumpProblem(prob, jump_switch; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump_switch; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, SRA1(), dt = 1.0) -sol_gill = solve(jump_prob_gill, SRA1(), dt = 1.0) +jump_prob = JumpProblem(prob, jump_switch; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump_switch; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, SRA1(), dt = 1.0; rng) +sol_gill = solve(jump_prob_gill, SRA1(), dt = 1.0; rng) ## Some integration tests @@ -88,10 +88,10 @@ prob = ODEProblem(f2, [0.2], (0.0, 10.0)) rate2(u, p, t) = 2 affect2!(integrator) = (integrator.u[1] = integrator.u[1] / 2) jump = ConstantRateJump(rate2, affect2!) -jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) sol(4.0) sol.u[4] @@ -99,10 +99,10 @@ rate2b(u, p, t) = u[1] affect2b!(integrator) = (integrator.u[1] = integrator.u[1] / 2) jump = VariableRateJump(rate2b, affect2b!) jump2 = deepcopy(jump) -jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) sol(4.0) sol.u[4] @@ -110,10 +110,10 @@ function g2(du, u, p, t) du[1] = u[1] end prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0)) -jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, SRIW1()) -sol_gill = solve(jump_prob_gill, SRIW1()) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, SRIW1(); rng) +sol_gill = solve(jump_prob_gill, SRIW1(); rng) sol(4.0) sol.u[4] @@ -129,10 +129,10 @@ function affect3!(integrator) integrator.u[4] = 1) end jump = VariableRateJump(rate3, affect3!) -jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) # test for https://discourse.julialang.org/t/differentialequations-jl-package-variable-rate-jumps-with-complex-variables/80366/2 function f4(dx, x, p, t) @@ -146,10 +146,10 @@ jump = VariableRateJump(rate4, affect4!) x₀ = 1.0 + 0.0im Δt = (0.0, 6.0) prob = ODEProblem(f4, [x₀], Δt) -jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) # Out of place test drift(x, p, t) = p * x @@ -158,7 +158,7 @@ affect!2(integrator) = (integrator.u ./= 2; nothing) x0 = rand(2) prob = ODEProblem(drift, x0, (0.0, 10.0), 2.0) jump = VariableRateJump(rate2c, affect!2) -jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM(), rng) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM()) # test to check lack of dependency graphs is caught in Coevolve for systems with non-maj # jumps @@ -256,12 +256,12 @@ let d_jump = VariableRateJump(d_rate, death!) ode_prob = ODEProblem(ode_fxn, u0, tspan, p) - sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator = VR_FRM(), rng) + sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator = VR_FRM()) # After callback initialize, integrator.u.jump_u should have unique thresholds # that differ between sequential solves (RNG advances each time). jump_u_old = zeros(length(sjm_prob.prob.u0.jump_u)) for i in 1:Nsims - integrator = init(sjm_prob, Tsit5(); saveat = tspan[2]) + integrator = init(sjm_prob, Tsit5(); saveat = tspan[2], rng) @test allunique(integrator.u.jump_u) @test integrator.u.jump_u != jump_u_old jump_u_old .= integrator.u.jump_u @@ -273,9 +273,9 @@ end # https://github.com/SciML/JumpProcesses.jl/issues/320 # note that even with the seeded StableRNG this test is not # deterministic for some reason. -function getmean(Nsims, prob, alg, tsave, seed) +function getmean(Nsims, prob, alg, tsave, rng) umean = zeros(length(tsave)) - integrator = init(prob, alg; saveat = tsave, seed) + integrator = init(prob, alg; saveat = tsave, rng) solve!(integrator) for j in eachindex(umean) umean[j] += integrator.sol.u[j][1] @@ -292,8 +292,7 @@ function getmean(Nsims, prob, alg, tsave, seed) end let - seed = 12345 - rng = StableRNG(seed) + rng_seed = 12345 b = 2.0 d = 1.0 n0 = 1.0 @@ -325,12 +324,12 @@ let ode_prob = ODEProblem(ode_fxn, u0, tspan, p) tsave = range(tspan[1], tspan[2]; step = 0.1) for vr_aggregator in (VR_Direct(), VR_DirectFW(), VR_FRM()) - sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator, rng) + sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator) for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization())) - umean = getmean(Nsims, sjm_prob, alg, tsave, seed) + umean = getmean(Nsims, sjm_prob, alg, tsave, StableRNG(rng_seed)) @test all(abs.(umean .- n.(tsave)) .< 0.05 * n.(tsave)) - seed += Nsims + rng_seed += Nsims end end end @@ -340,10 +339,10 @@ end # Function to run ensemble and compute statistics function run_ensemble(prob, alg, jumps...; vr_aggregator = VR_FRM(), Nsims = 8000) rng = StableRNG(12345) - jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator, rng) + jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator) total = 0.0 for i in 1:Nsims - sol = solve(jump_prob, alg; save_everystep = false) + sol = solve(jump_prob, alg; save_everystep = false, rng) total += sol.u[end][1] end return total / Nsims @@ -439,10 +438,10 @@ let jump_counts = zeros(Int, Nsims) p = [0.0, 0.0, 0] prob = ODEProblem(f, u0, tspan, p) - jump_prob = JumpProblem(prob, Direct(), birth_jump, death_jump; vr_aggregator, rng) + jump_prob = JumpProblem(prob, Direct(), birth_jump, death_jump; vr_aggregator) for i in 1:Nsims - sol = solve(jump_prob, Tsit5(); save_everystep = false) + sol = solve(jump_prob, Tsit5(); save_everystep = false, rng) jump_counts[i] = jump_prob.prob.p[3] jump_prob.prob.p[3] = 0 end @@ -497,8 +496,8 @@ let ] for agg in aggregators - local jprob = JumpProblem(prob, agg, maj, vrj; rng) - local sol = solve(jprob, Tsit5()) + local jprob = JumpProblem(prob, agg, maj, vrj) + local sol = solve(jprob, Tsit5(); rng) @test SciMLBase.successful_retcode(sol) # Verify conservation: total population should be conserved @test sol.u[end][1] + sol.u[end][2] ≈ u0[1] + u0[2] @@ -528,8 +527,8 @@ let prob = ODEProblem(f!, u0, tspan) # Test with Direct aggregator (most common case) - jprob = JumpProblem(prob, Direct(), crj, vrj; rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), crj, vrj) + sol = solve(jprob, Tsit5(); rng) @test SciMLBase.successful_retcode(sol) @test sol.u[end][1] + sol.u[end][2] ≈ u0[1] + u0[2] end @@ -560,8 +559,8 @@ let # Test RSSA and RSSACR specifically (the aggregators that had the bug) for agg in [RSSA(), RSSACR()] - local jprob = JumpProblem(prob, agg, maj, vrj1, vrj2; rng) - local sol = solve(jprob, Tsit5()) + local jprob = JumpProblem(prob, agg, maj, vrj1, vrj2) + local sol = solve(jprob, Tsit5(); rng) @test SciMLBase.successful_retcode(sol) @test sol.u[end][1] + sol.u[end][2] ≈ u0[1] + u0[2] end