Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1d38561
kernels setup done
sivasathyaseeelan Jul 23, 2025
b476393
ImplicitTauLeaping setup done for jump problem solver
sivasathyaseeelan Jul 24, 2025
bbad256
jump_problem solver fixed
sivasathyaseeelan Jul 28, 2025
e36c463
removed equilibrium_pair logic
sivasathyaseeelan Jul 31, 2025
94eccb2
tranculation error fixed
sivasathyaseeelan Jul 31, 2025
7e832f4
nonlinearsolver is implemented
sivasathyaseeelan Aug 1, 2025
8e1c84c
changed to SimpleImplicitTauLeaping
sivasathyaseeelan Aug 9, 2025
4c61d74
refactor
sivasathyaseeelan Aug 9, 2025
aa351c3
SimpleAdaptiveTauLeaping is done
sivasathyaseeelan Aug 9, 2025
227aeb6
simple version of SimpleImplicitTauLeaping
sivasathyaseeelan Aug 19, 2025
c56ee41
removed adaptive tau leap
sivasathyaseeelan Aug 19, 2025
fafe65d
poiss change
sivasathyaseeelan Aug 19, 2025
00ab7db
changed to inline non linear solver
sivasathyaseeelan Aug 19, 2025
73bd240
refactor
sivasathyaseeelan Aug 19, 2025
99b79a8
typo
sivasathyaseeelan Aug 19, 2025
9f81e7d
basic version of inplicit tau leap is done
sivasathyaseeelan Aug 19, 2025
606dcc6
added critical_threshold
sivasathyaseeelan Aug 20, 2025
6fa449f
residual update
sivasathyaseeelan Aug 20, 2025
199cede
added comment line
sivasathyaseeelan Aug 20, 2025
5be559c
SimpleImplicitTauLeaping
sivasathyaseeelan Sep 5, 2025
c250d32
project.toml
sivasathyaseeelan Sep 5, 2025
a348acd
project.toml
sivasathyaseeelan Sep 5, 2025
59bd4c1
some
sivasathyaseeelan Sep 5, 2025
e4c02c6
some
sivasathyaseeelan Sep 5, 2025
cba6df7
test update
sivasathyaseeelan Sep 5, 2025
3ef1e2e
refactor
sivasathyaseeelan Feb 13, 2026
9bcd583
some changes
sivasathyaseeelan Feb 13, 2026
e1941db
Merge branch 'master' into kernel-implicit-tau
sivasathyaseeelan Feb 14, 2026
6618348
refactor
sivasathyaseeelan Feb 14, 2026
5459dd9
comcat entries
sivasathyaseeelan Feb 14, 2026
67aab9a
Merge branch 'master' into kernel-implicit-tau
sivasathyaseeelan Feb 15, 2026
22b3198
typo
sivasathyaseeelan Feb 15, 2026
eb53701
test fix
sivasathyaseeelan Feb 15, 2026
7ae0a78
saveat implementation
sivasathyaseeelan Feb 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "9.21.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand All @@ -17,6 +18,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

Expand Down Expand Up @@ -53,6 +56,8 @@ RecursiveArrayTools = "3.35"
Reexport = "1.2"
SafeTestsets = "0.1"
SciMLBase = "2.115"
Setfield = "1"
SimpleNonlinearSolve = "1, 2"
StableRNGs = "1"
StaticArrays = "1.9.8"
Statistics = "1"
Expand All @@ -62,7 +67,6 @@ Test = "1"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
Expand All @@ -78,4 +82,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ADTypes", "Aqua", "ExplicitImports", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"]
test = ["Aqua", "ExplicitImports", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"]
7 changes: 5 additions & 2 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ using StaticArrays: StaticArrays, SVector, setindex
using Base.Threads: Threads
using Base.FastMath: add_fast

using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleNewtonRaphson
using ADTypes: ADTypes, AutoFiniteDiff

# Import functions we extend from Base
import Base: size, getindex, setindex!, length, similar, show, merge!, merge

Expand All @@ -36,7 +39,7 @@ using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction,
ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!,
deleteat!, isinplace, remake, savevalues!, step!,
u_modified!
using SciMLBase: SciMLBase, DEIntegrator
using SciMLBase: SciMLBase, DEIntegrator, NonlinearProblem

abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
Expand Down Expand Up @@ -127,7 +130,7 @@ export SSAStepper

# leaping:
include("simple_regular_solve.jl")
export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel
export SimpleTauLeaping, SimpleExplicitTauLeaping, SimpleImplicitTauLeaping, NewtonImplicitSolver, TrapezoidalImplicitSolver, EnsembleGPUKernel

# spatial:
include("spatial/spatial_massaction_jump.jl")
Expand Down
193 changes: 190 additions & 3 deletions src/simple_regular_solve.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end

# Define solver type hierarchy
abstract type AbstractImplicitSolver end
struct NewtonImplicitSolver <: AbstractImplicitSolver end
struct TrapezoidalImplicitSolver <: AbstractImplicitSolver end

struct SimpleImplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm
epsilon::T # Error control parameter for tau selection
solver::AbstractImplicitSolver # Solver type: Newton or Trapezoidal
end

SimpleImplicitTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver()) = SimpleImplicitTauLeaping(epsilon, solver)

struct SimpleExplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm
epsilon::T # Error control parameter
end
Expand Down Expand Up @@ -33,6 +45,19 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleExplici
jump_prob.massaction_jump !== nothing
end

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \
Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release."
end
isempty(jump_prob.jump_callback.continuous_callbacks) &&
isempty(jump_prob.jump_callback.discrete_callbacks) &&
isempty(jump_prob.constant_jumps) &&
isempty(jump_prob.variable_jumps) &&
jump_prob.massaction_jump !== nothing
end

