Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
45019f4
Fix #106: Add generic cache iteration for DifferentiationInterface co…
ChrisRackauckas Aug 30, 2025
4a9f3cb
Refine generic cache handling to be more conservative
ChrisRackauckas Aug 30, 2025
6a9e6bc
Merge branch 'master' into fix-diffeq-cache-iteration
ChrisRackauckas Aug 31, 2025
b4be677
Fix jacobian config resizing for DifferentiationInterface types
ChrisRackauckas Sep 3, 2025
8e6c132
Simplify resize_jac_config! to always use DI.prepare!_jacobian
ChrisRackauckas Sep 3, 2025
83351aa
Remove all generic fallbacks and use DI.prepare!_jacobian directly
ChrisRackauckas Sep 3, 2025
21b5a11
Remove remaining generic grad_config fallbacks
ChrisRackauckas Sep 3, 2025
317f135
Add DifferentiationInterface compat constraint
ChrisRackauckas Sep 3, 2025
f2828cf
Handle jacobian configs that are tuples for default algorithms
ChrisRackauckas Sep 3, 2025
c4513de
Use safer backend access with hasproperty check
ChrisRackauckas Sep 4, 2025
55bfc9e
Use proper backend from OrdinaryDiffEqCore.alg_autodiff(integrator.alg)
ChrisRackauckas Sep 5, 2025
f9db432
Use function from existing jacobian config for DI.prepare!_jacobian
ChrisRackauckas Sep 5, 2025
8f249bc
Use UJacobianWrapper for proper function handling like OrdinaryDiffEq…
ChrisRackauckas Sep 5, 2025
927a3e3
Add gradient config handling with DI.prepare!_gradient
ChrisRackauckas Sep 5, 2025
85ee7c5
Simplify DI integration without OrdinaryDiffEqDifferentiation dependency
ChrisRackauckas Sep 5, 2025
2ed038f
Add SciMLBase dependency for proper jacobian wrappers
ChrisRackauckas Sep 5, 2025
f7c3f8a
Fix SciMLBase UUID
ChrisRackauckas Sep 5, 2025
d5c8320
Remove wrapper on gradient prep as requested
ChrisRackauckas Sep 5, 2025
979664b
Use two-argument UJacobianWrapper to fix test failure
ChrisRackauckas Sep 5, 2025
bef5089
Use two-argument DI.prepare!_gradient as required
ChrisRackauckas Sep 5, 2025
41a795e
Unwrap DI.TwoArgWrapper configs for gradient preparation
ChrisRackauckas Sep 5, 2025
653c7a4
Unwrap DI cache from TwoArgWrapper properly
ChrisRackauckas Sep 5, 2025
b2dbd67
Add missing p parameter to UJacobianWrapper
ChrisRackauckas Sep 5, 2025
203dd0e
Fix gradient config handling to only use DI prep for DI types
ChrisRackauckas Sep 5, 2025
00253d9
Unwrap DI.TwoArgWrapper using .prep property
ChrisRackauckas Sep 5, 2025
2e83e1a
Always grab the prep from DI wrapper as requested
ChrisRackauckas Sep 5, 2025
c72a362
Fix gradient config to use correct .cache field
ChrisRackauckas Sep 5, 2025
4727b3c
Use DI.prepare!_derivative for DerivativePrep types
ChrisRackauckas Sep 5, 2025
3065664
Use DI.prepare!_derivative with integrator.f and integrator.t
ChrisRackauckas Sep 5, 2025
f63ebee
Handle both in-place and out-of-place prepare!_derivative calls
ChrisRackauckas Sep 5, 2025
78615a3
Fix prepare!_derivative signature for in-place functions
ChrisRackauckas Sep 5, 2025
82b5ca0
Try out-of-place prepare!_derivative signature
ChrisRackauckas Sep 5, 2025
c5691c8
Remove gradient config resizing - let DI handle internally
ChrisRackauckas Sep 5, 2025
7838614
Restore gradient config resizing with in-place prepare!_derivative
ChrisRackauckas Sep 5, 2025
fb7bff5
Try 4-argument prepare!_derivative with TimeGradientWrapper
ChrisRackauckas Sep 5, 2025
f180b91
Fix: Use prepare!_gradient instead of prepare!_derivative
ChrisRackauckas Sep 5, 2025
d4d008a
Add proper type dispatches for DerivativePrep vs GradientPrep
ChrisRackauckas Sep 5, 2025
d2cf214
Add comprehensive gradient config type handling
ChrisRackauckas Sep 5, 2025
fa6b9a3
Fix argument order based on CI error analysis
ChrisRackauckas Sep 5, 2025
8eaab72
Fix 4th argument: use u instead of backend for prepare!_derivative
ChrisRackauckas Sep 5, 2025
027be63
Fix prepare!_derivative signature: add backend as 4th argument
ChrisRackauckas Sep 5, 2025
0670aa0
Use OrdinaryDiffEqDifferentiation resizing tools
ChrisRackauckas Sep 7, 2025
fd2cca3
chunk
ChrisRackauckas Sep 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ version = "1.15.0"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"

Expand All @@ -24,10 +27,12 @@ MultiScaleArraysSparseDiffToolsExt = "SparseDiffTools"

