Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .JuliaFormatter.toml

This file was deleted.

14 changes: 10 additions & 4 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
name: "Format Check"
name: format-check

on:
push:
branches:
- 'master'
- 'main'
- 'release-'
tags: '*'
pull_request:

jobs:
format-check:
name: "Format Check"
uses: "SciML/.github/.github/workflows/format-check.yml@v1"
runic:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: fredrikekre/runic-action@v1
with:
version: '1'
4 changes: 2 additions & 2 deletions benchmarks/extended_jump_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ benchmark_out = ExtendedJumpArray(zeros(500000), zeros(500000))
benchmark_in = ExtendedJumpArray(rand(rng, 500000), rand(rng, 500000))

function test_single_dot(out, array)
@inbounds @. out = array + 1.23 * array
return @inbounds @. out = array + 1.23 * array
end

function test_double_dot(out, array)
@inbounds @.. out = array + 1.23 * array
return @inbounds @.. out = array + 1.23 * array
end

println("Base-case normal broadcasting")
Expand Down
34 changes: 22 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,33 @@ cp(joinpath(docpath, "Project.toml"), joinpath(assetpath, "Project.toml"), force

include("pages.jl")

mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/require", "[tex]/mathtools"]), :tex => Dict("inlineMath" => [["\$", "\$"], ["\\(", "\\)"]],
"packages" => [
"base",
"ams",
"autoload",
"mathtools",
"require"
])))
mathengine = MathJax3(
Dict(
:loader => Dict("load" => ["[tex]/require", "[tex]/mathtools"]), :tex => Dict(
"inlineMath" => [["\$", "\$"], ["\\(", "\\)"]],
"packages" => [
"base",
"ams",
"autoload",
"mathtools",
"require",
]
)
)
)

makedocs(sitename = "JumpProcesses.jl", authors = "Chris Rackauckas", modules = [JumpProcesses],
makedocs(
sitename = "JumpProcesses.jl", authors = "Chris Rackauckas", modules = [JumpProcesses],
clean = true, doctest = false, linkcheck = true, warnonly = [:missing_docs],
format = Documenter.HTML(; assets = ["assets/favicon.ico"],
format = Documenter.HTML(;
assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/JumpProcesses/",
prettyurls = (get(ENV, "CI", nothing) == "true"),
mathengine,
edit_link = "master",
repolink = "https://github.com/SciML/JumpProcesses.jl"),
pages = pages)
repolink = "https://github.com/SciML/JumpProcesses.jl"
),
pages = pages
)

deploydocs(repo = "github.com/SciML/JumpProcesses.jl.git"; push_preview = true)
17 changes: 11 additions & 6 deletions docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# Put in a separate page so it can be used by SciMLDocs.jl

pages = ["index.md",
"Tutorials" => Any["tutorials/simple_poisson_process.md",
pages = [
"index.md",
"Tutorials" => Any[
"tutorials/simple_poisson_process.md",
"tutorials/discrete_stochastic_example.md",
"tutorials/point_process_simulation.md",
"tutorials/jump_diffusion.md",
"tutorials/spatial.md"],
"tutorials/spatial.md",
],
"Applications" => Any["applications/advanced_point_process.md"],
"Type Documentation" => Any["Jumps, JumpProblem, and Aggregators" => "jump_types.md",
"Jump solvers" => "jump_solve.md"],
"Type Documentation" => Any[
"Jumps, JumpProblem, and Aggregators" => "jump_types.md",
"Jump solvers" => "jump_solve.md",
],
"FAQ" => "faq.md",
"API" => "api.md"
"API" => "api.md",
]
99 changes: 58 additions & 41 deletions ext/JumpProcessesKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@ using KernelAbstractions, Adapt
using StaticArrays
using PoissonRandom, Random

function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem,
function SciMLBase.__solve(
ensembleprob::SciMLBase.AbstractEnsembleProblem,
alg::SimpleTauLeaping,
ensemblealg::EnsembleGPUKernel;
trajectories,
seed = nothing,
dt = error("dt is required for SimpleTauLeaping."),
kwargs...)
kwargs...
)
if trajectories == 1
return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); trajectories = 1,
seed, dt, kwargs...)
return SciMLBase.__solve(
ensembleprob, alg, EnsembleSerial(); trajectories = 1,
seed, dt, kwargs...
)
end

ensemblealg.backend === nothing ? backend = CPU() :
backend = ensemblealg.backend
backend = ensemblealg.backend

jump_prob = ensembleprob.prob

Expand All @@ -31,36 +35,42 @@ function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem,

# Run vectorized solve
ts,
us = vectorized_solve(
probs, jump_prob, SimpleTauLeaping(); backend, trajectories, seed, dt)
us = vectorized_solve(
probs, jump_prob, SimpleTauLeaping(); backend, trajectories, seed, dt
)

# Convert to CPU for inspection
_ts = Array(ts)
_us = Array(us)

