diff --git a/Project.toml b/Project.toml index 9a61385..a104df2 100644 --- a/Project.toml +++ b/Project.toml @@ -5,14 +5,17 @@ version = "1.15.0" [deps] DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b" OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" @@ -24,10 +27,12 @@ MultiScaleArraysSparseDiffToolsExt = "SparseDiffTools" [compat] DiffEqBase = "6.5" +DifferentiationInterface = "0.7.7" FiniteDiff = "2.3" ForwardDiff = "0.10" -OrdinaryDiffEq = "5.33, 6" -OrdinaryDiffEqCore = "1" +OrdinaryDiffEq = "6" +OrdinaryDiffEqCore = "1.30.0" +OrdinaryDiffEqDifferentiation = "1.16" OrdinaryDiffEqRosenbrock = "1.17.0" RecursiveArrayTools = "1,2,3" SparseDiffTools = "1.6, 2" diff --git a/src/MultiScaleArrays.jl b/src/MultiScaleArrays.jl index 09456ef..8941739 100644 --- a/src/MultiScaleArrays.jl +++ b/src/MultiScaleArrays.jl @@ -173,6 +173,9 @@ abstract type AbstractMultiScaleArrayHead{B} <: AbstractMultiScaleArray{B} end using DiffEqBase, Statistics, LinearAlgebra, FiniteDiff import OrdinaryDiffEq, OrdinaryDiffEqCore, OrdinaryDiffEqRosenbrock, StochasticDiffEq, ForwardDiff +import OrdinaryDiffEqDifferentiation +import SciMLBase +import DifferentiationInterface as DI Base.show(io::IO, x::AbstractMultiScaleArray) = invoke(show, Tuple{IO, Any}, io, x) Base.show(io::IO, ::MIME"text/plain", x::AbstractMultiScaleArray) = show(io, x) diff --git a/src/diffeq.jl b/src/diffeq.jl index 5b82191..c77844a 100644 --- a/src/diffeq.jl +++ b/src/diffeq.jl @@ -86,8 +86,8 @@ function add_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrator, i = length(integrator.u) cache.J = similar(cache.J, i, i) cache.W = similar(cache.W, i, i) - add_node_jac_config!(cache, cache.jac_config, i, x) - add_node_grad_config!(cache, cache.grad_config, i, x) + OrdinaryDiffEqDifferentiation.resize_jac_config!(cache, integrator) + OrdinaryDiffEqDifferentiation.resize_grad_config!(cache, integrator) nothing end @@ -97,8 +97,8 @@ function add_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrator, i = length(integrator.u) cache.J = similar(cache.J, i, i) cache.W = similar(cache.W, i, i) - add_node_jac_config!(cache, cache.jac_config, i, x, node...) - add_node_grad_config!(cache, cache.grad_config, i, x, node...) + OrdinaryDiffEqDifferentiation.resize_jac_config!(cache, integrator) + OrdinaryDiffEqDifferentiation.resize_grad_config!(cache, integrator) nothing end @@ -108,11 +108,12 @@ function remove_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrato i = length(integrator.u) cache.J = similar(cache.J, i, i) cache.W = similar(cache.W, i, i) - remove_node_jac_config!(cache, cache.jac_config, i, node...) - remove_node_grad_config!(cache, cache.grad_config, i, node...) + OrdinaryDiffEqDifferentiation.resize_jac_config!(cache, integrator) + OrdinaryDiffEqDifferentiation.resize_grad_config!(cache, integrator) nothing end +# Specific implementation for FiniteDiff.JacobianCache (keeps backward compatibility) function add_node_jac_config!(cache, config::FiniteDiff.JacobianCache, i, x) #add_node!(cache.x1, fill!(similar(x, eltype(cache.x1)),0)) add_node!(config.fx, recursivecopy(x)) @@ -137,6 +138,8 @@ function remove_node_jac_config!(cache, config::FiniteDiff.JacobianCache, i, I.. nothing end + +# Specific implementation for ForwardDiff.DerivativeConfig (keeps backward compatibility) function add_node_grad_config!(cache, grad_config::ForwardDiff.DerivativeConfig, i, x) cache.grad_config = ForwardDiff.DerivativeConfig(cache.tf, cache.du1, cache.uf.t) nothing diff --git a/test/dynamic_diffeq.jl b/test/dynamic_diffeq.jl index a4a7fe6..c6954d5 100644 --- a/test/dynamic_diffeq.jl +++ b/test/dynamic_diffeq.jl @@ -64,9 +64,7 @@ test_embryo = deepcopy(embryo) sol = solve(prob, Tsit5(), callback = growing_cb, tstops = tstop) sol = solve(prob, Rosenbrock23(autodiff = false), tstops = tstop) sol = solve(prob, Rosenbrock23(autodiff = false), callback = growing_cb, tstops = tstop) -sol = solve(prob, Rosenbrock23(), callback = growing_cb, tstops = tstop) - -@test length(sol[end]) == 23 +sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = growing_cb, tstops = tstop) affect_del! = function (integrator) remove_node!(integrator, 1, 1, 1) @@ -78,7 +76,7 @@ sol = solve(prob, Tsit5(), callback = shrinking_cb, tstops = tstop) sol = solve(prob, Rosenbrock23(autodiff = false), callback = shrinking_cb, tstops = tstop) -sol = solve(prob, Rosenbrock23(), callback = shrinking_cb, tstops = tstop) +sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = shrinking_cb, tstops = tstop) @test length(sol[end]) == 17 diff --git a/test/single_layer_diffeq.jl b/test/single_layer_diffeq.jl index 2f08570..c2dd2f4 100644 --- a/test/single_layer_diffeq.jl +++ b/test/single_layer_diffeq.jl @@ -48,7 +48,7 @@ add_node!(pop, pop.nodes[1]) sol = solve(prob, Tsit5(), callback = growing_cb, tstops = tstop) -sol = solve(prob, Rosenbrock23(), callback = growing_cb, tstops = tstop) +sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = growing_cb, tstops = tstop) @test length(sol[end]) == 13 @@ -62,7 +62,7 @@ prob = ODEProblem(f4, deepcopy(pop), (0.0, 1.0)) sol = solve(prob, Tsit5(), callback = shrinking_cb, tstops = tstop) prob = ODEProblem(f4, deepcopy(pop), (0.0, 1.0)) -sol = solve(prob, Rosenbrock23(), callback = shrinking_cb, tstops = tstop) +sol = solve(prob, Rosenbrock23(chunk_size = 1), callback = shrinking_cb, tstops = tstop) @test length(sol[end]) == 10 println("Do the SDE Part")