"""
_process_saveat(saveat, tspan, save_start, save_end)

Expand Down Expand Up @@ -153,10 +178,10 @@ end
function compute_hor(reactant_stoch, numjumps)
stoch_type = eltype(first(first(reactant_stoch)))
hor = zeros(stoch_type, numjumps)
max_order = 3 * one(stoch_type) # Maximum supported reaction order (type-aware)
for j in 1:numjumps
order = sum(
stoch for (spec_idx, stoch) in reactant_stoch[j]; init = zero(stoch_type))
if order > 3
order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=zero(stoch_type))
if order > max_order
error("Reaction $j has order $order, which is not supported (maximum order is 3).")
end
hor[j] = order
Expand Down Expand Up @@ -257,6 +282,40 @@ function compute_tau(
return max(tau, dtmin)
end

# Define residual for implicit equation
# Newton: u_new = u_current + sum_j nu_j * a_j(u_new) * tau (Cao et al., 2004)
# Trapezoidal: u_new = u_current + sum_j nu_j * (a_j(u_current) + a_j(u_new))/2 * tau
function implicit_equation!(resid, u_new, params)
u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params
rate(rate_cache, u_new, p, t + tau)
resid .= u_new .- u_current
if isa(solver, NewtonImplicitSolver)
for j in 1:numjumps
for spec_idx in 1:size(nu, 1)
resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004)
end
end
else # TrapezoidalImplicitSolver
rate_current = similar(rate_cache)
rate(rate_current, u_current, p, t)
half = one(eltype(rate_cache)) / 2
for j in 1:numjumps
for spec_idx in 1:size(nu, 1)
resid[spec_idx] -= nu[spec_idx, j] * half * (rate_cache[j] + rate_current[j]) * tau
end
end
end
resid .= max.(resid, -u_new) # Ensure non-negative solution
end

# Solve implicit equation using SimpleNonlinearSolve
function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)
u_new = convert(Vector{float(eltype(u_current))}, u_current)
prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver))
sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6)
return sol.u, sol.retcode == ReturnCode.Success
end

# Function to generate a mass action rate function
function massaction_rate(maj, numjumps)
return (out, u, p, t) -> begin
Expand Down Expand Up @@ -405,6 +464,134 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping;
return sol
end

function simple_implicit_tau_leaping_loop!(
prob, alg, u_current, t_current, t_end, p, rng,
rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
maj, solver, save_end)
save_idx = 1

while t_current < t_end
rate(rate_cache, u_current, p, t_current)
tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
tau = min(tau, t_end - t_current)
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx]
tau = saveat_times[save_idx] - t_current
end

u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, t_current, tau, rate, numjumps, solver)
if !converged
tau /= 2
continue
end

rate(rate_cache, u_new_float, p, t_current + tau)
zero_rate = zero(eltype(rate_cache))
counts .= pois_rand.(rng, max.(rate_cache * tau, zero_rate))
du .= zero(eltype(du))
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
du[spec_idx] += stoch * counts[j]
end
end
u_new = u_current + du

zero_pop = zero(eltype(u_new))
if any(<(zero_pop), u_new)
# Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3
tau /= 2
continue
end
# Ensure non-negativity, as per Cao et al. (2006), Section 3.3
for i in eachindex(u_new)
u_new[i] = max(u_new[i], zero_pop)
end
t_new = t_current + tau

if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
push!(usave, copy(u_new))
push!(tsave, t_new)
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
save_idx += 1
end
end

u_current = u_new
t_current = t_new
end

# Save endpoint if requested and not already saved
if save_end && (isempty(tsave) || tsave[end] != t_end)
push!(usave, copy(u_current))
push!(tsave, t_end)
end
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
seed = nothing,
dtmin = nothing,
saveat = nothing, save_start = nothing, save_end = nothing)
validate_pure_leaping_inputs(jump_prob, alg) ||
error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")

(; prob, rng) = jump_prob
(seed !== nothing) && seed!(rng, seed)

maj = jump_prob.massaction_jump
numjumps = get_num_majumps(maj)
rj = jump_prob.regular_jump
# Extract rates
rate = rj !== nothing ? rj.rate : massaction_rate(maj, numjumps)
c = rj !== nothing ? rj.c : nothing
u0 = copy(prob.u0)
tspan = prob.tspan
p = prob.p

if dtmin === nothing
dtmin = 1e-10 * one(typeof(tspan[2]))
end

saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end)

# Initialize current state and saved history
u_current = copy(u0)
t_current = tspan[1]
if save_start
usave = [copy(u0)]
tsave = [tspan[1]]
else
usave = typeof(u0)[]
tsave = typeof(tspan[1])[]
end
rate_cache = zeros(float(eltype(u0)), numjumps)
counts = zero(rate_cache)
du = similar(u0)
t_end = tspan[2]
epsilon = alg.epsilon
solver = alg.solver

nu = zeros(float(eltype(u0)), length(u0), numjumps)
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
nu[spec_idx, j] = stoch
end
end
reactant_stoch = maj.reactant_stoch
hor = compute_hor(reactant_stoch, numjumps)
max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps)

simple_implicit_tau_leaping_loop!(
prob, alg, u_current, t_current, t_end, p, rng,
rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
maj, solver, save_end)

sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
calculate_error=false,
interp=DiffEqBase.ConstantInterpolation(tsave, usave))
return sol
end

struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
backend::Backend
cpu_offload::Float64
Expand Down
Loading
Loading