diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 959ad88a6..000000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -style = "sciml" -format_markdown = true -format_docstrings = true \ No newline at end of file diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index c240796cc..6762c6f3e 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -1,13 +1,19 @@ -name: "Format Check" +name: format-check on: push: branches: - 'master' + - 'main' + - 'release-' tags: '*' pull_request: jobs: - format-check: - name: "Format Check" - uses: "SciML/.github/.github/workflows/format-check.yml@v1" + runic: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: fredrikekre/runic-action@v1 + with: + version: '1' diff --git a/benchmarks/extended_jump_array.jl b/benchmarks/extended_jump_array.jl index 0de459a3f..4e7cf2a16 100644 --- a/benchmarks/extended_jump_array.jl +++ b/benchmarks/extended_jump_array.jl @@ -10,11 +10,11 @@ benchmark_out = ExtendedJumpArray(zeros(500000), zeros(500000)) benchmark_in = ExtendedJumpArray(rand(rng, 500000), rand(rng, 500000)) function test_single_dot(out, array) - @inbounds @. out = array + 1.23 * array + return @inbounds @. out = array + 1.23 * array end function test_double_dot(out, array) - @inbounds @.. out = array + 1.23 * array + return @inbounds @.. out = array + 1.23 * array end println("Base-case normal broadcasting") diff --git a/docs/make.jl b/docs/make.jl index 69988bd2a..d3331f2e1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,23 +11,33 @@ cp(joinpath(docpath, "Project.toml"), joinpath(assetpath, "Project.toml"), force include("pages.jl") -mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/require", "[tex]/mathtools"]), :tex => Dict("inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], - "packages" => [ - "base", - "ams", - "autoload", - "mathtools", - "require" - ]))) +mathengine = MathJax3( + Dict( + :loader => Dict("load" => ["[tex]/require", "[tex]/mathtools"]), :tex => Dict( + "inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], + "packages" => [ + "base", + "ams", + "autoload", + "mathtools", + "require", + ] + ) + ) +) -makedocs(sitename = "JumpProcesses.jl", authors = "Chris Rackauckas", modules = [JumpProcesses], +makedocs( + sitename = "JumpProcesses.jl", authors = "Chris Rackauckas", modules = [JumpProcesses], clean = true, doctest = false, linkcheck = true, warnonly = [:missing_docs], - format = Documenter.HTML(; assets = ["assets/favicon.ico"], + format = Documenter.HTML(; + assets = ["assets/favicon.ico"], canonical = "https://docs.sciml.ai/JumpProcesses/", prettyurls = (get(ENV, "CI", nothing) == "true"), mathengine, edit_link = "master", - repolink = "https://github.com/SciML/JumpProcesses.jl"), - pages = pages) + repolink = "https://github.com/SciML/JumpProcesses.jl" + ), + pages = pages +) deploydocs(repo = "github.com/SciML/JumpProcesses.jl.git"; push_preview = true) diff --git a/docs/pages.jl b/docs/pages.jl index 1f973f241..07c46f4cd 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,14 +1,19 @@ # Put in a separate page so it can be used by SciMLDocs.jl -pages = ["index.md", - "Tutorials" => Any["tutorials/simple_poisson_process.md", +pages = [ + "index.md", + "Tutorials" => Any[ + "tutorials/simple_poisson_process.md", "tutorials/discrete_stochastic_example.md", "tutorials/point_process_simulation.md", "tutorials/jump_diffusion.md", - "tutorials/spatial.md"], + "tutorials/spatial.md", + ], "Applications" => Any["applications/advanced_point_process.md"], - "Type Documentation" => Any["Jumps, JumpProblem, and Aggregators" => "jump_types.md", - "Jump solvers" => "jump_solve.md"], + "Type Documentation" => Any[ + "Jumps, JumpProblem, and Aggregators" => "jump_types.md", + "Jump solvers" => "jump_solve.md", + ], "FAQ" => "faq.md", - "API" => "api.md" + "API" => "api.md", ] diff --git a/ext/JumpProcessesKernelAbstractionsExt.jl b/ext/JumpProcessesKernelAbstractionsExt.jl index 2b345ebc0..8faba6838 100644 --- a/ext/JumpProcessesKernelAbstractionsExt.jl +++ b/ext/JumpProcessesKernelAbstractionsExt.jl @@ -5,20 +5,24 @@ using KernelAbstractions, Adapt using StaticArrays using PoissonRandom, Random -function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem, +function SciMLBase.__solve( + ensembleprob::SciMLBase.AbstractEnsembleProblem, alg::SimpleTauLeaping, ensemblealg::EnsembleGPUKernel; trajectories, seed = nothing, dt = error("dt is required for SimpleTauLeaping."), - kwargs...) + kwargs... + ) if trajectories == 1 - return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); trajectories = 1, - seed, dt, kwargs...) + return SciMLBase.__solve( + ensembleprob, alg, EnsembleSerial(); trajectories = 1, + seed, dt, kwargs... + ) end ensemblealg.backend === nothing ? backend = CPU() : - backend = ensemblealg.backend + backend = ensemblealg.backend jump_prob = ensembleprob.prob @@ -31,36 +35,42 @@ function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem, # Run vectorized solve ts, - us = vectorized_solve( - probs, jump_prob, SimpleTauLeaping(); backend, trajectories, seed, dt) + us = vectorized_solve( + probs, jump_prob, SimpleTauLeaping(); backend, trajectories, seed, dt + ) # Convert to CPU for inspection _ts = Array(ts) _us = Array(us) - time = @elapsed sol = [begin - ts = @view _ts[:, i] - us = @view _us[:, :, i] - sol_idx = findlast(x -> x != probs[i].prob.tspan[1], ts) - if sol_idx === nothing - @error "No solution found" tspan=probs[i].tspan[1] ts - error("Batch solve failed") - end - @views ensembleprob.output_func( - SciMLBase.build_solution(probs[i].prob, - alg, - ts[1:sol_idx], - [us[j, :] for j in 1:sol_idx], - k = nothing, - stats = nothing, - calculate_error = false, - retcode = sol_idx != - length(ts) ? - ReturnCode.Terminated : - ReturnCode.Success), - i)[1] - end - for i in eachindex(probs)] + time = @elapsed sol = [ + begin + ts = @view _ts[:, i] + us = @view _us[:, :, i] + sol_idx = findlast(x -> x != probs[i].prob.tspan[1], ts) + if sol_idx === nothing + @error "No solution found" tspan = probs[i].tspan[1] ts + error("Batch solve failed") + end + @views ensembleprob.output_func( + SciMLBase.build_solution( + probs[i].prob, + alg, + ts[1:sol_idx], + [us[j, :] for j in 1:sol_idx], + k = nothing, + stats = nothing, + calculate_error = false, + retcode = sol_idx != + length(ts) ? + ReturnCode.Terminated : + ReturnCode.Success + ), + i + )[1] + end + for i in eachindex(probs) + ] return SciMLBase.EnsembleSolution(sol, time, true) end @@ -82,7 +92,8 @@ end @kernel function simple_tau_leaping_kernel( @Const(probs_data), _us, _ts, dt, @Const(rj_data), current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, - seed::UInt64) + seed::UInt64 + ) i = @index(Global, Linear) # Get thread-local buffers @@ -114,7 +125,7 @@ end # Get input/output arrays ts_view = @inbounds view(_ts, :, i) - us_view = @inbounds view(_us,:,:,i) + us_view = @inbounds view(_us, :, :, i) # Initialize first time step and state @inbounds ts_view[1] = tspan[1] @@ -124,7 +135,7 @@ end # Main loop for j in 2:n - tprev = tspan[1] + (j-2) * dt + tprev = tspan[1] + (j - 2) * dt # Compute rates and scale by dt rate(rate_cache, current_u, p, tprev) @@ -143,20 +154,24 @@ end @inbounds for k in 1:state_dim us_view[j, k] = current_u[k] end - @inbounds ts_view[j] = tspan[1] + (j-1) * dt + @inbounds ts_view[j] = tspan[1] + (j - 1) * dt end end # Vectorized solve function -function vectorized_solve(probs, prob::JumpProblem, alg::SimpleTauLeaping; - backend, trajectories, seed, dt, kwargs...) +function vectorized_solve( + probs, prob::JumpProblem, alg::SimpleTauLeaping; + backend, trajectories, seed, dt, kwargs... + ) # Extract common jump data rj = prob.regular_jump rj_data = JumpData(rj.rate, rj.c, rj.numjumps) # Extract trajectory-specific data without static typing - probs_data = [TrajectoryData(SA{eltype(p.prob.u0)}[p.prob.u0...], p.prob.p, p.prob.tspan) - for p in probs] + probs_data = [ + TrajectoryData(SA{eltype(p.prob.u0)}[p.prob.u0...], p.prob.p, p.prob.tspan) + for p in probs + ] # Adapt to GPU probs_data_gpu = adapt(backend, probs_data) @@ -197,13 +212,15 @@ function vectorized_solve(probs, prob::JumpProblem, alg::SimpleTauLeaping; KernelAbstractions.synchronize(backend) # Seed for Poisson sampling - seed = seed === nothing ? UInt64(12345) : UInt64(seed); + seed = seed === nothing ? UInt64(12345) : UInt64(seed) # Launch main kernel kernel = simple_tau_leaping_kernel(backend) - main_event = kernel(probs_data_gpu, us, ts, dt, rj_data_gpu, + main_event = kernel( + probs_data_gpu, us, ts, dt, rj_data_gpu, current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, seed; - ndrange = n_trajectories) + ndrange = n_trajectories + ) KernelAbstractions.synchronize(backend) return ts, us diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index cf468777d..98ab22c36 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -33,10 +33,10 @@ import SymbolicIndexingInterface as SII # Import additional types and functions from DiffEqBase and SciMLBase using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction, - DDEFunction, DiscreteProblem, ODEFunction, ODEProblem, - ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!, - deleteat!, isinplace, remake, savevalues!, step!, - u_modified! + DDEFunction, DiscreteProblem, ODEFunction, ODEProblem, + ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!, + deleteat!, isinplace, remake, savevalues!, step!, + u_modified! using SciMLBase: SciMLBase, DEIntegrator abstract type AbstractJump end @@ -44,7 +44,7 @@ abstract type AbstractMassActionJump <: AbstractJump end abstract type AbstractAggregatorAlgorithm end abstract type AbstractJumpAggregator end abstract type AbstractSSAIntegrator{Alg, IIP, U, T} <: - DEIntegrator{Alg, IIP, U, T} end +DEIntegrator{Alg, IIP, U, T} end const DEFAULT_RNG = Random.default_rng() @@ -126,7 +126,7 @@ export init, solve, solve! include("SSA_stepper.jl") export SSAStepper -# leaping: +# leaping: include("simple_regular_solve.jl") export SimpleTauLeaping, EnsembleGPUKernel diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index ed3957a83..4a26da53a 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -68,7 +68,7 @@ Solution objects for pure jump problems solved via `SSAStepper`. $(FIELDS) """ mutable struct SSAIntegrator{F, uType, tType, tdirType, P, S, CB, SA, OPT, TS} <: - AbstractSSAIntegrator{SSAStepper, Nothing, uType, tType} + AbstractSSAIntegrator{SSAStepper, Nothing, uType, tType} """The underlying `prob.f` function. Not currently used.""" f::F """The current solution values.""" @@ -113,14 +113,14 @@ end (integrator::SSAIntegrator)(out, t) = (out .= integrator.u) function DiffEqBase.u_modified!(integrator::SSAIntegrator, bool::Bool) - integrator.u_modified = bool + return integrator.u_modified = bool end function DiffEqBase.__solve(jump_prob::JumpProblem, alg::SSAStepper; kwargs...) # init will handle kwargs merging via init_call integrator = init(jump_prob, alg; kwargs...) solve!(integrator) - integrator.sol + return integrator.sol end function DiffEqBase.solve!(integrator::SSAIntegrator) @@ -135,14 +135,16 @@ function DiffEqBase.solve!(integrator::SSAIntegrator) # check callbacks one last time if !(integrator.opts.callback.discrete_callbacks isa Tuple{}) - DiffEqBase.apply_discrete_callback!(integrator, - integrator.opts.callback.discrete_callbacks...) + DiffEqBase.apply_discrete_callback!( + integrator, + integrator.opts.callback.discrete_callbacks... + ) end if integrator.saveat !== nothing && !isempty(integrator.saveat) # Split to help prediction while integrator.cur_saveat <= length(integrator.saveat) && - integrator.saveat[integrator.cur_saveat] < integrator.t + integrator.saveat[integrator.cur_saveat] < integrator.t push!(integrator.sol.t, integrator.saveat[integrator.cur_saveat]) push!(integrator.sol.u, copy(integrator.u)) integrator.cur_saveat += 1 @@ -160,7 +162,7 @@ function DiffEqBase.solve!(integrator::SSAIntegrator) SciMLBase.save_final_discretes!(integrator, integrator.opts.callback) end - if integrator.sol.retcode === ReturnCode.Default + return if integrator.sol.retcode === ReturnCode.Default integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, ReturnCode.Success) end end @@ -177,22 +179,27 @@ function check_continuous_callback_error(callback) end if callback isa DiffEqBase.ContinuousCallback - error("SSAStepper does not support continuous callbacks. Only DiscreteCallbacks " * - "are supported for event detection with SSAStepper. Please use an ODE/SDE " * - "solver (e.g., Tsit5()) if you need continuous event detection.") + error( + "SSAStepper does not support continuous callbacks. Only DiscreteCallbacks " * + "are supported for event detection with SSAStepper. Please use an ODE/SDE " * + "solver (e.g., Tsit5()) if you need continuous event detection." + ) elseif callback isa DiffEqBase.CallbackSet n_continuous = length(callback.continuous_callbacks) if n_continuous > 0 - error("SSAStepper does not support continuous callbacks (found $n_continuous " * - "continuous callback$(n_continuous > 1 ? "s" : "")). Only DiscreteCallbacks " * - "are supported for event detection with SSAStepper. Please use an ODE/SDE " * - "solver (e.g., Tsit5()) if you need continuous event detection.") + error( + "SSAStepper does not support continuous callbacks (found $n_continuous " * + "continuous callback$(n_continuous > 1 ? "s" : "")). Only DiscreteCallbacks " * + "are supported for event detection with SSAStepper. Please use an ODE/SDE " * + "solver (e.g., Tsit5()) if you need continuous event detection." + ) end end return nothing end -function DiffEqBase.__init(jump_prob::JumpProblem, +function DiffEqBase.__init( + jump_prob::JumpProblem, alg::SSAStepper; save_start = true, save_end = true, @@ -201,7 +208,8 @@ function DiffEqBase.__init(jump_prob::JumpProblem, saveat = nothing, callback = nothing, tstops = nothing, - numsteps_hint = 100) + numsteps_hint = 100 + ) # hack until alias system is in place alias_tstops = false @@ -213,8 +221,10 @@ function DiffEqBase.__init(jump_prob::JumpProblem, # Check for continuous callbacks in the jump system isempty(jump_prob.jump_callback.continuous_callbacks) || - error("SSAStepper does not support continuous callbacks in the jump system. " * - "Please use an ODE/SDE solver over ODE or SDE problems instead.") + error( + "SSAStepper does not support continuous callbacks in the jump system. " * + "Please use an ODE/SDE solver over ODE or SDE problems instead." + ) # Check for continuous callbacks passed via kwargs (from JumpProblem constructor or solve) check_continuous_callback_error(callback) @@ -243,10 +253,12 @@ function DiffEqBase.__init(jump_prob::JumpProblem, end save_everystep = any(cb.save_positions) - sol = DiffEqBase.build_solution(prob, alg, t, u, dense = save_everystep, + sol = DiffEqBase.build_solution( + prob, alg, t, u, dense = save_everystep, calculate_error = false, stats = DiffEqBase.Stats(0), - interp = DiffEqBase.ConstantInterpolation(t, u)) + interp = DiffEqBase.ConstantInterpolation(t, u) + ) _saveat = (saveat isa Number) ? (prob.tspan[1]:saveat:prob.tspan[2]) : saveat if _saveat !== nothing && !isempty(_saveat) && _saveat[1] == prob.tspan[1] @@ -277,19 +289,21 @@ function DiffEqBase.__init(jump_prob::JumpProblem, _tstops = tstops end - integrator = SSAIntegrator(prob.f, copy(prob.u0), prob.tspan[1], prob.tspan[1], tdir, + integrator = SSAIntegrator( + prob.f, copy(prob.u0), prob.tspan[1], prob.tspan[1], tdir, prob.p, sol, 1, prob.tspan[1], cb, _saveat, save_everystep, - save_end, cur_saveat, opts, _tstops, 1, false, true, alias_tstops, false) + save_end, cur_saveat, opts, _tstops, 1, false, true, alias_tstops, false + ) cb.initialize(cb, integrator.u, prob.tspan[1], integrator) DiffEqBase.initialize!(opts.callback, integrator.u, prob.tspan[1], integrator) if save_start SciMLBase.save_discretes_if_enabled!(integrator, opts.callback; skip_duplicates = true) end - integrator + return integrator end function DiffEqBase.get_tstops(integrator::SSAIntegrator) - @view integrator.tstops[(integrator.tstops_idx):end] + return @view integrator.tstops[(integrator.tstops_idx):end] end DiffEqBase.get_tstops_array(integrator::SSAIntegrator) = DiffEqBase.get_tstops(integrator) @@ -318,26 +332,28 @@ function DiffEqBase.add_tstop!(integrator::SSAIntegrator, tstop) Base.insert!(integrator.tstops, insert_index, tstop) end - nothing + return nothing end # The Jump aggregators should not register the next jump through add_tstop! for SSAIntegrator # such that we can achieve maximum performance -@inline function register_next_jump_time!(integrator::SSAIntegrator, - p::AbstractSSAJumpAggregator, t) +@inline function register_next_jump_time!( + integrator::SSAIntegrator, + p::AbstractSSAJumpAggregator, t + ) integrator.tstop = p.next_jump_time - nothing + return nothing end function DiffEqBase.step!(integrator::SSAIntegrator) integrator.tprev = integrator.t next_jump_time = integrator.tstop > integrator.t ? integrator.tstop : - typemax(integrator.tstop) + typemax(integrator.tstop) doaffect = false if !isempty(integrator.tstops) && - integrator.tstops_idx <= length(integrator.tstops) && - integrator.tstops[integrator.tstops_idx] < next_jump_time + integrator.tstops_idx <= length(integrator.tstops) && + integrator.tstops[integrator.tstops_idx] < next_jump_time integrator.t = integrator.tstops[integrator.tstops_idx] integrator.tstops_idx += 1 else @@ -348,7 +364,7 @@ function DiffEqBase.step!(integrator::SSAIntegrator) @inbounds if integrator.saveat !== nothing && !isempty(integrator.saveat) # Split to help prediction while integrator.cur_saveat < length(integrator.saveat) && - integrator.saveat[integrator.cur_saveat] < integrator.t + integrator.saveat[integrator.cur_saveat] < integrator.t saved = true push!(integrator.sol.t, integrator.saveat[integrator.cur_saveat]) push!(integrator.sol.u, copy(integrator.u)) @@ -368,15 +384,17 @@ function DiffEqBase.step!(integrator::SSAIntegrator) if !(integrator.opts.callback.discrete_callbacks isa Tuple{}) discrete_modified, - saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator, - integrator.opts.callback.discrete_callbacks...) + saved_in_cb = DiffEqBase.apply_discrete_callback!( + integrator, + integrator.opts.callback.discrete_callbacks... + ) else saved_in_cb = false end !saved_in_cb && jump_modified_u && savevalues!(integrator) - nothing + return nothing end function DiffEqBase.savevalues!(integrator::SSAIntegrator, force = false) @@ -392,7 +410,7 @@ function DiffEqBase.savevalues!(integrator::SSAIntegrator, force = false) push!(integrator.sol.u, copy(integrator.u)) end - saved, savedexactly + return saved, savedexactly end function should_continue_solve(integrator::SSAIntegrator) @@ -400,29 +418,32 @@ function should_continue_solve(integrator::SSAIntegrator) # we continue the solve if there is a tstop between now and end_time has_tstop = !isempty(integrator.tstops) && - integrator.tstops_idx <= length(integrator.tstops) && - integrator.tstops[integrator.tstops_idx] < end_time + integrator.tstops_idx <= length(integrator.tstops) && + integrator.tstops[integrator.tstops_idx] < end_time # we continue the solve if there will be a jump between now and end_time has_jump = integrator.t < integrator.tstop < end_time - integrator.keep_stepping && (has_jump || has_tstop) + return integrator.keep_stepping && (has_jump || has_tstop) end function reset_aggregated_jumps!(integrator::SSAIntegrator, uprev = nothing) reset_aggregated_jumps!(integrator, uprev, integrator.cb) - nothing + return nothing end function DiffEqBase.terminate!(integrator::SSAIntegrator, retcode = ReturnCode.Terminated) integrator.keep_stepping = false integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, retcode) - nothing + return nothing end -function SciMLBase.isdenseplot(sol::ODESolution{ - T, N, uType, uType2, DType, tType, rateType, discType, P, - SSAStepper}) where {T, N, uType, uType2, DType, tType, rateType, discType, P} - sol.dense +function SciMLBase.isdenseplot( + sol::ODESolution{ + T, N, uType, uType2, DType, tType, rateType, discType, P, + SSAStepper, + } + ) where {T, N, uType, uType2, DType, tType, rateType, discType, P} + return sol.dense end diff --git a/src/aggregators/aggregated_api.jl b/src/aggregators/aggregated_api.jl index a65c4ad82..0fe84508b 100644 --- a/src/aggregators/aggregated_api.jl +++ b/src/aggregators/aggregated_api.jl @@ -10,47 +10,63 @@ Notes MassActionJump that was built from the parameter vector. If the parameter vector is unchanged, this can safely be set to false to improve performance. """ -function reset_aggregated_jumps!(integrator, uprev = nothing; update_jump_params = true, - kwargs...) - reset_aggregated_jumps!(integrator, uprev, integrator.opts.callback, - update_jump_params = update_jump_params, kwargs...) - nothing +function reset_aggregated_jumps!( + integrator, uprev = nothing; update_jump_params = true, + kwargs... + ) + reset_aggregated_jumps!( + integrator, uprev, integrator.opts.callback, + update_jump_params = update_jump_params, kwargs... + ) + return nothing end -function reset_aggregated_jumps!(integrator, uprev, callback::Nothing; - update_jump_params = true, kwargs...) - nothing +function reset_aggregated_jumps!( + integrator, uprev, callback::Nothing; + update_jump_params = true, kwargs... + ) + return nothing end -function reset_aggregated_jumps!(integrator, uprev, callback::CallbackSet; - update_jump_params = true, kwargs...) +function reset_aggregated_jumps!( + integrator, uprev, callback::CallbackSet; + update_jump_params = true, kwargs... + ) if !isempty(callback.discrete_callbacks) - reset_aggregated_jumps!(integrator, uprev, callback.discrete_callbacks..., - update_jump_params = update_jump_params, kwargs...) + reset_aggregated_jumps!( + integrator, uprev, callback.discrete_callbacks..., + update_jump_params = update_jump_params, kwargs... + ) end - nothing + return nothing end -function reset_aggregated_jumps!(integrator, uprev, cb::DiscreteCallback, cbs...; - update_jump_params = true, kwargs...) +function reset_aggregated_jumps!( + integrator, uprev, cb::DiscreteCallback, cbs...; + update_jump_params = true, kwargs... + ) if cb.condition isa AbstractSSAJumpAggregator maj = cb.condition.ma_jumps update_jump_params && using_params(maj) && update_parameters!(cb.condition.ma_jumps, integrator.p; kwargs...) cb.condition(cb, integrator.u, integrator.t, integrator) end - reset_aggregated_jumps!(integrator, uprev, cbs...; - update_jump_params = update_jump_params, kwargs...) - nothing + reset_aggregated_jumps!( + integrator, uprev, cbs...; + update_jump_params = update_jump_params, kwargs... + ) + return nothing end -function reset_aggregated_jumps!(integrator, uprev, cb::DiscreteCallback; - update_jump_params = true, kwargs...) +function reset_aggregated_jumps!( + integrator, uprev, cb::DiscreteCallback; + update_jump_params = true, kwargs... + ) if cb.condition isa AbstractSSAJumpAggregator maj = cb.condition.ma_jumps update_jump_params && using_params(maj) && update_parameters!(cb.condition.ma_jumps, integrator.p; kwargs...) cb.condition(cb, integrator.u, integrator.t, integrator) end - nothing + return nothing end diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index b8dd8dd5b..8a2d69e2a 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -172,8 +172,10 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108 """ struct DirectCRDirect <: AbstractAggregatorAlgorithm end -const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), - FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), CCNRM()) +const JUMP_AGGREGATORS = ( + Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), + FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), CCNRM(), +) # For JumpProblem construction without an aggregator struct NullAggregator <: AbstractAggregatorAlgorithm end @@ -204,9 +206,11 @@ is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true # return the fastest aggregator out of the available ones -function select_aggregator(jumps::JumpSet; vartojumps_map = nothing, +function select_aggregator( + jumps::JumpSet; vartojumps_map = nothing, jumptovars_map = nothing, dep_graph = nothing, spatial_system = nothing, - hopping_constants = nothing) + hopping_constants = nothing + ) # detect if a spatial SSA should be used !isnothing(spatial_system) && !isnothing(hopping_constants) && return DirectCRDirect diff --git a/src/aggregators/bracketing.jl b/src/aggregators/bracketing.jl index e20776e2f..f661217ba 100644 --- a/src/aggregators/bracketing.jl +++ b/src/aggregators/bracketing.jl @@ -26,7 +26,7 @@ BracketData{T1, T2}() where {T1, T2} = BracketData(T1(0.1), T2(25), T2(4)) @inline getΔu(bd::BracketData{T1, T2}, i) where {T1, T2 <: Number} = bd.Δu @inline function delta_bracket(u::Integer, δ) - (trunc(typeof(u), (one(δ) - δ) * u), trunc(typeof(u), (one(δ) + δ) * u)) + return (trunc(typeof(u), (one(δ) - δ) * u), trunc(typeof(u), (one(δ) + δ) * u)) end @inline delta_bracket(u, δ) = ((one(δ) - δ) * u), ((one(δ) + δ) * u) @@ -48,7 +48,7 @@ end # Get propensity brackets of massaction jump k. @inline function get_majump_brackets(ulow, uhigh, k, majumps) - evalrxrate(ulow, k, majumps), evalrxrate(uhigh, k, majumps) + return evalrxrate(ulow, k, majumps), evalrxrate(uhigh, k, majumps) end # for constant rate jumps we must check the ordering of the bracket values @@ -68,8 +68,10 @@ get brackets for the rate of reaction rx by first checking if the reaction is a if rx <= num_majumps return get_majump_brackets(p.ulow, p.uhigh, rx, ma_jumps) else - @inbounds return get_cjump_brackets(p.ulow, p.uhigh, p.rates[rx - num_majumps], - params, t) + @inbounds return get_cjump_brackets( + p.ulow, p.uhigh, p.rates[rx - num_majumps], + params, t + ) end end @@ -79,7 +81,7 @@ end @inbounds for (i, uval) in enumerate(u) ulow[i], uhigh[i] = get_spec_brackets(p.bracket_data, i, uval) end - nothing + return nothing end @inline function update_u_brackets!(p::AbstractSSAJumpAggregator, u::SVector) @@ -88,7 +90,7 @@ end p.ulow = setindex(p.ulow, ulow, i) p.uhigh = setindex(p.uhigh, uhigh, i) end - nothing + return nothing end # Set up bracketing. The aggregator must have fields @@ -117,5 +119,5 @@ function set_bracketing!(p::AbstractSSAJumpAggregator, u, params, t) end p.sum_rate = sum_rate - nothing + return nothing end diff --git a/src/aggregators/ccnrm.jl b/src/aggregators/ccnrm.jl index 379ad642a..d6446e867 100644 --- a/src/aggregators/ccnrm.jl +++ b/src/aggregators/ccnrm.jl @@ -4,7 +4,7 @@ # (2015). doi: 10.1063/1.4928635. mutable struct CCNRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PT} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::Int prev_jump::Int next_jump_time::T @@ -20,10 +20,12 @@ mutable struct CCNRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PT} <: ptt::PT end -function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, +function CCNRMJumpAggregation( + nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; num_specs, dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + kwargs... + ) where {T, S, F1, F2, RNG} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -41,28 +43,35 @@ function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, binwidthconst = haskey(kwargs, :binwidthconst) ? kwargs[:binwidthconst] : 16 numbinsconst = haskey(kwargs, :numbinsconst) ? kwargs[:numbinsconst] : 20 - ptt = PriorityTimeTable(zeros(T, length(crs)), zero(T), one(T), - binwidthconst = binwidthconst, numbinsconst = numbinsconst) # We will re-initialize this in initialize!() + ptt = PriorityTimeTable( + zeros(T, length(crs)), zero(T), one(T), + binwidthconst = binwidthconst, numbinsconst = numbinsconst + ) # We will re-initialize this in initialize!() affecttype = F2 <: Tuple ? F2 : Any - CCNRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(ptt)}( + return CCNRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(ptt)}( nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, - rng, dg, ptt) + rng, dg, ptt + ) end -+############################# Required Functions ############################## ++ ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::CCNRM, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::CCNRM, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation(CCNRMJumpAggregation, u, p, t, end_time, ma_jumps, + return build_jump_aggregation( + CCNRMJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; num_specs = length(u), - kwargs...) + kwargs... + ) end # set up a new simulation and calculate the first jump / jump time @@ -70,7 +79,7 @@ function initialize!(p::CCNRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] initialize_rates_and_times!(p, u, params, t) generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end # execute one jump, changing the system state @@ -80,7 +89,7 @@ function execute_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t, affec # update current jump rates and times update_dependent_rates!(p, u, params, t) - nothing + return nothing end # calculate the next jump / jump time @@ -88,7 +97,7 @@ end function generate_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t) p.next_jump, p.next_jump_time = getfirst(p.ptt) - # Rebuild the table if no next jump is found. + # Rebuild the table if no next jump is found. if p.next_jump == 0 timestep = 1 / sum(p.cur_rates) min_time = minimum(p.ptt.times) @@ -96,7 +105,7 @@ function generate_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t) p.next_jump, p.next_jump_time = getfirst(p.ptt) end - nothing + return nothing end ######################## SSA specific helper routines ######################## @@ -113,8 +122,10 @@ function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t) oldtime = times[rx] # update the jump rate - @inbounds cur_rates[rx] = calculate_jump_rate(ma_jumps, num_majumps, rates, u, - params, t, rx) + @inbounds cur_rates[rx] = calculate_jump_rate( + ma_jumps, num_majumps, rates, u, + params, t, rx + ) # Calculate new jump times for dependent jumps if rx != p.next_jump && oldrate > zero(oldrate) @@ -131,10 +142,10 @@ function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t) end end end - nothing + return nothing end -# Evaluate all the rates and initialize the times in the priority table. +# Evaluate all the rates and initialize the times in the priority table. function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t) # Initialize next-reaction times for the mass action jumps majumps = p.ma_jumps @@ -154,9 +165,9 @@ function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t) idx += 1 end - # Build the priority time table with the times and bin width. + # Build the priority time table with the times and bin width. timestep = 1 / sum(cur_rates) p.ptt.times = pttdata rebuild!(p.ptt, t, timestep) - nothing + return nothing end diff --git a/src/aggregators/coevolve.jl b/src/aggregators/coevolve.jl index dc48a5e21..5bfff6fc0 100644 --- a/src/aggregators/coevolve.jl +++ b/src/aggregators/coevolve.jl @@ -2,7 +2,7 @@ Queue method. This method handles variable intensity rates. """ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::Int # the next jump to execute prev_jump::Int # the previous jump that was executed next_jump_time::T # the time of the next jump @@ -23,11 +23,13 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: cur_lrates::Vector{T} # the last computed lower rate for each rate end -function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Nothing, +function CoevolveJumpAggregation( + nj::Int, njt::T, et::T, crs::Vector{T}, sr::Nothing, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; u::U, dep_graph = nothing, lrates, urates, rateintervals, haslratevec, - cur_lrates::Vector{T}) where {T, S, F1, F2, RNG, U} + cur_lrates::Vector{T} + ) where {T, S, F1, F2, RNG, U} if dep_graph === nothing if (get_num_majumps(maj) == 0) || !isempty(urates) error("To use Coevolve a dependency graph between jumps must be supplied.") @@ -36,8 +38,10 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not end else # using a Set to ensure that edges are not duplicate - dgsets = [Set{Int}(append!(Int[], jumps, [var])) - for (var, jumps) in enumerate(dep_graph)] + dgsets = [ + Set{Int}(append!(Int[], jumps, [var])) + for (var, jumps) in enumerate(dep_graph) + ] dg = [sort!(collect(i)) for i in dgsets] end @@ -49,27 +53,33 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not pq = MutableBinaryMinHeap{T}() affecttype = F2 <: Tuple ? F2 : Any - CoevolveJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), - typeof(pq)}(nj, nj, njt, et, crs, sr, maj, + return CoevolveJumpAggregation{ + T, S, F1, affecttype, RNG, typeof(dg), + typeof(pq), + }( + nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng, dg, pq, lrates, urates, rateintervals, - haslratevec, cur_lrates) + haslratevec, cur_lrates + ) end # display function num_constant_rate_jumps(aggregator::CoevolveJumpAggregation) - length(aggregator.urates) + return length(aggregator.urates) end # condition for jump to occur function (p::CoevolveJumpAggregation)(u, t, integrator) - p.next_jump_time == t && + return p.next_jump_time == t && accept_next_jump!(p, integrator, integrator.u, integrator.p, integrator.t) end # executing jump at the next jump time -function (p::CoevolveJumpAggregation)(integrator::I) where {I <: - AbstractSSAIntegrator} +function (p::CoevolveJumpAggregation)(integrator::I) where { + I <: + AbstractSSAIntegrator, + } if !accept_next_jump!(p, integrator, integrator.u, integrator.p, integrator.t) return nothing end @@ -81,27 +91,34 @@ function (p::CoevolveJumpAggregation)(integrator::I) where {I <: end generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) register_next_jump_time!(integrator, p, integrator.t) - nothing + return nothing end -function (p::CoevolveJumpAggregation{ - T, S, F1, F2})(integrator::AbstractSSAIntegrator) where - {T, S, F1, F2 <: Union{Tuple, Nothing}} +function ( + p::CoevolveJumpAggregation{ + T, S, F1, F2, + } + )(integrator::AbstractSSAIntegrator) where + {T, S, F1, F2 <: Union{Tuple, Nothing}} if !accept_next_jump!(p, integrator, integrator.u, integrator.p, integrator.t) return nothing end execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t, p.affects!) generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) register_next_jump_time!(integrator, p, integrator.t) - nothing + return nothing end # creating the JumpAggregation structure (tuple-based variable jumps) -function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, +function aggregate( + aggregator::Coevolve, u, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; dep_graph = nothing, - variable_jumps = nothing, kwargs...) - RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), - Tuple{typeof(u), typeof(p), typeof(t)}} + variable_jumps = nothing, kwargs... + ) + RateWrapper = FunctionWrappers.FunctionWrapper{ + typeof(t), + Tuple{typeof(u), typeof(p), typeof(t)}, + } ncrjs = (constant_jumps === nothing) ? 0 : length(constant_jumps) nvrjs = (variable_jumps === nothing) ? 0 : length(variable_jumps) @@ -140,10 +157,12 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, sum_rate = nothing next_jump = 0 next_jump_time = typemax(t) - CoevolveJumpAggregation(next_jump, next_jump_time, end_time, cur_rates, sum_rate, + return CoevolveJumpAggregation( + next_jump, next_jump_time, end_time, cur_rates, sum_rate, ma_jumps, rates, affects!, save_positions, rng; u, dep_graph, lrates, urates, rateintervals, haslratevec, - cur_lrates) + cur_lrates + ) end # set up a new simulation and calculate the first jump / jump time @@ -151,23 +170,25 @@ function initialize!(p::CoevolveJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, u, params, t) generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end # execute one jump, changing the system state -function execute_jumps!(p::CoevolveJumpAggregation, integrator, u, params, t, - affects!) +function execute_jumps!( + p::CoevolveJumpAggregation, integrator, u, params, t, + affects! + ) # execute jump update_state!(p, integrator, u, affects!) # update current jump rates and times update_dependent_rates!(p, integrator.u, integrator.p, t) - nothing + return nothing end # calculate the next jump / jump time function generate_jumps!(p::CoevolveJumpAggregation, integrator, u, params, t) p.next_jump_time, p.next_jump = top_with_handle(p.pq) - nothing + return nothing end ######################## SSA specific helper routines ######################## @@ -233,7 +254,7 @@ function update_dependent_rates!(p::CoevolveJumpAggregation, u, params, t) update!(pq, i, ti) @inbounds cur_rates[i] = urate_i end - nothing + return nothing end @inline function get_ma_urate(p::CoevolveJumpAggregation, i, u, params, t) @@ -308,5 +329,5 @@ function fill_rates_and_get_times!(p::CoevolveJumpAggregation, u, params, t) jump_times[i], p.cur_rates[i] = next_time(p, u, params, t, i) end p.pq = MutableBinaryMinHeap(jump_times) - nothing + return nothing end diff --git a/src/aggregators/direct.jl b/src/aggregators/direct.jl index ab33dd842..4b3e05988 100644 --- a/src/aggregators/direct.jl +++ b/src/aggregators/direct.jl @@ -1,5 +1,5 @@ mutable struct DirectJumpAggregation{T, S, F1, F2, RNG} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::Int prev_jump::Int next_jump_time::T @@ -12,50 +12,64 @@ mutable struct DirectJumpAggregation{T, S, F1, F2, RNG} <: save_positions::Tuple{Bool, Bool} rng::RNG end -function DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, +function DirectJumpAggregation( + nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; - kwargs...) where {T, S, F1, F2, RNG} + kwargs... + ) where {T, S, F1, F2, RNG} affecttype = F2 <: Tuple ? F2 : Any - DirectJumpAggregation{T, S, F1, affecttype, RNG}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng) + return DirectJumpAggregation{T, S, F1, affecttype, RNG}( + nj, nj, njt, et, crs, sr, maj, rs, + affs!, sps, rng + ) end ############################# Required Functions ############################# # creating the JumpAggregation structure (tuple-based constant jumps) -function aggregate(aggregator::Direct, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::Direct, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using tuples rates, affects! = get_jump_info_tuples(constant_jumps) - build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; kwargs...) + return build_jump_aggregation( + DirectJumpAggregation, u, p, t, end_time, ma_jumps, + rates, affects!, save_positions, rng; kwargs... + ) end # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::DirectFW, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::DirectFW, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; kwargs...) + return build_jump_aggregation( + DirectJumpAggregation, u, p, t, end_time, ma_jumps, + rates, affects!, save_positions, rng; kwargs... + ) end # set up a new simulation and calculate the first jump / jump time function initialize!(p::DirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end # execute one jump, changing the system state -@inline function execute_jumps!(p::DirectJumpAggregation, integrator, u, params, t, - affects!) +@inline function execute_jumps!( + p::DirectJumpAggregation, integrator, u, params, t, + affects! + ) update_state!(p, integrator, u, affects!) - nothing + return nothing end # calculate the next jump / jump time @@ -63,14 +77,16 @@ function generate_jumps!(p::DirectJumpAggregation, integrator, u, params, t) p.sum_rate, ttnj = time_to_next_jump(p, u, params, t) p.next_jump_time = add_fast(t, ttnj) @inbounds p.next_jump = searchsortedfirst(p.cur_rates, rand(p.rng) * p.sum_rate) - nothing + return nothing end ######################## SSA specific helper routines ######################## # tuple-based constant jumps -function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, - t) where {T, S, F1 <: Tuple} +function time_to_next_jump( + p::DirectJumpAggregation{T, S, F1}, u, params, + t + ) where {T, S, F1 <: Tuple} prev_rate = zero(t) new_rate = zero(t) cur_rates = p.cur_rates @@ -96,23 +112,25 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, end @inbounds sum_rate = cur_rates[end] - sum_rate, randexp(p.rng) / sum_rate + return sum_rate, randexp(p.rng) / sum_rate end @inline function fill_cur_rates(u, p, t, cur_rates, idx, rate, rates...) @inbounds cur_rates[idx] = rate(u, p, t) idx += 1 - fill_cur_rates(u, p, t, cur_rates, idx, rates...) + return fill_cur_rates(u, p, t, cur_rates, idx, rates...) end @inline function fill_cur_rates(u, p, t, cur_rates, idx, rate) @inbounds cur_rates[idx] = rate(u, p, t) - nothing + return nothing end # function wrapper-based constant jumps -function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, - t) where {T, S, F1 <: AbstractArray} +function time_to_next_jump( + p::DirectJumpAggregation{T, S, F1}, u, params, + t + ) where {T, S, F1 <: AbstractArray} prev_rate = zero(t) new_rate = zero(t) cur_rates = p.cur_rates @@ -137,5 +155,5 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, end @inbounds sum_rate = cur_rates[end] - sum_rate, randexp(p.rng) / sum_rate + return sum_rate, randexp(p.rng) / sum_rate end diff --git a/src/aggregators/directcr.jl b/src/aggregators/directcr.jl index 41d079fe8..31bb67e21 100644 --- a/src/aggregators/directcr.jl +++ b/src/aggregators/directcr.jl @@ -8,11 +8,13 @@ and by S. Mauch and M. Stalzer, ACM Trans. Comp. Biol. and Bioinf., 8, No. 1, 27-35 (2010). """ -const MINJUMPRATE = 2.0^exponent(1e-12) +const MINJUMPRATE = 2.0^exponent(1.0e-12) -mutable struct DirectCRJumpAggregation{T, S, F1, F2, RNG, DEPGR, U <: PriorityTable, - W <: Function} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct DirectCRJumpAggregation{ + T, S, F1, F2, RNG, DEPGR, U <: PriorityTable, + W <: Function, + } <: + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::Int prev_jump::Int next_jump_time::T @@ -31,12 +33,14 @@ mutable struct DirectCRJumpAggregation{T, S, F1, F2, RNG, DEPGR, U <: PriorityTa ratetogroup::W end -function DirectCRJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, +function DirectCRJumpAggregation( + nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; num_specs, dep_graph = nothing, minrate = convert(T, MINJUMPRATE), maxrate = convert(T, Inf), - kwargs...) where {T, S, F1, F2, RNG} + kwargs... + ) where {T, S, F1, F2, RNG} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -63,25 +67,33 @@ function DirectCRJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) affecttype = F2 <: Tuple ? F2 : Any - DirectCRJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), - typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, crs, sr, maj, + return DirectCRJumpAggregation{ + T, S, F1, affecttype, RNG, typeof(dg), + typeof(rt), typeof(ratetogroup), + }( + nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng, dg, minrate, maxrate, rt, - ratetogroup) + ratetogroup + ) end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::DirectCR, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::DirectCR, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation(DirectCRJumpAggregation, u, p, t, end_time, ma_jumps, + return build_jump_aggregation( + DirectCRJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; num_specs = length(u), - kwargs...) + kwargs... + ) end # set up a new simulation and calculate the first jump / jump time @@ -98,7 +110,7 @@ function initialize!(p::DirectCRJumpAggregation, integrator, u, params, t) end generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end # execute one jump, changing the system state @@ -108,7 +120,7 @@ function execute_jumps!(p::DirectCRJumpAggregation, integrator, u, params, t, af # update current jump rates update_dependent_rates!(p, u, params, t) - nothing + return nothing end # calculate the next jump / jump time @@ -118,7 +130,7 @@ function generate_jumps!(p::DirectCRJumpAggregation, integrator, u, params, t) if p.next_jump_time < p.end_time p.next_jump = sample(p.rt, p.cur_rates, p.rng) end - nothing + return nothing end ######################## SSA specific helper routines ######################### @@ -141,5 +153,5 @@ function update_dependent_rates!(p::DirectCRJumpAggregation, u, params, t) end p.sum_rate = groupsum(rt) - nothing + return nothing end diff --git a/src/aggregators/frm.jl b/src/aggregators/frm.jl index 94baed375..557468cf9 100644 --- a/src/aggregators/frm.jl +++ b/src/aggregators/frm.jl @@ -1,5 +1,5 @@ mutable struct FRMJumpAggregation{T, S, F1, F2, RNG} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::Int prev_jump::Int next_jump_time::T @@ -12,52 +12,62 @@ mutable struct FRMJumpAggregation{T, S, F1, F2, RNG} <: save_positions::Tuple{Bool, Bool} rng::RNG end -function FRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, +function FRMJumpAggregation( + nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; - kwargs...) where {T, S, F1, F2, RNG} + kwargs... + ) where {T, S, F1, F2, RNG} affecttype = F2 <: Tuple ? F2 : Any - FRMJumpAggregation{T, S, F1, affecttype, RNG}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng) + return FRMJumpAggregation{T, S, F1, affecttype, RNG}( + nj, nj, njt, et, crs, sr, maj, rs, + affs!, sps, rng + ) end ############################# Required Functions ############################# # creating the JumpAggregation structure (tuple-based constant jumps) -function aggregate(aggregator::FRM, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::FRM, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using tuples rates, affects! = get_jump_info_tuples(constant_jumps) - build_jump_aggregation( + return build_jump_aggregation( FRMJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, - save_positions, rng; kwargs...) + save_positions, rng; kwargs... + ) end # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::FRMFW, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::FRMFW, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation( + return build_jump_aggregation( FRMJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, - save_positions, rng; kwargs...) + save_positions, rng; kwargs... + ) end # set up a new simulation and calculate the first jump / jump time function initialize!(p::FRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end # execute one jump, changing the system state @inline function execute_jumps!(p::FRMJumpAggregation, integrator, u, params, t, affects!) # execute jump update_state!(p, integrator, u, affects!) - nothing + return nothing end # calculate the next jump / jump time @@ -73,7 +83,7 @@ function generate_jumps!(p::FRMJumpAggregation, integrator, u, params, t) p.next_jump = nextcrj p.next_jump_time = t + ttncrj end - nothing + return nothing end ######################## SSA specific helper routines ######################## @@ -91,12 +101,14 @@ function next_ma_jump(p::FRMJumpAggregation, u, params, t) nextrx = i end end - nextrx, ttnj + return nextrx, ttnj end # tuple-based constant jumps -function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u, params, - t) where {T, S, F1 <: Tuple, F2 <: Tuple, RNG} +function next_constant_rate_jump( + p::FRMJumpAggregation{T, S, F1, F2, RNG}, u, params, + t + ) where {T, S, F1 <: Tuple, F2 <: Tuple, RNG} ttnj = typemax(typeof(t)) nextrx = zero(Int) if !isempty(p.rates) @@ -110,12 +122,14 @@ function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u, pa end end end - nextrx, ttnj + return nextrx, ttnj end # function wrapper-based constant jumps -function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1}, u, params, - t) where {T, S, F1 <: AbstractArray} +function next_constant_rate_jump( + p::FRMJumpAggregation{T, S, F1}, u, params, + t + ) where {T, S, F1 <: AbstractArray} ttnj = typemax(typeof(t)) nextrx = zero(Int) if !isempty(p.rates) @@ -130,5 +144,5 @@ function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1}, u, params, idx += 1 end end - nextrx, ttnj + return nextrx, ttnj end diff --git a/src/aggregators/nrm.jl b/src/aggregators/nrm.jl index 7fbcd5964..c7279d281 100644 --- a/src/aggregators/nrm.jl +++ b/src/aggregators/nrm.jl @@ -2,7 +2,7 @@ # Gibson and Bruck, J. Phys. Chem. A, 104 (9), (2000) mutable struct NRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PQ} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::Int prev_jump::Int next_jump_time::T @@ -18,10 +18,12 @@ mutable struct NRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PQ} <: pq::PQ end -function NRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, +function NRMJumpAggregation( + nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; num_specs, dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + kwargs... + ) where {T, S, F1, F2, RNG} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -40,23 +42,29 @@ function NRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, pq = MutableBinaryMinHeap{T}() affecttype = F2 <: Tuple ? F2 : Any - NRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(pq)}(nj, nj, njt, et, + return NRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(pq)}( + nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, - rng, dg, pq) + rng, dg, pq + ) end -+############################# Required Functions ############################## ++ ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::NRM, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::NRM, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation(NRMJumpAggregation, u, p, t, end_time, ma_jumps, + return build_jump_aggregation( + NRMJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; num_specs = length(u), - kwargs...) + kwargs... + ) end # set up a new simulation and calculate the first jump / jump time @@ -64,7 +72,7 @@ function initialize!(p::NRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, u, params, t) generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end # execute one jump, changing the system state @@ -74,14 +82,14 @@ function execute_jumps!(p::NRMJumpAggregation, integrator, u, params, t, affects # update current jump rates and times update_dependent_rates!(p, u, params, t) - nothing + return nothing end # calculate the next jump / jump time # just the top of the priority queue function generate_jumps!(p::NRMJumpAggregation, integrator, u, params, t) p.next_jump_time, p.next_jump = top_with_handle(p.pq) - nothing + return nothing end ######################## SSA specific helper routines ######################## @@ -96,8 +104,10 @@ function update_dependent_rates!(p::NRMJumpAggregation, u, params, t) oldrate = cur_rates[rx] # update the jump rate - @inbounds cur_rates[rx] = calculate_jump_rate(ma_jumps, num_majumps, rates, u, - params, t, rx) + @inbounds cur_rates[rx] = calculate_jump_rate( + ma_jumps, num_majumps, rates, u, + params, t, rx + ) # calculate new jump times for dependent jumps if rx != p.next_jump && oldrate > zero(oldrate) @@ -114,7 +124,7 @@ function update_dependent_rates!(p::NRMJumpAggregation, u, params, t) end end end - nothing + return nothing end # reevaluate all rates, recalculate all jump times, and reinit the priority queue @@ -140,5 +150,5 @@ function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t) # setup a new indexed priority queue to storing rx times p.pq = MutableBinaryMinHeap(pqdata) - nothing + return nothing end diff --git a/src/aggregators/prioritytable.jl b/src/aggregators/prioritytable.jl index 11ddf171b..285bc3866 100644 --- a/src/aggregators/prioritytable.jl +++ b/src/aggregators/prioritytable.jl @@ -41,7 +41,7 @@ PriorityGroup{U}(maxpriority::T) where {T, U} = PriorityGroup(maxpriority, 0, Ve end # return the position of insertion - pg.numpids += 1 + return pg.numpids += 1 end @inline function remove!(pg::PriorityGroup, pididx) @@ -54,18 +54,18 @@ end pg.numpids -= 1 # return the pid that was swapped to pididx - lastpid + return lastpid end @inline function ids(pg::PriorityGroup) - pg.pids[1:(pg.numpids)] + return pg.pids[1:(pg.numpids)] end function Base.show(io::IO, pg::PriorityGroup) println(io, " ", summary(pg)) println(io, " maxpriority = ", pg.maxpriority) println(io, " numpids = ", pg.numpids) - println(io, " pids = ", ids(pg)) + return println(io, " pids = ", ids(pg)) end """ @@ -100,8 +100,10 @@ end Setup table from a vector of priorities. The id of a priority is its position within this vector. """ -function PriorityTable(priortogid::Function, priorities::AbstractVector, minpriority, - maxpriority) +function PriorityTable( + priortogid::Function, priorities::AbstractVector, minpriority, + maxpriority + ) numgroups = priortogid(maxpriority) numgroups -= one(typeof(numgroups)) pidtype = typeof(numgroups) @@ -128,20 +130,20 @@ function PriorityTable(priortogid::Function, priorities::AbstractVector, minprio insert!(pt, pid, priority) end - pt + return pt end ########################## ACCESSORS ########################## @inline function numgroups(pt::PriorityTable) - length(pt.groups) + return length(pt.groups) end @inline function numpriorities(pt::PriorityTable) - length(pt.pidtogroup) + return length(pt.pidtogroup) end @inline function groupsum(pt::PriorityTable) - pt.gsum + return pt.gsum end """ @@ -157,7 +159,7 @@ function padtable!(pt::PriorityTable, pid, priority) push!(gsums, zero(eltype(gsums))) end pt.maxpriority = maxpriority - nothing + return nothing end # assumes pid is at most 1 greater than last priority (id) currently in table @@ -186,7 +188,7 @@ function insert!(pt::PriorityTable, pid, priority) push!(pidtogroup, (gid, pididx)) end - nothing + return nothing end function update!(pt::PriorityTable, pid, oldpriority, newpriority) @@ -225,11 +227,11 @@ function update!(pt::PriorityTable, pid, oldpriority, newpriority) # update sums, special case if group empty to avoid FP error in running sums grpsz = groups[oldgid].numpids gsums[oldgid] = (grpsz == zero(grpsz)) ? zero(oldpriority) : - gsums[oldgid] - oldpriority + gsums[oldgid] - oldpriority gsums[newgid] += newpriority end end - nothing + return nothing end function reset!(pt::PriorityTable{F, S, T, U}) where {F, S, T, U} @@ -240,6 +242,7 @@ function reset!(pt::PriorityTable{F, S, T, U}) where {F, S, T, U} for group in groups group.numpids = zero(T) end + return end function Base.show(io::IO, pt::PriorityTable) @@ -254,6 +257,7 @@ function Base.show(io::IO, pt::PriorityTable) Base.show(io, group) end end + return end ############################# @@ -286,7 +290,7 @@ end ((r - pididx) * maxpriority < priorities[pid]) && break end - pid + return pid end function sample(pt::PriorityTable, priorities, rng = DEFAULT_RNG) @@ -316,7 +320,7 @@ function sample(pt::PriorityTable, priorities, rng = DEFAULT_RNG) iszero(gid) && return gid # sample element within the group - @inbounds sample(groups[gid], priorities, rng) + return @inbounds sample(groups[gid], priorities, rng) end ########################## @@ -338,17 +342,18 @@ mutable struct PriorityTimeTable{T, F <: Int} times::Vector{T} timegrouper::TimeGrouper{T} minbin::F - steps::F # TODO: For adaptive rebuilding. + steps::F # TODO: For adaptive rebuilding. maxtime::T binwidthconst::F numbinsconst::F end -# Construct the time table with the default optimal bin width and number of bins. +# Construct the time table with the default optimal bin width and number of bins. # DEFAULT NUMBINS: 20 * √length(times) # DEFAULT BINWIDTH: 16 / sum(propensities) function PriorityTimeTable( - times::AbstractVector, mintime, timestep; binwidthconst = 16, numbinsconst = 20) + times::AbstractVector, mintime, timestep; binwidthconst = 16, numbinsconst = 20 + ) binwidth = binwidthconst * timestep numbins = floor(Int64, numbinsconst * sqrt(length(times))) maxtime = mintime + numbins * binwidth @@ -366,7 +371,8 @@ function PriorityTimeTable( ptt = PriorityTimeTable( groups, pidtogroup, times, ttgdata, zero(pidtype), - zero(pidtype), maxtime, binwidthconst, numbinsconst) + zero(pidtype), maxtime, binwidthconst, numbinsconst + ) # Insert priority ids into the groups for (pid, time) in enumerate(times) if time > maxtime @@ -378,11 +384,11 @@ function PriorityTimeTable( ptt.minbin = findfirst(g -> g.numpids > (0), groups) ptt.minbin === nothing && (ptt.minbin = 0) - ptt + return ptt end # Rebuild the table when there are no more reaction times within the current -# time window. +# time window. function rebuild!(ptt::PriorityTimeTable{T, F}, mintime, timestep) where {T, F} (; pidtogroup, groups, times, binwidthconst) = ptt fill!(pidtogroup, (zero(F), zero(F))) @@ -399,7 +405,7 @@ function rebuild!(ptt::PriorityTimeTable{T, F}, mintime, timestep) where {T, F} group.maxpriority = groupmaxtime end - # Reinsert the times into the groups. + # Reinsert the times into the groups. for (id, time) in enumerate(times) time > ptt.maxtime && continue insert!(ptt, id, time) @@ -449,13 +455,13 @@ end # Update the priority table when a reaction time gets updated. We only shift # between bins if the new time is within the current time window; otherwise -# we remove the reaction and wait until rebuild. +# we remove the reaction and wait until rebuild. function update!(ptt::PriorityTimeTable{T, F}, pid, oldtime, newtime) where {T, F} (; times, timegrouper, maxtime, pidtogroup, groups) = ptt times[pid] = newtime if oldtime >= maxtime - # If a reaction comes back into the time window, insert it. + # If a reaction comes back into the time window, insert it. newtime < maxtime ? insert!(ptt, pid, newtime) : return nothing elseif newtime >= maxtime # If the new time lands outside of current window, remove it. @@ -466,7 +472,7 @@ function update!(ptt::PriorityTimeTable{T, F}, pid, oldtime, newtime) where {T, pidtogroup[pid] = (zero(F), zero(F)) end else - # Move bins if the reaction was already inside. + # Move bins if the reaction was already inside. oldgid = timegrouper(oldtime) newgid = timegrouper(newtime) oldgid == newgid && return nothing diff --git a/src/aggregators/rdirect.jl b/src/aggregators/rdirect.jl index 8376b71b9..a1a8b5693 100644 --- a/src/aggregators/rdirect.jl +++ b/src/aggregators/rdirect.jl @@ -3,7 +3,7 @@ Direct with rejection sampling """ mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::Int prev_jump::Int next_jump_time::T @@ -21,11 +21,13 @@ mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: counter_threshold::Any end -function RDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, +function RDirectJumpAggregation( + nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; num_specs, counter_threshold = length(crs), dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + kwargs... + ) where {T, S, F1, F2, RNG} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing if (get_num_majumps(maj) == 0) || !isempty(rs) @@ -42,25 +44,31 @@ function RDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, m max_rate = maximum(crs) affecttype = F2 <: Tuple ? F2 : Any - return RDirectJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg)}(nj, nj, njt, et, + return RDirectJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg)}( + nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng, dg, max_rate, 0, - counter_threshold) + counter_threshold + ) end ############################# Required Functions ############################# # creating the JumpAggregation structure (tuple-based constant jumps) -function aggregate(aggregator::RDirect, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::RDirect, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation(RDirectJumpAggregation, u, p, t, end_time, ma_jumps, + return build_jump_aggregation( + RDirectJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; num_specs = length(u), - kwargs...) + kwargs... + ) end # set up a new simulation and calculate the first jump / jump time @@ -69,7 +77,7 @@ function initialize!(p::RDirectJumpAggregation, integrator, u, params, t) fill_rates_and_sum!(p, u, params, t) p.max_rate = maximum(p.cur_rates) generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end """ @@ -81,7 +89,7 @@ function execute_jumps!(p::RDirectJumpAggregation, integrator, u, params, t, aff # update rates update_dependent_rates!(p, u, params, t) - nothing + return nothing end """ @@ -106,7 +114,7 @@ function generate_jumps!(p::RDirectJumpAggregation, integrator, u, params, t) p.counter = counter p.next_jump = rx p.next_jump_time = t + randexp(p.rng) / sum_rate - nothing + return nothing end ######################## SSA specific helper routines ######################### @@ -118,7 +126,8 @@ function update_dependent_rates!(p::RDirectJumpAggregation, u, params, t) @inbounds for rx in dep_rxs @inbounds new_rate = calculate_jump_rate( ma_jumps, num_majumps, rates, u, params, t, - rx) + rx + ) sum_rate += new_rate - cur_rates[rx] if new_rate > p.max_rate p.max_rate = new_rate @@ -131,5 +140,5 @@ function update_dependent_rates!(p::RDirectJumpAggregation, u, params, t) end p.sum_rate = sum_rate - nothing + return nothing end diff --git a/src/aggregators/rssa.jl b/src/aggregators/rssa.jl index 2943bf88b..35305d69e 100644 --- a/src/aggregators/rssa.jl +++ b/src/aggregators/rssa.jl @@ -5,7 +5,7 @@ # requires vartojumps_map and fluct_rates as JumpProblem keywords mutable struct RSSAJumpAggregation{T, S, F1, F2, RNG, VJMAP, JVMAP, BD, U} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::Int prev_jump::Int next_jump_time::T @@ -25,11 +25,13 @@ mutable struct RSSAJumpAggregation{T, S, F1, F2, RNG, VJMAP, JVMAP, BD, U} <: uhigh::U end -function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, +function RSSAJumpAggregation( + nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; u::U, vartojumps_map = nothing, jumptovars_map = nothing, - bracket_data = nothing, kwargs...) where {T, S, F1, F2, RNG, U} + bracket_data = nothing, kwargs... + ) where {T, S, F1, F2, RNG, U} # a dependency graph is needed and must be provided if there are constant rate jumps if vartojumps_map === nothing if (get_num_majumps(maj) == 0) || !isempty(rs) @@ -63,25 +65,33 @@ function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, uhigh = similar(u) affecttype = F2 <: Tuple ? F2 : Any - RSSAJumpAggregation{T, S, F1, affecttype, RNG, typeof(vtoj_map), - typeof(jtov_map), typeof(bd), U}(nj, nj, njt, et, crl_bnds, + return RSSAJumpAggregation{ + T, S, F1, affecttype, RNG, typeof(vtoj_map), + typeof(jtov_map), typeof(bd), U, + }( + nj, nj, njt, et, crl_bnds, crh_bnds, sr, maj, rs, affs!, sps, rng, vtoj_map, jtov_map, bd, ulow, - uhigh) + uhigh + ) end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::RSSA, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::RSSA, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation(RSSAJumpAggregation, u, p, t, end_time, ma_jumps, + return build_jump_aggregation( + RSSAJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; u = u, - kwargs...) + kwargs... + ) end # set up a new simulation and calculate the first jump / jump time @@ -89,7 +99,7 @@ function initialize!(p::RSSAJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] set_bracketing!(p, u, params, t) generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end # execute one jump, changing the system state @@ -97,7 +107,7 @@ function execute_jumps!(p::RSSAJumpAggregation, integrator, u, params, t, affect # execute jump u = update_state!(p, integrator, u, affects!) update_rates!(p, u, params, t) - nothing + return nothing end # calculate the next jump / jump time @@ -119,8 +129,10 @@ function generate_jumps!(p::RSSAJumpAggregation, integrator, u, params, t) return nothing end rerl += randexp(rng) - @inbounds while rejectrx(ma_jumps, num_majumps, rates, cur_rate_high, - cur_rate_low, rng, u, jidx, params, t) + @inbounds while rejectrx( + ma_jumps, num_majumps, rates, cur_rate_high, + cur_rate_low, rng, u, jidx, params, t + ) # sample candidate reaction r = rand(rng) * sum_rate jidx = linear_search(cur_rate_high, r) @@ -129,7 +141,7 @@ function generate_jumps!(p::RSSAJumpAggregation, integrator, u, params, t) p.next_jump = jidx p.next_jump_time = t + rerl / sum_rate - nothing + return nothing end # alt erlang sampling above @@ -164,7 +176,7 @@ Update rates end end end - p.sum_rate = sum_rate + return p.sum_rate = sum_rate end @inline function update_rates!(p::RSSAJumpAggregation, u::SVector, params, t) @@ -190,5 +202,5 @@ end end end end - p.sum_rate = sum_rate + return p.sum_rate = sum_rate end diff --git a/src/aggregators/rssacr.jl b/src/aggregators/rssacr.jl index 1caf3be53..fc49bb0f9 100644 --- a/src/aggregators/rssacr.jl +++ b/src/aggregators/rssacr.jl @@ -2,11 +2,13 @@ Composition-Rejection with Rejection sampling method (RSSA-CR) """ -const MINJUMPRATE = 2.0^exponent(1e-12) +const MINJUMPRATE = 2.0^exponent(1.0e-12) -mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, - P <: PriorityTable, W <: Function} <: - AbstractSSAJumpAggregator{F, S, F1, F2, RNG} +mutable struct RSSACRJumpAggregation{ + F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, + P <: PriorityTable, W <: Function, + } <: + AbstractSSAJumpAggregator{F, S, F1, F2, RNG} next_jump::Int prev_jump::Int next_jump_time::F @@ -30,12 +32,14 @@ mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, ratetogroup::W end -function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate::F, maj::S, +function RSSACRJumpAggregation( + nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate::F, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; u::U, vartojumps_map = nothing, jumptovars_map = nothing, bracket_data = nothing, minrate = convert(F, MINJUMPRATE), maxrate = convert(F, Inf), - kwargs...) where {F, S, F1, F2, RNG, U} + kwargs... + ) where {F, S, F1, F2, RNG, U} # a dependency graph is needed and must be provided if there are constant rate jumps if vartojumps_map === nothing if (get_num_majumps(maj) == 0) || !isempty(rs) @@ -80,25 +84,33 @@ function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate: rt = PriorityTable(ratetogroup, zeros(F, 1), minrate, 2 * minrate) affecttype = F2 <: Tuple ? F2 : Any - RSSACRJumpAggregation{typeof(njt), S, F1, affecttype, RNG, U, typeof(vtoj_map), + return RSSACRJumpAggregation{ + typeof(njt), S, F1, affecttype, RNG, U, typeof(vtoj_map), typeof(jtov_map), typeof(bd), typeof(rt), - typeof(ratetogroup)}(nj, nj, njt, et, crl_bnds, crh_bnds, + typeof(ratetogroup), + }( + nj, nj, njt, et, crl_bnds, crh_bnds, sum_rate, maj, rs, affs!, sps, rng, vtoj_map, jtov_map, bd, ulow, uhigh, minrate, maxrate, - rt, ratetogroup) + rt, ratetogroup + ) end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::RSSACR, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::RSSACR, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation(RSSACRJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; u = u, kwargs...) + return build_jump_aggregation( + RSSACRJumpAggregation, u, p, t, end_time, ma_jumps, + rates, affects!, save_positions, rng; u = u, kwargs... + ) end # set up a new simulation and calculate the first jump / jump time @@ -113,7 +125,7 @@ function initialize!(p::RSSACRJumpAggregation, integrator, u, params, t) end generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end # execute one jump, changing the system state @@ -123,7 +135,7 @@ function execute_jumps!(p::RSSACRJumpAggregation, integrator, u, params, t, affe # update rates update_dependent_rates!(p, u, params, t) - nothing + return nothing end # calculate the next jump / jump time @@ -144,8 +156,10 @@ function generate_jumps!(p::RSSACRJumpAggregation, integrator, u, params, t) return nothing end rerl += randexp(rng) - while rejectrx(ma_jumps, num_majumps, rates, cur_rate_high, cur_rate_low, rng, u, jidx, - params, t) + while rejectrx( + ma_jumps, num_majumps, rates, cur_rate_high, cur_rate_low, rng, u, jidx, + params, t + ) # sample candidate reaction jidx = sample(rt, cur_rate_high, rng) rerl += randexp(rng) @@ -154,15 +168,17 @@ function generate_jumps!(p::RSSACRJumpAggregation, integrator, u, params, t) # update time to next jump p.next_jump_time = t + rerl / sum_rate - nothing + return nothing end ######################## SSA specific helper routines ######################### """ update bracketing for species that depend on the just executed jump """ -@inline function update_dependent_rates!(p::RSSACRJumpAggregation, u::AbstractVector, - params, t) +@inline function update_dependent_rates!( + p::RSSACRJumpAggregation, u::AbstractVector, + params, t + ) # update bracketing intervals (; ulow, uhigh) = p crhigh = p.cur_rate_high @@ -186,7 +202,7 @@ update bracketing for species that depend on the just executed jump end p.sum_rate = groupsum(p.rt) - nothing + return nothing end @inline function update_dependent_rates!(p::RSSACRJumpAggregation, u::SVector, params, t) @@ -214,5 +230,5 @@ end end p.sum_rate = groupsum(p.rt) - nothing + return nothing end diff --git a/src/aggregators/sortingdirect.jl b/src/aggregators/sortingdirect.jl index f9048f039..dfcf445a9 100644 --- a/src/aggregators/sortingdirect.jl +++ b/src/aggregators/sortingdirect.jl @@ -3,7 +3,7 @@ # Comp. Bio. and Chem., 30, pg. 39-49 (2006). mutable struct SortingDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::Int prev_jump::Int next_jump_time::T @@ -20,10 +20,12 @@ mutable struct SortingDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: jump_search_idx::Int end -function SortingDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, +function SortingDirectJumpAggregation( + nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; num_specs, dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + kwargs... + ) where {T, S, F1, F2, RNG} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -42,25 +44,31 @@ function SortingDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr # map jump idx to idx in cur_rates jtoidx = collect(1:length(crs)) affecttype = F2 <: Tuple ? F2 : Any - SortingDirectJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg)}(nj, nj, njt, et, + return SortingDirectJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg)}( + nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng, dg, jtoidx, - zero(Int)) + zero(Int) + ) end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::SortingDirect, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) +function aggregate( + aggregator::SortingDirect, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs... + ) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation(SortingDirectJumpAggregation, u, p, t, end_time, ma_jumps, + return build_jump_aggregation( + SortingDirectJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; num_specs = length(u), - kwargs...) + kwargs... + ) end # set up a new simulation and calculate the first jump / jump time @@ -68,7 +76,7 @@ function initialize!(p::SortingDirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_sum!(p, u, params, t) generate_jumps!(p, integrator, u, params, t) - nothing + return nothing end # execute one jump, changing the system state @@ -87,7 +95,7 @@ function execute_jumps!(p::SortingDirectJumpAggregation, integrator, u, params, # update current jump rates update_dependent_rates!(p, u, params, t) - nothing + return nothing end # calculate the next jump / jump time @@ -110,5 +118,5 @@ function generate_jumps!(p::SortingDirectJumpAggregation, integrator, u, params, @inbounds p.next_jump = jso[p.jump_search_idx] end - nothing + return nothing end diff --git a/src/aggregators/ssajump.jl b/src/aggregators/ssajump.jl index 90c260c97..557071bcc 100644 --- a/src/aggregators/ssajump.jl +++ b/src/aggregators/ssajump.jl @@ -22,7 +22,7 @@ An aggregator interface for SSA-like algorithms. abstract type AbstractSSAJumpAggregator{T, S, F1, F2, RNG} <: AbstractJumpAggregator end function DiscreteCallback(c::AbstractSSAJumpAggregator) - DiscreteCallback(c, c, initialize = c, save_positions = c.save_positions) + return DiscreteCallback(c, c, initialize = c, save_positions = c.save_positions) end ########### The following routines are templates for all SSAs ########### @@ -36,7 +36,7 @@ end @inline function makewrapper(::Type{T}, aff) where {T} # rewrap existing wrappers - if aff isa FunctionWrappers.FunctionWrapper + return if aff isa FunctionWrappers.FunctionWrapper T(aff.obj[]) elseif aff isa Function T(aff) @@ -45,19 +45,23 @@ end end end -@inline function concretize_affects!(p::AbstractSSAJumpAggregator, - ::I) where {I <: SciMLBase.DEIntegrator} +@inline function concretize_affects!( + p::AbstractSSAJumpAggregator, + ::I + ) where {I <: SciMLBase.DEIntegrator} if (p.affects! isa Vector) && - !(p.affects! isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}) + !(p.affects! isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}) AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}} p.affects! = AffectWrapper[makewrapper(AffectWrapper, aff) for aff in p.affects!] end - nothing + return nothing end -@inline function concretize_affects!(p::AbstractSSAJumpAggregator{T, S, F1, F2}, - ::I) where {T, S, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} - nothing +@inline function concretize_affects!( + p::AbstractSSAJumpAggregator{T, S, F1, F2}, + ::I + ) where {T, S, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} + return nothing end # setting up a new simulation @@ -66,12 +70,12 @@ function (p::AbstractSSAJumpAggregator)(dj, u, t, integrator) # initialize initialize!(p, integrator, u, integrator.p, t) register_next_jump_time!(integrator, p, integrator.t) u_modified!(integrator, false) - nothing + return nothing end # condition for jump to occur @inline function (p::AbstractSSAJumpAggregator)(u, t, integrator) - p.next_jump_time == t + return p.next_jump_time == t end # executing jump at the next jump time @@ -84,16 +88,19 @@ function (p::AbstractSSAJumpAggregator)(integrator::I) where {I <: SciMLBase.DEI end generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) register_next_jump_time!(integrator, p, integrator.t) - nothing + return nothing end -function (p::AbstractSSAJumpAggregator{ - T, S, F1, F2})(integrator::SciMLBase.DEIntegrator) where - {T, S, F1, F2 <: Union{Tuple, Nothing}} +function ( + p::AbstractSSAJumpAggregator{ + T, S, F1, F2, + } + )(integrator::SciMLBase.DEIntegrator) where + {T, S, F1, F2 <: Union{Tuple, Nothing}} execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t, p.affects!) generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) register_next_jump_time!(integrator, p, integrator.t) - nothing + return nothing end ############################## Generic Routines ############################### @@ -107,7 +114,7 @@ Adds a `tstop` to the integrator at the next jump time. if p.next_jump_time < p.end_time add_tstop!(integrator, p.next_jump_time) end - nothing + return nothing end """ @@ -116,15 +123,19 @@ end Helper routine for setting up standard fields of SSA jump aggregations. """ -function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rates, - affects!, save_positions, rng; kwargs...) +function build_jump_aggregation( + jump_agg_type, u, p, t, end_time, ma_jumps, rates, + affects!, save_positions, rng; kwargs... + ) # mass action jumps majumps = ma_jumps if majumps === nothing - majumps = MassActionJump(Vector{typeof(t)}(), + majumps = MassActionJump( + Vector{typeof(t)}(), Vector{Vector{Pair{Int, eltype(u)}}}(), - Vector{Vector{Pair{Int, eltype(u)}}}()) + Vector{Vector{Pair{Int, eltype(u)}}}() + ) end # current jump rates, allows mass action rates and constant jumps @@ -133,8 +144,10 @@ function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rate sum_rate = zero(typeof(t)) next_jump = 0 next_jump_time = typemax(typeof(t)) - jump_agg_type(next_jump, next_jump_time, end_time, cur_rates, sum_rate, - majumps, rates, affects!, save_positions, rng; kwargs...) + return jump_agg_type( + next_jump, next_jump_time, end_time, cur_rates, sum_rate, + majumps, rates, affects!, save_positions, rng; kwargs... + ) end """ @@ -163,7 +176,7 @@ function fill_rates_and_sum!(p::AbstractSSAJumpAggregator, u, params, t) end p.sum_rate = sum_rate - nothing + return nothing end """ @@ -195,13 +208,15 @@ function update_dependent_rates!(p::AbstractSSAJumpAggregator, u, params, t) num_majumps = get_num_majumps(p.ma_jumps) @inbounds for rx in dep_rxs sum_rate -= cur_rates[rx] - @inbounds cur_rates[rx] = calculate_jump_rate(p.ma_jumps, num_majumps, p.rates, u, - params, t, rx) + @inbounds cur_rates[rx] = calculate_jump_rate( + p.ma_jumps, num_majumps, p.rates, u, + params, t, rx + ) sum_rate += cur_rates[rx] end p.sum_rate = sum_rate - nothing + return nothing end """ @@ -228,9 +243,11 @@ Execute `p.next_jump`. return integrator.u end -@generated function update_state!(p::AbstractSSAJumpAggregator, integrator, u, - affects!::T) where {T <: Tuple} - quote +@generated function update_state!( + p::AbstractSSAJumpAggregator, integrator, u, + affects!::T + ) where {T <: Tuple} + return quote (; ma_jumps, next_jump) = p num_ma_rates = get_num_majumps(ma_jumps) if next_jump <= num_ma_rates # is next jump a mass action jump @@ -241,7 +258,7 @@ end end else idx = next_jump - num_ma_rates - Base.Cartesian.@nif $(fieldcount(T)) i->(i == idx) i->(@inbounds affects![i](integrator)) i->(@inbounds affects![fieldcount(T)](integrator)) + Base.Cartesian.@nif $(fieldcount(T)) i -> (i == idx) i -> (@inbounds affects![i](integrator)) i -> (@inbounds affects![fieldcount(T)](integrator)) end # save jump that was just executed @@ -298,7 +315,8 @@ Perform rejection sampling test (used in RSSA methods). """ @inline function rejectrx( ma_jumps, num_majumps, rates, cur_rate_high, cur_rate_low, rng, u, - jidx, params, t) + jidx, params, t + ) # rejection test @inbounds r2 = rand(rng) * cur_rate_high[jidx] @inbounds crlow = cur_rate_low[jidx] diff --git a/src/coupled_array.jl b/src/coupled_array.jl index 05ff1164f..8529a8677 100644 --- a/src/coupled_array.jl +++ b/src/coupled_array.jl @@ -7,7 +7,7 @@ end Base.length(A::CoupledArray) = length(A.u) + length(A.u_control) Base.size(A::CoupledArray) = (length(A),) @inline function Base.getindex(A::CoupledArray, i::Int) - if A.order == true + return if A.order == true i <= length(A.u) ? A.u[i] : A.u_control[i - length(A.u)] else i <= length(A.u) ? A.u_control[i] : A.u[i - length(A.u)] @@ -15,17 +15,17 @@ Base.size(A::CoupledArray) = (length(A),) end @inline function Base.getindex(A::CoupledArray, I...) - A[CartesianIndices(A.u, I...)] + return A[CartesianIndices(A.u, I...)] end @inline function Base.getindex(A::CoupledArray, I::CartesianIndex{1}) - A[I[1]] + return A[I[1]] end @inline Base.setindex!(A::CoupledArray, v, I...) = (A[CartesianIndices(A.u, I...)] = v) @inline Base.setindex!(A::CoupledArray, v, I::CartesianIndex{1}) = (A[I[1]] = v) @inline function Base.setindex!(A::CoupledArray, v, i::Int) - if A.order == true + return if A.order == true i <= length(A.u) ? (A.u[i] = v) : (A.u_control[i - length(A.u)] = v) else i <= length(A.u) ? (A.u_control[i] = v) : (A.u[i - length(A.u)] = v) @@ -35,13 +35,13 @@ end Base.IndexStyle(::Type{<:CoupledArray}) = IndexLinear() Base.similar(A::CoupledArray) = CoupledArray(similar(A.u), similar(A.u_control), A.order) function Base.similar(A::CoupledArray, ::Type{S}) where {S} - CoupledArray(similar(A.u, S), similar(A.u_control, S), A.order) + return CoupledArray(similar(A.u, S), similar(A.u_control, S), A.order) end function recursivecopy!(dest::T, src::T) where {T <: CoupledArray} recursivecopy!(dest.u, src.u) recursivecopy!(dest.u_control, src.u_control) - dest.order = src.order + return dest.order = src.order end add_idxs1(::Type{T}, expr) where {T <: CoupledArray} = :($(expr).u) @@ -53,7 +53,7 @@ add_idxs2(::Type{T}, expr) where {T <: CoupledArray} = :($(expr).u_control) broadcast!(f, A.u, $(exs1...)) broadcast!(f, A.u_control, $(exs2...)) end - res + return res end Base.show(io::IO, A::CoupledArray) = show(io, A.u) diff --git a/src/coupling.jl b/src/coupling.jl index bce40869b..5aebb72d5 100644 --- a/src/coupling.jl +++ b/src/coupling.jl @@ -3,22 +3,28 @@ David F. Anderson, Masanori Koyama; An asymptotic relationship between coupling methods for stochastically modeled population processes. IMA J Numer Anal 2015; 35 (4): 1757-1778. doi: 10.1093/imanum/dru044 """ -function SplitCoupledJumpProblem(prob::DiffEqBase.AbstractJumpProblem, +function SplitCoupledJumpProblem( + prob::DiffEqBase.AbstractJumpProblem, prob_control::DiffEqBase.AbstractJumpProblem, aggregator::AbstractAggregatorAlgorithm, - coupling_map::Vector{Tuple{Int, Int}}; kwargs...) - JumpProblem(cat_problems(prob.prob, prob_control.prob), aggregator, - build_split_jumps(prob, prob_control, coupling_map)...; kwargs...) + coupling_map::Vector{Tuple{Int, Int}}; kwargs... + ) + return JumpProblem( + cat_problems(prob.prob, prob_control.prob), aggregator, + build_split_jumps(prob, prob_control, coupling_map)...; kwargs... + ) end # make new problem by joining initial_data function cat_problems(prob::DiscreteProblem, prob_control::DiscreteProblem) u0_coupled = CoupledArray(prob.u0, prob_control.u0, true) - DiscreteProblem(u0_coupled, prob.tspan, prob.p) + return DiscreteProblem(u0_coupled, prob.tspan, prob.p) end -function cat_problems(prob::DiffEqBase.AbstractODEProblem, - prob_control::DiffEqBase.AbstractODEProblem) +function cat_problems( + prob::DiffEqBase.AbstractODEProblem, + prob_control::DiffEqBase.AbstractODEProblem + ) l = length(prob.u0) # add l_c = length(prob_control.u0) _f = SciMLBase.unwrapped_f(prob.f) @@ -26,10 +32,10 @@ function cat_problems(prob::DiffEqBase.AbstractODEProblem, new_f = function (du, u, p, t) _f(@view(du[1:l]), u.u, p, t) - _f_control(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) + return _f_control(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) end u0_coupled = CoupledArray(prob.u0, prob_control.u0, true) - ODEProblem(new_f, u0_coupled, prob.tspan, prob.p) + return ODEProblem(new_f, u0_coupled, prob.tspan, prob.p) end function cat_problems(prob::DiscreteProblem, prob_control::DiffEqBase.AbstractODEProblem) @@ -43,29 +49,33 @@ function cat_problems(prob::DiscreteProblem, prob_control::DiffEqBase.AbstractOD new_f = function (du, u, p, t) _f(@view(du[1:l]), u.u, p, t) - _f_control(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) + return _f_control(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) end u0_coupled = CoupledArray(prob.u0, prob_control.u0, true) - ODEProblem(new_f, u0_coupled, prob.tspan, prob.p) + return ODEProblem(new_f, u0_coupled, prob.tspan, prob.p) end -function cat_problems(prob::DiffEqBase.AbstractSDEProblem, - prob_control::DiffEqBase.AbstractSDEProblem) +function cat_problems( + prob::DiffEqBase.AbstractSDEProblem, + prob_control::DiffEqBase.AbstractSDEProblem + ) l = length(prob.u0) new_f = function (du, u, p, t) prob.f(@view(du[1:l]), u.u, p, t) - prob_control.f(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) + return prob_control.f(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) end new_g = function (du, u, p, t) prob.g(@view(du[1:l]), u.u, p, t) - prob_control.g(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) + return prob_control.g(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) end u0_coupled = CoupledArray(prob.u0, prob_control.u0, true) - SDEProblem(new_f, new_g, u0_coupled, prob.tspan, prob.p) + return SDEProblem(new_f, new_g, u0_coupled, prob.tspan, prob.p) end -function cat_problems(prob::DiffEqBase.AbstractSDEProblem, - prob_control::DiffEqBase.AbstractODEProblem) +function cat_problems( + prob::DiffEqBase.AbstractSDEProblem, + prob_control::DiffEqBase.AbstractODEProblem + ) l = length(prob.u0) _f = SciMLBase.unwrapped_f(prob.f) @@ -73,16 +83,17 @@ function cat_problems(prob::DiffEqBase.AbstractSDEProblem, new_f = function (du, u, p, t) _f(@view(du[1:l]), u.u, p, t) - _f_control(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) + return _f_control(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) end new_g = function (du, u, p, t) prob.g(@view(du[1:l]), u.u, p, t) for i in (l + 1):(2 * l) du[i] = 0.0 end + return end u0_coupled = CoupledArray(prob.u0, prob_control.u0, true) - SDEProblem(new_f, new_g, u0_coupled, prob.tspan, prob.p) + return SDEProblem(new_f, new_g, u0_coupled, prob.tspan, prob.p) end function cat_problems(prob::DiffEqBase.AbstractSDEProblem, prob_control::DiscreteProblem) @@ -92,33 +103,38 @@ function cat_problems(prob::DiffEqBase.AbstractSDEProblem, prob_control::Discret end new_f = function (du, u, p, t) prob.f(@view(du[1:l]), u.u, p, t) - prob_control.f(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) + return prob_control.f(@view(du[(l + 1):(2 * l)]), u.u_control, p, t) end new_g = function (du, u, p, t) prob.g(@view(du[1:l]), u.u, p, t) for i in (l + 1):(2 * l) du[i] = 0.0 end + return end u0_coupled = CoupledArray(prob.u0, prob_control.u0, true) - SDEProblem(new_f, new_g, u0_coupled, prob.tspan) + return SDEProblem(new_f, new_g, u0_coupled, prob.tspan) end function cat_problems(prob_control::DiffEqBase.AbstractODEProblem, prob::DiscreteProblem) - cat_problems(prob, prob_control) + return cat_problems(prob, prob_control) end function cat_problems(prob_control::DiscreteProblem, prob::DiffEqBase.AbstractSDEProblem) - cat_problems(prob, prob_control) + return cat_problems(prob, prob_control) end -function cat_problems(prob_control::DiffEqBase.AbstractODEProblem, - prob::DiffEqBase.AbstractSDEProblem) - cat_problems(prob, prob_control) +function cat_problems( + prob_control::DiffEqBase.AbstractODEProblem, + prob::DiffEqBase.AbstractSDEProblem + ) + return cat_problems(prob, prob_control) end # this only depends on the jumps in prob, not prob.prob -function build_split_jumps(prob::DiffEqBase.AbstractJumpProblem, +function build_split_jumps( + prob::DiffEqBase.AbstractJumpProblem, prob_control::DiffEqBase.AbstractJumpProblem, - coupling_map::Vector{Tuple{Int, Int}}) + coupling_map::Vector{Tuple{Int, Int}} + ) num_jumps = length(prob.discrete_jump_aggregation.rates) num_jumps_control = length(prob_control.discrete_jump_aggregation.rates) jumps = [] @@ -137,7 +153,7 @@ function build_split_jumps(prob::DiffEqBase.AbstractJumpProblem, new_affect! = function (integrator) flip_u!(integrator.u) affect!(integrator) - flip_u!(integrator.u) + return flip_u!(integrator.u) end push!(jumps, ConstantRateJump(new_rate, new_affect!)) end @@ -152,24 +168,24 @@ function build_split_jumps(prob::DiffEqBase.AbstractJumpProblem, affect!(integrator) flip_u!(integrator.u) affect_control!(integrator) - flip_u!(integrator.u) + return flip_u!(integrator.u) end new_rate = (u, p, t) -> min(rate(u.u, p, t), rate_control(u.u_control, p, t)) push!(jumps, ConstantRateJump(new_rate, new_affect!)) # only prob new_affect! = affect! new_rate = (u, p, t) -> rate(u.u, p, t) - - min(rate(u.u, p, t), rate_control(u.u_control, p, t)) + min(rate(u.u, p, t), rate_control(u.u_control, p, t)) push!(jumps, ConstantRateJump(new_rate, new_affect!)) # only prob_control new_affect! = function (integrator) flip_u!(integrator.u) affect!(integrator) - flip_u!(integrator.u) + return flip_u!(integrator.u) end new_rate = (u, p, t) -> rate_control(u.u_control, p, t) - - min(rate(u.u, p, t), rate_control(u.u_control, p, t)) + min(rate(u.u, p, t), rate_control(u.u_control, p, t)) push!(jumps, ConstantRateJump(new_rate, new_affect!)) end - jumps + return jumps end diff --git a/src/extended_jump_array.jl b/src/extended_jump_array.jl index f8bb07d37..702e2e7f9 100644 --- a/src/extended_jump_array.jl +++ b/src/extended_jump_array.jl @@ -58,7 +58,7 @@ sol = solve(jprob, Tsit5()) operations should use `ueja.u.u` to obtain the aliased state object. """ struct ExtendedJumpArray{T3 <: Number, T1, T <: AbstractArray{T3, T1}, T2} <: - AbstractArray{T3, 1} + AbstractArray{T3, 1} """The current state.""" u::T """The current rate (i.e. hazard, intensity, or propensity) values for the `VariableRateJump`s.""" @@ -68,36 +68,38 @@ end Base.length(A::ExtendedJumpArray) = length(A.u) + length(A.jump_u) Base.size(A::ExtendedJumpArray) = (length(A),) @inline function Base.getindex(A::ExtendedJumpArray, i::Int) - i <= length(A.u) ? A.u[i] : A.jump_u[i - length(A.u)] + return i <= length(A.u) ? A.u[i] : A.jump_u[i - length(A.u)] end @inline function Base.getindex(A::ExtendedJumpArray, I::Int...) - prod(I) <= length(A.u) ? A.u[I...] : A.jump_u[prod(I) - length(A.u)] + return prod(I) <= length(A.u) ? A.u[I...] : A.jump_u[prod(I) - length(A.u)] end @inline function Base.getindex(A::ExtendedJumpArray, I::CartesianIndex{1}) - A[I[1]] + return A[I[1]] end @inline Base.setindex!(A::ExtendedJumpArray, v, I...) = (A[CartesianIndices(A.u, I...)] = v) @inline Base.setindex!(A::ExtendedJumpArray, v, I::CartesianIndex{1}) = (A[I[1]] = v) @inline function Base.setindex!(A::ExtendedJumpArray, v, i::Int) - i <= length(A.u) ? (A.u[i] = v) : (A.jump_u[i - length(A.u)] = v) + return i <= length(A.u) ? (A.u[i] = v) : (A.jump_u[i - length(A.u)] = v) end Base.IndexStyle(::Type{<:ExtendedJumpArray}) = IndexLinear() Base.similar(A::ExtendedJumpArray) = ExtendedJumpArray(similar(A.u), similar(A.jump_u)) function Base.similar(A::ExtendedJumpArray, ::Type{S}) where {S} - ExtendedJumpArray(similar(A.u, S), similar(A.jump_u, S)) + return ExtendedJumpArray(similar(A.u, S), similar(A.jump_u, S)) end Base.zero(A::ExtendedJumpArray) = fill!(similar(A), 0) # Required for non-diagonal noise function LinearAlgebra.mul!(c::ExtendedJumpArray, A::AbstractVecOrMat, u::AbstractVector) - mul!(c.u, A, u) + return mul!(c.u, A, u) end # Ignore axes -function Base.similar(A::ExtendedJumpArray, ::Type{S}, - axes::Tuple{Base.OneTo{Int}}) where {S} - ExtendedJumpArray(similar(A.u, S), similar(A.jump_u, S)) +function Base.similar( + A::ExtendedJumpArray, ::Type{S}, + axes::Tuple{Base.OneTo{Int}} + ) where {S} + return ExtendedJumpArray(similar(A.u, S), similar(A.jump_u, S)) end # plotting @@ -105,21 +107,21 @@ SciMLBase.plottable_indices(u::ExtendedJumpArray) = SciMLBase.plottable_indices( # ODE norm to prevent type-unstable fallback @inline function DiffEqBase.ODE_DEFAULT_NORM(u::ExtendedJumpArray, t) - Base.FastMath.sqrt_fast(real(sum(abs2, u)) / max(length(u), 1)) + return Base.FastMath.sqrt_fast(real(sum(abs2, u)) / max(length(u), 1)) end # Stiff ODE solver function ArrayInterface.zeromatrix(A::ExtendedJumpArray) u = [vec(A.u); vec(A.jump_u)] - u .* u' .* false + return u .* u' .* false end function LinearAlgebra.ldiv!(A::LinearAlgebra.LU, b::ExtendedJumpArray) - LinearAlgebra.ldiv!(A, [vec(b.u); vec(b.jump_u)]) + return LinearAlgebra.ldiv!(A, [vec(b.u); vec(b.jump_u)]) end function recursivecopy!(dest::T, src::T) where {T <: ExtendedJumpArray} recursivecopy!(dest.u, src.u) - recursivecopy!(dest.jump_u, src.jump_u) + return recursivecopy!(dest.jump_u, src.jump_u) end Base.show(io::IO, A::ExtendedJumpArray) = show(io, A.u) plot_indices(A::ExtendedJumpArray) = eachindex(A.u) @@ -128,46 +130,71 @@ plot_indices(A::ExtendedJumpArray) = eachindex(A.u) # The jump array styles stores two sub-styles in the type, # one for the `u` array and one for the `jump_u` array -struct ExtendedJumpArrayStyle{UStyle <: Broadcast.BroadcastStyle, - JumpUStyle <: Broadcast.BroadcastStyle} <: - Broadcast.BroadcastStyle end +struct ExtendedJumpArrayStyle{ + UStyle <: Broadcast.BroadcastStyle, + JumpUStyle <: Broadcast.BroadcastStyle, + } <: + Broadcast.BroadcastStyle end # Init style based on type of u/jump_u function ExtendedJumpArrayStyle(::US, ::JumpUS) where {US, JumpUS} - ExtendedJumpArrayStyle{US, JumpUS}() + return ExtendedJumpArrayStyle{US, JumpUS}() end -function Base.BroadcastStyle(::Type{ExtendedJumpArray{ - T3, T1, UType, JumpUType}}) where {T3, +function Base.BroadcastStyle( + ::Type{ + ExtendedJumpArray{ + T3, T1, UType, JumpUType, + }, + } + ) where { + T3, T1, UType, - JumpUType -} - ExtendedJumpArrayStyle(Base.BroadcastStyle(UType), Base.BroadcastStyle(JumpUType)) + JumpUType, + } + return ExtendedJumpArrayStyle(Base.BroadcastStyle(UType), Base.BroadcastStyle(JumpUType)) end # Combine with other styles by combining individually with u/jump_u styles -function Base.BroadcastStyle(::ExtendedJumpArrayStyle{UStyle, JumpUStyle}, - ::Style) where {UStyle, JumpUStyle, - Style <: Base.Broadcast.BroadcastStyle} - ExtendedJumpArrayStyle(Broadcast.result_style(UStyle(), Style()), - Broadcast.result_style(JumpUStyle(), Style())) +function Base.BroadcastStyle( + ::ExtendedJumpArrayStyle{UStyle, JumpUStyle}, + ::Style + ) where { + UStyle, JumpUStyle, + Style <: Base.Broadcast.BroadcastStyle, + } + return ExtendedJumpArrayStyle( + Broadcast.result_style(UStyle(), Style()), + Broadcast.result_style(JumpUStyle(), Style()) + ) end # Decay back to the DefaultArrayStyle for higher-order default styles, to support adding to raw vectors as needed -function Base.BroadcastStyle(::ExtendedJumpArrayStyle{UStyle, JumpUStyle}, - ::Broadcast.DefaultArrayStyle{0}) where {UStyle, JumpUStyle} - ExtendedJumpArrayStyle(UStyle(), JumpUStyle()) +function Base.BroadcastStyle( + ::ExtendedJumpArrayStyle{UStyle, JumpUStyle}, + ::Broadcast.DefaultArrayStyle{0} + ) where {UStyle, JumpUStyle} + return ExtendedJumpArrayStyle(UStyle(), JumpUStyle()) end -function Base.BroadcastStyle(::ExtendedJumpArrayStyle{UStyle, JumpUStyle}, - ::Broadcast.DefaultArrayStyle{N}) where {N, UStyle, JumpUStyle} - Broadcast.DefaultArrayStyle{N}() +function Base.BroadcastStyle( + ::ExtendedJumpArrayStyle{UStyle, JumpUStyle}, + ::Broadcast.DefaultArrayStyle{N} + ) where {N, UStyle, JumpUStyle} + return Broadcast.DefaultArrayStyle{N}() end -function Base.Broadcast.BroadcastStyle(::S, - ::Base.Broadcast.Unknown) where { - UStyle, JumpUStyle, S <: JumpProcesses.ExtendedJumpArrayStyle{UStyle, JumpUStyle}} - return throw(ArgumentError("Cannot broadcast JumpProcesses.ExtendedJumpArray with" * - " something of type Base.Broadcast.Unknown."),) +function Base.Broadcast.BroadcastStyle( + ::S, + ::Base.Broadcast.Unknown + ) where { + UStyle, JumpUStyle, S <: JumpProcesses.ExtendedJumpArrayStyle{UStyle, JumpUStyle}, + } + return throw( + ArgumentError( + "Cannot broadcast JumpProcesses.ExtendedJumpArray with" * + " something of type Base.Broadcast.Unknown." + ), + ) end # Lookup the first ExtendedJumpArray to pick output container size @@ -181,23 +208,29 @@ find_eja(::Tuple{}) = nothing find_eja(a::ExtendedJumpArray, rest) = a find_eja(::Any, rest) = find_eja(rest) -function Base.similar(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}, - ::Type{ElType}) where {US, JumpUS, ElType} +function Base.similar( + bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}, + ::Type{ElType} + ) where {US, JumpUS, ElType} A = find_eja(bc) - ExtendedJumpArray(similar(A.u, ElType), similar(A.jump_u, ElType)) + return ExtendedJumpArray(similar(A.u, ElType), similar(A.jump_u, ElType)) end # Helper functions that repack broadcasted functions @inline function repack(bc::Broadcast.Broadcasted{Style}, i) where {Style} - Broadcast.Broadcasted{Style}(bc.f, repack_args(i, bc.args)) + return Broadcast.Broadcasted{Style}(bc.f, repack_args(i, bc.args)) end -@inline function repack(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}, - i::Val{:u}) where {US, JumpUS} - Broadcast.Broadcasted{US}(bc.f, repack_args(i, bc.args)) +@inline function repack( + bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}, + i::Val{:u} + ) where {US, JumpUS} + return Broadcast.Broadcasted{US}(bc.f, repack_args(i, bc.args)) end -@inline function repack(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}, - i::Val{:jump_u}) where {US, JumpUS} - Broadcast.Broadcasted{JumpUS}(bc.f, repack_args(i, bc.args)) +@inline function repack( + bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}, + i::Val{:jump_u} + ) where {US, JumpUS} + return Broadcast.Broadcasted{JumpUS}(bc.f, repack_args(i, bc.args)) end # Helper functions that repack arguments @@ -215,19 +248,21 @@ end end end -@inline function Base.copyto!(dest::ExtendedJumpArray, - bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}) where { +@inline function Base.copyto!( + dest::ExtendedJumpArray, + bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}} + ) where { US, - JumpUS -} + JumpUS, + } copyto!(dest.u, repack(bc, Val(:u))) copyto!(dest.jump_u, repack(bc, Val(:jump_u))) - dest + return dest end Base.:*(x::ExtendedJumpArray, y::Number) = ExtendedJumpArray(y .* x.u, y .* x.jump_u) Base.:*(y::Number, x::ExtendedJumpArray) = ExtendedJumpArray(y .* x.u, y .* x.jump_u) Base.:/(x::ExtendedJumpArray, y::Number) = ExtendedJumpArray(x.u ./ y, x.jump_u ./ y) function Base.:+(x::ExtendedJumpArray, y::ExtendedJumpArray) - ExtendedJumpArray(x.u .+ y.u, x.jump_u .+ y.jump_u) + return ExtendedJumpArray(x.u .+ y.u, x.jump_u .+ y.jump_u) end diff --git a/src/jumps.jl b/src/jumps.jl index 63704ef92..c7ef621d5 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -171,15 +171,17 @@ function VariableRateJump(rate, affect!; lrate = nothing, urate = nothing, abstol = 1e-12, reltol = 0) ``` """ -function VariableRateJump(rate, affect!; +function VariableRateJump( + rate, affect!; lrate = nothing, urate = nothing, rateinterval = nothing, rootfind = true, idxs = nothing, save_positions = (false, true), interp_points = 10, - abstol = 1e-12, reltol = 0) + abstol = 1.0e-12, reltol = 0 + ) if !(urate !== nothing && rateinterval !== nothing) && - !(urate === nothing && rateinterval === nothing) + !(urate === nothing && rateinterval === nothing) error("`urate` and `rateinterval` must both be `nothing`, or must both be defined.") end @@ -188,8 +190,10 @@ function VariableRateJump(rate, affect!; error("If a lower bound rate, `lrate`, is given than an upper bound rate, `urate`, and rate interval, `rateinterval`, must also be provided.") end - VariableRateJump(rate, affect!, lrate, urate, rateinterval, idxs, rootfind, - interp_points, save_positions, abstol, reltol) + return VariableRateJump( + rate, affect!, lrate, urate, rateinterval, idxs, rootfind, + interp_points, save_positions, abstol, reltol + ) end """ @@ -241,14 +245,14 @@ struct RegularJump{iip, R, C, MD} """ A distribution for marks. Not currently used or supported. """ mark_dist::MD function RegularJump{iip}(rate, c, numjumps::Int; mark_dist = nothing) where {iip} - new{iip, typeof(rate), typeof(c), typeof(mark_dist)}(rate, c, numjumps, mark_dist) + return new{iip, typeof(rate), typeof(c), typeof(mark_dist)}(rate, c, numjumps, mark_dist) end end DiffEqBase.isinplace(::RegularJump{iip, R, C, MD}) where {iip, R, C, MD} = iip function RegularJump(rate, c, numjumps::Int; kwargs...) - RegularJump{DiffEqBase.isinplace(rate, 4)}(rate, c, numjumps; kwargs...) + return RegularJump{DiffEqBase.isinplace(rate, 4)}(rate, c, numjumps; kwargs...) end # deprecate old call @@ -256,9 +260,9 @@ function RegularJump(rate, c, dc::AbstractMatrix; constant_c = false, mark_dist @warn("The RegularJump interface has changed to be matrix-free. See the documentation for more details.") function _c(du, u, p, t, counts, mark) c(dc, u, p, t, mark) - mul!(du, dc, counts) + return mul!(du, dc, counts) end - RegularJump{true}(rate, _c, size(dc, 2); mark_dist = mark_dist) + return RegularJump{true}(rate, _c, size(dc, 2); mark_dist = mark_dist) end """ @@ -334,9 +338,11 @@ struct MassActionJump{T, S, U, V} <: AbstractMassActionJump """Parameter mapping functor to identify reaction rate constants with parameters in `p` vectors.""" param_mapper::V - function MassActionJump{T, S, U, V}(rates::T, rs_in::S, ns::U, pmapper::V, + function MassActionJump{T, S, U, V}( + rates::T, rs_in::S, ns::U, pmapper::V, scale_rates::Bool, useiszero::Bool, - nocopy::Bool) where {T <: AbstractVector, S, U, V} + nocopy::Bool + ) where {T <: AbstractVector, S, U, V} sr = nocopy ? rates : copy(rates) rs = nocopy ? rs_in : copy(rs_in) for i in eachindex(rs) @@ -348,61 +354,83 @@ struct MassActionJump{T, S, U, V} <: AbstractMassActionJump if scale_rates && !isempty(sr) scalerates!(sr, rs) end - new(sr, rs, ns, pmapper) + return new(sr, rs, ns, pmapper) end - function MassActionJump{Nothing, Vector{S}, - Vector{U}, V}(::Nothing, rs_in::Vector{S}, + function MassActionJump{ + Nothing, Vector{S}, + Vector{U}, V, + }( + ::Nothing, rs_in::Vector{S}, ns::Vector{U}, pmapper::V, scale_rates::Bool, useiszero::Bool, - nocopy::Bool) where {S <: AbstractVector, - U <: AbstractVector, V} + nocopy::Bool + ) where { + S <: AbstractVector, + U <: AbstractVector, V, + } rs = nocopy ? rs_in : copy(rs_in) for i in eachindex(rs) if useiszero && (length(rs[i]) == 1) && iszero(rs[i][1][1]) rs[i] = typeof(rs[i])() end end - new(nothing, rs, ns, pmapper) + return new(nothing, rs, ns, pmapper) end - function MassActionJump{T, S, U, V}(rate::T, rs_in::S, ns::U, pmapper::V, + function MassActionJump{T, S, U, V}( + rate::T, rs_in::S, ns::U, pmapper::V, scale_rates::Bool, useiszero::Bool, - nocopy::Bool) where {T <: Number, S, U, V} + nocopy::Bool + ) where {T <: Number, S, U, V} rs = rs_in if useiszero && (length(rs) == 1) && iszero(rs[1][1]) rs = typeof(rs)() end sr = scale_rates ? scalerate(rate, rs) : rate - new(sr, rs, ns, pmapper) + return new(sr, rs, ns, pmapper) end - function MassActionJump{Nothing, S, U, V}(::Nothing, rs_in::S, ns::U, pmapper::V, + function MassActionJump{Nothing, S, U, V}( + ::Nothing, rs_in::S, ns::U, pmapper::V, scale_rates::Bool, useiszero::Bool, - nocopy::Bool) where {S, U, V} + nocopy::Bool + ) where {S, U, V} rs = rs_in if useiszero && (length(rs) == 1) && iszero(rs[1][1]) rs = typeof(rs)() end - new(nothing, rs, ns, pmapper) + return new(nothing, rs, ns, pmapper) end end -function MassActionJump(usr::T, rs::S, ns::U, pmapper::V; scale_rates = true, - useiszero = true, nocopy = false) where {T, S, U, V} - MassActionJump{T, S, U, V}(usr, rs, ns, pmapper, scale_rates, useiszero, nocopy) -end -function MassActionJump(usr::T, rs, ns; scale_rates = true, useiszero = true, - nocopy = false) where {T <: AbstractVector} - MassActionJump(usr, rs, ns, nothing; scale_rates = scale_rates, useiszero = useiszero, - nocopy = nocopy) -end -function MassActionJump(usr::T, rs, ns; scale_rates = true, useiszero = true, - nocopy = false) where {T <: Number} - MassActionJump(usr, rs, ns, nothing; scale_rates = scale_rates, useiszero = useiszero, - nocopy = nocopy) +function MassActionJump( + usr::T, rs::S, ns::U, pmapper::V; scale_rates = true, + useiszero = true, nocopy = false + ) where {T, S, U, V} + return MassActionJump{T, S, U, V}(usr, rs, ns, pmapper, scale_rates, useiszero, nocopy) +end +function MassActionJump( + usr::T, rs, ns; scale_rates = true, useiszero = true, + nocopy = false + ) where {T <: AbstractVector} + return MassActionJump( + usr, rs, ns, nothing; scale_rates = scale_rates, useiszero = useiszero, + nocopy = nocopy + ) +end +function MassActionJump( + usr::T, rs, ns; scale_rates = true, useiszero = true, + nocopy = false + ) where {T <: Number} + return MassActionJump( + usr, rs, ns, nothing; scale_rates = scale_rates, useiszero = useiszero, + nocopy = nocopy + ) end # with parameter indices or mapping, multiple jump case -function MassActionJump(rs, ns; param_idxs = nothing, param_mapper = nothing, - scale_rates = true, useiszero = true, nocopy = false) +function MassActionJump( + rs, ns; param_idxs = nothing, param_mapper = nothing, + scale_rates = true, useiszero = true, nocopy = false + ) if param_mapper === nothing (param_idxs === nothing) && error("If no parameter indices are given via param_idxs, an explicit parameter mapping must be passed in via param_mapper.") @@ -413,8 +441,10 @@ function MassActionJump(rs, ns; param_idxs = nothing, param_mapper = nothing, pmapper = param_mapper end - MassActionJump(nothing, nocopy ? rs : copy(rs), ns, pmapper; scale_rates = scale_rates, - useiszero = useiszero, nocopy = true) + return MassActionJump( + nothing, nocopy ? rs : copy(rs), ns, pmapper; scale_rates = scale_rates, + useiszero = useiszero, nocopy = true + ) end using_params(maj::MassActionJump{T, S, U, Nothing}) where {T, S, U} = false @@ -430,43 +460,53 @@ end # create the initial parameter vector for use in a MassActionJump # Note these are unscaled function (ratemap::MassActionJumpParamMapper{U})(params) where {U <: AbstractArray} - [params[pidx] for pidx in ratemap.param_idxs] + return [params[pidx] for pidx in ratemap.param_idxs] end # Note this is unscaled function (ratemap::MassActionJumpParamMapper{U})(params) where {U <: Int} - params[ratemap.param_idxs] + return params[ratemap.param_idxs] end # update a maj with parameter vectors -function (ratemap::MassActionJumpParamMapper{U})(maj::MassActionJump, newparams; +function (ratemap::MassActionJumpParamMapper{U})( + maj::MassActionJump, newparams; scale_rates, - kwargs...) where {U <: AbstractArray} + kwargs... + ) where {U <: AbstractArray} for i in 1:get_num_majumps(maj) maj.scaled_rates[i] = newparams[ratemap.param_idxs[i]] end scale_rates && scalerates!(maj.scaled_rates, maj.reactant_stoch) - nothing + return nothing end function to_collection(ratemap::MassActionJumpParamMapper{Int}) - MassActionJumpParamMapper([ratemap.param_idxs]) + return MassActionJumpParamMapper([ratemap.param_idxs]) end -function Base.merge!(pmap1::MassActionJumpParamMapper{U}, - pmap2::MassActionJumpParamMapper{U}) where {U <: AbstractVector} - append!(pmap1.param_idxs, pmap2.param_idxs) +function Base.merge!( + pmap1::MassActionJumpParamMapper{U}, + pmap2::MassActionJumpParamMapper{U} + ) where {U <: AbstractVector} + return append!(pmap1.param_idxs, pmap2.param_idxs) end -function Base.merge!(pmap1::MassActionJumpParamMapper{U}, - pmap2::MassActionJumpParamMapper{V}) where {U <: AbstractVector, - V <: Int} - push!(pmap1.param_idxs, pmap2.param_idxs) +function Base.merge!( + pmap1::MassActionJumpParamMapper{U}, + pmap2::MassActionJumpParamMapper{V} + ) where { + U <: AbstractVector, + V <: Int, + } + return push!(pmap1.param_idxs, pmap2.param_idxs) end -function Base.merge(pmap1::MassActionJumpParamMapper{Int}, - pmap2::MassActionJumpParamMapper{Int}) - MassActionJumpParamMapper([pmap1.param_idxs, pmap2.param_idxs]) +function Base.merge( + pmap1::MassActionJumpParamMapper{Int}, + pmap2::MassActionJumpParamMapper{Int} + ) + return MassActionJumpParamMapper([pmap1.param_idxs, pmap2.param_idxs]) end """ @@ -484,7 +524,7 @@ Notes: function update_parameters!(maj::MassActionJump, newparams; scale_rates = true, kwargs...) (maj.param_mapper === nothing) && error("MassActionJumps must be constructed with param_idxs or a param_mapper to be updateable.") - maj.param_mapper(maj, newparams; scale_rates, kwargs) + return maj.param_mapper(maj, newparams; scale_rates, kwargs) end """ @@ -532,22 +572,24 @@ struct JumpSet{T1, T2, T3, T4} <: AbstractJump massaction_jump::T4 end function JumpSet(vj, cj, rj, maj::MassActionJump{S, T, U, V}) where {S <: Number, T, U, V} - JumpSet(vj, cj, rj, check_majump_type(maj)) + return JumpSet(vj, cj, rj, check_majump_type(maj)) end JumpSet(jump::ConstantRateJump) = JumpSet((), (jump,), nothing, nothing) JumpSet(jump::VariableRateJump) = JumpSet((jump,), (), nothing, nothing) JumpSet(jump::RegularJump) = JumpSet((), (), jump, nothing) JumpSet(jump::AbstractMassActionJump) = JumpSet((), (), nothing, jump) -function JumpSet(; variable_jumps = (), constant_jumps = (), - regular_jumps = nothing, massaction_jumps = nothing) - JumpSet(variable_jumps, constant_jumps, regular_jumps, massaction_jumps) +function JumpSet(; + variable_jumps = (), constant_jumps = (), + regular_jumps = nothing, massaction_jumps = nothing + ) + return JumpSet(variable_jumps, constant_jumps, regular_jumps, massaction_jumps) end JumpSet(jb::Nothing) = JumpSet() # For Varargs, use recursion to make it type-stable function JumpSet(jumps::AbstractJump...) - JumpSet(split_jumps((), (), nothing, nothing, jumps...)...) + return JumpSet(split_jumps((), (), nothing, nothing, jumps...)...) end # handle vector of mass action jumps @@ -556,32 +598,34 @@ function JumpSet(vjs, cjs, rj, majv::Vector{T}) where {T <: MassActionJump} error("JumpSets do not accept empty mass action jump collections; use \"nothing\" instead.") end - maj = setup_majump_to_merge(majv[1].scaled_rates, majv[1].reactant_stoch, - majv[1].net_stoch, majv[1].param_mapper) + maj = setup_majump_to_merge( + majv[1].scaled_rates, majv[1].reactant_stoch, + majv[1].net_stoch, majv[1].param_mapper + ) for i in 2:length(majv) massaction_jump_combine(maj, majv[i]) end - JumpSet(vjs, cjs, rj, maj) + return JumpSet(vjs, cjs, rj, maj) end @inline get_num_majumps(jset::JumpSet) = get_num_majumps(jset.massaction_jump) @inline num_majumps(jset::JumpSet) = get_num_majumps(jset) @inline function num_crjs(jset::JumpSet) - (jset.constant_jumps !== nothing) ? length(jset.constant_jumps) : 0 + return (jset.constant_jumps !== nothing) ? length(jset.constant_jumps) : 0 end @inline function num_vrjs(jset::JumpSet) - (jset.variable_jumps !== nothing) ? length(jset.variable_jumps) : 0 + return (jset.variable_jumps !== nothing) ? length(jset.variable_jumps) : 0 end @inline function num_bndvrjs(jset::JumpSet) - (jset.variable_jumps !== nothing) ? count(isbounded, jset.variable_jumps) : 0 + return (jset.variable_jumps !== nothing) ? count(isbounded, jset.variable_jumps) : 0 end @inline function num_continvrjs(jset::JumpSet) - (jset.variable_jumps !== nothing) ? count(!isbounded, jset.variable_jumps) : 0 + return (jset.variable_jumps !== nothing) ? count(!isbounded, jset.variable_jumps) : 0 end num_jumps(jset::JumpSet) = num_majumps(jset) + num_crjs(jset) + num_vrjs(jset) @@ -590,22 +634,24 @@ num_cdiscretejumps(jset::JumpSet) = num_majumps(jset) + num_crjs(jset) @inline split_jumps(vj, cj, rj, maj) = vj, cj, rj, maj @inline function split_jumps(vj, cj, rj, maj, v::VariableRateJump, args...) - split_jumps((vj..., v), cj, rj, maj, args...) + return split_jumps((vj..., v), cj, rj, maj, args...) end @inline function split_jumps(vj, cj, rj, maj, c::ConstantRateJump, args...) - split_jumps(vj, (cj..., c), rj, maj, args...) + return split_jumps(vj, (cj..., c), rj, maj, args...) end @inline function split_jumps(vj, cj, rj, maj, c::RegularJump, args...) - split_jumps(vj, cj, regular_jump_combine(rj, c), maj, args...) + return split_jumps(vj, cj, regular_jump_combine(rj, c), maj, args...) end @inline function split_jumps(vj, cj, rj, maj, c::MassActionJump, args...) - split_jumps(vj, cj, rj, massaction_jump_combine(maj, c), args...) + return split_jumps(vj, cj, rj, massaction_jump_combine(maj, c), args...) end @inline function split_jumps(vj, cj, rj, maj, j::JumpSet, args...) - split_jumps((vj..., j.variable_jumps...), + return split_jumps( + (vj..., j.variable_jumps...), (cj..., j.constant_jumps...), regular_jump_combine(rj, j.regular_jump), - massaction_jump_combine(maj, j.massaction_jump), args...) + massaction_jump_combine(maj, j.massaction_jump), args... + ) end regular_jump_combine(rj1::RegularJump, rj2::Nothing) = rj1 @@ -617,43 +663,65 @@ end # functionality to merge two mass action jumps together function check_majump_type(maj::MassActionJump{S, T, U, V}) where {S <: Number, T, U, V} - setup_majump_to_merge(maj.scaled_rates, maj.reactant_stoch, maj.net_stoch, - maj.param_mapper) + return setup_majump_to_merge( + maj.scaled_rates, maj.reactant_stoch, maj.net_stoch, + maj.param_mapper + ) end function check_majump_type(maj::MassActionJump{Nothing, T, U, V}) where {T, U, V} - setup_majump_to_merge(maj.scaled_rates, maj.reactant_stoch, maj.net_stoch, - maj.param_mapper) + return setup_majump_to_merge( + maj.scaled_rates, maj.reactant_stoch, maj.net_stoch, + maj.param_mapper + ) end # if given containers of rates and stoichiometry directly create a jump -function setup_majump_to_merge(sr::T, rs::AbstractVector{S}, ns::AbstractVector{U}, - pmapper) where {T <: AbstractVector, S <: AbstractArray, - U <: AbstractArray} - MassActionJump(sr, rs, ns, pmapper; scale_rates = false) +function setup_majump_to_merge( + sr::T, rs::AbstractVector{S}, ns::AbstractVector{U}, + pmapper + ) where { + T <: AbstractVector, S <: AbstractArray, + U <: AbstractArray, + } + return MassActionJump(sr, rs, ns, pmapper; scale_rates = false) end # if just given the data for one jump (and not in a container) wrap in a vector -function setup_majump_to_merge(sr::S, rs::T, ns::U, - pmapper) where {S <: Number, T <: AbstractArray, - U <: AbstractArray} - MassActionJump([sr], [rs], [ns], +function setup_majump_to_merge( + sr::S, rs::T, ns::U, + pmapper + ) where { + S <: Number, T <: AbstractArray, + U <: AbstractArray, + } + return MassActionJump( + [sr], [rs], [ns], (pmapper === nothing) ? pmapper : to_collection(pmapper); - scale_rates = false) + scale_rates = false + ) end # if no rate field setup yet -function setup_majump_to_merge(::Nothing, rs::T, ns::U, - pmapper) where {T <: AbstractArray, U <: AbstractArray} - MassActionJump(nothing, [rs], [ns], +function setup_majump_to_merge( + ::Nothing, rs::T, ns::U, + pmapper + ) where {T <: AbstractArray, U <: AbstractArray} + return MassActionJump( + nothing, [rs], [ns], (pmapper === nothing) ? pmapper : to_collection(pmapper); - scale_rates = false) + scale_rates = false + ) end # when given a collection of reactions to add to maj -function majump_merge!(maj::MassActionJump{U, <:AbstractVector{V}, <:AbstractVector{W}, X}, +function majump_merge!( + maj::MassActionJump{U, <:AbstractVector{V}, <:AbstractVector{W}, X}, sr::U, rs::AbstractVector{V}, ns::AbstractVector{W}, - param_mapper) where {U <: Union{AbstractVector, Nothing}, - V <: AbstractVector, W <: AbstractVector, X} + param_mapper + ) where { + U <: Union{AbstractVector, Nothing}, + V <: AbstractVector, W <: AbstractVector, X, + } (U <: AbstractVector) && append!(maj.scaled_rates, sr) append!(maj.reactant_stoch, rs) append!(maj.net_stoch, ns) @@ -663,16 +731,20 @@ function majump_merge!(maj::MassActionJump{U, <:AbstractVector{V}, <:AbstractVec else merge!(maj.param_mapper, param_mapper) end - maj + return maj end # when given a single jump's worth of data to add to maj -function majump_merge!(maj::MassActionJump{U, V, W, X}, sr::T, rs::S1, ns::S2, - param_mapper) where {T <: Union{Number, Nothing}, +function majump_merge!( + maj::MassActionJump{U, V, W, X}, sr::T, rs::S1, ns::S2, + param_mapper + ) where { + T <: Union{Number, Nothing}, S1 <: AbstractArray, S2 <: AbstractArray, U <: Union{AbstractVector{T}, Nothing}, V <: AbstractVector{S1}, - W <: AbstractVector{S2}, X} + W <: AbstractVector{S2}, X, + } (T <: Number) && push!(maj.scaled_rates, sr) push!(maj.reactant_stoch, rs) push!(maj.net_stoch, ns) @@ -683,24 +755,32 @@ function majump_merge!(maj::MassActionJump{U, V, W, X}, sr::T, rs::S1, ns::S2, merge!(maj.param_mapper, param_mapper) end - maj + return maj end # when maj only stores a single jump's worth of data (and not in a collection) # create a new jump with the merged data stored in vectors -function majump_merge!(maj::MassActionJump{T, S, U, V}, sr::T, rs::S, ns::U, - param_mapper::V) where {T <: Union{Number, Nothing}, +function majump_merge!( + maj::MassActionJump{T, S, U, V}, sr::T, rs::S, ns::U, + param_mapper::V + ) where { + T <: Union{Number, Nothing}, S <: AbstractArray{<:Pair}, - U <: AbstractArray{<:Pair}, V} + U <: AbstractArray{<:Pair}, V, + } rates = (T <: Nothing) ? nothing : [maj.scaled_rates, sr] if maj.param_mapper === nothing (param_mapper === nothing) || error("Error, trying to merge a MassActionJump with a parameter mapping to one without a parameter mapping.") - return MassActionJump(rates, [maj.reactant_stoch, rs], [maj.net_stoch, ns], - param_mapper; scale_rates = false) + return MassActionJump( + rates, [maj.reactant_stoch, rs], [maj.net_stoch, ns], + param_mapper; scale_rates = false + ) else - return MassActionJump(rates, [maj.reactant_stoch, rs], [maj.net_stoch, ns], - merge(maj.param_mapper, param_mapper); scale_rates = false) + return MassActionJump( + rates, [maj.reactant_stoch, rs], [maj.net_stoch, ns], + merge(maj.param_mapper, param_mapper); scale_rates = false + ) end end @@ -708,8 +788,10 @@ massaction_jump_combine(maj1::MassActionJump, maj2::Nothing) = maj1 massaction_jump_combine(maj1::Nothing, maj2::MassActionJump) = maj2 massaction_jump_combine(maj1::Nothing, maj2::Nothing) = maj1 function massaction_jump_combine(maj1::MassActionJump, maj2::MassActionJump) - majump_merge!(maj1, maj2.scaled_rates, maj2.reactant_stoch, maj2.net_stoch, - maj2.param_mapper) + return majump_merge!( + maj1, maj2.scaled_rates, maj2.reactant_stoch, maj2.net_stoch, + maj2.param_mapper + ) end ##### helper methods for unpacking rates and affects! from constant jumps ##### @@ -722,12 +804,14 @@ function get_jump_info_tuples(jumps) affects! = () end - rates, affects! + return rates, affects! end function get_jump_info_fwrappers(u, p, t, jumps) - RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), - Tuple{typeof(u), typeof(p), typeof(t)}} + RateWrapper = FunctionWrappers.FunctionWrapper{ + typeof(t), + Tuple{typeof(u), typeof(p), typeof(t)}, + } if (jumps !== nothing) && !isempty(jumps) rates = [RateWrapper(c.rate) for c in jumps] @@ -737,5 +821,5 @@ function get_jump_info_fwrappers(u, p, t, jumps) affects! = Any[] end - rates, affects! + return rates, affects! end diff --git a/src/massaction_rates.jl b/src/massaction_rates.jl index 06d4dfa8c..9ebd173c0 100644 --- a/src/massaction_rates.jl +++ b/src/massaction_rates.jl @@ -3,8 +3,10 @@ # stochiometric coefficient. ############################################################################### -@inline function evalrxrate(speciesvec::AbstractVector{T}, rxidx, - majump::MassActionJump{U})::R where {T <: Integer, R, U <: AbstractVector{R}} +@inline function evalrxrate( + speciesvec::AbstractVector{T}, rxidx, + majump::MassActionJump{U} + )::R where {T <: Integer, R, U <: AbstractVector{R}} val = one(T) @inbounds for specstoch in majump.reactant_stoch[rxidx] specpop = speciesvec[specstoch[1]] @@ -18,8 +20,10 @@ @inbounds return val * majump.scaled_rates[rxidx] end -@inline function evalrxrate(speciesvec::AbstractVector{T}, rxidx, - majump::MassActionJump{U})::R where {T <: Real, R, U <: AbstractVector{R}} +@inline function evalrxrate( + speciesvec::AbstractVector{T}, rxidx, + majump::MassActionJump{U} + )::R where {T <: Real, R, U <: AbstractVector{R}} val = one(T) @inbounds for specstoch in majump.reactant_stoch[rxidx] specpop = speciesvec[specstoch[1]] @@ -29,30 +33,36 @@ end val *= specpop end # we need to check the smallest rate law term is positive - # i.e. for an order k reaction: x - k + 1 > 0 + # i.e. for an order k reaction: x - k + 1 > 0 (specpop <= 0) && return zero(R) end @inbounds return val * majump.scaled_rates[rxidx] end -@inline function executerx!(speciesvec::AbstractVector{T}, rxidx::S, - majump::M) where {T, S, M <: AbstractMassActionJump} +@inline function executerx!( + speciesvec::AbstractVector{T}, rxidx::S, + majump::M + ) where {T, S, M <: AbstractMassActionJump} @inbounds net_stoch = majump.net_stoch[rxidx] @inbounds for specstoch in net_stoch speciesvec[specstoch[1]] += specstoch[2] end - nothing + return nothing end -@inline function executerx(speciesvec::SVector{T}, rxidx::S, - majump::M) where {T, S, M <: AbstractMassActionJump} +@inline function executerx( + speciesvec::SVector{T}, rxidx::S, + majump::M + ) where {T, S, M <: AbstractMassActionJump} @inbounds net_stoch = majump.net_stoch[rxidx] @inbounds for specstoch in net_stoch - speciesvec = setindex(speciesvec, speciesvec[specstoch[1]] + specstoch[2], - specstoch[1]) + speciesvec = setindex( + speciesvec, speciesvec[specstoch[1]] + specstoch[2], + specstoch[1] + ) end - speciesvec + return speciesvec #= map(net_stoch) do stoch @@ -61,9 +71,13 @@ end =# end -function scalerates!(unscaled_rates::AbstractVector{U}, - stochmat::AbstractVector{V}) where {U, S, T, W <: Pair{S, T}, - V <: AbstractVector{W}} +function scalerates!( + unscaled_rates::AbstractVector{U}, + stochmat::AbstractVector{V} + ) where { + U, S, T, W <: Pair{S, T}, + V <: AbstractVector{W}, + } @inbounds for i in eachindex(unscaled_rates) coef = one(T) @inbounds for specstoch in stochmat[i] @@ -71,12 +85,16 @@ function scalerates!(unscaled_rates::AbstractVector{U}, end unscaled_rates[i] /= coef end - nothing + return nothing end -function scalerates!(unscaled_rates::AbstractMatrix{U}, - stochmat::AbstractVector{V}) where {U, S, T, W <: Pair{S, T}, - V <: AbstractVector{W}} +function scalerates!( + unscaled_rates::AbstractMatrix{U}, + stochmat::AbstractVector{V} + ) where { + U, S, T, W <: Pair{S, T}, + V <: AbstractVector{W}, + } @inbounds for i in size(unscaled_rates, 1) coef = one(T) @inbounds for specstoch in stochmat[i] @@ -84,16 +102,18 @@ function scalerates!(unscaled_rates::AbstractMatrix{U}, end unscaled_rates[i, :] /= coef end - nothing + return nothing end -function scalerate(unscaled_rate::U, - stochmat::AbstractVector{Pair{S, T}}) where {U <: Number, S, T} +function scalerate( + unscaled_rate::U, + stochmat::AbstractVector{Pair{S, T}} + ) where {U <: Number, S, T} coef = one(T) @inbounds for specstoch in stochmat coef *= factorial(specstoch[2]) end - unscaled_rate /= coef + return unscaled_rate /= coef end ############################################################################### @@ -116,14 +136,14 @@ function var_to_jumps_map(numspec, ma_jumps::AbstractMassActionJump) end foreach(s -> unique!(sort!(s)), spec_to_dep_rxs) - spec_to_dep_rxs + return spec_to_dep_rxs end """ make a map from reactions to dependent species """ function jump_to_vars_map(majumps) - [[s for (s, c) in majumps.net_stoch[i]] for i in 1:get_num_majumps(majumps)] + return [[s for (s, c) in majumps.net_stoch[i]] for i in 1:get_num_majumps(majumps)] end # dependency graph is a map from a reaction to a vector of reactions @@ -146,7 +166,7 @@ function make_dependency_graph(numspec, ma_jumps::AbstractMassActionJump) add_self_dependencies!(dep_graph, dosort = false) foreach(deps -> unique!(sort!(deps)), dep_graph) - dep_graph + return dep_graph end # update dependency graph to make sure jumps depend on themselves @@ -157,4 +177,5 @@ function add_self_dependencies!(dg; dosort = true) dosort && sort!(jump_deps) end end + return end diff --git a/src/problem.jl b/src/problem.jl index aeb97485e..a6dfff35b 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -1,6 +1,6 @@ function isinplace_jump(p, rj) - if p isa DiscreteProblem && p.f === DiffEqBase.DISCRETE_INPLACE_DEFAULT && - rj !== nothing + return if p isa DiscreteProblem && p.f === DiffEqBase.DISCRETE_INPLACE_DEFAULT && + rj !== nothing # Just a default discrete problem f, so don't use it for iip DiffEqBase.isinplace(rj) else @@ -67,8 +67,10 @@ page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_e the DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage examples and commonly asked questions. """ -mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J1, - J2, J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J} +mutable struct JumpProblem{ + iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J1, + J2, J3, J4, R, K, + } <: DiffEqBase.AbstractJumpProblem{P, J} """The type of problem to couple the jumps to. For a pure jump process use `DiscreteProblem`, to couple to ODEs, `ODEProblem`, etc.""" prob::P """The aggregator algorithm that determines the next jump times and types for `ConstantRateJump`s and `MassActionJump`s. Examples include `Direct`.""" @@ -90,22 +92,26 @@ mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggrega """kwargs to pass on to solve call.""" kwargs::K end -function JumpProblem(p::P, a::A, dj::J, jc::C, cj::J1, vj::J2, rj::J3, mj::J4, - rng::R, kwargs::K) where {P, A, J, C, J1, J2, J3, J4, R, K} +function JumpProblem( + p::P, a::A, dj::J, jc::C, cj::J1, vj::J2, rj::J3, mj::J4, + rng::R, kwargs::K + ) where {P, A, J, C, J1, J2, J3, J4, R, K} iip = isinplace_jump(p, rj) - JumpProblem{iip, P, A, C, J, J1, J2, J3, J4, R, K}(p, a, dj, jc, cj, vj, rj, mj, - rng, kwargs) + return JumpProblem{iip, P, A, C, J, J1, J2, J3, J4, R, K}( + p, a, dj, jc, cj, vj, rj, mj, + rng, kwargs + ) end ######## remaking ###### -# for a problem where prob.u0 is an ExtendedJumpArray, create an ExtendedJumpArray that +# for a problem where prob.u0 is an ExtendedJumpArray, create an ExtendedJumpArray that # aliases and resets prob.u0.jump_u while having newu0 as the new u component. function remake_extended_u0(prob, newu0, rng) jump_u = prob.u0.jump_u ttype = eltype(prob.tspan) @. jump_u = -randexp(rng, ttype) - ExtendedJumpArray(newu0, jump_u) + return ExtendedJumpArray(newu0, jump_u) end Base.@pure remaker_of(prob::T) where {T <: JumpProblem} = DiffEqBase.parameterless_type(T) @@ -156,9 +162,11 @@ function DiffEqBase.remake(jprob::JumpProblem; kwargs...) end end - T(newprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback, - jprob.constant_jumps, jprob.variable_jumps, jprob.regular_jump, - jprob.massaction_jump, jprob.rng, jprob.kwargs) + return T( + newprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback, + jprob.constant_jumps, jprob.variable_jumps, jprob.regular_jump, + jprob.massaction_jump, jprob.rng, jprob.kwargs + ) end # for updating parameters in JumpProblems to update MassActionJumps @@ -166,51 +174,63 @@ function SII.finalize_parameters_hook!(prob::JumpProblem, p) if using_params(prob.massaction_jump) update_parameters!(prob.massaction_jump, SII.parameter_values(prob)) end - nothing + return nothing end DiffEqBase.isinplace(::JumpProblem{iip}) where {iip} = iip JumpProblem(prob::JumpProblem) = prob function JumpProblem(prob, jumps::ConstantRateJump; kwargs...) - JumpProblem(prob, JumpSet(jumps); kwargs...) + return JumpProblem(prob, JumpSet(jumps); kwargs...) end function JumpProblem(prob, jumps::VariableRateJump; kwargs...) - JumpProblem(prob, JumpSet(jumps); kwargs...) + return JumpProblem(prob, JumpSet(jumps); kwargs...) end function JumpProblem(prob, jumps::RegularJump; kwargs...) - JumpProblem(prob, JumpSet(jumps); kwargs...) + return JumpProblem(prob, JumpSet(jumps); kwargs...) end function JumpProblem(prob, jumps::MassActionJump; kwargs...) - JumpProblem(prob, JumpSet(jumps); kwargs...) + return JumpProblem(prob, JumpSet(jumps); kwargs...) end function JumpProblem(prob, jumps::AbstractJump...; kwargs...) - JumpProblem(prob, JumpSet(jumps...); kwargs...) + return JumpProblem(prob, JumpSet(jumps...); kwargs...) end -function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, - jumps::ConstantRateJump; kwargs...) - JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...) +function JumpProblem( + prob, aggregator::AbstractAggregatorAlgorithm, + jumps::ConstantRateJump; kwargs... + ) + return JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...) end -function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, - jumps::VariableRateJump; kwargs...) - JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...) +function JumpProblem( + prob, aggregator::AbstractAggregatorAlgorithm, + jumps::VariableRateJump; kwargs... + ) + return JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...) end -function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::RegularJump; - kwargs...) - JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...) +function JumpProblem( + prob, aggregator::AbstractAggregatorAlgorithm, jumps::RegularJump; + kwargs... + ) + return JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...) end -function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, - jumps::AbstractMassActionJump; kwargs...) - JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...) +function JumpProblem( + prob, aggregator::AbstractAggregatorAlgorithm, + jumps::AbstractMassActionJump; kwargs... + ) + return JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...) end -function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::AbstractJump...; - kwargs...) - JumpProblem(prob, aggregator, JumpSet(jumps...); kwargs...) +function JumpProblem( + prob, aggregator::AbstractAggregatorAlgorithm, jumps::AbstractJump...; + kwargs... + ) + return JumpProblem(prob, aggregator, JumpSet(jumps...); kwargs...) end -function JumpProblem(prob, jumps::JumpSet; vartojumps_map = nothing, +function JumpProblem( + prob, jumps::JumpSet; vartojumps_map = nothing, jumptovars_map = nothing, dep_graph = nothing, - spatial_system = nothing, hopping_constants = nothing, kwargs...) + spatial_system = nothing, hopping_constants = nothing, kwargs... + ) ps = (; vartojumps_map, jumptovars_map, dep_graph, spatial_system, hopping_constants) aggtype = select_aggregator(jumps; ps...) return JumpProblem(prob, aggtype(), jumps; ps..., kwargs...) @@ -218,27 +238,31 @@ end # this makes it easier to test the aggregator selection function JumpProblem(prob, aggregator::NullAggregator, jumps::JumpSet; kwargs...) - JumpProblem(prob, jumps; kwargs...) + return JumpProblem(prob, jumps; kwargs...) end make_kwarg(; kwargs...) = kwargs -function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpSet; +function JumpProblem( + prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpSet; vr_aggregator::VariableRateAggregator = VR_FRM(), save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? - (false, true) : (true, true), + (false, true) : (true, true), rng = DEFAULT_RNG, scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, - callback = nothing, use_vrj_bounds = true, kwargs...) + callback = nothing, use_vrj_bounds = true, kwargs... + ) # initialize the MassActionJump rate constants with the user parameters if using_params(jumps.massaction_jump) rates = jumps.massaction_jump.param_mapper(prob.p) - maj = MassActionJump(rates, jumps.massaction_jump.reactant_stoch, + maj = MassActionJump( + rates, jumps.massaction_jump.reactant_stoch, jumps.massaction_jump.net_stoch, jumps.massaction_jump.param_mapper; scale_rates = scale_rates, useiszero = useiszero, - nocopy = true) + nocopy = true + ) else maj = jumps.massaction_jump end @@ -276,16 +300,20 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS disc_agg = nothing constant_jump_callback = CallbackSet() else - disc_agg = aggregate(aggregator, u, prob.p, t, end_time, crjs, maj, - save_positions, rng; kwargs...) + disc_agg = aggregate( + aggregator, u, prob.p, t, end_time, crjs, maj, + save_positions, rng; kwargs... + ) constant_jump_callback = DiscreteCallback(disc_agg) end # handle any remaining vrjs if length(cvrjs) > 0 # Handle variable rate jumps based on vr_aggregator - new_prob, variable_jump_callback = configure_jump_problem(prob, vr_aggregator, - jumps, cvrjs; rng) + new_prob, variable_jump_callback = configure_jump_problem( + prob, vr_aggregator, + jumps, cvrjs; rng + ) else new_prob = prob variable_jump_callback = CallbackSet() @@ -296,19 +324,25 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS iip = isinplace_jump(prob, jumps.regular_jump) solkwargs = make_kwarg(; callback) - JumpProblem{iip, typeof(new_prob), typeof(aggregator), typeof(jump_cbs), + return JumpProblem{ + iip, typeof(new_prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(cvrjs), typeof(jumps.regular_jump), - typeof(maj), typeof(rng), typeof(solkwargs)}(new_prob, aggregator, disc_agg, - jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, rng, solkwargs) + typeof(maj), typeof(rng), typeof(solkwargs), + }( + new_prob, aggregator, disc_agg, + jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, rng, solkwargs + ) end # Special dispatch for PureLeaping aggregator - bypasses all aggregation -function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; +function JumpProblem( + prob, aggregator::PureLeaping, jumps::JumpSet; save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? - (false, true) : (true, true), + (false, true) : (true, true), rng = DEFAULT_RNG, scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, - callback = nothing, kwargs...) + callback = nothing, kwargs... + ) # Validate no spatial systems (not currently supported) (spatial_system !== nothing || hopping_constants !== nothing) && @@ -317,11 +351,13 @@ function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; # Initialize the MassActionJump rate constants with the user parameters if using_params(jumps.massaction_jump) rates = jumps.massaction_jump.param_mapper(prob.p) - maj = MassActionJump(rates, jumps.massaction_jump.reactant_stoch, + maj = MassActionJump( + rates, jumps.massaction_jump.reactant_stoch, jumps.massaction_jump.net_stoch, jumps.massaction_jump.param_mapper; scale_rates = scale_rates, useiszero = useiszero, - nocopy = true) + nocopy = true + ) else maj = jumps.massaction_jump end @@ -330,24 +366,28 @@ function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; # No discrete jump aggregation or variable rate callbacks are created disc_agg = nothing jump_cbs = CallbackSet() - + # Store all jump types for access by tau-leaping solver crjs = jumps.constant_jumps vrjs = jumps.variable_jumps - + iip = isinplace_jump(prob, jumps.regular_jump) solkwargs = make_kwarg(; callback) - JumpProblem{iip, typeof(prob), typeof(aggregator), typeof(jump_cbs), + return JumpProblem{ + iip, typeof(prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(vrjs), typeof(jumps.regular_jump), - typeof(maj), typeof(rng), typeof(solkwargs)}(prob, aggregator, disc_agg, - jump_cbs, crjs, vrjs, jumps.regular_jump, maj, rng, solkwargs) + typeof(maj), typeof(rng), typeof(solkwargs), + }( + prob, aggregator, disc_agg, + jump_cbs, crjs, vrjs, jumps.regular_jump, maj, rng, solkwargs + ) end aggregator(jp::JumpProblem{iip, P, A}) where {iip, P, A} = A -@inline function extend_tstops!(tstops, jp::JumpProblem) - !(jp.jump_callback.discrete_callbacks isa Tuple{}) && +@inline function extend_tstops!(tstops, jp::JumpProblem) + return !(jp.jump_callback.discrete_callbacks isa Tuple{}) && push!(tstops, jp.jump_callback.discrete_callbacks[1].condition.next_jump_time) end @@ -356,23 +396,27 @@ num_constant_rate_jumps(aggregator::AbstractSSAJumpAggregator) = length(aggregat function Base.summary(io::IO, prob::JumpProblem) type_color, no_color = SciMLBase.get_colorizers(io) - print(io, + return print( + io, type_color, nameof(typeof(prob)), no_color, " with problem ", type_color, nameof(typeof(prob.prob)), no_color, " with aggregator ", - type_color, typeof(prob.aggregator)) + type_color, typeof(prob.aggregator) + ) end function Base.show(io::IO, mime::MIME"text/plain", A::JumpProblem) summary(io, A) println(io) - println(io, "Number of jumps with discrete aggregation: ", + println( + io, "Number of jumps with discrete aggregation: ", A.discrete_jump_aggregation === nothing ? 0 : - num_constant_rate_jumps(A.discrete_jump_aggregation)) + num_constant_rate_jumps(A.discrete_jump_aggregation) + ) println(io, "Number of jumps with continuous aggregation: ", length(A.variable_jumps)) nmajs = (A.massaction_jump !== nothing) ? get_num_majumps(A.massaction_jump) : 0 println(io, "Number of mass action jumps: ", nmajs) - if A.regular_jump !== nothing + return if A.regular_jump !== nothing println(io, "Have a regular jump") end end diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index d512953ae..8eb269fb3 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -6,16 +6,18 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \ Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release." end - isempty(jump_prob.jump_callback.continuous_callbacks) && - isempty(jump_prob.jump_callback.discrete_callbacks) && - isempty(jump_prob.constant_jumps) && - isempty(jump_prob.variable_jumps) && - get_num_majumps(jump_prob.massaction_jump) == 0 && - jump_prob.regular_jump !== nothing + return isempty(jump_prob.jump_callback.continuous_callbacks) && + isempty(jump_prob.jump_callback.discrete_callbacks) && + isempty(jump_prob.constant_jumps) && + isempty(jump_prob.variable_jumps) && + get_num_majumps(jump_prob.massaction_jump) == 0 && + jump_prob.regular_jump !== nothing end -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; - seed = nothing, dt = error("dt is required for SimpleTauLeaping.")) +function DiffEqBase.solve( + jump_prob::JumpProblem, alg::SimpleTauLeaping; + seed = nothing, dt = error("dt is required for SimpleTauLeaping.") + ) validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only RegularJumps.") @@ -56,9 +58,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; u[i] = du + uprev end - sol = DiffEqBase.build_solution(prob, alg, t, u, + return sol = DiffEqBase.build_solution( + prob, alg, t, u, calculate_error = false, - interp = DiffEqBase.ConstantInterpolation(t, u)) + interp = DiffEqBase.ConstantInterpolation(t, u) + ) end struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm @@ -67,9 +71,9 @@ struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm end function EnsembleGPUKernel(backend) - EnsembleGPUKernel(backend, 0.0) + return EnsembleGPUKernel(backend, 0.0) end function EnsembleGPUKernel() - EnsembleGPUKernel(nothing, 0.0) + return EnsembleGPUKernel(nothing, 0.0) end diff --git a/src/solve.jl b/src/solve.jl index b80f00ad8..523bc174a 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,49 +1,59 @@ -function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}, +function DiffEqBase.__solve( + jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg::DiffEqBase.DEAlgorithm; - merge_callbacks = true, kwargs...) where {P} + merge_callbacks = true, kwargs... + ) where {P} # Merge jump_prob.kwargs with passed kwargs kwargs = DiffEqBase.merge_problem_kwargs(jump_prob; merge_callbacks, kwargs...) integrator = __jump_init(jump_prob, alg; kwargs...) solve!(integrator) - integrator.sol + return integrator.sol end #Ambiguity Fix -function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}, +function DiffEqBase.__solve( + jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg::Union{SciMLBase.AbstractRODEAlgorithm, SciMLBase.AbstractSDEAlgorithm}; - merge_callbacks = true, kwargs...) where {P} + merge_callbacks = true, kwargs... + ) where {P} # Merge jump_prob.kwargs with passed kwargs kwargs = DiffEqBase.merge_problem_kwargs(jump_prob; merge_callbacks, kwargs...) integrator = __jump_init(jump_prob, alg; kwargs...) solve!(integrator) - integrator.sol + return integrator.sol end # if passed a JumpProblem over a DiscreteProblem, and no aggregator is selected use # SSAStepper -function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}; - kwargs...) where {P <: DiscreteProblem} - DiffEqBase.__solve(jump_prob, SSAStepper(); kwargs...) +function DiffEqBase.__solve( + jump_prob::DiffEqBase.AbstractJumpProblem{P}; + kwargs... + ) where {P <: DiscreteProblem} + return DiffEqBase.__solve(jump_prob, SSAStepper(); kwargs...) end function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem; kwargs...) error("Auto-solver selection is currently only implemented for JumpProblems defined over DiscreteProblems. Please explicitly specify a solver algorithm in calling solve.") end -function DiffEqBase.__init(_jump_prob::DiffEqBase.AbstractJumpProblem{P}, - alg::DiffEqBase.DEAlgorithm; merge_callbacks = true, kwargs...) where {P} +function DiffEqBase.__init( + _jump_prob::DiffEqBase.AbstractJumpProblem{P}, + alg::DiffEqBase.DEAlgorithm; merge_callbacks = true, kwargs... + ) where {P} # Merge jump_prob.kwargs with passed kwargs kwargs = DiffEqBase.merge_problem_kwargs(_jump_prob; merge_callbacks, kwargs...) - __jump_init(_jump_prob, alg; kwargs...) -end + return __jump_init(_jump_prob, alg; kwargs...) +end -function __jump_init(_jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg; +function __jump_init( + _jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg; callback = nothing, seed = nothing, alias_jump = Threads.threadid() == 1, - kwargs...) where {P} + kwargs... + ) where {P} if alias_jump jump_prob = _jump_prob reset_jump_problem!(jump_prob, seed) @@ -52,16 +62,20 @@ function __jump_init(_jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg; end # DDEProblems do not have a recompile_flag argument - if jump_prob.prob isa DiffEqBase.AbstractDDEProblem + return if jump_prob.prob isa DiffEqBase.AbstractDDEProblem # callback comes after jump consistent with SSAStepper - integrator = init(jump_prob.prob, alg; + integrator = init( + jump_prob.prob, alg; callback = CallbackSet(jump_prob.jump_callback, callback), - kwargs...) + kwargs... + ) else # callback comes after jump consistent with SSAStepper - integrator = init(jump_prob.prob, alg; + integrator = init( + jump_prob.prob, alg; callback = CallbackSet(jump_prob.jump_callback, callback), - kwargs...) + kwargs... + ) end end @@ -80,7 +94,7 @@ function resetted_jump_problem(_jump_prob, seed) randexp!(_jump_prob.rng, jump_prob.prob.u0.jump_u) jump_prob.prob.u0.jump_u .*= -1 end - jump_prob + return jump_prob end function reset_jump_problem!(jump_prob, seed) @@ -88,7 +102,7 @@ function reset_jump_problem!(jump_prob, seed) Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, seed) end - if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray + return if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray randexp!(jump_prob.rng, jump_prob.prob.u0.jump_u) jump_prob.prob.u0.jump_u .*= -1 end diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 6346fa4be..f58c92ef5 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -12,21 +12,21 @@ end function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh) println(io, "Low: \n $(low_high.low)") - println(io, "High: \n $(low_high.high)") + return println(io, "High: \n $(low_high.high)") end @inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) @inbounds for (i, uval) in enumerate(u) u_low_high[i] = LowHigh(get_spec_brackets(bracket_data, i, uval)) end - nothing + return nothing end ### convenience functions for LowHigh ### function setindex!(low_high::LowHigh, val::LowHigh, i) low_high.low[i] = val.low low_high.high[i] = val.high - val + return val end function getindex(low_high::LowHigh, i) @@ -36,15 +36,16 @@ end function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site) return LowHigh( total_site_rate(rx_rates.low, hop_rates.low, site), - total_site_rate(rx_rates.high, hop_rates.high, site)) + total_site_rate(rx_rates.high, hop_rates.high, site) + ) end function update_rx_rates!(rx_rates::LowHigh, rxs, u_low_high, integrator, site) update_rx_rates!(rx_rates.low, rxs, u_low_high.low, integrator, site) - update_rx_rates!(rx_rates.high, rxs, u_low_high.high, integrator, site) + return update_rx_rates!(rx_rates.high, rxs, u_low_high.high, integrator, site) end function update_hop_rates!(hop_rates::LowHigh, species, u_low_high, site, spatial_system) update_hop_rates!(hop_rates.low, species, u_low_high.low, site, spatial_system) - update_hop_rates!(hop_rates.high, species, u_low_high.high, site, spatial_system) + return update_hop_rates!(hop_rates.high, species, u_low_high.high, site, spatial_system) end diff --git a/src/spatial/directcrdirect.jl b/src/spatial/directcrdirect.jl index f282bd1ae..e52699206 100644 --- a/src/spatial/directcrdirect.jl +++ b/src/spatial/directcrdirect.jl @@ -1,14 +1,16 @@ # site chosen with DirectCR, rx or hop chosen with Direct ############################ DirectCRDirect ################################### -const MINJUMPRATE = 2.0^exponent(1e-12) +const MINJUMPRATE = 2.0^exponent(1.0e-12) #NOTE state vector u is a matrix. u[i,j] is species i, site j #NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j -mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, - VJMAP, JVMAP, SS, U <: PriorityTable, - W <: Function} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct DirectCRDirectJumpAggregation{ + T, S, F1, F2, RNG, J, RX, HOP, DEPGR, + VJMAP, JVMAP, SS, U <: PriorityTable, + W <: Function, + } <: + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::SpatialJump{J} #some structure to identify the next event: reaction or hop prev_jump::SpatialJump{J} #some structure to identify the previous event: reaction or hop next_jump_time::T @@ -29,13 +31,15 @@ mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPG ratetogroup::W end -function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, +function DirectCRDirectJumpAggregation( + nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, hop_rates::HOP, site_rates::Vector{T}, sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; num_specs, minrate = convert(T, MINJUMPRATE), vartojumps_map = nothing, jumptovars_map = nothing, dep_graph = nothing, - kwargs...) where {J, T, RX, HOP, RNG, SS} + kwargs... + ) where {J, T, RX, HOP, RNG, SS} # a dependency graph is needed if dep_graph === nothing @@ -69,27 +73,35 @@ function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rat # construct an empty initial priority table -- we'll reset this in init rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) - DirectCRDirectJumpAggregation{T, Nothing, Nothing, Nothing, RNG, J, RX, HOP, + return DirectCRDirectJumpAggregation{ + T, Nothing, Nothing, Nothing, RNG, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), - typeof(ratetogroup)}(nj, nj, njt, et, rx_rates, hop_rates, + typeof(ratetogroup), + }( + nj, nj, njt, et, rx_rates, hop_rates, site_rates, nothing, nothing, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, - rt, ratetogroup) + rt, ratetogroup + ) end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::DirectCRDirect, starting_state, p, t, end_time, +function aggregate( + aggregator::DirectCRDirect, starting_state, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; hopping_constants, - spatial_system, kwargs...) + spatial_system, kwargs... + ) num_species = size(starting_state, 1) majumps = ma_jumps if majumps === nothing - majumps = MassActionJump(Vector{typeof(end_time)}(), + majumps = MassActionJump( + Vector{typeof(end_time)}(), Vector{Vector{Pair{Int, Int}}}(), - Vector{Vector{Pair{Int, Int}}}()) + Vector{Vector{Pair{Int, Int}}}() + ) end next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder @@ -98,9 +110,11 @@ function aggregate(aggregator::DirectCRDirect, starting_state, p, t, end_time, hop_rates = HopRates(hopping_constants, spatial_system) site_rates = zeros(typeof(end_time), num_sites(spatial_system)) - DirectCRDirectJumpAggregation(next_jump, next_jump_time, end_time, rx_rates, hop_rates, + return DirectCRDirectJumpAggregation( + next_jump, next_jump_time, end_time, rx_rates, hop_rates, site_rates, save_positions, rng, spatial_system; - num_specs = num_species, kwargs...) + num_specs = num_species, kwargs... + ) end # set up a new simulation and calculate the first jump / jump time @@ -108,7 +122,7 @@ function initialize!(p::DirectCRDirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) generate_jumps!(p, integrator, params, u, t) - nothing + return nothing end # calculate the next jump / jump time @@ -117,18 +131,20 @@ function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, params, u p.next_jump_time >= p.end_time && return nothing site = sample(p.rt, p.site_rates, p.rng) p.next_jump = sample_jump_direct(p, site) - nothing + return nothing end # execute one jump, changing the system state -function execute_jumps!(p::DirectCRDirectJumpAggregation, integrator, u, params, t, - affects!) +function execute_jumps!( + p::DirectCRDirectJumpAggregation, integrator, u, params, t, + affects! + ) # execute jump update_state!(p, integrator) # update current jump rates and times update_dependent_rates_and_firing_times!(p, integrator, t) - nothing + return nothing end ######################## SSA specific helper routines ######################## @@ -138,7 +154,8 @@ end reset all structs, reevaluate all rates, repopulate the priority table """ function fill_rates_and_get_times!( - aggregation::DirectCRDirectJumpAggregation, integrator, t) + aggregation::DirectCRDirectJumpAggregation, integrator, t + ) (; spatial_system, rx_rates, hop_rates, site_rates, rt) = aggregation u = integrator.u @@ -159,7 +176,7 @@ function fill_rates_and_get_times!( for (pid, priority) in enumerate(site_rates) insert!(rt, pid, priority) end - nothing + return nothing end """ @@ -168,11 +185,12 @@ end recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) """ function update_dependent_rates_and_firing_times!( - p::DirectCRDirectJumpAggregation, integrator, t) + p::DirectCRDirectJumpAggregation, integrator, t + ) u = integrator.u site_rates = p.site_rates jump = p.prev_jump - if is_hop(p, jump) + return if is_hop(p, jump) source_site = jump.src target_site = jump.dst update_rates_after_hop!(p, integrator, source_site, target_site, jump.jidx) diff --git a/src/spatial/flatten.jl b/src/spatial/flatten.jl index aed6279a2..354dd951a 100644 --- a/src/spatial/flatten.jl +++ b/src/spatial/flatten.jl @@ -3,14 +3,18 @@ using JumpProcesses, DiffEqBase, Graphs """ prob.u0 must be a Matrix with prob.u0[i,j] being the number of species i at site j """ -function flatten(ma_jump, prob::DiscreteProblem, spatial_system, hopping_constants; - kwargs...) +function flatten( + ma_jump, prob::DiscreteProblem, spatial_system, hopping_constants; + kwargs... + ) tspan = prob.tspan u0 = prob.u0 if ma_jump === nothing - ma_jump = MassActionJump(Vector{typeof(tspan[1])}(), + ma_jump = MassActionJump( + Vector{typeof(tspan[1])}(), Vector{Vector{Pair{Int, eltype(u0)}}}(), - Vector{Vector{Pair{Int, eltype(u0)}}}()) + Vector{Vector{Pair{Int, eltype(u0)}}}() + ) end netstoch = ma_jump.net_stoch reactstoch = ma_jump.reactant_stoch @@ -23,57 +27,75 @@ function flatten(ma_jump, prob::DiscreteProblem, spatial_system, hopping_constan elseif isnothing(ma_jump.uniform_rates) ma_jump.spatial_rates elseif isnothing(ma_jump.spatial_rates) - reshape(repeat(ma_jump.uniform_rates, num_nodes), - length(ma_jump.uniform_rates), num_nodes) + reshape( + repeat(ma_jump.uniform_rates, num_nodes), + length(ma_jump.uniform_rates), num_nodes + ) else @assert size(ma_jump.spatial_rates, 2) == num_nodes - cat(dims = 1, - reshape(repeat(ma_jump.uniform_rates, num_nodes), - length(ma_jump.uniform_rates), num_nodes), - ma_jump.spatial_rates) + cat( + dims = 1, + reshape( + repeat(ma_jump.uniform_rates, num_nodes), + length(ma_jump.uniform_rates), num_nodes + ), + ma_jump.spatial_rates + ) end else error("flatten: unsupported jump type $(typeof(ma_jump))") end - flatten(netstoch, reactstoch, rx_rates, spatial_system, u0, tspan, hopping_constants; - scale_rates = false, kwargs...) + return flatten( + netstoch, reactstoch, rx_rates, spatial_system, u0, tspan, hopping_constants; + scale_rates = false, kwargs... + ) end """ if hopping_constants is a matrix, assume hopping_constants[i,j] is the hopping constant of species i from site j to any neighbor """ -function flatten(netstoch::AbstractArray, reactstoch::AbstractArray, +function flatten( + netstoch::AbstractArray, reactstoch::AbstractArray, rx_rates::AbstractArray, spatial_system, u0::Matrix{Int}, tspan, - hopping_constants::Matrix{F}; kwargs...) where {F <: Number} + hopping_constants::Matrix{F}; kwargs... + ) where {F <: Number} @assert size(hopping_constants) == size(u0) hop_constants = Matrix{Vector{F}}(undef, size(hopping_constants)) for ci in CartesianIndices(hop_constants) (species, site) = Tuple(ci) hop_constants[ci] = hopping_constants[species, site] * ones(outdegree(spatial_system, site)) end - flatten(netstoch, reactstoch, rx_rates, spatial_system, u0, tspan, hop_constants; - kwargs...) + return flatten( + netstoch, reactstoch, rx_rates, spatial_system, u0, tspan, hop_constants; + kwargs... + ) end """ if reaction rates is a vector, assume reaction rates are equal across sites """ -function flatten(netstoch::AbstractArray, reactstoch::AbstractArray, rx_rates::Vector, +function flatten( + netstoch::AbstractArray, reactstoch::AbstractArray, rx_rates::Vector, spatial_system, u0::Matrix{Int}, tspan, - hopping_constants::Matrix{Vector{F}}; kwargs...) where {F <: Number} + hopping_constants::Matrix{Vector{F}}; kwargs... + ) where {F <: Number} num_nodes = num_sites(spatial_system) rates = reshape(repeat(rx_rates, num_nodes), length(rx_rates), num_nodes) - flatten(netstoch, reactstoch, rates, spatial_system, u0, tspan, hopping_constants; - kwargs...) + return flatten( + netstoch, reactstoch, rates, spatial_system, u0, tspan, hopping_constants; + kwargs... + ) end """ "flatten" the spatial jump problem. Return flattened DiscreteProblem and MassActionJump. """ -function flatten(netstoch::Vector{R}, reactstoch::Vector{R}, rx_rates::Matrix{F}, +function flatten( + netstoch::Vector{R}, reactstoch::Vector{R}, rx_rates::Matrix{F}, spatial_system, u0::Matrix{Int}, tspan, hopping_constants::Matrix{Vector{F}}; scale_rates = true, - kwargs...) where {R, F <: Number} + kwargs... + ) where {R, F <: Number} num_species = size(u0, 1) num_nodes = num_sites(spatial_system) num_rxs = length(reactstoch) @@ -119,9 +141,11 @@ function flatten(netstoch::Vector{R}, reactstoch::Vector{R}, rx_rates::Matrix{F} append!(total_rates, vec(rx_rates)) # assuming rx_rates isa Matrix where rx_rates[rx, site] is the rate of rx at site # put everything together - ma_jump = MassActionJump(total_rates, total_reactstoch, total_netstoch; nocopy = true, - scale_rates = scale_rates) + ma_jump = MassActionJump( + total_rates, total_reactstoch, total_netstoch; nocopy = true, + scale_rates = scale_rates + ) flattened_u0 = vec(u0) prob = DiscreteProblem(flattened_u0, tspan, total_rates) - prob, ma_jump + return prob, ma_jump end diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index d986d23c5..043eaa93b 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -11,49 +11,67 @@ A file with structs and functions for sampling hops and updating hopping rates abstract type AbstractHopRates end function HopRates(hopping_constants::Vector{F}, spatial_system) where {F <: Number} - HopRatesGraphDs(hopping_constants, num_sites(spatial_system)) + return HopRatesGraphDs(hopping_constants, num_sites(spatial_system)) end -function HopRates(hopping_constants::Vector{F}, - grid::CartesianGridRej) where {F <: Number} - HopRatesGraphDs(hopping_constants, num_sites(grid)) +function HopRates( + hopping_constants::Vector{F}, + grid::CartesianGridRej + ) where {F <: Number} + return HopRatesGraphDs(hopping_constants, num_sites(grid)) end function HopRates(hopping_constants::Matrix{F}, spatial_system) where {F <: Number} - HopRatesGraphDsi(hopping_constants) + return HopRatesGraphDsi(hopping_constants) end -function HopRates(hopping_constants::Matrix{F}, - grid::CartesianGridRej) where {F <: Number} - HopRatesGraphDsi(hopping_constants) +function HopRates( + hopping_constants::Matrix{F}, + grid::CartesianGridRej + ) where {F <: Number} + return HopRatesGraphDsi(hopping_constants) end function HopRates(hopping_constants::Matrix{Vector{F}}, spatial_system) where {F <: Number} - HopRatesGraphDsij(hopping_constants) -end -function HopRates(hopping_constants::Matrix{Vector{F}}, - grid::CartesianGridRej) where {F <: Number} - HopRatesGridDsij(hopping_constants, grid) -end - -function HopRates(p::Pair{SpecHop, SiteHop}, - spatial_system) where {F <: Number, SpecHop <: Vector{F}, - SiteHop <: Vector{Vector{F}}} - HopRatesGraphDsLij(p...) -end -function HopRates(p::Pair{SpecHop, SiteHop}, - grid::CartesianGridRej) where - {F <: Number, SpecHop <: Vector{F}, SiteHop <: Vector{Vector{F}}} - HopRatesGridDsLij(p..., grid) -end - -function HopRates(p::Pair{SpecHop, SiteHop}, - spatial_system) where {F <: Number, SpecHop <: Matrix{F}, - SiteHop <: Vector{Vector{F}}} - HopRatesGraphDsiLij(p...) -end -function HopRates(p::Pair{SpecHop, SiteHop}, - grid::CartesianGridRej) where - {SpecHop <: Matrix{F}, SiteHop <: Vector{Vector{F}}} where {F <: Number} - HopRatesGridDsiLij(p..., grid) + return HopRatesGraphDsij(hopping_constants) +end +function HopRates( + hopping_constants::Matrix{Vector{F}}, + grid::CartesianGridRej + ) where {F <: Number} + return HopRatesGridDsij(hopping_constants, grid) +end + +function HopRates( + p::Pair{SpecHop, SiteHop}, + spatial_system + ) where { + F <: Number, SpecHop <: Vector{F}, + SiteHop <: Vector{Vector{F}}, + } + return HopRatesGraphDsLij(p...) +end +function HopRates( + p::Pair{SpecHop, SiteHop}, + grid::CartesianGridRej + ) where + {F <: Number, SpecHop <: Vector{F}, SiteHop <: Vector{Vector{F}}} + return HopRatesGridDsLij(p..., grid) +end + +function HopRates( + p::Pair{SpecHop, SiteHop}, + spatial_system + ) where { + F <: Number, SpecHop <: Matrix{F}, + SiteHop <: Vector{Vector{F}}, + } + return HopRatesGraphDsiLij(p...) +end +function HopRates( + p::Pair{SpecHop, SiteHop}, + grid::CartesianGridRej + ) where + {SpecHop <: Matrix{F}, SiteHop <: Vector{Vector{F}}} where {F <: Number} + return HopRatesGridDsiLij(p..., grid) end """ @@ -61,9 +79,11 @@ end update rates of all specs in species at site """ -function update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, - spatial_system) - @inbounds for spec in species +function update_hop_rates!( + hop_rates::AbstractHopRates, species::AbstractArray, u, site, + spatial_system + ) + return @inbounds for spec in species update_hop_rate!(hop_rates, spec, u, site, spatial_system) end end @@ -76,10 +96,12 @@ update rates of single species at site function update_hop_rate!(hop_rates::AbstractHopRates, species, u, site, spatial_system) rates = hop_rates.rates @inbounds old_rate = rates[species, site] - @inbounds rates[species, site] = evalhoprate(hop_rates, u, species, site, - spatial_system) + @inbounds rates[species, site] = evalhoprate( + hop_rates, u, species, site, + spatial_system + ) @inbounds hop_rates.sum_rates[site] += rates[species, site] - old_rate - old_rate + return old_rate end """ @@ -97,7 +119,7 @@ make all rates zero function reset!(hop_rates::AbstractHopRates) hop_rates.rates .= zero(eltype(hop_rates.rates)) hop_rates.sum_rates .= zero(eltype(hop_rates.sum_rates)) - nothing + return nothing end """ @@ -106,8 +128,10 @@ end sample species to hop from site """ function sample_species(hop_rates::AbstractHopRates, site, rng) - @inbounds linear_search((@view hop_rates.rates[:, site]), - rand(rng) * total_site_hop_rate(hop_rates, site)) + return @inbounds linear_search( + (@view hop_rates.rates[:, site]), + rand(rng) * total_site_hop_rate(hop_rates, site) + ) end """ @@ -135,8 +159,10 @@ end function Base.show(io::IO, ::MIME"text/plain", hop_rates::HopRatesGraphDs) num_specs, num_sites = size(hop_rates.rates) - println(io, - "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_{s} where s is species.") + return println( + io, + "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_{s} where s is species." + ) end """ @@ -146,18 +172,18 @@ initializes HopRatesGraphDs with zero rates """ function HopRatesGraphDs(hopping_constants::Vector{F}, num_nodes) where {F <: Number} rates = zeros(F, length(hopping_constants), num_nodes) - HopRatesGraphDs{F}(hopping_constants, rates, zeros(F, size(rates, 2))) + return HopRatesGraphDs{F}(hopping_constants, rates, zeros(F, size(rates, 2))) end function sample_target_site(hop_rates::HopRatesGraphDs, site, species, rng, spatial_system) - rand_nbr(rng, spatial_system, site) + return rand_nbr(rng, spatial_system, site) end """ return hopping rate of species at site """ function evalhoprate(hop_rates::HopRatesGraphDs, u, species, site, spatial_system) - @inbounds u[species, site] * hop_rates.hopping_constants[species] * outdegree(spatial_system, site) + return @inbounds u[species, site] * hop_rates.hopping_constants[species] * outdegree(spatial_system, site) end ############## hopping rates of form D_{s,i} ################ @@ -174,8 +200,10 @@ end function Base.show(io::IO, ::MIME"text/plain", hop_rates::HopRatesGraphDsi) num_specs, num_sites = size(hop_rates.rates) - println(io, - "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_{s,i} where s is species, and i is source.") + return println( + io, + "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_{s,i} where s is species, and i is source." + ) end """ @@ -185,18 +213,18 @@ initializes HopRatesGraphDsi with zero rates """ function HopRatesGraphDsi(hopping_constants::Matrix{F}) where {F <: Number} rates = zeros(F, size(hopping_constants)) - HopRatesGraphDsi{F}(hopping_constants, rates, zeros(F, size(rates, 2))) + return HopRatesGraphDsi{F}(hopping_constants, rates, zeros(F, size(rates, 2))) end function sample_target_site(hop_rates::HopRatesGraphDsi, site, species, rng, spatial_system) - rand_nbr(rng, spatial_system, site) + return rand_nbr(rng, spatial_system, site) end """ return hopping rate of species at site """ function evalhoprate(hop_rates::HopRatesGraphDsi, u, species, site, spatial_system) - @inbounds u[species, site] * hop_rates.hopping_constants[species, site] * outdegree(spatial_system, site) + return @inbounds u[species, site] * hop_rates.hopping_constants[species, site] * outdegree(spatial_system, site) end ############## hopping rates of form D_{s,i,j} ################ @@ -213,8 +241,10 @@ end function Base.show(io::IO, ::MIME"text/plain", hop_rates::HopRatesGraphDsij) num_specs, num_sites = size(hop_rates.rates) - println(io, - "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_{s,i,j} where s is species, i is source and j is destination.") + return println( + io, + "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_{s,i,j} where s is species, i is source and j is destination." + ) end """ @@ -222,23 +252,27 @@ end initializes HopRates with zero rates """ -function HopRatesGraphDsij(hopping_constants::Matrix{Vector{F}}; - do_cumsum = true) where {F <: Number} +function HopRatesGraphDsij( + hopping_constants::Matrix{Vector{F}}; + do_cumsum = true + ) where {F <: Number} do_cumsum && (hopping_constants = map(cumsum, hopping_constants)) rates = zeros(F, size(hopping_constants)) sum_rates = zeros(F, size(rates, 2)) - HopRatesGraphDsij{F}(hopping_constants, rates, sum_rates) + return HopRatesGraphDsij{F}(hopping_constants, rates, sum_rates) end -function sample_target_site(hop_rates::HopRatesGraphDsij, site, species, rng, - spatial_system) +function sample_target_site( + hop_rates::HopRatesGraphDsij, site, species, rng, + spatial_system + ) @inbounds cum_hop_consts = hop_rates.hop_const_cumulative_sums[species, site] @inbounds n = searchsortedfirst(cum_hop_consts, rand(rng) * cum_hop_consts[end]) return nth_nbr(spatial_system, site, n) end function evalhoprate(hop_rates::HopRatesGraphDsij, u, species, site, spatial_system) - @inbounds u[species, site] * hop_rates.hop_const_cumulative_sums[species, site][end] + return @inbounds u[species, site] * hop_rates.hop_const_cumulative_sums[species, site][end] end ################# hopping rates of form L_{s,i,j} optimized for cartesian grid ###################### @@ -258,8 +292,10 @@ end function Base.show(io::IO, ::MIME"text/plain", hop_rates::HopRatesGridDsij) num_specs, num_sites = size(hop_rates.rates) - println(io, - "HopRates with $num_specs species and $num_sites sites, optimized for CartesianGrid. \nHopping constants of form L_{s,i,j} where s is species, i is source and j is destination.") + return println( + io, + "HopRates with $num_specs species and $num_sites sites, optimized for CartesianGrid. \nHopping constants of form L_{s,i,j} where s is species, i is source and j is destination." + ) end """ @@ -267,24 +303,28 @@ end initializes HopRates with zero rates """ -function HopRatesGridDsij(hopping_constants::Array{F, 3}; - do_cumsum = true) where {F <: Number} +function HopRatesGridDsij( + hopping_constants::Array{F, 3}; + do_cumsum = true + ) where {F <: Number} do_cumsum && (hopping_constants = mapslices(cumsum, hopping_constants, dims = 1)) rates = zeros(F, size(hopping_constants)[2:3]) sum_rates = zeros(F, size(rates, 2)) - HopRatesGridDsij{F}(hopping_constants, rates, sum_rates) + return HopRatesGridDsij{F}(hopping_constants, rates, sum_rates) end function HopRatesGridDsij(hopping_constants::Matrix{Vector{F}}, grid) where {F <: Number} - new_hopping_constants = Array{F, 3}(undef, 2 * dimension(grid), - size(hopping_constants)...) + new_hopping_constants = Array{F, 3}( + undef, 2 * dimension(grid), + size(hopping_constants)... + ) for ci in CartesianIndices(hopping_constants) species, site = Tuple(ci) nb_constants = @view new_hopping_constants[:, species, site] pad_hop_vec!(nb_constants, grid, site, hopping_constants[ci]) cumsum!(nb_constants, nb_constants) end - HopRatesGridDsij(new_hopping_constants, do_cumsum = false) + return HopRatesGridDsij(new_hopping_constants, do_cumsum = false) end function sample_target_site(hop_rates::HopRatesGridDsij, site, species, rng, grid) @@ -294,7 +334,7 @@ function sample_target_site(hop_rates::HopRatesGridDsij, site, species, rng, gri end function evalhoprate(hop_rates::HopRatesGridDsij, u, species, site, spatial_system) - @inbounds u[species, site] * hop_rates.hop_const_cumulative_sums[end, species, site] + return @inbounds u[species, site] * hop_rates.hop_const_cumulative_sums[end, species, site] end ############## hopping rates of form D_s * L_{i,j} ################ @@ -314,10 +354,12 @@ end function Base.show(io::IO, ::MIME"text/plain", hop_rates::HopRatesGraphDsLij) num_specs, - num_sites = length(hop_rates.species_hop_constants), - length(hop_rates.hop_const_cumulative_sums) - println(io, - "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_s * L_{i,j} where s is species, i is source and j is destination.") + num_sites = length(hop_rates.species_hop_constants), + length(hop_rates.hop_const_cumulative_sums) + return println( + io, + "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_s * L_{i,j} where s is species, i is source and j is destination." + ) end """ @@ -325,24 +367,28 @@ end initializes HopRates with zero rates """ -function HopRatesGraphDsLij(species_hop_constants::Vector{F}, +function HopRatesGraphDsLij( + species_hop_constants::Vector{F}, site_hop_constants::Vector{Vector{F}}; - do_cumsum = true) where {F <: Number} + do_cumsum = true + ) where {F <: Number} do_cumsum && (site_hop_constants = map(cumsum, site_hop_constants)) rates = zeros(F, length(species_hop_constants), length(site_hop_constants)) sum_rates = zeros(size(rates, 2)) - HopRatesGraphDsLij{F}(species_hop_constants, site_hop_constants, rates, sum_rates) + return HopRatesGraphDsLij{F}(species_hop_constants, site_hop_constants, rates, sum_rates) end -function sample_target_site(hop_rates::HopRatesGraphDsLij, site, species, rng, - spatial_system) +function sample_target_site( + hop_rates::HopRatesGraphDsLij, site, species, rng, + spatial_system + ) @inbounds cum_hop_consts = hop_rates.hop_const_cumulative_sums[site] @inbounds n = searchsortedfirst(cum_hop_consts, rand(rng) * cum_hop_consts[end]) return nth_nbr(spatial_system, site, n) end function evalhoprate(hop_rates::HopRatesGraphDsLij, u, species, site, spatial_system) - @inbounds u[species, site] * hop_rates.species_hop_constants[species] * hop_rates.hop_const_cumulative_sums[site][end] + return @inbounds u[species, site] * hop_rates.species_hop_constants[species] * hop_rates.hop_const_cumulative_sums[site][end] end ############## hopping rates of form D_s * L_{i,j} optimized for cartesian grid ################ @@ -362,30 +408,38 @@ end function Base.show(io::IO, ::MIME"text/plain", hop_rates::HopRatesGridDsLij) num_specs, - num_sites = length(hop_rates.species_hop_constants), - size(hop_rates.hop_const_cumulative_sums, 2) - println(io, - "HopRates with $num_specs species and $num_sites sites, optimized for CartesianGrid. \nHopping constants of form D_s * L_{i,j} where s is species, i is source and j is destination.") -end - -function HopRatesGridDsLij(species_hop_constants::Vector{F}, site_hop_constants::Matrix{F}; - do_cumsum = true) where {F <: Number} + num_sites = length(hop_rates.species_hop_constants), + size(hop_rates.hop_const_cumulative_sums, 2) + return println( + io, + "HopRates with $num_specs species and $num_sites sites, optimized for CartesianGrid. \nHopping constants of form D_s * L_{i,j} where s is species, i is source and j is destination." + ) +end + +function HopRatesGridDsLij( + species_hop_constants::Vector{F}, site_hop_constants::Matrix{F}; + do_cumsum = true + ) where {F <: Number} do_cumsum && (site_hop_constants = mapslices(cumsum, site_hop_constants, dims = 1)) rates = zeros(F, length(species_hop_constants), size(site_hop_constants, 2)) sum_rates = zeros(size(rates, 2)) - HopRatesGridDsLij{F}(species_hop_constants, site_hop_constants, rates, sum_rates) + return HopRatesGridDsLij{F}(species_hop_constants, site_hop_constants, rates, sum_rates) end -function HopRatesGridDsLij(species_hop_constants::Vector{F}, - site_hop_constants::Vector{Vector{F}}, grid) where {F <: Number} - new_hopping_constants = Matrix{F}(undef, 2 * dimension(grid), - length(site_hop_constants)) +function HopRatesGridDsLij( + species_hop_constants::Vector{F}, + site_hop_constants::Vector{Vector{F}}, grid + ) where {F <: Number} + new_hopping_constants = Matrix{F}( + undef, 2 * dimension(grid), + length(site_hop_constants) + ) for site in 1:length(site_hop_constants) nb_constants = @view new_hopping_constants[:, site] pad_hop_vec!(nb_constants, grid, site, site_hop_constants[site]) cumsum!(nb_constants, nb_constants) end - HopRatesGridDsLij(species_hop_constants, new_hopping_constants, do_cumsum = false) + return HopRatesGridDsLij(species_hop_constants, new_hopping_constants, do_cumsum = false) end function sample_target_site(hop_rates::HopRatesGridDsLij, site, species, rng, grid) @@ -395,7 +449,7 @@ function sample_target_site(hop_rates::HopRatesGridDsLij, site, species, rng, gr end function evalhoprate(hop_rates::HopRatesGridDsLij, u, species, site, grid) - @inbounds u[species, site] * hop_rates.species_hop_constants[species] * hop_rates.hop_const_cumulative_sums[end, site] + return @inbounds u[species, site] * hop_rates.species_hop_constants[species] * hop_rates.hop_const_cumulative_sums[end, site] end ############## hopping rates of form D_{s,i} * L_{i,j} ################ @@ -415,29 +469,35 @@ end function Base.show(io::IO, ::MIME"text/plain", hop_rates::HopRatesGraphDsiLij) num_specs, num_sites = size(hop_rates.species_hop_constants) - println(io, - "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_{s,i} * L_{i,j} where s is species, i is source and j is destination.") + return println( + io, + "HopRates with $num_specs species and $num_sites sites. \nHopping constants of form D_{s,i} * L_{i,j} where s is species, i is source and j is destination." + ) end -function HopRatesGraphDsiLij(species_hop_constants::Matrix{F}, +function HopRatesGraphDsiLij( + species_hop_constants::Matrix{F}, site_hop_constants::Vector{Vector{F}}; - do_cumsum = true) where {F <: Number} + do_cumsum = true + ) where {F <: Number} @assert size(species_hop_constants, 2) == length(site_hop_constants) do_cumsum && (site_hop_constants = map(cumsum, site_hop_constants)) rates = zeros(F, length(species_hop_constants), length(site_hop_constants)) sum_rates = zeros(size(rates, 2)) - HopRatesGraphDsiLij{F}(species_hop_constants, site_hop_constants, rates, sum_rates) + return HopRatesGraphDsiLij{F}(species_hop_constants, site_hop_constants, rates, sum_rates) end -function sample_target_site(hop_rates::HopRatesGraphDsiLij, site, species, rng, - spatial_system) +function sample_target_site( + hop_rates::HopRatesGraphDsiLij, site, species, rng, + spatial_system + ) @inbounds cum_hop_consts = hop_rates.hop_const_cumulative_sums[site] @inbounds n = searchsortedfirst(cum_hop_consts, rand(rng) * cum_hop_consts[end]) return nth_nbr(spatial_system, site, n) end function evalhoprate(hop_rates::HopRatesGraphDsiLij, u, species, site, spatial_system) - @inbounds u[species, site] * hop_rates.species_hop_constants[species, site] * hop_rates.hop_const_cumulative_sums[site][end] + return @inbounds u[species, site] * hop_rates.species_hop_constants[species, site] * hop_rates.hop_const_cumulative_sums[site][end] end ############## hopping rates of form D_{s,i} * L_{i,j} optimized for cartesian grid ################ @@ -457,32 +517,39 @@ end function Base.show(io::IO, ::MIME"text/plain", hop_rates::HopRatesGridDsiLij) num_specs, - num_sites = length(hop_rates.species_hop_constants), - size(hop_rates.hop_const_cumulative_sums, 2) - println(io, - "HopRates with $num_specs species and $num_sites sites, optimized for CartesianGrid. \nHopping constants of form D_{s,i} * L_{i,j} where s is species, i is source and j is destination.") + num_sites = length(hop_rates.species_hop_constants), + size(hop_rates.hop_const_cumulative_sums, 2) + return println( + io, + "HopRates with $num_specs species and $num_sites sites, optimized for CartesianGrid. \nHopping constants of form D_{s,i} * L_{i,j} where s is species, i is source and j is destination." + ) end function HopRatesGridDsiLij( species_hop_constants::Matrix{F}, site_hop_constants::Matrix{F}; - do_cumsum = true) where {F <: Number} + do_cumsum = true + ) where {F <: Number} @assert size(species_hop_constants, 2) == size(site_hop_constants, 2) do_cumsum && (site_hop_constants = mapslices(cumsum, site_hop_constants, dims = 1)) rates = zeros(F, size(species_hop_constants)) sum_rates = zeros(size(rates, 2)) - HopRatesGridDsiLij{F}(species_hop_constants, site_hop_constants, rates, sum_rates) + return HopRatesGridDsiLij{F}(species_hop_constants, site_hop_constants, rates, sum_rates) end -function HopRatesGridDsiLij(species_hop_constants::Matrix{F}, - site_hop_constants::Vector{Vector{F}}, grid) where {F <: Number} - new_hopping_constants = Matrix{F}(undef, 2 * dimension(grid), - length(site_hop_constants)) +function HopRatesGridDsiLij( + species_hop_constants::Matrix{F}, + site_hop_constants::Vector{Vector{F}}, grid + ) where {F <: Number} + new_hopping_constants = Matrix{F}( + undef, 2 * dimension(grid), + length(site_hop_constants) + ) for site in 1:length(site_hop_constants) nb_constants = @view new_hopping_constants[:, site] pad_hop_vec!(nb_constants, grid, site, site_hop_constants[site]) cumsum!(nb_constants, nb_constants) end - HopRatesGridDsiLij(species_hop_constants, new_hopping_constants, do_cumsum = false) + return HopRatesGridDsiLij(species_hop_constants, new_hopping_constants, do_cumsum = false) end function sample_target_site(hop_rates::HopRatesGridDsiLij, site, species, rng, grid) @@ -492,5 +559,5 @@ function sample_target_site(hop_rates::HopRatesGridDsiLij, site, species, rng, g end function evalhoprate(hop_rates::HopRatesGridDsiLij, u, species, site, grid) - @inbounds u[species, site] * hop_rates.species_hop_constants[species, site] * hop_rates.hop_const_cumulative_sums[end, site] + return @inbounds u[species, site] * hop_rates.species_hop_constants[species, site] * hop_rates.hop_const_cumulative_sums[end, site] end diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index 07c48da6e..4b629665c 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -3,9 +3,11 @@ ############################ NSM ################################### #NOTE state vector u is a matrix. u[i,j] is species i, site j #NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j -mutable struct NSMJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, - PQ, SS} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct NSMJumpAggregation{ + T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, + PQ, SS, + } <: + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::SpatialJump{J} #some structure to identify the next event: reaction or hop prev_jump::SpatialJump{J} #some structure to identify the previous event: reaction or hop next_jump_time::T @@ -29,7 +31,8 @@ function NSMJumpAggregation( sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; num_specs, vartojumps_map = nothing, jumptovars_map = nothing, - dep_graph = nothing, kwargs...) where {J, T, RX, HOP, RNG, SS} + dep_graph = nothing, kwargs... + ) where {J, T, RX, HOP, RNG, SS} # a dependency graph is needed if dep_graph === nothing @@ -55,8 +58,11 @@ function NSMJumpAggregation( pq = MutableBinaryMinHeap{T}() - NSMJumpAggregation{T, Nothing, Nothing, Nothing, RNG, J, RX, HOP, typeof(dg), - typeof(vtoj_map), typeof(jtov_map), typeof(pq), SS}(nj, nj, njt, et, + return NSMJumpAggregation{ + T, Nothing, Nothing, Nothing, RNG, J, RX, HOP, typeof(dg), + typeof(vtoj_map), typeof(jtov_map), typeof(pq), SS, + }( + nj, nj, njt, et, rx_rates, hop_rates, nothing, @@ -66,20 +72,25 @@ function NSMJumpAggregation( jtov_map, pq, spatial_system, - num_specs) + num_specs + ) end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::NSM, starting_state, p, t, end_time, constant_jumps, +function aggregate( + aggregator::NSM, starting_state, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; hopping_constants, spatial_system, - kwargs...) + kwargs... + ) num_species = size(starting_state, 1) majumps = ma_jumps if majumps === nothing - majumps = MassActionJump(Vector{typeof(end_time)}(), + majumps = MassActionJump( + Vector{typeof(end_time)}(), Vector{Vector{Pair{Int, Int}}}(), - Vector{Vector{Pair{Int, Int}}}()) + Vector{Vector{Pair{Int, Int}}}() + ) end next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder @@ -87,9 +98,11 @@ function aggregate(aggregator::NSM, starting_state, p, t, end_time, constant_jum rx_rates = RxRates(num_sites(spatial_system), majumps) hop_rates = HopRates(hopping_constants, spatial_system) - NSMJumpAggregation(next_jump, next_jump_time, end_time, rx_rates, hop_rates, + return NSMJumpAggregation( + next_jump, next_jump_time, end_time, rx_rates, hop_rates, save_positions, rng, spatial_system; num_specs = num_species, - kwargs...) + kwargs... + ) end # set up a new simulation and calculate the first jump / jump time @@ -97,7 +110,7 @@ function initialize!(p::NSMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) generate_jumps!(p, integrator, params, u, t) - nothing + return nothing end # calculate the next jump / jump time @@ -105,7 +118,7 @@ function generate_jumps!(p::NSMJumpAggregation, integrator, params, u, t) p.next_jump_time, site = top_with_handle(p.pq) p.next_jump_time >= p.end_time && return nothing p.next_jump = sample_jump_direct(p, site) - nothing + return nothing end # execute one jump, changing the system state @@ -115,7 +128,7 @@ function execute_jumps!(p::NSMJumpAggregation, integrator, u, params, t, affects # update current jump rates and times update_dependent_rates_and_firing_times!(p, integrator, t) - nothing + return nothing end ######################## SSA specific helper routines ######################## @@ -143,7 +156,7 @@ function fill_rates_and_get_times!(aggregation::NSMJumpAggregation, integrator, end aggregation.pq = MutableBinaryMinHeap(pqdata) - nothing + return nothing end """ @@ -165,7 +178,7 @@ function update_dependent_rates_and_firing_times!(p::NSMJumpAggregation, integra update_rates_after_reaction!(p, integrator, site, reaction_id_from_jump(p, jump)) update_site_time!(p, site, t) end - nothing + return nothing end """ @@ -180,7 +193,7 @@ function update_site_time!(p::NSMJumpAggregation, site, t) else update!(p.pq, site, typemax(t)) end - nothing + return nothing end """ diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 737cc5c98..63f0eab31 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -22,7 +22,7 @@ initializes RxRates with zero rates function RxRates(num_sites::Int, ma_jumps::M) where {M} numrxjumps = get_num_majumps(ma_jumps) rates = zeros(Float64, numrxjumps, num_sites) - RxRates{Float64, M}(rates, vec(sum(rates, dims = 1)), ma_jumps) + return RxRates{Float64, M}(rates, vec(sum(rates, dims = 1)), ma_jumps) end num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) @@ -35,7 +35,7 @@ make all rates zero function reset!(rx_rates::RxRates) fill!(rx_rates.rates, zero(eltype(rx_rates.rates))) fill!(rx_rates.sum_rates, zero(eltype(rx_rates.sum_rates))) - nothing + return nothing end """ @@ -44,7 +44,7 @@ end return total reaction rate at site """ function total_site_rx_rate(rx_rates::RxRates, site) - @inbounds rx_rates.sum_rates[site] + return @inbounds rx_rates.sum_rates[site] end """ @@ -52,19 +52,23 @@ end update rates of all reactions in rxs at site """ -function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, - site) +function update_rx_rates!( + rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, + site + ) ma_jumps = rx_rates.ma_jumps - @inbounds for rx in rxs + return @inbounds for rx in rxs rate = eval_massaction_rate(u, rx, ma_jumps, site) set_rx_rate_at_site!(rx_rates, site, rx, rate) end end -function update_rx_rates!(rx_rates::RxRates, rxs, integrator, - site) +function update_rx_rates!( + rx_rates::RxRates, rxs, integrator, + site + ) u = integrator.u - update_rx_rates!(rx_rates, rxs, u, integrator, site) + return update_rx_rates!(rx_rates, rxs, u, integrator, site) end """ @@ -73,8 +77,10 @@ end sample a reaction at site, return reaction index """ function sample_rx_at_site(rx_rates::RxRates, site, rng) - linear_search((@view rx_rates.rates[:, site]), - rand(rng) * total_site_rx_rate(rx_rates, site)) + return linear_search( + (@view rx_rates.rates[:, site]), + rand(rng) * total_site_rx_rate(rx_rates, site) + ) end # helper functions @@ -82,17 +88,17 @@ function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate) @inbounds old_rate = rx_rates.rates[rx, site] @inbounds rx_rates.rates[rx, site] = rate @inbounds rx_rates.sum_rates[site] += rate - old_rate - old_rate + return old_rate end function Base.show(io::IO, ::MIME"text/plain", rx_rates::RxRates) num_rxs, num_sites = size(rx_rates.rates) - println(io, "RxRates with $num_rxs reactions and $num_sites sites") + return println(io, "RxRates with $num_rxs reactions and $num_sites sites") end function eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: SpatialMassActionJump} - evalrxrate(u, rx, ma_jumps, site) + return evalrxrate(u, rx, ma_jumps, site) end function eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: MassActionJump} - evalrxrate((@view u[:, site]), rx, ma_jumps) + return evalrxrate((@view u[:, site]), rx, ma_jumps) end diff --git a/src/spatial/spatial_massaction_jump.jl b/src/spatial/spatial_massaction_jump.jl index f1a4fea55..340e089ce 100644 --- a/src/spatial/spatial_massaction_jump.jl +++ b/src/spatial/spatial_massaction_jump.jl @@ -2,7 +2,7 @@ const AVecOrNothing = Union{AbstractVector, Nothing} const AMatOrNothing = Union{AbstractMatrix, Nothing} struct SpatialMassActionJump{A <: AVecOrNothing, B <: AMatOrNothing, S, U, V} <: - AbstractMassActionJump + AbstractMassActionJump uniform_rates::A # reactions that are uniform in space spatial_rates::B # reactions whose rate depends on the site reactant_stoch::S @@ -12,21 +12,25 @@ struct SpatialMassActionJump{A <: AVecOrNothing, B <: AMatOrNothing, S, U, V} <: """ uniform rates go first in ordering """ - function SpatialMassActionJump{A, B, S, U, V}(uniform_rates::A, spatial_rates::B, + function SpatialMassActionJump{A, B, S, U, V}( + uniform_rates::A, spatial_rates::B, reactant_stoch::S, net_stoch::U, param_mapper::V, scale_rates::Bool, useiszero::Bool, - nocopy::Bool) where {A <: AVecOrNothing, + nocopy::Bool + ) where { + A <: AVecOrNothing, B <: AMatOrNothing, - S, U, V} + S, U, V, + } uniform_rates = (nocopy || isnothing(uniform_rates)) ? uniform_rates : - copy(uniform_rates) + copy(uniform_rates) spatial_rates = (nocopy || isnothing(spatial_rates)) ? spatial_rates : - copy(spatial_rates) + copy(spatial_rates) reactant_stoch = nocopy ? reactant_stoch : copy(reactant_stoch) for i in eachindex(reactant_stoch) if useiszero && (length(reactant_stoch[i]) == 1) && - iszero(reactant_stoch[i][1][1]) + iszero(reactant_stoch[i][1][1]) reactant_stoch[i] = typeof(reactant_stoch[i])() end end @@ -37,100 +41,152 @@ struct SpatialMassActionJump{A <: AVecOrNothing, B <: AMatOrNothing, S, U, V} <: if scale_rates && !isnothing(spatial_rates) && !isempty(spatial_rates) scalerates!(spatial_rates, reactant_stoch[(num_unif_rates + 1):end]) end - new(uniform_rates, spatial_rates, reactant_stoch, net_stoch, param_mapper) + return new(uniform_rates, spatial_rates, reactant_stoch, net_stoch, param_mapper) end end ################ Constructors ################## -function SpatialMassActionJump(urates::A, srates::B, rs::S, ns::U, pmapper::V; +function SpatialMassActionJump( + urates::A, srates::B, rs::S, ns::U, pmapper::V; scale_rates = true, useiszero = true, - nocopy = false) where {A <: AVecOrNothing, - B <: AMatOrNothing, S, U, V} - SpatialMassActionJump{A, B, S, U, V}(urates, srates, rs, ns, pmapper, scale_rates, - useiszero, nocopy) -end -function SpatialMassActionJump(urates::A, srates::B, rs, ns; scale_rates = true, + nocopy = false + ) where { + A <: AVecOrNothing, + B <: AMatOrNothing, S, U, V, + } + return SpatialMassActionJump{A, B, S, U, V}( + urates, srates, rs, ns, pmapper, scale_rates, + useiszero, nocopy + ) +end +function SpatialMassActionJump( + urates::A, srates::B, rs, ns; scale_rates = true, useiszero = true, - nocopy = false) where {A <: AVecOrNothing, - B <: AMatOrNothing} - SpatialMassActionJump(urates, srates, rs, ns, nothing; scale_rates = scale_rates, - useiszero = useiszero, nocopy = nocopy) + nocopy = false + ) where { + A <: AVecOrNothing, + B <: AMatOrNothing, + } + return SpatialMassActionJump( + urates, srates, rs, ns, nothing; scale_rates = scale_rates, + useiszero = useiszero, nocopy = nocopy + ) end -function SpatialMassActionJump(srates::B, rs, ns, pmapper; scale_rates = true, +function SpatialMassActionJump( + srates::B, rs, ns, pmapper; scale_rates = true, useiszero = true, - nocopy = false) where {B <: AMatOrNothing} - SpatialMassActionJump(nothing, srates, rs, ns, pmapper; scale_rates = scale_rates, - useiszero = useiszero, nocopy = nocopy) -end -function SpatialMassActionJump(srates::B, rs, ns; scale_rates = true, useiszero = true, - nocopy = false) where {B <: AMatOrNothing} - SpatialMassActionJump(nothing, srates, rs, ns, nothing; scale_rates = scale_rates, - useiszero = useiszero, nocopy = nocopy) + nocopy = false + ) where {B <: AMatOrNothing} + return SpatialMassActionJump( + nothing, srates, rs, ns, pmapper; scale_rates = scale_rates, + useiszero = useiszero, nocopy = nocopy + ) +end +function SpatialMassActionJump( + srates::B, rs, ns; scale_rates = true, useiszero = true, + nocopy = false + ) where {B <: AMatOrNothing} + return SpatialMassActionJump( + nothing, srates, rs, ns, nothing; scale_rates = scale_rates, + useiszero = useiszero, nocopy = nocopy + ) end -function SpatialMassActionJump(urates::A, rs, ns, pmapper; scale_rates = true, +function SpatialMassActionJump( + urates::A, rs, ns, pmapper; scale_rates = true, useiszero = true, - nocopy = false) where {A <: AVecOrNothing} - SpatialMassActionJump(urates, nothing, rs, ns, pmapper; scale_rates = scale_rates, - useiszero = useiszero, nocopy = nocopy) -end -function SpatialMassActionJump(urates::A, rs, ns; scale_rates = true, useiszero = true, - nocopy = false) where {A <: AVecOrNothing} - SpatialMassActionJump(urates, nothing, rs, ns, nothing; scale_rates = scale_rates, - useiszero = useiszero, nocopy = nocopy) + nocopy = false + ) where {A <: AVecOrNothing} + return SpatialMassActionJump( + urates, nothing, rs, ns, pmapper; scale_rates = scale_rates, + useiszero = useiszero, nocopy = nocopy + ) +end +function SpatialMassActionJump( + urates::A, rs, ns; scale_rates = true, useiszero = true, + nocopy = false + ) where {A <: AVecOrNothing} + return SpatialMassActionJump( + urates, nothing, rs, ns, nothing; scale_rates = scale_rates, + useiszero = useiszero, nocopy = nocopy + ) end -function SpatialMassActionJump(ma_jumps::MassActionJump{T, S, U, V}; scale_rates = true, - useiszero = true, nocopy = false) where {T, S, U, V} - SpatialMassActionJump(ma_jumps.scaled_rates, ma_jumps.reactant_stoch, +function SpatialMassActionJump( + ma_jumps::MassActionJump{T, S, U, V}; scale_rates = true, + useiszero = true, nocopy = false + ) where {T, S, U, V} + return SpatialMassActionJump( + ma_jumps.scaled_rates, ma_jumps.reactant_stoch, ma_jumps.net_stoch, ma_jumps.param_mapper; - scale_rates = scale_rates, useiszero = useiszero, nocopy = nocopy) + scale_rates = scale_rates, useiszero = useiszero, nocopy = nocopy + ) end ############################################## -function get_num_majumps(smaj::SpatialMassActionJump{ - Nothing, Nothing, S, U, V}) where - {S, U, V} - 0 -end -function get_num_majumps(smaj::SpatialMassActionJump{ - Nothing, B, S, U, V}) where - {B, S, U, V} - size(smaj.spatial_rates, 1) -end -function get_num_majumps(smaj::SpatialMassActionJump{ - A, Nothing, S, U, V}) where - {A, S, U, V} - length(smaj.uniform_rates) -end -function get_num_majumps(smaj::SpatialMassActionJump{ - A, B, S, U, V}) where - {A <: AbstractVector, B <: AbstractMatrix, S, U, V} - length(smaj.uniform_rates) + size(smaj.spatial_rates, 1) +function get_num_majumps( + smaj::SpatialMassActionJump{ + Nothing, Nothing, S, U, V, + } + ) where + {S, U, V} + return 0 +end +function get_num_majumps( + smaj::SpatialMassActionJump{ + Nothing, B, S, U, V, + } + ) where + {B, S, U, V} + return size(smaj.spatial_rates, 1) +end +function get_num_majumps( + smaj::SpatialMassActionJump{ + A, Nothing, S, U, V, + } + ) where + {A, S, U, V} + return length(smaj.uniform_rates) +end +function get_num_majumps( + smaj::SpatialMassActionJump{ + A, B, S, U, V, + } + ) where + {A <: AbstractVector, B <: AbstractMatrix, S, U, V} + return length(smaj.uniform_rates) + size(smaj.spatial_rates, 1) end using_params(smaj::SpatialMassActionJump) = false -function rate_at_site(rx, site, - smaj::SpatialMassActionJump{Nothing, B, S, U, V}) where {B, S, U, V} - smaj.spatial_rates[rx, site] -end -function rate_at_site(rx, site, - smaj::SpatialMassActionJump{A, Nothing, S, U, V}) where {A, S, U, V} - smaj.uniform_rates[rx] -end -function rate_at_site(rx, site, - smaj::SpatialMassActionJump{A, B, S, U, V}) where - {A <: AbstractVector, B <: AbstractMatrix, S, U, V} +function rate_at_site( + rx, site, + smaj::SpatialMassActionJump{Nothing, B, S, U, V} + ) where {B, S, U, V} + return smaj.spatial_rates[rx, site] +end +function rate_at_site( + rx, site, + smaj::SpatialMassActionJump{A, Nothing, S, U, V} + ) where {A, S, U, V} + return smaj.uniform_rates[rx] +end +function rate_at_site( + rx, site, + smaj::SpatialMassActionJump{A, B, S, U, V} + ) where + {A <: AbstractVector, B <: AbstractMatrix, S, U, V} num_unif_rxs = length(smaj.uniform_rates) - rx <= num_unif_rxs ? smaj.uniform_rates[rx] : - smaj.spatial_rates[rx - num_unif_rxs, site] + return rx <= num_unif_rxs ? smaj.uniform_rates[rx] : + smaj.spatial_rates[rx - num_unif_rxs, site] end -function evalrxrate(speciesmat::AbstractMatrix{T}, rxidx::S, majump::SpatialMassActionJump, - site::Int) where {T, S} +function evalrxrate( + speciesmat::AbstractMatrix{T}, rxidx::S, majump::SpatialMassActionJump, + site::Int + ) where {T, S} val = one(T) @inbounds for specstoch in majump.reactant_stoch[rxidx] specpop = speciesmat[specstoch[1], site] diff --git a/src/spatial/topology.jl b/src/spatial/topology.jl index bf2d5deb7..17d644a4a 100644 --- a/src/spatial/topology.jl +++ b/src/spatial/topology.jl @@ -15,7 +15,7 @@ const offsets_2D = [ CartesianIndex(0, -1), CartesianIndex(-1, 0), CartesianIndex(1, 0), - CartesianIndex(0, 1) + CartesianIndex(0, 1), ] const offsets_3D = [ CartesianIndex(0, 0, -1), @@ -23,7 +23,7 @@ const offsets_3D = [ CartesianIndex(-1, 0, 0), CartesianIndex(1, 0, 0), CartesianIndex(0, 1, 0), - CartesianIndex(0, 0, 1) + CartesianIndex(0, 0, 1), ] """ @@ -61,7 +61,7 @@ function nth_nbr(grid, site, n) CI = grid.CI offsets = grid.offsets @inbounds I = CI[site] - @inbounds for off in offsets + return @inbounds for off in offsets nb = I + off if nb in CI n -= 1 @@ -79,7 +79,7 @@ function neighbors(grid, site) CI = grid.CI LI = grid.LI I = CI[site] - Iterators.map(off -> LI[off + I], Iterators.filter(off -> off + I in CI, grid.offsets)) + return Iterators.map(off -> LI[off + I], Iterators.filter(off -> off + I in CI, grid.offsets)) end """ @@ -97,7 +97,7 @@ function pad_hop_vec!(to_pad::AbstractVector, grid, site, hop_vec::AbstractVecto to_pad[i] = zero(eltype(to_pad)) end end - to_pad + return to_pad end CartesianGrid(dims) = CartesianGridRej(dims) # use CartesianGridRej by default @@ -127,11 +127,11 @@ function CartesianGridRej(dims::Tuple) LI = LinearIndices(dims) offsets = potential_offsets(dim) nums_neighbors = Int8[count(x -> x + CI[site] in CI, offsets) for site in 1:prod(dims)] - CartesianGridRej(dims, nums_neighbors, CI, LI, offsets) + return CartesianGridRej(dims, nums_neighbors, CI, LI, offsets) end CartesianGridRej(dims) = CartesianGridRej(Tuple(dims)) function CartesianGridRej(dimension, linear_size::Int) - CartesianGridRej([linear_size for i in 1:dimension]) + return CartesianGridRej([linear_size for i in 1:dimension]) end function rand_nbr(rng, grid::CartesianGridRej, site::Int) CI = grid.CI @@ -141,9 +141,12 @@ function rand_nbr(rng, grid::CartesianGridRej, site::Int) @inbounds nb = rand(rng, offsets) + I @inbounds nb in CI && return grid.LI[nb] end + return end -function Base.show(io::IO, ::MIME"text/plain", - grid::CartesianGridRej) - println(io, "A Cartesian grid with dimensions $(grid.dims)") +function Base.show( + io::IO, ::MIME"text/plain", + grid::CartesianGridRej + ) + return println(io, "A Cartesian grid with dimensions $(grid.dims)") end diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index 370e86015..877929748 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -17,8 +17,10 @@ struct SpatialJump{J} end function Base.show(io::IO, ::MIME"text/plain", jump::SpatialJump) - println(io, - "SpatialJump with source $(jump.src), destination $(jump.dst) and index $(jump.jidx).") + return println( + io, + "SpatialJump with source $(jump.src), destination $(jump.dst) and index $(jump.jidx)." + ) end ######################## helper routines for all spatial SSAs ######################## @@ -29,25 +31,27 @@ sample jump at site with direct method """ function sample_jump_direct(p, site) if rand(p.rng) * (total_site_rate(p.rx_rates, p.hop_rates, site)) < - total_site_rx_rate(p.rx_rates, site) + total_site_rx_rate(p.rx_rates, site) rx = sample_rx_at_site(p.rx_rates, site, p.rng) return SpatialJump(site, rx + p.numspecies, site) else species_to_diffuse, - target_site = sample_hop_at_site(p.hop_rates, site, p.rng, - p.spatial_system) + target_site = sample_hop_at_site( + p.hop_rates, site, p.rng, + p.spatial_system + ) return SpatialJump(site, species_to_diffuse, target_site) end end function total_site_rate(rx_rates::RxRates, hop_rates::AbstractHopRates, site) - total_site_hop_rate(hop_rates, site) + total_site_rx_rate(rx_rates, site) + return total_site_hop_rate(hop_rates, site) + total_site_rx_rate(rx_rates, site) end function update_rates_after_reaction!(p, integrator, site, reaction_id) u = integrator.u update_rx_rates!(p.rx_rates, p.dep_gr[reaction_id], integrator, site) - update_hop_rates!(p.hop_rates, p.jumptovars_map[reaction_id], u, site, p.spatial_system) + return update_hop_rates!(p.hop_rates, p.jumptovars_map[reaction_id], u, site, p.spatial_system) end function update_rates_after_hop!(p, integrator, source_site, target_site, species) @@ -56,7 +60,7 @@ function update_rates_after_hop!(p, integrator, source_site, target_site, specie update_hop_rate!(p.hop_rates, species, u, source_site, p.spatial_system) update_rx_rates!(p.rx_rates, p.vartojumps_map[species], integrator, target_site) - update_hop_rate!(p.hop_rates, species, u, target_site, p.spatial_system) + return update_hop_rate!(p.hop_rates, species, u, target_site, p.spatial_system) end """ @@ -70,12 +74,14 @@ function update_state!(p, integrator) execute_hop!(integrator, jump.src, jump.dst, jump.jidx) else rx_index = reaction_id_from_jump(p, jump) - @inbounds executerx!((@view integrator.u[:, jump.src]), rx_index, - p.rx_rates.ma_jumps) + @inbounds executerx!( + (@view integrator.u[:, jump.src]), rx_index, + p.rx_rates.ma_jumps + ) end # save jump that was just executed p.prev_jump = jump - nothing + return nothing end """ @@ -84,7 +90,7 @@ end true if jump is a hop """ function is_hop(p, jump) - jump.jidx <= p.numspecies + return jump.jidx <= p.numspecies end """ @@ -94,7 +100,7 @@ documentation """ function execute_hop!(integrator, source_site, target_site, species) @inbounds integrator.u[species, source_site] -= 1 - @inbounds integrator.u[species, target_site] += 1 + return @inbounds integrator.u[species, target_site] += 1 end """ @@ -103,5 +109,5 @@ end return reaction id by subtracting the number of hops """ function reaction_id_from_jump(p, jump) - jump.jidx - p.numspecies + return jump.jidx - p.numspecies end diff --git a/src/variable_rate.jl b/src/variable_rate.jl index 4dc57c717..e7c15ba34 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -58,8 +58,10 @@ sol = solve(jprob, Tsit5()) """ struct VR_FRM <: VariableRateAggregator end -function configure_jump_problem(prob, vr_aggregator::VR_FRM, jumps, cvrjs; - rng = DEFAULT_RNG) +function configure_jump_problem( + prob, vr_aggregator::VR_FRM, jumps, cvrjs; + rng = DEFAULT_RNG + ) new_prob = extend_problem(prob, cvrjs; rng) variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng) return new_prob, variable_jump_callback @@ -84,7 +86,7 @@ function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAUL jump_f = let _f = _f function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t) _f(du.u, u.u, p, t) - update_jumps!(du, u, p, t, length(u.u), jumps...) + return update_jumps!(du, u, p, t, length(u.u), jumps...) end end else @@ -98,9 +100,11 @@ function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAUL end u0 = extend_u0(prob, length(jumps), rng) - f = ODEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, - observed = prob.f.observed) - remake(prob; f, u0) + f = ODEFunction{isinplace(prob)}( + jump_f; sys = prob.f.sys, + observed = prob.f.observed + ) + return remake(prob; f, u0) end function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAULT_RNG) @@ -110,7 +114,7 @@ function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAUL jump_f = let _f = _f function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t) _f(du.u, u.u, p, t) - update_jumps!(du, u, p, t, length(u.u), jumps...) + return update_jumps!(du, u, p, t, length(u.u), jumps...) end end else @@ -125,18 +129,20 @@ function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAUL if prob.noise_rate_prototype === nothing jump_g = function (du, u, p, t) - prob.g(du.u, u.u, p, t) + return prob.g(du.u, u.u, p, t) end else jump_g = function (du, u, p, t) - prob.g(du, u.u, p, t) + return prob.g(du, u.u, p, t) end end u0 = extend_u0(prob, length(jumps), rng) - f = SDEFunction{isinplace(prob)}(jump_f, jump_g; sys = prob.f.sys, - observed = prob.f.observed) - remake(prob; f, g = jump_g, u0) + f = SDEFunction{isinplace(prob)}( + jump_f, jump_g; sys = prob.f.sys, + observed = prob.f.observed + ) + return remake(prob; f, g = jump_g, u0) end function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAULT_RNG) @@ -146,7 +152,7 @@ function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAUL jump_f = let _f = _f function (du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t) _f(du.u, u.u, h, p, t) - update_jumps!(du, u, p, t, length(u.u), jumps...) + return update_jumps!(du, u, p, t, length(u.u), jumps...) end end else @@ -160,9 +166,11 @@ function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAUL end u0 = extend_u0(prob, length(jumps), rng) - f = DDEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, - observed = prob.f.observed) - remake(prob; f, u0) + f = DDEFunction{isinplace(prob)}( + jump_f; sys = prob.f.sys, + observed = prob.f.observed + ) + return remake(prob; f, u0) end # Not sure if the DAE one is correct: Should be a residual of sorts @@ -173,7 +181,7 @@ function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAUL jump_f = let _f = _f function (out, du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t) _f(out, du.u, u.u, h, p, t) - update_jumps!(out, u, p, t, length(u.u), jumps...) + return update_jumps!(out, u, p, t, length(u.u), jumps...) end end else @@ -187,50 +195,54 @@ function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAUL end u0 = extend_u0(prob, length(jumps), rng) - f = DAEFunction{isinplace(prob)}(jump_f, sys = prob.f.sys, - observed = prob.f.observed) - remake(prob; f, u0) + f = DAEFunction{isinplace(prob)}( + jump_f, sys = prob.f.sys, + observed = prob.f.observed + ) + return remake(prob; f, u0) end function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG) condition = function (u, t, integrator) - u.jump_u[idx] + return u.jump_u[idx] end affect! = function (integrator) jump.affect!(integrator) integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t)) - nothing + return nothing end - new_cb = ContinuousCallback(condition, affect!; + new_cb = ContinuousCallback( + condition, affect!; idxs = jump.idxs, rootfind = jump.rootfind, interp_points = jump.interp_points, save_positions = jump.save_positions, abstol = jump.abstol, - reltol = jump.reltol) + reltol = jump.reltol + ) return new_cb end function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG) idx += 1 new_cb = wrap_jump_in_callback(idx, jump; rng) - build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG) + return build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG) end function build_variable_callback(cb, idx, jump; rng = DEFAULT_RNG) idx += 1 - CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng)) + return CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng)) end @inline function update_jumps!(du, u, p, t, idx, jump) idx += 1 - du[idx] = jump.rate(u.u, p, t) + return du[idx] = jump.rate(u.u, p, t) end @inline function update_jumps!(du, u, p, t, idx, jump, jumps...) idx += 1 du[idx] = jump.rate(u.u, p, t) - update_jumps!(du, u, p, t, idx, jumps...) + return update_jumps!(du, u, p, t, idx, jumps...) end ################################### VR_Direct and VR_DirectFW #################################### @@ -294,7 +306,8 @@ mutable struct VR_DirectEventCache{T, RNG, F1, F2} end function VR_DirectEventCache( - jumps::JumpSet, ::VR_Direct, prob, ::Type{T}; rng = DEFAULT_RNG) where {T} + jumps::JumpSet, ::VR_Direct, prob, ::Type{T}; rng = DEFAULT_RNG + ) where {T} initial_threshold = randexp(rng, T) vjumps = jumps.variable_jumps @@ -303,13 +316,16 @@ function VR_DirectEventCache( cum_rate_sum = Vector{T}(undef, length(vjumps)) - VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T), + return VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}( + zero(T), initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, - affect_funcs, cum_rate_sum) + affect_funcs, cum_rate_sum + ) end function VR_DirectEventCache( - jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}; rng = DEFAULT_RNG) where {T} + jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}; rng = DEFAULT_RNG + ) where {T} initial_threshold = randexp(rng, T) vjumps = jumps.variable_jumps @@ -320,9 +336,11 @@ function VR_DirectEventCache( cum_rate_sum = Vector{T}(undef, length(vjumps)) - VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), Any}(zero(T), + return VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), Any}( + zero(T), initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, - affect_funcs, cum_rate_sum) + affect_funcs, cum_rate_sum + ) end # Initialization function for VR_DirectEventCache @@ -333,23 +351,29 @@ function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrato cache.current_threshold = cache.prev_threshold cache.total_rate = zero(integrator.t) cache.cum_rate_sum .= 0 - nothing + return nothing end -@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache, - ::I) where {I <: SciMLBase.DEIntegrator} +@inline function concretize_vr_direct_affects!( + cache::VR_DirectEventCache, + ::I + ) where {I <: SciMLBase.DEIntegrator} if (cache.affect_funcs isa Vector) && - !(cache.affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}) + !(cache.affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}) AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}} - cache.affect_funcs = AffectWrapper[makewrapper(AffectWrapper, aff) - for aff in cache.affect_funcs] + cache.affect_funcs = AffectWrapper[ + makewrapper(AffectWrapper, aff) + for aff in cache.affect_funcs + ] end - nothing + return nothing end -@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2}, - ::I) where {T, RNG, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} - nothing +@inline function concretize_vr_direct_affects!( + cache::VR_DirectEventCache{T, RNG, F1, F2}, + ::I + ) where {T, RNG, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} + return nothing end # Wrapper for initialize to match ContinuousCallback signature @@ -357,7 +381,7 @@ function initialize_vr_direct_wrapper(cb::ContinuousCallback, u, t, integrator) concretize_vr_direct_affects!(cb.condition, integrator) initialize_vr_direct_cache!(cb.condition, u, t, integrator) u_modified!(integrator, false) - nothing + return nothing end # Merge callback parameters across all jumps for VR_Direct @@ -372,8 +396,10 @@ function build_variable_integcallback(cache::VR_DirectEventCache, jumps) reltol = min(reltol, jump.reltol) end - return ContinuousCallback(cache, cache; initialize = initialize_vr_direct_wrapper, - save_positions, abstol, reltol) + return ContinuousCallback( + cache, cache; initialize = initialize_vr_direct_wrapper, + save_positions, abstol, reltol + ) end function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs; rng = DEFAULT_RNG) @@ -390,26 +416,28 @@ function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs; rng = DEFAULT return new_prob, variable_jump_callback end -# recursively evaluate the cumulative sum of the rates for type stability +# recursively evaluate the cumulative sum of the rates for type stability @inline function cumsum_rates!(cum_rate_sum, u, p, t, rates) cur_sum = zero(eltype(cum_rate_sum)) - cumsum_rates!(cum_rate_sum, u, p, t, 1, cur_sum, rates...) + return cumsum_rates!(cum_rate_sum, u, p, t, 1, cur_sum, rates...) end @inline function cumsum_rates!(cum_rate_sum, u, p, t, idx, cur_sum, rate, rates...) new_sum = cur_sum + rate(u, p, t) @inbounds cum_rate_sum[idx] = new_sum idx += 1 - cumsum_rates!(cum_rate_sum, u, p, t, idx, new_sum, rates...) + return cumsum_rates!(cum_rate_sum, u, p, t, idx, new_sum, rates...) end @inline function cumsum_rates!(cum_rate_sum, u, p, t, idx, cur_sum, rate) - @inbounds cum_rate_sum[idx] = cur_sum + rate(u, p, t) + return @inbounds cum_rate_sum[idx] = cur_sum + rate(u, p, t) end function total_variable_rate( cache::VR_DirectEventCache{ - T, RNG, F1, F2}, u, p, t) where {T, RNG, F1, F2} + T, RNG, F1, F2, + }, u, p, t + ) where {T, RNG, F1, F2} (; cum_rate_sum, rate_funcs) = cache sum_rate = cumsum_rates!(cum_rate_sum, u, p, t, rate_funcs) return sum_rate @@ -454,18 +482,22 @@ function (cache::VR_DirectEventCache)(u, t, integrator) return cache.current_threshold end -@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, - integrator::I, idx) where {T, RNG, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} - quote +@generated function execute_affect!( + cache::VR_DirectEventCache{T, RNG, F1, F2}, + integrator::I, idx + ) where {T, RNG, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} + return quote (; affect_funcs) = cache Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds affect_funcs[i](integrator)) i -> (@inbounds affect_funcs[fieldcount(F2)](integrator)) end end -@inline function execute_affect!(cache::VR_DirectEventCache, - integrator::I, idx) where {I <: SciMLBase.DEIntegrator} +@inline function execute_affect!( + cache::VR_DirectEventCache, + integrator::I, idx + ) where {I <: SciMLBase.DEIntegrator} (; affect_funcs) = cache - if affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}} + return if affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}} @inbounds affect_funcs[idx](integrator) else error("Error, invalid affect_funcs type. Expected a vector of function wrappers and got $(typeof(affect_funcs))") diff --git a/test/allocations.jl b/test/allocations.jl index 64efb7582..263864aa6 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -14,7 +14,7 @@ let function affect1!(integrator) integrator.u[1] -= 1 # S -> S - 1 integrator.u[2] += 1 # I -> I + 1 - nothing + return nothing end jump = ConstantRateJump(rate1, affect1!) @@ -22,7 +22,7 @@ let function affect2!(integrator) integrator.u[2] -= 1 # I -> I - 1 integrator.u[3] += 1 # R -> R + 1 - nothing + return nothing end jump2 = ConstantRateJump(rate2, affect2!) @@ -55,8 +55,10 @@ let return (η / K) * (K - (X + Y)) end - function makeprob(; T = 100.0, alg = Direct(), save_positions = (false, false), - graphkwargs = (;), rng) + function makeprob(; + T = 100.0, alg = Direct(), save_positions = (false, false), + graphkwargs = (;), rng + ) r1(u, p, t) = rate(p[1], u[1], u[2], p[2]) * u[1] r2(u, p, t) = rate(p[1], u[2], u[1], p[2]) * u[2] r3(u, p, t) = p[3] * u[1] @@ -69,24 +71,26 @@ let aff4!(integrator) = integrator.u[2] -= 1 function aff5!(integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end function aff6!(integrator) integrator.u[1] += 1 - integrator.u[2] -= 1 + return integrator.u[2] -= 1 end # η K μ γ ρ - p = (1.0, 1e4, 0.1, 1e-4, 0.01) + p = (1.0, 1.0e4, 0.1, 1.0e-4, 0.01) u0 = [1000, 10] tspan = (0.0, T) dprob = DiscreteProblem(u0, tspan, p) - jprob = JumpProblem(dprob, alg, + jprob = JumpProblem( + dprob, alg, ConstantRateJump(r1, aff1!), ConstantRateJump(r2, aff2!), ConstantRateJump(r3, aff3!), ConstantRateJump(r4, aff4!), ConstantRateJump(r5, aff5!), ConstantRateJump(r6, aff6!); - save_positions, rng, graphkwargs...) + save_positions, rng, graphkwargs... + ) return jprob end diff --git a/test/bimolerx_test.jl b/test/bimolerx_test.jl index befeefab1..3fd3af4eb 100644 --- a/test/bimolerx_test.jl +++ b/test/bimolerx_test.jl @@ -40,14 +40,14 @@ reactstoch = [ [2 => 1], [1 => 1, 2 => 1], [3 => 1], - [3 => 3] + [3 => 3], ] netstoch = [ [1 => -2, 2 => 1], [1 => 2, 2 => -1], [1 => -1, 2 => -1, 3 => 1], [1 => 1, 2 => 1, 3 => -1], - [1 => 3, 3 => -3] + [1 => 3, 3 => -3], ] rates = [1.0, 2.0, 0.5, 0.75, 0.25] spec_to_dep_jumps = [[1, 3], [2, 3], [4, 5]] @@ -61,7 +61,7 @@ function runSSAs(jump_prob; use_stepper = true) sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob) Psamp[i] = sol[1, end] end - mean(Psamp) + return mean(Psamp) end # TESTING: @@ -70,9 +70,11 @@ prob = DiscreteProblem(u0, (0.0, tf), rates) # plotting one full trajectory if doplot for alg in SSAalgs - local jump_prob = JumpProblem(prob, alg, majumps, + local jump_prob = JumpProblem( + prob, alg, majumps, vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) + jumptovars_map = jump_to_dep_specs, rng = rng + ) local sol = solve(jump_prob, SSAStepper()) local plothand = plot(sol, seriestype = :steppost, reuse = false) display(plothand) @@ -82,13 +84,17 @@ end # test the means if dotestmean for (i, alg) in enumerate(SSAalgs) - local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false), + local jump_prob = JumpProblem( + prob, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) + jumptovars_map = jump_to_dep_specs, rng = rng + ) means = runSSAs(jump_prob) relerr = abs(means - expected_avg) / expected_avg - doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means, - ", rel err = ", relerr) + doprintmeans && println( + "Mean from method: ", typeof(alg), " is = ", means, + ", rel err = ", relerr + ) @test abs(means - expected_avg) < reltol * expected_avg # test not specifying SSAStepper @@ -105,14 +111,18 @@ if dotestmean push!(majump_vec, MassActionJump(rates[i], reactstoch[i], netstoch[i])) end jset = JumpSet((), (), nothing, majump_vec) - jump_prob = JumpProblem(prob, Direct(), jset, save_positions = (false, false), + jump_prob = JumpProblem( + prob, Direct(), jset, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) + jumptovars_map = jump_to_dep_specs, rng = rng + ) meanval = runSSAs(jump_prob) relerr = abs(meanval - expected_avg) / expected_avg if doprintmeans - println("Using individual MassActionJumps; Mean from method: ", typeof(Direct()), - " is = ", meanval, ", rel err = ", relerr) + println( + "Using individual MassActionJumps; Mean from method: ", typeof(Direct()), + " is = ", meanval, ", rel err = ", relerr + ) end @test abs(meanval - expected_avg) < reltol * expected_avg end diff --git a/test/bracketing.jl b/test/bracketing.jl index 7a4776da4..80cde8693 100644 --- a/test/bracketing.jl +++ b/test/bracketing.jl @@ -8,7 +8,7 @@ bd = BracketData(fluctuation_rate, threshold, Δu) ### Getters ### species_index = 1 -# The fluctuation rate δ corresponds to species brackets (1-δ)*u, (1+δ)*u. So 0 < δ < 1. +# The fluctuation rate δ corresponds to species brackets (1-δ)*u, (1+δ)*u. So 0 < δ < 1. @test 0 < JP.getfr(bd, species_index) < 1 # If u < threshold, then the brackets are (max(u-Δu, 0), u+Δu). So 0 <= threshold and 0 <= Δu. @test 0 <= JP.gettv(bd, species_index) @@ -23,8 +23,8 @@ species_index = 2 @test JP.get_spec_brackets(bd, species_index, u)[1] == u[2] - Δu @test JP.get_spec_brackets(bd, species_index, u)[2] == u[2] + Δu species_index = 3 -@test JP.get_spec_brackets(bd, species_index, u)[1]≈u[3] * (1 - fluctuation_rate) atol=1 -@test JP.get_spec_brackets(bd, species_index, u)[2]≈u[3] * (1 + fluctuation_rate) atol=1 +@test JP.get_spec_brackets(bd, species_index, u)[1] ≈ u[3] * (1 - fluctuation_rate) atol = 1 +@test JP.get_spec_brackets(bd, species_index, u)[2] ≈ u[3] * (1 + fluctuation_rate) atol = 1 ### Reaction rate brackets ### ulow = [2] @@ -34,8 +34,10 @@ uhigh = [10] majump_rates = [0.1] # death at rate 0.1 reactstoch = [[1 => 1]] netstoch = [[1 => -1]] -majump = MassActionJump(majump_rates, reactstoch, - netstoch) +majump = MassActionJump( + majump_rates, reactstoch, + netstoch +) reaction_index = 1 @test JP.get_majump_brackets(ulow, uhigh, reaction_index, majump)[1] == majump_rates[1] * ulow[1] # low @test JP.get_majump_brackets(ulow, uhigh, reaction_index, majump)[2] == majump_rates[1] * uhigh[1] # high @@ -49,7 +51,7 @@ t = 0.0 ### Aggregator ### mutable struct DummyAggregator{T, M, R, BD} <: - JP.AbstractSSAJumpAggregator{T, M, R, Nothing, Nothing} + JP.AbstractSSAJumpAggregator{T, M, R, Nothing, Nothing} ulow::Vector{Int} uhigh::Vector{Int} cur_rate_low::Vector{T} @@ -67,8 +69,8 @@ p = DummyAggregator([0], [0], cur_rate_low, cur_rate_high, sum_rate, majump, [ra u = [100] JP.update_u_brackets!(p, u) -@test p.ulow[1]≈u[1] * (1 - fluctuation_rate) atol=1 -@test p.uhigh[1]≈u[1] * (1 + fluctuation_rate) atol=1 +@test p.ulow[1] ≈ u[1] * (1 - fluctuation_rate) atol = 1 +@test p.uhigh[1] ≈ u[1] * (1 + fluctuation_rate) atol = 1 reaction_index = 1 @test JP.get_jump_brackets(reaction_index, p, params, t)[1] == majump_rates[1] * p.ulow[1] @@ -79,10 +81,10 @@ reaction_index = 2 p = DummyAggregator([0], [0], cur_rate_low, cur_rate_high, sum_rate, majump, [rate], bd) JP.set_bracketing!(p, u, params, t) -@test p.ulow[1]≈u[1] * (1 - fluctuation_rate) atol=1 -@test p.uhigh[1]≈u[1] * (1 + fluctuation_rate) atol=1 -@test p.cur_rate_low[1]≈majump_rates[1] * u[1] * (1 - fluctuation_rate) atol=1 -@test p.cur_rate_high[1]≈majump_rates[1] * u[1] * (1 + fluctuation_rate) atol=1 +@test p.ulow[1] ≈ u[1] * (1 - fluctuation_rate) atol = 1 +@test p.uhigh[1] ≈ u[1] * (1 + fluctuation_rate) atol = 1 +@test p.cur_rate_low[1] ≈ majump_rates[1] * u[1] * (1 - fluctuation_rate) atol = 1 +@test p.cur_rate_high[1] ≈ majump_rates[1] * u[1] * (1 + fluctuation_rate) atol = 1 @test p.cur_rate_low[2] == rate(p.uhigh, params, t) @test p.cur_rate_high[2] == rate(p.ulow, params, t) @test p.sum_rate ≈ sum(p.cur_rate_high) diff --git a/test/constant_rate.jl b/test/constant_rate.jl index 86c237c06..22289e6c3 100644 --- a/test/constant_rate.jl +++ b/test/constant_rate.jl @@ -5,13 +5,13 @@ rng = StableRNG(12345) rate = (u, p, t) -> u affect! = function (integrator) - integrator.u += 1 + return integrator.u += 1 end jump = ConstantRateJump(rate, affect!) rate = (u, p, t) -> 0.5u affect! = function (integrator) - integrator.u -= 1 + return integrator.u -= 1 end jump2 = ConstantRateJump(rate, affect!) diff --git a/test/degenerate_rx_cases.jl b/test/degenerate_rx_cases.jl index 3b211decd..289f974e9 100644 --- a/test/degenerate_rx_cases.jl +++ b/test/degenerate_rx_cases.jl @@ -84,7 +84,7 @@ ns = [1 => 1] jump = MassActionJump(rate, rs, ns) ratefun = (u, p, t) -> 2.0 * u[1] affect! = function (integrator) - integrator.u[1] -= 1 + return integrator.u[1] -= 1 end jump2 = ConstantRateJump(ratefun, affect!) if doplot @@ -93,7 +93,7 @@ end dep_graph = [ [1, 2], - [1, 2] + [1, 2], ] spec_to_dep_jumps = [[2]] jump_to_dep_specs = [[1], [1]] @@ -128,12 +128,16 @@ maj1 = MassActionJump(rs1, ns1; param_idxs = 1) maj2 = MassActionJump(rs2, ns2; param_idxs = 2) js = JumpSet(maj1, maj2) maj = MassActionJump([rs1, rs2], [ns1, ns2]; param_idxs = [1, 2]) -@test all(getfield(maj, fn) == getfield(js.massaction_jump, fn) -for fn in [:scaled_rates, :reactant_stoch, :net_stoch]) +@test all( + getfield(maj, fn) == getfield(js.massaction_jump, fn) + for fn in [:scaled_rates, :reactant_stoch, :net_stoch] +) @test all(maj.param_mapper.param_idxs .== js.massaction_jump.param_mapper.param_idxs) maj1 = MassActionJump([rs1], [ns1]; param_idxs = [1]) maj2 = MassActionJump([rs2], [ns2]; param_idxs = [2]) js = JumpSet(maj1, maj2) -@test all(getfield(maj, fn) == getfield(js.massaction_jump, fn) -for fn in [:scaled_rates, :reactant_stoch, :net_stoch]) +@test all( + getfield(maj, fn) == getfield(js.massaction_jump, fn) + for fn in [:scaled_rates, :reactant_stoch, :net_stoch] +) @test all(maj.param_mapper.param_idxs .== js.massaction_jump.param_mapper.param_idxs) diff --git a/test/extended_jump_array.jl b/test/extended_jump_array.jl index 349b5bb36..1e362c37a 100644 --- a/test/extended_jump_array.jl +++ b/test/extended_jump_array.jl @@ -5,16 +5,22 @@ using StableRNGs rng = StableRNG(123) # Check that the new broadcast norm gives the same result as the old one -rand_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(rng, 5), - rand(rng, 2)) +rand_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}( + rand(rng, 5), + rand(rng, 2) +) old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) / max(DiffEqBase.recursive_length(rand_array), 1)) new_norm = DiffEqBase.ODE_DEFAULT_NORM(rand_array, 0.0) @test old_norm ≈ new_norm # Check for an ExtendedJumpArray where the types differ (Float64/Int64) -rand_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Int64}}(rand(rng, 5), - rand(rng, 1:1000, - 2)) +rand_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Int64}}( + rand(rng, 5), + rand( + rng, 1:1000, + 2 + ) +) old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) / max(DiffEqBase.recursive_length(rand_array), 1)) new_norm = DiffEqBase.ODE_DEFAULT_NORM(rand_array, 0.0) @test old_norm ≈ new_norm @@ -41,7 +47,7 @@ bc_out .= 3.14 .* bc_eja_1 + 2.7 .* bc_eja_2 # Test that mismatched arrays cannot be broadcasted bc_mismatch = ExtendedJumpArray(rand(rng, 8), rand(rng, 4)) -@test_throws DimensionMismatch bc_mismatch+bc_eja_1 +@test_throws DimensionMismatch bc_mismatch + bc_eja_1 @test_throws DimensionMismatch bc_mismatch .+ bc_eja_1 # Test that datatype mixing persists through broadcasting @@ -63,7 +69,7 @@ out_result .= bc_dtype_1 .+ bc_dtype_2 .* 2 oop_test_rate(u, p, t) = exp(t) function oop_test_affect!(integrator) integrator.u[1] += 1 - nothing + return nothing end oop_test_jump = VariableRateJump(oop_test_rate, oop_test_affect!) @@ -109,7 +115,7 @@ let function f!(du, u, p, t) du .= 0 - nothing + return nothing end u₀ = [0, 0] oprob = ODEProblem(f!, u₀, (0.0, 10.0), p) diff --git a/test/extinction_test.jl b/test/extinction_test.jl index 880254ba0..9ae2c2561 100644 --- a/test/extinction_test.jl +++ b/test/extinction_test.jl @@ -4,11 +4,11 @@ using StableRNGs rng = StableRNG(12345) reactstoch = [ - [1 => 1] + [1 => 1], ] netstoch = [ - [1 => -1] + [1 => -1], ] Nsims = 10 @@ -16,13 +16,15 @@ rates = [1.0] dg = [[1]] majump = MassActionJump(rates, reactstoch, netstoch) u0 = [100000] -dprob = DiscreteProblem(u0, (0.0, 1e5), rates) +dprob = DiscreteProblem(u0, (0.0, 1.0e5), rates) algs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator()) for n in 1:Nsims for ssa in algs - local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false), - rng = rng) + local jprob = JumpProblem( + dprob, ssa, majump, save_positions = (false, false), + rng = rng + ) local sol = solve(jprob, SSAStepper()) @test sol[1, end] == 0 @test sol.t[end] < Inf @@ -33,8 +35,10 @@ u0 = SA[10] dprob = DiscreteProblem(u0, (0.0, 100.0), rates) for ssa in algs - local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false), - rng = rng) + local jprob = JumpProblem( + dprob, ssa, majump, save_positions = (false, false), + rng = rng + ) local sol = solve(jprob, SSAStepper(), saveat = 100.0) @test sol[1, end] == 0 @test sol.t[end] < Inf @@ -45,14 +49,14 @@ Base.@kwdef mutable struct ExtinctionTest cnt::Int = 0 end function (e::ExtinctionTest)(u, t, integrator) - (e.cnt == 0) && (integrator.cb.affect!.next_jump_time == Inf) + return (e.cnt == 0) && (integrator.cb.affect!.next_jump_time == Inf) end function (e::ExtinctionTest)(integrator) (saved, savedexactly) = savevalues!(integrator, true) @test saved == true @test savedexactly == true e.cnt += 1 - nothing + return nothing end et = ExtinctionTest() cb = DiscreteCallback(et, et, save_positions = (false, false)) @@ -63,15 +67,17 @@ sol = solve(jprob, SSAStepper(), callback = cb, save_end = false) # test terminate function extinction_condition2(u, t, integrator) - u[1] == 1 + return u[1] == 1 end function extinction_affect!2(integrator) (saved, savedexactly) = savevalues!(integrator, true) terminate!(integrator) - nothing + return nothing end -cb = DiscreteCallback(extinction_condition2, extinction_affect!2, - save_positions = (false, false)) +cb = DiscreteCallback( + extinction_condition2, extinction_affect!2, + save_positions = (false, false) +) dprob = DiscreteProblem(u0, (0.0, 1000.0), rates) jprob = JumpProblem(dprob, majump; save_positions = (false, false), rng) sol = solve(jprob; callback = cb, save_end = false) diff --git a/test/fp_unknowns.jl b/test/fp_unknowns.jl index 6a8b5f6d6..d9463acdf 100644 --- a/test/fp_unknowns.jl +++ b/test/fp_unknowns.jl @@ -9,16 +9,22 @@ rng = StableRNG(12345) # 1, X --> ∅ function test(rng) # dep graphs - dg = [[1, 2, 3, 4], + dg = [ + [1, 2, 3, 4], [2, 3, 4], [2, 3, 4], - [2, 3, 4]] - vtoj = [[2, 3, 4], - [2]] - jtov = [[1], + [2, 3, 4], + ] + vtoj = [ + [2, 3, 4], + [2], + ] + jtov = [ + [1], [1, 2], [1, 2], - [1]] + [1], + ] # reaction as MassActionJump sr = [1.0, 0.5, 4.0, 1.0] @@ -34,9 +40,11 @@ function test(rng) Xmeans = zeros(length(SSAalgs)) Ymeans = zeros(length(SSAalgs)) for (j, agg) in enumerate(SSAalgs) - jprob = JumpProblem(dprob, agg, maj; save_positions = (false, false), rng, + jprob = JumpProblem( + dprob, agg, maj; save_positions = (false, false), rng, vartojumps_map = vtoj, jumptovars_map = jtov, dep_graph = dg, - scale_rates = false) + scale_rates = false + ) for i in 1:Nsims sol = solve(jprob, SSAStepper()) Xmeans[j] += sol[1, end] @@ -44,7 +52,7 @@ function test(rng) end end Xmeans ./= Nsims - Ymeans ./= Nsims + return Ymeans ./= Nsims # for i in 2:length(SSAalgs) # @test abs(Xmeans[i] - Xmeans[1]) < (.1 * Xmeans[1]) # @test abs(Ymeans[i] - Ymeans[1]) < (.1 * Ymeans[1]) diff --git a/test/functionwrappers.jl b/test/functionwrappers.jl index 2f009ead4..ceeec52d3 100644 --- a/test/functionwrappers.jl +++ b/test/functionwrappers.jl @@ -6,10 +6,12 @@ let rate(u, p, t; debug = true) = 5.0 function affect!(integrator) integrator.u[1] += 1 - nothing + return nothing end - jump = VariableRateJump(rate, affect!; urate = (u, p, t) -> 10.0, - rateinterval = (u, p, t) -> 0.1) + jump = VariableRateJump( + rate, affect!; urate = (u, p, t) -> 10.0, + rateinterval = (u, p, t) -> 0.1 + ) prob = DiscreteProblem([0.0], (0.0, 2.0), [1.0]) jprob = JumpProblem(prob, Coevolve(), jump; dep_graph = [[1]], rng) diff --git a/test/geneexpr_test.jl b/test/geneexpr_test.jl index 708e91563..af4b44054 100644 --- a/test/geneexpr_test.jl +++ b/test/geneexpr_test.jl @@ -18,7 +18,7 @@ SSAalgs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator()) Nsims = 8000 tf = 1000.0 u0 = [1, 0, 0, 0] -expected_avg = 5.926553750000000e+02 +expected_avg = 5.92655375e+2 reltol = 0.01 # average number of proteins in a simulation @@ -28,7 +28,7 @@ function runSSAs(jump_prob; use_stepper = true) sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob) Psamp[i] = sol[3, end] end - mean(Psamp) + return mean(Psamp) end function runSSAs_ode(vrjprob) @@ -61,7 +61,7 @@ reactstoch = [ [2 => 1], [3 => 1], [1 => 1, 3 => 1], - [4 => 1] + [4 => 1], ] netstoch = [ [2 => 1], @@ -69,7 +69,7 @@ netstoch = [ [2 => -1], [3 => -1], [1 => -1, 3 => -1, 4 => 1], - [1 => 1, 3 => 1, 4 => -1] + [1 => 1, 3 => 1, 4 => -1], ] spec_to_dep_jumps = [[1, 5], [2, 3], [4, 5], [6]] jump_to_dep_specs = [[2], [3], [2], [3], [1, 3, 4], [1, 3, 4]] @@ -87,9 +87,11 @@ probf = DiscreteProblem(u0f, (0.0, tf), rates) if doplot plothand = plot(reuse = false) for alg in SSAalgs - local jump_prob = JumpProblem(prob, alg, majumps, + local jump_prob = JumpProblem( + prob, alg, majumps, vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) + jumptovars_map = jump_to_dep_specs, rng = rng + ) local sol = solve(jump_prob, SSAStepper()) plot!(plothand, sol.t, sol[3, :], seriestype = :steppost) end @@ -99,33 +101,43 @@ end # test the means if dotestmean for (i, alg) in enumerate(SSAalgs) - local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false), + local jump_prob = JumpProblem( + prob, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) + jumptovars_map = jump_to_dep_specs, rng = rng + ) means = runSSAs(jump_prob) relerr = abs(means - expected_avg) / expected_avg - doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means, - ", rel err = ", relerr) + doprintmeans && println( + "Mean from method: ", typeof(alg), " is = ", means, + ", rel err = ", relerr + ) @test abs(means - expected_avg) < reltol * expected_avg means = runSSAs(jump_prob; use_stepper = false) relerr = abs(means - expected_avg) / expected_avg @test abs(means - expected_avg) < reltol * expected_avg - jump_probf = JumpProblem(probf, alg, majumps, save_positions = (false, false), + jump_probf = JumpProblem( + probf, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) + jumptovars_map = jump_to_dep_specs, rng = rng + ) means = runSSAs(jump_probf) relerr = abs(means - expected_avg) / expected_avg - doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means, - ", rel err = ", relerr) + doprintmeans && println( + "Mean from method: ", typeof(alg), " is = ", means, + ", rel err = ", relerr + ) @test abs(means - expected_avg) < reltol * expected_avg end end # no-aggregator tests -jump_prob = JumpProblem(prob, majumps; save_positions = (false, false), - vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs, rng) +jump_prob = JumpProblem( + prob, majumps; save_positions = (false, false), + vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs, rng +) @test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg @test abs(runSSAs(jump_prob; use_stepper = false) - expected_avg) < reltol * expected_avg @@ -156,23 +168,27 @@ let integ.u[1] -= 1 integ.u[3] -= 1 integ.u[4] += 1 - nothing + return nothing end function a6!(integ) integ.u[1] += 1 integ.u[3] += 1 integ.u[4] -= 1 - nothing + return nothing end - crjs = JumpSet(ConstantRateJump(r1, a1!), ConstantRateJump(r2, a2!), + crjs = JumpSet( + ConstantRateJump(r1, a1!), ConstantRateJump(r2, a2!), ConstantRateJump(r3, a3!), ConstantRateJump(r4, a4!), ConstantRateJump(r5, a5!), - ConstantRateJump(r6, a6!)) - vrjs = JumpSet(VariableRateJump(r1, a1!; save_positions = (false, false)), + ConstantRateJump(r6, a6!) + ) + vrjs = JumpSet( + VariableRateJump(r1, a1!; save_positions = (false, false)), VariableRateJump(r2, a2!, save_positions = (false, false)), VariableRateJump(r3, a3!, save_positions = (false, false)), VariableRateJump(r4, a4!, save_positions = (false, false)), VariableRateJump(r5, a5!, save_positions = (false, false)), - VariableRateJump(r6, a6!, save_positions = (false, false))) + VariableRateJump(r6, a6!, save_positions = (false, false)) + ) prob = DiscreteProblem(u0, (0.0, tf), rates) crjprob = JumpProblem(prob, crjs; save_positions = (false, false), rng) @@ -187,7 +203,8 @@ let for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) vrjprob = JumpProblem( - oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false), rng) + oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false), rng + ) vrjmean = runSSAs_ode(vrjprob) @test abs(vrjmean - crjmean) < reltol * crjmean end diff --git a/test/gpu/regular_jumps.jl b/test/gpu/regular_jumps.jl index 5e60ee2d4..f6fb0c711 100644 --- a/test/gpu/regular_jumps.jl +++ b/test/gpu/regular_jumps.jl @@ -33,12 +33,16 @@ let rj = RegularJump(regular_rate, regular_c, 3) jump_prob = JumpProblem(prob_disc, PureLeaping(), rj) - sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), - EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0) + sol = solve( + EnsembleProblem(jump_prob), SimpleTauLeaping(), + EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0 + ) mean_kernel = mean(sol.u[i][1, end] for i in 1:Nsims) - sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), - EnsembleSerial(); trajectories = Nsims, dt = 1.0) + sol = solve( + EnsembleProblem(jump_prob), SimpleTauLeaping(), + EnsembleSerial(); trajectories = Nsims, dt = 1.0 + ) mean_serial = mean(sol.u[i][1, end] for i in 1:Nsims) @test isapprox(mean_kernel, mean_serial, rtol = 0.05) @@ -74,12 +78,16 @@ let rj = RegularJump(regular_rate, regular_c, 3) jump_prob = JumpProblem(prob_disc, PureLeaping(), rj; rng = StableRNG(12345)) - sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), - EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0) + sol = solve( + EnsembleProblem(jump_prob), SimpleTauLeaping(), + EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0 + ) mean_kernel = mean(sol.u[i][end, end] for i in 1:Nsims) - sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), - EnsembleSerial(); trajectories = Nsims, dt = 1.0) + sol = solve( + EnsembleProblem(jump_prob), SimpleTauLeaping(), + EnsembleSerial(); trajectories = Nsims, dt = 1.0 + ) mean_serial = mean(sol.u[i][end, end] for i in 1:Nsims) @test isapprox(mean_kernel, mean_serial, rtol = 0.05) diff --git a/test/hawkes_test.jl b/test/hawkes_test.jl index adf5a83dc..a6d7c6b8e 100644 --- a/test/hawkes_test.jl +++ b/test/hawkes_test.jl @@ -7,7 +7,7 @@ function reset_history!(h; start_time = nothing) @inbounds for i in 1:length(h) h[i] = eltype(h)[] end - nothing + return nothing end function empirical_rate(sol) @@ -38,7 +38,8 @@ function hawkes_jump(i::Int, g, h; uselrate = true) if uselrate lrate(u, p, t) = p[1] rateinterval = ( - u, p, t) -> begin + u, p, t, + ) -> begin _lrate = lrate(u, p, t) _urate = urate(u, p, t) return _urate == _lrate ? typemax(t) : 1 / (2 * _urate) @@ -52,7 +53,7 @@ function hawkes_jump(i::Int, g, h; uselrate = true) end function affect!(integrator) push!(h[i], integrator.t) - integrator.u[i] += 1 + return integrator.u[i] += 1 end return VariableRateJump(rate, affect!; lrate, urate, rateinterval) end @@ -61,8 +62,10 @@ function hawkes_jump(u, g, h; uselrate = true) return [hawkes_jump(i, g, h; uselrate) for i in 1:length(u)] end -function hawkes_problem(p, agg::Coevolve; u = [0.0], tspan = (0.0, 50.0), - save_positions = (false, true), g = [[1]], h = [[]], uselrate = true, kwargs...) +function hawkes_problem( + p, agg::Coevolve; u = [0.0], tspan = (0.0, 50.0), + save_positions = (false, true), g = [[1]], h = [[]], uselrate = true, kwargs... + ) dprob = DiscreteProblem(u, tspan, p) jumps = hawkes_jump(u, g, h; uselrate) jprob = JumpProblem(dprob, agg, jumps...; dep_graph = g, save_positions, rng) @@ -71,11 +74,13 @@ end function f!(du, u, p, t) du .= 0 - nothing + return nothing end -function hawkes_problem(p, agg; u = [0.0], tspan = (0.0, 50.0), - save_positions = (false, true), g = [[1]], h = [[]], vr_aggregator = VR_FRM(), kwargs...) +function hawkes_problem( + p, agg; u = [0.0], tspan = (0.0, 50.0), + save_positions = (false, true), g = [[1]], h = [[]], vr_aggregator = VR_FRM(), kwargs... + ) oprob = ODEProblem(f!, u, tspan, p) jumps = hawkes_jump(u, g, h) jprob = JumpProblem(oprob, agg, jumps...; vr_aggregator, save_positions, rng) diff --git a/test/jprob_symbol_indexing.jl b/test/jprob_symbol_indexing.jl index 4cad845cb..9c1be8d2c 100644 --- a/test/jprob_symbol_indexing.jl +++ b/test/jprob_symbol_indexing.jl @@ -7,8 +7,10 @@ affect2!(integ) = (integ.u[2] += 1) crj1 = ConstantRateJump(rate1, affect1!) crj2 = ConstantRateJump(rate2, affect2!) maj = MassActionJump([[1 => 1], [1 => 1]], [[1 => -1], [1 => -1]]; param_idxs = [1, 2]) -g = DiscreteFunction((du, u, p, t) -> nothing; - sys = SymbolicIndexingInterface.SymbolCache([:a, :b], [:p1, :p2], :t)) +g = DiscreteFunction( + (du, u, p, t) -> nothing; + sys = SymbolicIndexingInterface.SymbolCache([:a, :b], [:p1, :p2], :t) +) dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0]) jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj) diff --git a/test/linearreaction_test.jl b/test/linearreaction_test.jl index 657982a6b..17412abcc 100644 --- a/test/linearreaction_test.jl +++ b/test/linearreaction_test.jl @@ -32,7 +32,7 @@ function runSSAs(jump_prob) sol = solve(jump_prob, SSAStepper()) Asamp[i] = sol[1, end] end - mean(Asamp) + return mean(Asamp) end # uses constant jumps as a tuple within a JumpSet @@ -43,7 +43,7 @@ function A_to_B_tuple(N, method) ratefunc = (u, p, t) -> rates[i] * u[1] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end push!(jumpvec, ConstantRateJump(ratefunc, affect!)) end @@ -52,10 +52,12 @@ function A_to_B_tuple(N, method) jumps = ((jump for jump in jumpvec)...,) jset = JumpSet((), jumps, nothing, nothing) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, - namedpars...) + jump_prob = JumpProblem( + prob, method, jset; save_positions = (false, false), rng, + namedpars... + ) - jump_prob + return jump_prob end # uses constant jumps as a vector within a JumpSet @@ -66,7 +68,7 @@ function A_to_B_vec(N, method) ratefunc = (u, p, t) -> rates[i] * u[1] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end push!(jumps, ConstantRateJump(ratefunc, affect!)) end @@ -74,10 +76,12 @@ function A_to_B_vec(N, method) # convert jumpvec to tuple to send to JumpProblem... jset = JumpSet((), jumps, nothing, nothing) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, - namedpars...) + jump_prob = JumpProblem( + prob, method, jset; save_positions = (false, false), rng, + namedpars... + ) - jump_prob + return jump_prob end # uses a single mass action jump to represent all reactions @@ -92,10 +96,12 @@ function A_to_B_ma(N, method) majumps = MassActionJump(rates, reactstoch, netstoch) jset = JumpSet((), (), nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, - namedpars...) + jump_prob = JumpProblem( + prob, method, jset; save_positions = (false, false), rng, + namedpars... + ) - jump_prob + return jump_prob end # uses one mass action jump to represent half the reactions and a vector @@ -118,7 +124,7 @@ function A_to_B_hybrid(N, method) ratefunc = (u, p, t) -> rates[i] * u[1] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end push!(jumps, ConstantRateJump(ratefunc, affect!)) end @@ -126,10 +132,12 @@ function A_to_B_hybrid(N, method) majumps = MassActionJump(rates[1:switchidx], reactstoch, netstoch) jset = JumpSet((), jumps, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, - namedpars...) + jump_prob = JumpProblem( + prob, method, jset; save_positions = (false, false), rng, + namedpars... + ) - jump_prob + return jump_prob end # uses a mass action jump to represent half the reactions and a vector @@ -152,7 +160,7 @@ function A_to_B_hybrid_nojset(N, method) ratefunc = (u, p, t) -> rates[i] * u[1] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end push!(jumpvec, ConstantRateJump(ratefunc, affect!)) end @@ -160,10 +168,12 @@ function A_to_B_hybrid_nojset(N, method) majumps = MassActionJump(rates[1:switchidx], reactstoch, netstoch) jumps = (constjumps..., majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jumps...; save_positions = (false, false), - rng, namedpars...) + jump_prob = JumpProblem( + prob, method, jumps...; save_positions = (false, false), + rng, namedpars... + ) - jump_prob + return jump_prob end # uses a vector of mass action jumps of vectors to represent half the reactions and a vector @@ -184,16 +194,18 @@ function A_to_B_hybrid_vecs(N, method) ratefunc = (u, p, t) -> rates[i] * u[1] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end push!(jumpvec, ConstantRateJump(ratefunc, affect!)) end jset = JumpSet((), jumpvec, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, - namedpars...) + jump_prob = JumpProblem( + prob, method, jset; save_positions = (false, false), rng, + namedpars... + ) - jump_prob + return jump_prob end # uses a vector of scalar mass action jumps to represent half the reactions and a vector @@ -214,16 +226,18 @@ function A_to_B_hybrid_vecs_scalars(N, method) ratefunc = (u, p, t) -> rates[i] * u[1] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end push!(jumpvec, ConstantRateJump(ratefunc, affect!)) end jset = JumpSet((), jumpvec, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, - namedpars...) + jump_prob = JumpProblem( + prob, method, jset; save_positions = (false, false), rng, + namedpars... + ) - jump_prob + return jump_prob end # uses a vector of scalar mass action jumps to represent half the reactions and a vector @@ -244,17 +258,19 @@ function A_to_B_hybrid_tups_scalars(N, method) ratefunc = (u, p, t) -> rates[i] * u[1] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end push!(jumpvec, ConstantRateJump(ratefunc, affect!)) end jumps = ((maj for maj in majumpsv)..., (jump for jump in jumpvec)...) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jumps...; save_positions = (false, false), - rng, namedpars...) + jump_prob = JumpProblem( + prob, method, jumps...; save_positions = (false, false), + rng, namedpars... + ) - jump_prob + return jump_prob end # uses a mass action jump to represent half the reactions and a tuple @@ -275,22 +291,26 @@ function A_to_B_hybrid_tups(N, method) ratefunc = (u, p, t) -> rates[i] * u[1] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end push!(jumpvec, ConstantRateJump(ratefunc, affect!)) end jumps = ((jump for jump in jumpvec)...,) jset = JumpSet((), jumps, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, - namedpars...) + jump_prob = JumpProblem( + prob, method, jset; save_positions = (false, false), rng, + namedpars... + ) - jump_prob + return jump_prob end -jump_prob_gens = [A_to_B_tuple, A_to_B_vec, A_to_B_ma, A_to_B_hybrid, A_to_B_hybrid_nojset, +jump_prob_gens = [ + A_to_B_tuple, A_to_B_vec, A_to_B_ma, A_to_B_hybrid, A_to_B_hybrid_nojset, A_to_B_hybrid_vecs, A_to_B_hybrid_vecs_scalars, A_to_B_hybrid_tups, - A_to_B_hybrid_tups_scalars] + A_to_B_hybrid_tups_scalars, +] #jump_prob_gens = [A_to_B_tuple, A_to_B_ma, A_to_B_hybrid, A_to_B_hybrid_vecs, A_to_B_hybrid_vecs_scalars,A_to_B_hybrid_tups_scalars] for method in SSAalgs @@ -302,8 +322,10 @@ for method in SSAalgs local jump_prob = jump_prob_gen(Nrxs, method) meanval = runSSAs(jump_prob) if doprint - println("Method: ", method, ", Jump input types: ", jump_prob_gen, - ", sample mean = ", meanval, ", actual mean = ", exactmeanval) + println( + "Method: ", method, ", Jump input types: ", jump_prob_gen, + ", sample mean = ", meanval, ", actual mean = ", exactmeanval + ) end @test abs(meanval - exactmeanval) < 1.0 end @@ -319,8 +341,10 @@ for method in SSAalgs local jump_prob = jump_prob_gen(Nrxs, method) meanval = runSSAs(jump_prob) if doprint - println("Method: ", method, ", Jump input types: ", jump_prob_gen, - ", sample mean = ", meanval, ", actual mean = ", exactmeanval) + println( + "Method: ", method, ", Jump input types: ", jump_prob_gen, + ", sample mean = ", meanval, ", actual mean = ", exactmeanval + ) end @test abs(meanval - exactmeanval) < 1.0 end diff --git a/test/longtimes_test.jl b/test/longtimes_test.jl index 7c787b4f3..28608a612 100644 --- a/test/longtimes_test.jl +++ b/test/longtimes_test.jl @@ -8,7 +8,7 @@ ns = [[1 => 1], [1 => -1], [1 => 1]] rs = [[1 => 1], [1 => 1], Pair{Int64, Int64}[]] maj = MassActionJump(p, rs, ns) u0 = [5] -tspan = (0.0, 2e6) +tspan = (0.0, 2.0e6) dt = tspan[2] / 1000 dprob = DiscreteProblem(u0, tspan, p) jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng) diff --git a/test/monte_carlo_test.jl b/test/monte_carlo_test.jl index c197df9d5..b1764d35e 100644 --- a/test/monte_carlo_test.jl +++ b/test/monte_carlo_test.jl @@ -10,25 +10,33 @@ affect! = integrator -> (integrator.u[1] = integrator.u[1] / 2) jump = VariableRateJump(rate, affect!, save_positions = (false, true)) jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM(), rng) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve( + monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false +) @test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2] jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_Direct(), rng) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve( + monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false +) @test allunique(sol.u[1].t) jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_DirectFW(), rng) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve( + monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false +) @test allunique(sol.u[1].t) jump = ConstantRateJump(rate, affect!) jump_prob = JumpProblem(prob, Direct(), jump; save_positions = (true, false), rng) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve( + monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false +) @test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2] diff --git a/test/qa.jl b/test/qa.jl index 29a637087..3fe536e28 100644 --- a/test/qa.jl +++ b/test/qa.jl @@ -5,26 +5,28 @@ using Test @testset "QA Tests" begin @testset "Aqua tests" begin - Aqua.test_all(JumpProcesses; - ambiguities = false, # TODO: fix ambiguities and enable - deps_compat = true, - piracies = false, # We define default solvers for AbstractJumpProblem - unbound_args = true, - undefined_exports = true, - project_extras = true, - stale_deps = true, - persistent_tasks = false) # disabled due to false positives + Aqua.test_all( + JumpProcesses; + ambiguities = false, # TODO: fix ambiguities and enable + deps_compat = true, + piracies = false, # We define default solvers for AbstractJumpProblem + unbound_args = true, + undefined_exports = true, + project_extras = true, + stale_deps = true, + persistent_tasks = false + ) # disabled due to false positives end @testset "ExplicitImports tests" begin # Check that we're using explicit imports @test check_no_implicit_imports(JumpProcesses) === nothing - + # Check for stale explicit imports (imports that are not used) @test check_no_stale_explicit_imports(JumpProcesses) === nothing - + # Allow some flexibility for non-public imports during transition # This can be made stricter once all non-public API usage is resolved @test_nowarn check_all_explicit_imports_via_owners(JumpProcesses) end -end \ No newline at end of file +end diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 3ccc67404..7819ac4e0 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -5,7 +5,7 @@ rng = StableRNG(12345) function regular_rate(out, u, p, t) out[1] = (0.1 / 1000.0) * u[1] * u[2] - out[2] = 0.01u[2] + return out[2] = 0.01u[2] end const dc = zeros(3, 2) @@ -15,7 +15,7 @@ dc[2, 2] = -1 dc[3, 2] = 1 function regular_c(du, u, p, t, counts, mark) - mul!(du, dc, counts) + return mul!(du, dc, counts) end rj = RegularJump(regular_rate, regular_c, 2) @@ -31,63 +31,65 @@ sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0) tspan = (0.0, 10.0) p = [0.1, 0.2] prob = DiscreteProblem(u0, tspan, p) - + # Create MassActionJump reactant_stoich = [[1 => 1], [1 => 2]] net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]] rates = [0.1, 0.05] maj = MassActionJump(rates, reactant_stoich, net_stoich) - + # Test PureLeaping JumpProblem creation jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj); rng) @test jp_pure.aggregator isa PureLeaping @test jp_pure.discrete_jump_aggregation === nothing @test jp_pure.massaction_jump !== nothing @test length(jp_pure.jump_callback.discrete_callbacks) == 0 - + # Test with ConstantRateJump rate(u, p, t) = p[1] * u[1] affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) crj = ConstantRateJump(rate, affect!) - + jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj); rng) @test jp_pure_crj.aggregator isa PureLeaping @test jp_pure_crj.discrete_jump_aggregation === nothing @test length(jp_pure_crj.constant_jumps) == 1 - + # Test with VariableRateJump vrate(u, p, t) = t * p[1] * u[1] vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) vrj = VariableRateJump(vrate, vaffect!) - + jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj); rng) @test jp_pure_vrj.aggregator isa PureLeaping @test jp_pure_vrj.discrete_jump_aggregation === nothing @test length(jp_pure_vrj.variable_jumps) == 1 - + # Test with RegularJump function rj_rate(out, u, p, t) out[1] = p[1] * u[1] end - + rj_dc = zeros(3, 1) rj_dc[1, 1] = -1 rj_dc[3, 1] = 1 - + function rj_c(du, u, p, t, counts, mark) mul!(du, rj_dc, counts) end - + regj = RegularJump(rj_rate, rj_c, 1) - + jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj); rng) @test jp_pure_regj.aggregator isa PureLeaping @test jp_pure_regj.discrete_jump_aggregation === nothing @test jp_pure_regj.regular_jump !== nothing - + # Test mixed jump types - mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), - variable_jumps = (vrj,), regular_jumps = regj) + mixed_jumps = JumpSet(; + massaction_jumps = maj, constant_jumps = (crj,), + variable_jumps = (vrj,), regular_jumps = regj + ) jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps; rng) @test jp_pure_mixed.aggregator isa PureLeaping @test jp_pure_mixed.discrete_jump_aggregation === nothing @@ -95,18 +97,22 @@ sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0) @test length(jp_pure_mixed.constant_jumps) == 1 @test length(jp_pure_mixed.variable_jumps) == 1 @test jp_pure_mixed.regular_jump !== nothing - + # Test spatial system error spatial_sys = CartesianGrid((2, 2)) hopping_consts = [1.0] - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, - spatial_system = spatial_sys) - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, - hopping_constants = hopping_consts) - + @test_throws ErrorException JumpProblem( + prob, PureLeaping(), JumpSet(maj); rng, + spatial_system = spatial_sys + ) + @test_throws ErrorException JumpProblem( + prob, PureLeaping(), JumpSet(maj); rng, + hopping_constants = hopping_consts + ) + # Test MassActionJump with parameter mapping maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng) - scaled_rates = [p[1], p[2]/2] + scaled_rates = [p[1], p[2] / 2] @test jp_params.massaction_jump.scaled_rates == scaled_rates end diff --git a/test/remake_test.jl b/test/remake_test.jl index 91ab4daba..c439234b4 100644 --- a/test/remake_test.jl +++ b/test/remake_test.jl @@ -5,14 +5,14 @@ rng = StableRNG(12345) rate = (u, p, t) -> p[1] * u[1] * u[2] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end jump = ConstantRateJump(rate, affect!) rate = (u, p, t) -> p[2] * u[2] affect! = function (integrator) integrator.u[2] -= 1 - integrator.u[3] += 1 + return integrator.u[3] += 1 end jump2 = ConstantRateJump(rate, affect!) @@ -21,8 +21,10 @@ p = (0.1 / 1000, 0.01) tspan = (0.0, 2500.0) dprob = DiscreteProblem(u0, tspan, p) -jprob = JumpProblem(dprob, Direct(), jump, jump2, save_positions = (false, false), - rng = rng) +jprob = JumpProblem( + dprob, Direct(), jump, jump2, save_positions = (false, false), + rng = rng +) sol = solve(jprob, SSAStepper()) @test sol[3, end] == 1000 @@ -65,8 +67,8 @@ sol3 = solve(jprob3, SSAStepper()) ################# # test error handling -@test_throws ErrorException jprob4=remake(jprob, prob = dprob2, p = p2) -@test_throws ErrorException jprob5=remake(jprob, aggregator = RSSA()) +@test_throws ErrorException jprob4 = remake(jprob, prob = dprob2, p = p2) +@test_throws ErrorException jprob5 = remake(jprob, aggregator = RSSA()) # test for #446 let @@ -97,7 +99,7 @@ let jprob3 = remake(jprob2; u0) sol = solve(jprob3, Tsit5()) @test all(==(0.0), sol[1, :]) - @test_throws ErrorException jprob4=remake(jprob, u0 = 1) + @test_throws ErrorException jprob4 = remake(jprob, u0 = 1) end # tests when changing u0 via a passed in prob @@ -118,7 +120,7 @@ let @test all(==(0.0), sol[1, :]) u0 = [4.0] prob2 = remake(jprob.prob; u0) - @test_throws ErrorException jprob2=remake(jprob; prob = prob2) + @test_throws ErrorException jprob2 = remake(jprob; prob = prob2) u0eja = JumpProcesses.remake_extended_u0(jprob.prob, u0, rng) prob3 = remake(jprob.prob; u0 = u0eja) jprob3 = remake(jprob; prob = prob3) diff --git a/test/reversible_binding.jl b/test/reversible_binding.jl index f872d92a4..cf307ba7b 100644 --- a/test/reversible_binding.jl +++ b/test/reversible_binding.jl @@ -3,16 +3,16 @@ using Test, LinearAlgebra using StableRNGs rng = StableRNG(12345) -Nsims = 1e4 +Nsims = 1.0e4 # ABC model A + B <--> C reactstoch = [ [1 => 1, 2 => 1], - [3 => 1] + [3 => 1], ] netstoch = [ [1 => -1, 2 => -1, 3 => 1], - [1 => 1, 2 => 1, 3 => -1] + [1 => 1, 2 => 1, 3 => -1], ] rates = [0.1, 1.0] u0 = [500, 500, 0] @@ -27,7 +27,7 @@ function getmean(jprob, Nsims) Amean += sol[1, end] end Amean /= Nsims - Amean + return Amean end function mastereqmean(u, rates) @@ -41,15 +41,17 @@ function mastereqmean(u, rates) P_a = nullspace(L) P_a ./= sum(P_a) P_a .= abs.(P_a) - sum((a - 1) * p for (a, p) in enumerate(P_a)) + return sum((a - 1) * p for (a, p) in enumerate(P_a)) end mastereq_mean = mastereqmean(u0, rates) algs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator()) relative_tolerance = 0.01 for alg in algs - local jprob = JumpProblem(prob, alg, majumps, save_positions = (false, false), - rng = rng) + local jprob = JumpProblem( + prob, alg, majumps, save_positions = (false, false), + rng = rng + ) local Amean = getmean(jprob, Nsims) @test abs(Amean - mastereq_mean) / mastereq_mean < relative_tolerance end diff --git a/test/save_positions.jl b/test/save_positions.jl index 13413b5b2..9160d4656 100644 --- a/test/save_positions.jl +++ b/test/save_positions.jl @@ -12,21 +12,28 @@ let # set the rate to 0, so that no jump ever occurs; but urate is positive so # Coevolve will consider many candidates before the end of the simmulation. # None of these points should be saved. - jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1; - urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0) - jumpproblem = JumpProblem(dprob, alg, jump; dep_graph = [[1]], - save_positions = (false, true), rng) + jump = VariableRateJump( + (u, p, t) -> 0, (integrator) -> integrator.u[1] += 1; + urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0 + ) + jumpproblem = JumpProblem( + dprob, alg, jump; dep_graph = [[1]], + save_positions = (false, true), rng + ) sol = solve(jumpproblem, SSAStepper()) @test sol.t == [0.0, 30.0] oprob = ODEProblem((du, u, p, t) -> 0, u0, tspan) - jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1; - urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0) + jump = VariableRateJump( + (u, p, t) -> 0, (integrator) -> integrator.u[1] += 1; + urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0 + ) for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jumpproblem = JumpProblem( oprob, alg, jump; vr_aggregator = vr_agg, dep_graph = [[1]], - save_positions = (false, true), rng) + save_positions = (false, true), rng + ) sol = solve(jumpproblem, Tsit5(); save_everystep = false) @test sol.t == [0.0, 30.0] end diff --git a/test/saveat_regression.jl b/test/saveat_regression.jl index 03665d76a..6bd5a57e8 100644 --- a/test/saveat_regression.jl +++ b/test/saveat_regression.jl @@ -14,15 +14,17 @@ jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = ts = collect(0:0.002:tspan[2]) NA = zeros(length(ts)) Nsims = 10_000 -sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(), saveat = ts, - trajectories = Nsims) +sol = JumpProcesses.solve( + EnsembleProblem(jprob), SSAStepper(), saveat = ts, + trajectories = Nsims +) for i in 1:length(sol) NA .+= sol.u[i][1, :] end for i in 1:length(ts) - @test NA[i] / Nsims≈exp(-10 * ts[i]) rtol=1e-1 + @test NA[i] / Nsims ≈ exp(-10 * ts[i]) rtol = 1.0e-1 end NA = zeros(length(ts)) @@ -38,5 +40,5 @@ for i in 1:Nsims end for i in 1:length(ts) - @test NA[i] / Nsims≈exp(-10 * ts[i]) rtol=1e-1 + @test NA[i] / Nsims ≈ exp(-10 * ts[i]) rtol = 1.0e-1 end diff --git a/test/sir_model.jl b/test/sir_model.jl index e8cea455e..34c11a63f 100644 --- a/test/sir_model.jl +++ b/test/sir_model.jl @@ -6,14 +6,14 @@ rng = StableRNG(12345) rate = (u, p, t) -> (0.1 / 1000.0) * u[1] * u[2] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end jump = ConstantRateJump(rate, affect!) rate = (u, p, t) -> 0.01u[2] affect! = function (integrator) integrator.u[2] -= 1 - integrator.u[3] += 1 + return integrator.u[3] += 1 end jump2 = ConstantRateJump(rate, affect!) @@ -24,7 +24,7 @@ integrator = init(jump_prob, FunctionMap()) condition(u, t, integrator) = t == 100 function purge_affect!(integrator) integrator.u[2] ÷= 10 - reset_aggregated_jumps!(integrator) + return reset_aggregated_jumps!(integrator) end cb = DiscreteCallback(condition, purge_affect!, save_positions = (false, false)) sol = solve(jump_prob, FunctionMap(), callback = cb, tstops = [100]) @@ -34,11 +34,15 @@ sol = solve(jump_prob, SSAStepper(), callback = cb, tstops = [100]) let # here we order S = 1, I = 2, and R = 3 # substrate stoichiometry: - substoich = [[1 => 1, 2 => 1], # 1*S + 1*I - [2 => 1]] # 1*I + substoich = [ + [1 => 1, 2 => 1], # 1*S + 1*I + [2 => 1], + ] # 1*I # net change by each jump type - netstoich = [[1 => -1, 2 => 1], # S -> S-1, I -> I+1 - [2 => -1, 3 => 1]] # I -> I-1, R -> R+1 + netstoich = [ + [1 => -1, 2 => 1], # S -> S-1, I -> I+1 + [2 => -1, 3 => 1], + ] # I -> I-1, R -> R+1 # rate constants for each jump p = (0.1 / 1000, 0.01) diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 8d7230a74..359e83391 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -44,23 +44,35 @@ function get_mean_end_state(jump_prob, Nsims) sol = solve(jump_prob, SSAStepper()) end_state .+= sol.u[end] end - end_state / Nsims + return end_state / Nsims end # testing grids = [CartesianGridRej(dims), Graphs.grid(dims)] -jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, - hopping_constants = hopping_constants, - spatial_system = grid, - save_positions = (false, false), rng = rng) - for grid in grids] -push!(jump_problems, - JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) +jump_problems = JumpProblem[ + JumpProblem( + prob, NSM(), majumps, + hopping_constants = hopping_constants, + spatial_system = grid, + save_positions = (false, false), rng = rng + ) + for grid in grids +] +push!( + jump_problems, + JumpProblem( + prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false), rng = rng + ) +) # setup flattenned jump prob -push!(jump_problems, - JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) +push!( + jump_problems, + JumpProblem( + prob, NRM(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false), rng = rng + ) +) # test for spatial_jump_prob in jump_problems solution = solve(spatial_jump_prob, SSAStepper()) diff --git a/test/spatial/bracketing.jl b/test/spatial/bracketing.jl index 31c1e23bf..7264d3fab 100644 --- a/test/spatial/bracketing.jl +++ b/test/spatial/bracketing.jl @@ -16,8 +16,10 @@ site_rates = JP.LowHigh(zeros(n), zeros(n)) majump_rates = [0.1] # death at rate 0.1 reactstoch = [[1 => 1]] netstoch = [[1 => -1]] -majump = MassActionJump(majump_rates, reactstoch, - netstoch) +majump = MassActionJump( + majump_rates, reactstoch, + netstoch +) rx_rates = JP.LowHigh(JP.RxRates(n, majump)) # set up hop rates @@ -40,18 +42,18 @@ for site in 1:num_sites(spatial_system) end # test species brackets -@test u_low_high.low[1, 1]≈u[1, 1] * (1 - fluctuation_rate) atol=1 -@test u_low_high.high[1, 1]≈u[1, 1] * (1 + fluctuation_rate) atol=1 +@test u_low_high.low[1, 1] ≈ u[1, 1] * (1 - fluctuation_rate) atol = 1 +@test u_low_high.high[1, 1] ≈ u[1, 1] * (1 + fluctuation_rate) atol = 1 # test site rate brackets site = 1 rx = 1 species = 1 @test JP.total_site_rx_rate(rx_rates.low, site) == - majump_rates[rx] * u_low_high.low[species, site] + majump_rates[rx] * u_low_high.low[species, site] @test JP.total_site_rx_rate(rx_rates.high, site) == - majump_rates[rx] * u_low_high.high[species, site] + majump_rates[rx] * u_low_high.high[species, site] @test JP.total_site_hop_rate(hop_rates.low, site) == - hop_constants[site] * u_low_high.low[species, site] + hop_constants[site] * u_low_high.low[species, site] @test JP.total_site_hop_rate(hop_rates.high, site) == - hop_constants[site] * u_low_high.high[species, site] + hop_constants[site] * u_low_high.high[species, site] diff --git a/test/spatial/diffusion.jl b/test/spatial/diffusion.jl index c014b5c52..ea8703426 100644 --- a/test/spatial/diffusion.jl +++ b/test/spatial/diffusion.jl @@ -10,7 +10,7 @@ function get_mean_sol(jump_prob, Nsims, saveat) for i in 1:(Nsims - 1) sol += solve(jump_prob, SSAStepper(), saveat = saveat).u end - sol / Nsims + return sol / Nsims end # assume sites are labeled from 1 to num_sites(spatial_system) @@ -24,7 +24,7 @@ function discrete_laplacian_from_spatial_system(spatial_system, hopping_rate) end end laplacian .*= hopping_rate - laplacian + return laplacian end # problem setup @@ -63,43 +63,69 @@ times = 0.0:(tf / num_time_points):tf algs = [NSM(), DirectCRDirect()] grids = [CartesianGridRej(dims), Graphs.grid(dims)] -jump_problems = JumpProblem[JumpProblem(prob, algs[2], majumps, - hopping_constants = hopping_constants, - spatial_system = grid, - save_positions = (false, false), rng = rng) - for grid in grids] +jump_problems = JumpProblem[ + JumpProblem( + prob, algs[2], majumps, + hopping_constants = hopping_constants, + spatial_system = grid, + save_positions = (false, false), rng = rng + ) + for grid in grids +] sizehint!(jump_problems, 15 + length(jump_problems)) # flattenned jump prob -push!(jump_problems, - JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) +push!( + jump_problems, + JumpProblem( + prob, NRM(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false), rng = rng + ) +) # hop rates of form D_s hop_constants = [hopping_rate] for alg in algs - push!(jump_problems, - JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) - push!(jump_problems, - JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + push!( + jump_problems, + JumpProblem( + prob, alg, majumps, hopping_constants = hop_constants, + spatial_system = grids[1], save_positions = (false, false), rng = rng + ) + ) + push!( + jump_problems, + JumpProblem( + prob, alg, majumps, hopping_constants = hop_constants, + spatial_system = grids[end], save_positions = (false, false), rng = rng + ) + ) end # hop rates of form L_{s,i,j} hop_constants = Matrix{Vector{Float64}}(undef, size(hopping_constants)) for ci in CartesianIndices(hop_constants) (species, site) = Tuple(ci) - hop_constants[ci] = repeat([hopping_constants[species, site]], - outdegree(grids[1], site)) + hop_constants[ci] = repeat( + [hopping_constants[species, site]], + outdegree(grids[1], site) + ) end for alg in algs - push!(jump_problems, - JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) - push!(jump_problems, - JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + push!( + jump_problems, + JumpProblem( + prob, alg, majumps, hopping_constants = hop_constants, + spatial_system = grids[1], save_positions = (false, false), rng = rng + ) + ) + push!( + jump_problems, + JumpProblem( + prob, alg, majumps, hopping_constants = hop_constants, + spatial_system = grids[end], save_positions = (false, false), rng = rng + ) + ) end # hop rates of form D_s * L_{i,j} @@ -109,14 +135,22 @@ for site in 1:num_nodes site_hop_constants[site] = repeat([1.0], JumpProcesses.outdegree(grids[1], site)) end for alg in algs - push!(jump_problems, - JumpProblem(prob, alg, majumps, + push!( + jump_problems, + JumpProblem( + prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[1], save_positions = (false, false), rng = rng)) - push!(jump_problems, - JumpProblem(prob, alg, majumps, + spatial_system = grids[1], save_positions = (false, false), rng = rng + ) + ) + push!( + jump_problems, + JumpProblem( + prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false), rng = rng + ) + ) end # hop rates of form D_{s,i} * L_{i,j} @@ -126,14 +160,22 @@ for site in 1:num_nodes site_hop_constants[site] = repeat([1.0], JumpProcesses.outdegree(grids[1], site)) end for alg in algs - push!(jump_problems, - JumpProblem(prob, alg, majumps, + push!( + jump_problems, + JumpProblem( + prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[1], save_positions = (false, false), rng = rng)) - push!(jump_problems, - JumpProblem(prob, alg, majumps, + spatial_system = grids[1], save_positions = (false, false), rng = rng + ) + ) + push!( + jump_problems, + JumpProblem( + prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false), rng = rng + ) + ) end # testing @@ -142,7 +184,7 @@ for (j, spatial_jump_prob) in enumerate(jump_problems) for (i, t) in enumerate(times) local diff = analytic_solution(t) - reshape(mean_sol[i], num_nodes, 1) @test abs(sum(diff[1:center_node]) / sum(analytic_solution(t)[1:center_node])) < - rel_tol + rel_tol end end @@ -164,8 +206,10 @@ starting_state = 25 * ones(Int, length(u0), num_nodes) tspan = (0.0, 10.0) prob = DiscreteProblem(starting_state, tspan) -jp = JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, - spatial_system = grid, save_positions = (false, false), rng = rng) +jp = JumpProblem( + prob, NSM(), majumps, hopping_constants = hopping_constants, + spatial_system = grid, save_positions = (false, false), rng = rng +) sol = solve(jp, SSAStepper()) @test sol.u[end][1, 1] == sum(sol.u[end]) diff --git a/test/spatial/hop_rates.jl b/test/spatial/hop_rates.jl index 2c1d93684..f53335403 100644 --- a/test/spatial/hop_rates.jl +++ b/test/spatial/hop_rates.jl @@ -13,6 +13,7 @@ function test_reset(hop_rates, num_nodes) for site in 1:num_nodes @test JP.total_site_hop_rate(hop_rates, site) == 0.0 end + return end function normalized(distribution::Dict) @@ -26,8 +27,10 @@ end normalized(distribution) = distribution / sum(distribution) -function statistical_test(hop_rates, spec_propensities, target_propensities::Dict, - num_species, u, site, g, rng, rel_tol) +function statistical_test( + hop_rates, spec_propensities, target_propensities::Dict, + num_species, u, site, g, rng, rel_tol + ) spec_probs = normalized(spec_propensities) target_probs = normalized(target_propensities) JP.update_hop_rates!(hop_rates, 1:num_species, u, site, g) @@ -45,6 +48,7 @@ function statistical_test(hop_rates, spec_propensities, target_propensities::Dic for target in JP.neighbors(g, site) @test abs(site_dict[target]) / num_samples - target_probs[target] < rel_tol end + return end io = IOBuffer() @@ -68,8 +72,10 @@ for site in 1:num_nodes for (i, target) in enumerate(JP.neighbors(g, site)) target_propensities[target] = 1.0 end - statistical_test(hop_rates, spec_propensities, target_propensities, num_species, u, - site, g, rng, rel_tol) + statistical_test( + hop_rates, spec_propensities, target_propensities, num_species, u, + site, g, rng, rel_tol + ) end test_reset(hop_rates, num_nodes) @@ -84,8 +90,10 @@ for site in 1:num_nodes for (i, target) in enumerate(JP.neighbors(g, site)) target_propensities[target] = 1.0 end - statistical_test(hop_rates, spec_propensities, target_propensities, num_species, u, - site, g, rng, rel_tol) + statistical_test( + hop_rates, spec_propensities, target_propensities, num_species, u, + site, g, rng, rel_tol + ) end test_reset(hop_rates, num_nodes) @@ -97,7 +105,7 @@ hop_constants[1, :] = [ [1.0, 2.0], [1.0, 2.0], [1.0, 2.0, 4.0], - [1.0, 2.0] + [1.0, 2.0], ] hop_constants[2, :] = [ [3.0, 12.0], @@ -105,11 +113,11 @@ hop_constants[2, :] = [ [3.0, 6.0], [3.0, 6.0], [3.0, 6.0, 12.0], - [3.0, 6.0] + [3.0, 6.0], ] hop_rates_structs = [ JP.HopRatesGraphDsij(hop_constants), - JP.HopRates(hop_constants, g) + JP.HopRates(hop_constants, g), ] @test hop_rates_structs[2] isa JP.HopRatesGridDsij for hop_rates in hop_rates_structs @@ -118,8 +126,12 @@ for hop_rates in hop_rates_structs spec_propensities = [sum(hop_constants[species, site]) for species in 1:num_species] target_propensities = Dict{Int, Float64}() for (i, target) in enumerate(JP.neighbors(g, site)) - target_propensities[target] = sum([hop_constants[species, site][i] - for species in 1:num_species]) + target_propensities[target] = sum( + [ + hop_constants[species, site][i] + for species in 1:num_species + ] + ) end statistical_test(hop_rates, spec_propensities, target_propensities, num_species, u, site, g, rng, rel_tol) end @@ -134,23 +146,29 @@ site_hop_constants = [ [1.0, 2.0], [1.0, 2.0], [1.0, 2.0, 4.0], - [1.0, 2.0] + [1.0, 2.0], ] #[site][target_site] hop_rates_structs = [ JP.HopRatesGraphDsLij(species_hop_constants, site_hop_constants), - JP.HopRates((species_hop_constants => site_hop_constants), g) + JP.HopRates((species_hop_constants => site_hop_constants), g), ] @test hop_rates_structs[2] isa JP.HopRatesGridDsLij for hop_rates in hop_rates_structs show(io, "text/plain", hop_rates) for site in 1:num_nodes - spec_propensities = [species_hop_constants[species] * sum(site_hop_constants[site]) - for species in 1:num_species] + spec_propensities = [ + species_hop_constants[species] * sum(site_hop_constants[site]) + for species in 1:num_species + ] target_propensities = Dict{Int, Float64}() for (i, target) in enumerate(JP.neighbors(g, site)) - target_propensities[target] = sum([species_hop_constants[species] * - site_hop_constants[site][i] - for species in 1:num_species]) + target_propensities[target] = sum( + [ + species_hop_constants[species] * + site_hop_constants[site][i] + for species in 1:num_species + ] + ) end statistical_test(hop_rates, spec_propensities, target_propensities, num_species, u, site, g, rng, rel_tol) end @@ -165,11 +183,11 @@ site_hop_constants = [ [1.0, 2.0], [1.0, 2.0], [1.0, 2.0, 4.0], - [1.0, 2.0] + [1.0, 2.0], ] #[site][target_site] hop_rates_structs = [ JP.HopRatesGraphDsiLij(species_hop_constants, site_hop_constants), - JP.HopRates((species_hop_constants => site_hop_constants), g) + JP.HopRates((species_hop_constants => site_hop_constants), g), ] @test hop_rates_structs[2] isa JP.HopRatesGridDsiLij for hop_rates in hop_rates_structs @@ -178,9 +196,13 @@ for hop_rates in hop_rates_structs spec_propensities = [species_hop_constants[species, site] * sum(site_hop_constants[site]) for species in 1:num_species] target_propensities = Dict{Int, Float64}() for (i, target) in enumerate(JP.neighbors(g, site)) - target_propensities[target] = sum([species_hop_constants[species, site] * - site_hop_constants[site][i] - for species in 1:num_species]) + target_propensities[target] = sum( + [ + species_hop_constants[species, site] * + site_hop_constants[site][i] + for species in 1:num_species + ] + ) end statistical_test(hop_rates, spec_propensities, target_propensities, num_species, u, site, g, rng, rel_tol) end diff --git a/test/spatial/spatial_majump.jl b/test/spatial/spatial_majump.jl index 6afcd367c..882a3bf85 100644 --- a/test/spatial/spatial_majump.jl +++ b/test/spatial/spatial_majump.jl @@ -37,44 +37,70 @@ uniform_majumps_1 = SpatialMassActionJump(uniform_rates[:, 1], reactstoch, netst uniform_majumps_2 = SpatialMassActionJump(uniform_rates, reactstoch, netstoch) uniform_majumps_3 = SpatialMassActionJump( [1.0], reshape(uniform_rates[2, :], 1, num_nodes), - reactstoch, netstoch) # hybrid -uniform_majumps_4 = SpatialMassActionJump(MassActionJump(uniform_rates[:, 1], reactstoch, - netstoch)) + reactstoch, netstoch +) # hybrid +uniform_majumps_4 = SpatialMassActionJump( + MassActionJump( + uniform_rates[:, 1], reactstoch, + netstoch + ) +) uniform_majumps = [ uniform_majumps_1, uniform_majumps_2, uniform_majumps_3, - uniform_majumps_4 + uniform_majumps_4, ] non_uniform_majumps_1 = SpatialMassActionJump(non_uniform_rates, reactstoch, netstoch) # reactions are zero outside of center site -non_uniform_majumps_2 = SpatialMassActionJump([1.0], - reshape(non_uniform_rates[2, :], 1, - num_nodes), reactstoch, netstoch) # birth everywhere, death only at center site +non_uniform_majumps_2 = SpatialMassActionJump( + [1.0], + reshape( + non_uniform_rates[2, :], 1, + num_nodes + ), reactstoch, netstoch +) # birth everywhere, death only at center site non_uniform_majumps_3 = SpatialMassActionJump( - [1.0 0.0 0.0 0.0 0.0; - 0.0 0.0 0.0 0.0 death_rate], reactstoch, - netstoch) # birth on the left, death on the right + [ + 1.0 0.0 0.0 0.0 0.0; + 0.0 0.0 0.0 0.0 death_rate + ], reactstoch, + netstoch +) # birth on the left, death on the right non_uniform_majumps = [non_uniform_majumps_1, non_uniform_majumps_2, non_uniform_majumps_3] # put together the JumpProblem's -uniform_jump_problems = JumpProblem[JumpProblem(prob, NSM(), majump, - hopping_constants = hopping_constants, - spatial_system = grid, - save_positions = (false, false), rng = rng) - for majump in uniform_majumps] +uniform_jump_problems = JumpProblem[ + JumpProblem( + prob, NSM(), majump, + hopping_constants = hopping_constants, + spatial_system = grid, + save_positions = (false, false), rng = rng + ) + for majump in uniform_majumps +] # flattenned -append!(uniform_jump_problems, - JumpProblem[JumpProblem(prob, NRM(), majump, hopping_constants = hopping_constants, - spatial_system = grid, save_positions = (false, false), rng = rng) - for majump in uniform_majumps]) +append!( + uniform_jump_problems, + JumpProblem[ + JumpProblem( + prob, NRM(), majump, hopping_constants = hopping_constants, + spatial_system = grid, save_positions = (false, false), rng = rng + ) + for majump in uniform_majumps + ] +) # non-uniform -non_uniform_jump_problems = JumpProblem[JumpProblem(prob, NSM(), majump, - hopping_constants = hopping_constants, - spatial_system = grid, - save_positions = (false, false), rng = rng) - for majump in non_uniform_majumps] +non_uniform_jump_problems = JumpProblem[ + JumpProblem( + prob, NSM(), majump, + hopping_constants = hopping_constants, + spatial_system = grid, + save_positions = (false, false), rng = rng + ) + for majump in non_uniform_majumps +] # testing function get_mean_end_state(jump_prob, Nsims) @@ -83,7 +109,7 @@ function get_mean_end_state(jump_prob, Nsims) sol = solve(jump_prob, SSAStepper()) end_state .+= sol.u[end] end - end_state / Nsims + return end_state / Nsims end function discrete_laplacian_from_spatial_system(spatial_system, hopping_rate) @@ -96,7 +122,7 @@ function discrete_laplacian_from_spatial_system(spatial_system, hopping_rate) end end laplacian .*= hopping_rate - laplacian + return laplacian end L = discrete_laplacian_from_spatial_system(grid, diffusivity) @@ -117,7 +143,7 @@ end # birth and death zero outside of center site function f2(u, p, t) - L * u - diagm([0.0, 0.0, death_rate, 0.0, 0.0]) * u + [0.0, 0.0, 1.0, 0.0, 0.0] + return L * u - diagm([0.0, 0.0, death_rate, 0.0, 0.0]) * u + [0.0, 0.0, 1.0, 0.0, 0.0] end ode_prob = ODEProblem(f2, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) @@ -145,7 +171,7 @@ end # birth on left end, death on right end function f4(u, p, t) - L * u - diagm([0.0, 0.0, 0.0, 0.0, death_rate]) * u + [1.0, 0.0, 0.0, 0.0, 0.0] + return L * u - diagm([0.0, 0.0, 0.0, 0.0, death_rate]) * u + [1.0, 0.0, 0.0, 0.0, 0.0] end ode_prob = ODEProblem(f4, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) diff --git a/test/spatial/topology.jl b/test/spatial/topology.jl index 9c93445bf..dbbe7d13b 100644 --- a/test/spatial/topology.jl +++ b/test/spatial/topology.jl @@ -17,7 +17,7 @@ num_samples = 10^5 rel_tol = 0.01 grids = [ JP.CartesianGridRej(dims), - Graphs.grid(dims) + Graphs.grid(dims), ] for grid in grids show(io, "text/plain", grid) @@ -30,7 +30,7 @@ for grid in grids @test JP.outdegree(grid, 6) == 5 for site in sites @test [JP.nth_nbr(grid, site, n) for n in 1:outdegree(grid, site)] == - collect(neighbors(grid, site)) + collect(neighbors(grid, site)) d = Dict{Int, Int}() for i in 1:num_samples nb = JP.rand_nbr(rng, grid, site) diff --git a/test/splitcoupled.jl b/test/splitcoupled.jl index e871e2c0f..bf6e241d9 100644 --- a/test/splitcoupled.jl +++ b/test/splitcoupled.jl @@ -5,7 +5,7 @@ rng = StableRNG(12345) rate = (u, p, t) -> 1.0 * u[1] affect! = function (integrator) - integrator.u[1] = 1.0 + return integrator.u[1] = 1.0 end jump1 = ConstantRateJump(rate, affect!) @@ -18,7 +18,8 @@ jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) coupling_map = [(1, 1)] coupled_prob = SplitCoupledJumpProblem( jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) + rng = rng +) @time sol = solve(coupled_prob, FunctionMap()) @time solve(jump_prob, FunctionMap()) @@ -26,17 +27,17 @@ coupled_prob = SplitCoupledJumpProblem( rate = (u, p, t) -> 1.0 affect! = function (integrator) - integrator.u[1] = 1.0 + return integrator.u[1] = 1.0 end jump1 = ConstantRateJump(rate, affect!) rate = (u, p, t) -> 2.0 jump2 = ConstantRateJump(rate, affect!) f = function (du, u, p, t) - du[1] = u[1] + return du[1] = u[1] end g = function (du, u, p, t) - du[1] = 0.1 + return du[1] = 0.1 end # Jump ODE to jump ODE @@ -46,7 +47,8 @@ jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) jump_prob_control = JumpProblem(prob_control, Direct(), jump2; rng = rng) coupled_prob = SplitCoupledJumpProblem( jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) + rng = rng +) sol = solve(coupled_prob, Tsit5()) @test mean([abs(s[1] - s[2]) for s in sol.u]) <= 5.0 @@ -57,7 +59,8 @@ jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) coupled_prob = SplitCoupledJumpProblem( jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) + rng = rng +) sol = solve(coupled_prob, SRIW1()) @test mean([abs(s[1] - s[2]) for s in sol.u]) <= 5.0 @@ -68,14 +71,15 @@ jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) coupled_prob = SplitCoupledJumpProblem( jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) + rng = rng +) sol = solve(coupled_prob, SRIW1()) @test mean([abs(s[1] - s[2]) for s in sol.u]) <= 5.0 # Jump SDE to Discrete rate = (u, p, t) -> 1.0 affect! = function (integrator) - integrator.u[1] += 1.0 + return integrator.u[1] += 1.0 end prob = DiscreteProblem([1.0], (0.0, 1.0)) prob_control = SDEProblem(f, g, [1.0], (0.0, 1.0)) @@ -83,7 +87,8 @@ jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) coupled_prob = SplitCoupledJumpProblem( jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) + rng = rng +) sol = solve(coupled_prob, SRIW1()) # test mass action jumps coupled to ODE @@ -93,11 +98,13 @@ react_stoch = [Vector{Pair{Int, Int}}()] net_stoch = [[1 => 1]] majumps = MassActionJump(rate, react_stoch, net_stoch) f = function (du, u, p, t) - du[1] = -1.0 * u[1] + return du[1] = -1.0 * u[1] end odeprob = ODEProblem(f, [10.0], (0.0, 10.0)) -jump_prob = JumpProblem(odeprob, Direct(), majumps, save_positions = (false, false); - rng = rng) +jump_prob = JumpProblem( + odeprob, Direct(), majumps, save_positions = (false, false); + rng = rng +) Nsims = 8000 Amean = 0.0 for i in 1:Nsims diff --git a/test/ssa_callback_test.jl b/test/ssa_callback_test.jl index 62d1c8c58..4399d06cd 100644 --- a/test/ssa_callback_test.jl +++ b/test/ssa_callback_test.jl @@ -6,7 +6,7 @@ rng = StableRNG(12345) rate = (u, p, t) -> u[1] affect! = function (integrator) integrator.u[1] -= 1 - integrator.u[2] += 1 + return integrator.u[2] += 1 end jump = ConstantRateJump(rate, affect!) @@ -21,19 +21,19 @@ sol = solve(jump_prob, SSAStepper()) condition(u, t, integrator) = t == 5 function fuel_affect!(integrator) integrator.u[1] += 100 - reset_aggregated_jumps!(integrator) + return reset_aggregated_jumps!(integrator) end cb = DiscreteCallback(condition, fuel_affect!, save_positions = (false, true)) sol = solve(jump_prob, SSAStepper(); callback = cb, tstops = [5]) @test sol.t[1:2] == [0.0, 5.0] # no jumps between t=0 and t=5 -@test sol(5 + 1e-10) == [100, 0] # state just after fueling before any decays can happen +@test sol(5 + 1.0e-10) == [100, 0] # state just after fueling before any decays can happen # test can pass callbacks via JumpProblem jump_prob2 = JumpProblem(prob, Direct(), jump; rng = rng, callback = cb) sol2 = solve(jump_prob2, SSAStepper(); tstops = [5]) @test sol2.t[1:2] == [0.0, 5.0] # no jumps between t=0 and t=5 -@test sol2(5 + 1e-10) == [100, 0] # state just after fueling before any decays can happen +@test sol2(5 + 1.0e-10) == [100, 0] # state just after fueling before any decays can happen # test that callback initializer/finalizer is called and add_tstop! works as expected random_tstops = rand(rng, 100) .* 10 # 100 random Float64 between 0.0 and 10.0 @@ -42,7 +42,7 @@ function fuel_init!(cb, u, t, integrator) for tstop in random_tstops add_tstop!(integrator, tstop) end - @test issorted(integrator.tstops) + return @test issorted(integrator.tstops) end finalizer_called = 0 fuel_finalize(cb, u, t, integrator) = global finalizer_called += 1 @@ -67,7 +67,7 @@ pcondit(u, t, integrator) = t == 1000.0 function paffect!(integrator) integrator.p[1] = 0.0 integrator.p[2] = 1.0 - reset_aggregated_jumps!(integrator) + return reset_aggregated_jumps!(integrator) end sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p .== [0.0, 1.0]) @@ -90,16 +90,20 @@ sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback( @test sol[1, end] == 100 p2 .= [1.0, 0.0, 0.0] -jprob = JumpProblem(dprob, Direct(), JumpSet(; massaction_jumps = [maj1, maj2, maj3]), - save_positions = (false, false), rng = rng) +jprob = JumpProblem( + dprob, Direct(), JumpSet(; massaction_jumps = [maj1, maj2, maj3]), + save_positions = (false, false), rng = rng +) sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p2 .== [0.0, 1.0, 0.0]) @test sol[1, end] == 100 p .= [1.0, 0.0] dprob = DiscreteProblem(u₀, tspan, p) -maj4 = MassActionJump([[1 => 1], [2 => 1]], [[1 => -1, 2 => 1], [1 => 1, 2 => -1]]; - param_idxs = [1, 2]) +maj4 = MassActionJump( + [[1 => 1], [2 => 1]], [[1 => -1, 2 => 1], [1 => 1, 2 => -1]]; + param_idxs = [1, 2] +) jprob = JumpProblem(dprob, Direct(), maj4, save_positions = (false, false), rng = rng) sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p .== [0.0, 1.0]) @@ -115,8 +119,10 @@ jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false), rng @test all(jprob.massaction_jump.scaled_rates .== [1.0]) # test for https://github.com/SciML/JumpProcesses.jl/issues/239 -maj6 = MassActionJump([[1 => 1], [2 => 1]], [[1 => -1, 2 => 1], [1 => 1, 2 => -1]]; - param_idxs = [1, 2]) +maj6 = MassActionJump( + [[1 => 1], [2 => 1]], [[1 => -1, 2 => 1], [1 => 1, 2 => -1]]; + param_idxs = [1, 2] +) p = (0.1, 0.1) dprob = DiscreteProblem([10, 0], (0.0, 100.0), p) jprob = JumpProblem(dprob, Direct(), maj6; save_positions = (false, false), rng = rng) @@ -160,9 +166,9 @@ let jprob = JumpProblem(dprob, crj; rng) tstops = Float64[] # tests for when aliasing system is in place - #sol = solve(jprob; callback = cb, tstops, alias_tstops = true) + #sol = solve(jprob; callback = cb, tstops, alias_tstops = true) # @test sol[1, end] == 9 - #@test tstops == 1.0:9.0 + #@test tstops == 1.0:9.0 # empty!(tstops) # sol = solve(jprob; callback = cb, tstops, alias_tstops = false) # @test sol[1, end] == 9 diff --git a/test/ssa_tests.jl b/test/ssa_tests.jl index e82c50959..ddca540b3 100644 --- a/test/ssa_tests.jl +++ b/test/ssa_tests.jl @@ -5,13 +5,13 @@ rng = StableRNG(12345) rate = (u, p, t) -> u[1] affect! = function (integrator) - integrator.u[1] += 1 + return integrator.u[1] += 1 end jump = ConstantRateJump(rate, affect!) rate = (u, p, t) -> 0.5u[1] affect! = function (integrator) - integrator.u[1] -= 1 + return integrator.u[1] -= 1 end jump2 = ConstantRateJump(rate, affect!) @@ -37,8 +37,10 @@ sol = solve(jump_prob, SSAStepper(), save_start = false) @test sol.t[begin] > 0.0 @test sol.t[end] == 3.0 -jump_prob = JumpProblem(prob, Direct(), jump, jump2, save_positions = (false, false); - rng = rng) +jump_prob = JumpProblem( + prob, Direct(), jump, jump2, save_positions = (false, false); + rng = rng +) sol = solve(jump_prob, SSAStepper(), save_start = false, save_end = false) @test isempty(sol.t) && isempty(sol.u) diff --git a/test/table_test.jl b/test/table_test.jl index 555ffaf1e..0fd67c745 100644 --- a/test/table_test.jl +++ b/test/table_test.jl @@ -3,9 +3,9 @@ using Test const DJ = JumpProcesses # test data -minpriority = 2.0^exponent(1e-12) -maxpriority = 2.0^exponent(1e12) -priorities = [1e-13, 0.99 * minpriority, minpriority, 1.01e-4, 1e-4, 5.0, 0.0, 1e10] +minpriority = 2.0^exponent(1.0e-12) +maxpriority = 2.0^exponent(1.0e12) +priorities = [1.0e-13, 0.99 * minpriority, minpriority, 1.01e-4, 1.0e-4, 5.0, 0.0, 1.0e10] mingid = exponent(minpriority) # = -40 ptog = priority -> DJ.priortogid(priority, mingid) @@ -54,7 +54,7 @@ priorities[10] = 0.0 # test sampling cnt = 0 -Nsamps = Int(1e7) +Nsamps = Int(1.0e7) for i in 1:Nsamps global cnt pid = DJ.sample(pt, priorities) @@ -76,7 +76,7 @@ ptt = DJ.PriorityTimeTable(times, mintime, timestep) DJ.update!(ptt, 1, times[1], 10 * times[1]) # 2. -> 20., group 2 to group 14 @test ptt.groups[14].numpids == 1 @test DJ.getfirst(ptt) == (2, 8.0) -# Updating beyond the time window should not change the max priority. +# Updating beyond the time window should not change the max priority. DJ.update!(ptt, 1, times[1], 70.0) # 20. -> 70. @test ptt.groups[14].numpids == 0 @test ptt.maxtime == 66.0 diff --git a/test/thread_safety.jl b/test/thread_safety.jl index 197ed72c6..7ab081d58 100644 --- a/test/thread_safety.jl +++ b/test/thread_safety.jl @@ -11,14 +11,16 @@ u0 = [5] dprob = DiscreteProblem(u0, tspan, params) jprob = JumpProblem(dprob, Direct(), maj; rng = rng) solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); trajectories = 10) -solve(EnsembleProblem(jprob; safetycopy = true), SSAStepper(), EnsembleThreads(); - trajectories = 10) +solve( + EnsembleProblem(jprob; safetycopy = true), SSAStepper(), EnsembleThreads(); + trajectories = 10 +) # test for https://github.com/SciML/JumpProcesses.jl/issues/472 let function f!(du, u, p, t) du[1] = -u[1] - nothing + return nothing end u_0 = [1.0] ode_prob = ODEProblem(f!, u_0, (0.0, 10)) @@ -28,8 +30,10 @@ let jump_prob = JumpProblem(ode_prob, Direct(), vrj; vr_aggregator = agg) prob_func(prob, i, repeat) = deepcopy(prob) prob = EnsembleProblem(jump_prob, prob_func = prob_func) - sol = solve(prob, Tsit5(), EnsembleThreads(), trajectories = 400, - save_everystep = false) + sol = solve( + prob, Tsit5(), EnsembleThreads(), trajectories = 400, + save_everystep = false + ) firstrx_time = [sol.u[i].t[2] for i in 1:length(sol)] @test allunique(firstrx_time) end diff --git a/test/variable_rate.jl b/test/variable_rate.jl index ed6aecd0b..2abb0677f 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -26,7 +26,7 @@ jump = VariableRateJump(rate, affect!, interp_points = 1000) jump2 = deepcopy(jump) f = function (du, u, p, t) - du[1] = u[1] + return du[1] = u[1] end prob = ODEProblem(f, [0.2], (0.0, 10.0)) @@ -41,22 +41,22 @@ integrator = init(jump_prob_gill, Tsit5()) sol_gill = solve(jump_prob_gill, Tsit5()) sol_gill = solve(jump_prob, Rosenbrock23(autodiff = AutoFiniteDiff())) sol_gill = solve(jump_prob, Rosenbrock23()) -@test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12 -@test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1e-12 +@test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1.0e-12 +@test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1.0e-12 g = function (du, u, p, t) - du[1] = u[1] + return du[1] = u[1] end prob = SDEProblem(f, g, [0.2], (0.0, 10.0)) jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng = rng) sol = solve(jump_prob, SRIW1()) jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng = rng) sol_gill = solve(jump_prob_gill, SRIW1()) -@test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12 -@test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1e-12 +@test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1.0e-12 +@test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1.0e-12 function ff(du, u, p, t) - if p == 0 + return if p == 0 du .= 1.01u else du .= 2.01u @@ -66,11 +66,11 @@ function gg(du, u, p, t) du[1, 1] = 0.3u[1] du[1, 2] = 0.6u[1] du[2, 1] = 1.2u[1] - du[2, 2] = 0.2u[2] + return du[2, 2] = 0.2u[2] end rate_switch(u, p, t) = u[1] * 1.0 function affect_switch!(integrator) - integrator.p = 1 + return integrator.p = 1 end jump_switch = VariableRateJump(rate_switch, affect_switch!) prob = SDEProblem(ff, gg, ones(2), (0.0, 1.0), 0, noise_rate_prototype = zeros(2, 2)) @@ -82,7 +82,7 @@ sol_gill = solve(jump_prob_gill, SRA1(), dt = 1.0) ## Some integration tests function f2(du, u, p, t) - du[1] = u[1] + return du[1] = u[1] end prob = ODEProblem(f2, [0.2], (0.0, 10.0)) rate2(u, p, t) = 2 @@ -107,7 +107,7 @@ sol(4.0) sol.u[4] function g2(du, u, p, t) - du[1] = u[1] + return du[1] = u[1] end prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0)) jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) @@ -118,15 +118,17 @@ sol(4.0) sol.u[4] function f3(du, u, p, t) - du .= u + return du .= u end prob = ODEProblem(f3, [1.0 2.0; 3.0 4.0], (0.0, 1.0)) rate3(u, p, t) = u[1] + u[2] function affect3!(integrator) - (integrator.u[1] = 0.25; + ( + integrator.u[1] = 0.25; integrator.u[2] = 0.5; integrator.u[3] = 0.75; - integrator.u[4] = 1) + integrator.u[4] = 1 + ) end jump = VariableRateJump(rate3, affect3!) jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) @@ -136,11 +138,11 @@ sol_gill = solve(jump_prob_gill, Tsit5()) # test for https://discourse.julialang.org/t/differentialequations-jl-package-variable-rate-jumps-with-complex-variables/80366/2 function f4(dx, x, p, t) - dx[1] = x[1] + return dx[1] = x[1] end rate4(x, p, t) = t function affect4!(integrator) - integrator.u[1] = integrator.u[1] * 0.5 + return integrator.u[1] = integrator.u[1] * 0.5 end jump = VariableRateJump(rate4, affect4!) x₀ = 1.0 + 0.0im @@ -166,11 +168,13 @@ let maj_rate = [1.0] react_stoich_ = [Vector{Pair{Int, Int}}()] net_stoich_ = [[1 => 1]] - mass_action_jump_ = MassActionJump(maj_rate, react_stoich_, net_stoich_; - scale_rates = false) + mass_action_jump_ = MassActionJump( + maj_rate, react_stoich_, net_stoich_; + scale_rates = false + ) affect! = function (integrator) - integrator.u[1] -= 1 + return integrator.u[1] -= 1 end cs_rate1(u, p, t) = 0.2 * u[1] constant_rate_jump = ConstantRateJump(cs_rate1, affect!) @@ -180,13 +184,19 @@ let u0 = [0] tspan = (0.0, 30.0) dprob_ = DiscreteProblem(u0, tspan) - @test_throws ErrorException JumpProblem(dprob_, alg, jumpset_, - save_positions = (false, false)) - - vrj = VariableRateJump(cs_rate1, affect!; urate = ((u, p, t) -> 1.0), - rateinterval = ((u, p, t) -> 1.0)) - @test_throws ErrorException JumpProblem(dprob_, alg, mass_action_jump_, vrj; - save_positions = (false, false)) + @test_throws ErrorException JumpProblem( + dprob_, alg, jumpset_, + save_positions = (false, false) + ) + + vrj = VariableRateJump( + cs_rate1, affect!; urate = ((u, p, t) -> 1.0), + rateinterval = ((u, p, t) -> 1.0) + ) + @test_throws ErrorException JumpProblem( + dprob_, alg, mass_action_jump_, vrj; + save_positions = (false, false) + ) end end @@ -216,8 +226,10 @@ let end end - test_jump = VariableRateJump(test_rate, test_affect!; urate = test_urate, - rateinterval = (u, p, t) -> 1.0) + test_jump = VariableRateJump( + test_rate, test_affect!; urate = test_urate, + rateinterval = (u, p, t) -> 1.0 + ) dprob = DiscreteProblem([0], (0.0, 1.0), nothing) jprob = JumpProblem(dprob, Coevolve(), test_jump; dep_graph = [[1]]) @@ -239,19 +251,19 @@ let function ode_fxn(du, u, p, t) du .= 0 - nothing + return nothing end b_rate(u, p, t) = (u[1] * p[1]) function birth!(integrator) integrator.u[1] += 1 - nothing + return nothing end b_jump = VariableRateJump(b_rate, birth!) d_rate(u, p, t) = (u[1] * p[2]) function death!(integrator) integrator.u[1] -= 1 - nothing + return nothing end d_jump = VariableRateJump(d_rate, death!) @@ -267,9 +279,9 @@ let end end -# accuracy test based on +# accuracy test based on # https://github.com/SciML/JumpProcesses.jl/issues/320 -# note that even with the seeded StableRNG this test is not +# note that even with the seeded StableRNG this test is not # deterministic for some reason. function getmean(Nsims, prob, alg, dt, tsave, seed) umean = zeros(length(tsave)) @@ -296,20 +308,20 @@ let function ode_fxn(du, u, p, t) du .= 0 - nothing + return nothing end b_rate(u, p, t) = (u[1] * p[1]) function birth!(integrator) integrator.u[1] += 1 - nothing + return nothing end b_jump = VariableRateJump(b_rate, birth!) d_rate(u, p, t) = (u[1] * p[2]) function death!(integrator) integrator.u[1] -= 1 - nothing + return nothing end d_jump = VariableRateJump(d_rate, death!) @@ -327,7 +339,7 @@ let end end -# Correctness test based on +# Correctness test based on # VR_Direct and VR_FRM # Function to run ensemble and compute statistics function run_ensemble(prob, alg, jumps...; vr_aggregator = VR_FRM(), Nsims = 8000) @@ -391,7 +403,7 @@ let t = 10.0 u0 = 0.2 - analytical_mean = u0 * exp(-t) + λ*(1 - exp(-t)) + analytical_mean = u0 * exp(-t) + λ * (1 - exp(-t)) @test isapprox(mean_vrfr, analytical_mean, rtol = 0.05) @test isapprox(mean_vrfr, mean_vrcb, rtol = 0.05) @@ -407,7 +419,7 @@ let function birth_affect!(integrator) integrator.u[1] += 1 integrator.p[3] += 1 - nothing + return nothing end birth_jump = VariableRateJump(birth_rate, birth_affect!) @@ -416,7 +428,7 @@ let function death_affect!(integrator) integrator.u[1] -= 1 integrator.p[3] += 1 - nothing + return nothing end death_jump = VariableRateJump(death_rate, death_affect!)