[compat]
DiffEqBase = "6.5"
DifferentiationInterface = "0.7.7"
FiniteDiff = "2.3"
ForwardDiff = "0.10"
OrdinaryDiffEq = "5.33, 6"
OrdinaryDiffEqCore = "1"
OrdinaryDiffEq = "6"
OrdinaryDiffEqCore = "1.30.0"
OrdinaryDiffEqDifferentiation = "1.16"
OrdinaryDiffEqRosenbrock = "1.17.0"
RecursiveArrayTools = "1,2,3"
SparseDiffTools = "1.6, 2"
Expand Down
3 changes: 3 additions & 0 deletions src/MultiScaleArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ abstract type AbstractMultiScaleArrayHead{B} <: AbstractMultiScaleArray{B} end

using DiffEqBase, Statistics, LinearAlgebra, FiniteDiff
import OrdinaryDiffEq, OrdinaryDiffEqCore, OrdinaryDiffEqRosenbrock, StochasticDiffEq, ForwardDiff
import OrdinaryDiffEqDifferentiation
import SciMLBase
import DifferentiationInterface as DI

Base.show(io::IO, x::AbstractMultiScaleArray) = invoke(show, Tuple{IO, Any}, io, x)
Base.show(io::IO, ::MIME"text/plain", x::AbstractMultiScaleArray) = show(io, x)
Expand Down
15 changes: 9 additions & 6 deletions src/diffeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ function add_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrator,
i = length(integrator.u)
cache.J = similar(cache.J, i, i)
cache.W = similar(cache.W, i, i)
add_node_jac_config!(cache, cache.jac_config, i, x)
add_node_grad_config!(cache, cache.grad_config, i, x)
OrdinaryDiffEqDifferentiation.resize_jac_config!(cache, integrator)
OrdinaryDiffEqDifferentiation.resize_grad_config!(cache, integrator)
nothing
end

Expand All @@ -97,8 +97,8 @@ function add_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrator,
i = length(integrator.u)
cache.J = similar(cache.J, i, i)
cache.W = similar(cache.W, i, i)
add_node_jac_config!(cache, cache.jac_config, i, x, node...)
add_node_grad_config!(cache, cache.grad_config, i, x, node...)
OrdinaryDiffEqDifferentiation.resize_jac_config!(cache, integrator)
OrdinaryDiffEqDifferentiation.resize_grad_config!(cache, integrator)
nothing
end

Expand All @@ -108,11 +108,12 @@ function remove_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrato
i = length(integrator.u)
cache.J = similar(cache.J, i, i)
cache.W = similar(cache.W, i, i)
remove_node_jac_config!(cache, cache.jac_config, i, node...)
remove_node_grad_config!(cache, cache.grad_config, i, node...)
OrdinaryDiffEqDifferentiation.resize_jac_config!(cache, integrator)
OrdinaryDiffEqDifferentiation.resize_grad_config!(cache, integrator)
nothing
end

# Specific implementation for FiniteDiff.JacobianCache (keeps backward compatibility)
function add_node_jac_config!(cache, config::FiniteDiff.JacobianCache, i, x)
#add_node!(cache.x1, fill!(similar(x, eltype(cache.x1)),0))
add_node!(config.fx, recursivecopy(x))
Expand All @@ -137,6 +138,8 @@ function remove_node_jac_config!(cache, config::FiniteDiff.JacobianCache, i, I..
nothing
end


# Specific implementation for ForwardDiff.DerivativeConfig (keeps backward compatibility)
function add_node_grad_config!(cache, grad_config::ForwardDiff.DerivativeConfig, i, x)
cache.grad_config = ForwardDiff.DerivativeConfig(cache.tf, cache.du1, cache.uf.t)
nothing
Expand Down
6 changes: 2 additions & 4 deletions test/dynamic_diffeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ test_embryo = deepcopy(embryo)
sol = solve(prob, Tsit5(), callback = growing_cb, tstops = tstop)
sol = solve(prob, Rosenbrock23(autodiff = false), tstops = tstop)
sol = solve(prob, Rosenbrock23(autodiff = false), callback = growing_cb, tstops = tstop)
sol = solve(prob, Rosenbrock23(), callback = growing_cb, tstops = tstop)

@test length(sol[end]) == 23
sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = growing_cb, tstops = tstop)

affect_del! = function (integrator)
remove_node!(integrator, 1, 1, 1)
Expand All @@ -78,7 +76,7 @@ sol = solve(prob, Tsit5(), callback = shrinking_cb, tstops = tstop)

sol = solve(prob, Rosenbrock23(autodiff = false), callback = shrinking_cb, tstops = tstop)

sol = solve(prob, Rosenbrock23(), callback = shrinking_cb, tstops = tstop)
sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = shrinking_cb, tstops = tstop)

@test length(sol[end]) == 17

Expand Down
4 changes: 2 additions & 2 deletions test/single_layer_diffeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ add_node!(pop, pop.nodes[1])

sol = solve(prob, Tsit5(), callback = growing_cb, tstops = tstop)

sol = solve(prob, Rosenbrock23(), callback = growing_cb, tstops = tstop)
sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = growing_cb, tstops = tstop)

@test length(sol[end]) == 13

Expand All @@ -62,7 +62,7 @@ prob = ODEProblem(f4, deepcopy(pop), (0.0, 1.0))
sol = solve(prob, Tsit5(), callback = shrinking_cb, tstops = tstop)

prob = ODEProblem(f4, deepcopy(pop), (0.0, 1.0))
sol = solve(prob, Rosenbrock23(), callback = shrinking_cb, tstops = tstop)
sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = shrinking_cb, tstops = tstop)
@test length(sol[end]) == 10

println("Do the SDE Part")
Expand Down
Loading