time = @elapsed sol = [begin
ts = @view _ts[:, i]
us = @view _us[:, :, i]
sol_idx = findlast(x -> x != probs[i].prob.tspan[1], ts)
if sol_idx === nothing
@error "No solution found" tspan=probs[i].tspan[1] ts
error("Batch solve failed")
end
@views ensembleprob.output_func(
SciMLBase.build_solution(probs[i].prob,
alg,
ts[1:sol_idx],
[us[j, :] for j in 1:sol_idx],
k = nothing,
stats = nothing,
calculate_error = false,
retcode = sol_idx !=
length(ts) ?
ReturnCode.Terminated :
ReturnCode.Success),
i)[1]
end
for i in eachindex(probs)]
time = @elapsed sol = [
begin
ts = @view _ts[:, i]
us = @view _us[:, :, i]
sol_idx = findlast(x -> x != probs[i].prob.tspan[1], ts)
if sol_idx === nothing
@error "No solution found" tspan = probs[i].tspan[1] ts
error("Batch solve failed")
end
@views ensembleprob.output_func(
SciMLBase.build_solution(
probs[i].prob,
alg,
ts[1:sol_idx],
[us[j, :] for j in 1:sol_idx],
k = nothing,
stats = nothing,
calculate_error = false,
retcode = sol_idx !=
length(ts) ?
ReturnCode.Terminated :
ReturnCode.Success
),
i
)[1]
end
for i in eachindex(probs)
]
return SciMLBase.EnsembleSolution(sol, time, true)
end

Expand All @@ -82,7 +92,8 @@ end
@kernel function simple_tau_leaping_kernel(
@Const(probs_data), _us, _ts, dt, @Const(rj_data),
current_u_buf, rate_cache_buf, counts_buf, local_dc_buf,
seed::UInt64)
seed::UInt64
)
i = @index(Global, Linear)

# Get thread-local buffers
Expand Down Expand Up @@ -114,7 +125,7 @@ end

# Get input/output arrays
ts_view = @inbounds view(_ts, :, i)
us_view = @inbounds view(_us,:,:,i)
us_view = @inbounds view(_us, :, :, i)

# Initialize first time step and state
@inbounds ts_view[1] = tspan[1]
Expand All @@ -124,7 +135,7 @@ end

# Main loop
for j in 2:n
tprev = tspan[1] + (j-2) * dt
tprev = tspan[1] + (j - 2) * dt

# Compute rates and scale by dt
rate(rate_cache, current_u, p, tprev)
Expand All @@ -143,20 +154,24 @@ end
@inbounds for k in 1:state_dim
us_view[j, k] = current_u[k]
end
@inbounds ts_view[j] = tspan[1] + (j-1) * dt
@inbounds ts_view[j] = tspan[1] + (j - 1) * dt
end
end

# Vectorized solve function
function vectorized_solve(probs, prob::JumpProblem, alg::SimpleTauLeaping;
backend, trajectories, seed, dt, kwargs...)
function vectorized_solve(
probs, prob::JumpProblem, alg::SimpleTauLeaping;
backend, trajectories, seed, dt, kwargs...
)
# Extract common jump data
rj = prob.regular_jump
rj_data = JumpData(rj.rate, rj.c, rj.numjumps)

# Extract trajectory-specific data without static typing
probs_data = [TrajectoryData(SA{eltype(p.prob.u0)}[p.prob.u0...], p.prob.p, p.prob.tspan)
for p in probs]
probs_data = [
TrajectoryData(SA{eltype(p.prob.u0)}[p.prob.u0...], p.prob.p, p.prob.tspan)
for p in probs
]

# Adapt to GPU
probs_data_gpu = adapt(backend, probs_data)
Expand Down Expand Up @@ -197,13 +212,15 @@ function vectorized_solve(probs, prob::JumpProblem, alg::SimpleTauLeaping;
KernelAbstractions.synchronize(backend)

# Seed for Poisson sampling
seed = seed === nothing ? UInt64(12345) : UInt64(seed);
seed = seed === nothing ? UInt64(12345) : UInt64(seed)

# Launch main kernel
kernel = simple_tau_leaping_kernel(backend)
main_event = kernel(probs_data_gpu, us, ts, dt, rj_data_gpu,
main_event = kernel(
probs_data_gpu, us, ts, dt, rj_data_gpu,
current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, seed;
ndrange = n_trajectories)
ndrange = n_trajectories
)
KernelAbstractions.synchronize(backend)

return ts, us
Expand Down
12 changes: 6 additions & 6 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@ import SymbolicIndexingInterface as SII

# Import additional types and functions from DiffEqBase and SciMLBase
using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction,
DDEFunction, DiscreteProblem, ODEFunction, ODEProblem,
ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!,
deleteat!, isinplace, remake, savevalues!, step!,
u_modified!
DDEFunction, DiscreteProblem, ODEFunction, ODEProblem,
ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!,
deleteat!, isinplace, remake, savevalues!, step!,
u_modified!
using SciMLBase: SciMLBase, DEIntegrator

abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
abstract type AbstractAggregatorAlgorithm end
abstract type AbstractJumpAggregator end
abstract type AbstractSSAIntegrator{Alg, IIP, U, T} <:
DEIntegrator{Alg, IIP, U, T} end
DEIntegrator{Alg, IIP, U, T} end

const DEFAULT_RNG = Random.default_rng()

Expand Down Expand Up @@ -126,7 +126,7 @@ export init, solve, solve!
include("SSA_stepper.jl")
export SSAStepper

# leaping:
# leaping:
include("simple_regular_solve.jl")
export SimpleTauLeaping, EnsembleGPUKernel

Expand Down
Loading
Loading