Skip to content

Commit 2627f2b

Browse files
Merge pull request #1078 from AayushSabharwal/as/reversediff
fix: fix usage of ReverseDiff in parameters
2 parents 13ac2da + 4e60a87 commit 2627f2b

File tree

5 files changed

+41
-1
lines changed

5 files changed

+41
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2727
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2828
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2929
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
30+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
3031
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
3132
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
3233
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
@@ -95,6 +96,7 @@ Reexport = "1.0"
9596
ReverseDiff = "1"
9697
SciMLBase = "2.28.0"
9798
SciMLOperators = "0.3"
99+
SciMLStructures = "1.5"
98100
Setfield = "1"
99101
SparseArrays = "1.9"
100102
Static = "1"
@@ -131,4 +133,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
131133
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
132134

133135
[targets]
134-
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
136+
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]

ext/DiffEqBaseReverseDiffExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ using DiffEqBase
44
import DiffEqBase: value
55
import ReverseDiff
66
import DiffEqBase.ArrayInterface
7+
import DiffEqBase.ForwardDiff
8+
9+
function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}}
10+
DiffEqBase.anyeltypedual(V, Val{counter})
11+
end
712

813
DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
914
function DiffEqBase.value(x::Type{
@@ -33,6 +38,7 @@ function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
3338
u0
3439
end
3540
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
41+
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual} = ReverseDiff.track(T.(u0))
3642
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)
3743

3844
# Support adaptive with non-tracked time

src/DiffEqBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficien
101101

102102
import SciMLBase: AbstractDiffEqLinearOperator # deprecation path
103103

104+
import SciMLStructures
105+
104106
import Tricks
105107

106108
using Reexport

src/forwarddiff.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,12 @@ DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}
352352
@inline promote_u0(::Nothing, p, t0) = nothing
353353

354354
@inline function promote_u0(u0, p, t0)
355+
if SciMLStructures.isscimlstructure(p)
356+
_p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]
357+
if _p != p
358+
return promote_u0(u0, _p, t0)
359+
end
360+
end
355361
Tu = eltype(u0)
356362
if Tu <: ForwardDiff.Dual
357363
return u0
@@ -373,6 +379,12 @@ DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}
373379
end
374380

375381
@inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0)
382+
if SciMLStructures.isscimlstructure(p)
383+
_p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]
384+
if _p != p
385+
return promote_u0(u0, _p, t0)
386+
end
387+
end
376388
Tu = real(eltype(u0))
377389
if Tu <: ForwardDiff.Dual
378390
return u0

test/forwarddiff_dual_detection.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DiffEqBase, ForwardDiff, Test, InteractiveUtils
2+
using ReverseDiff, SciMLStructures
23
using Plots
34

45
u0 = 2.0
@@ -348,3 +349,20 @@ foo = SciMLBase.build_solution(
348349
prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0])
349350
DiffEqBase.anyeltypedual((; x = foo))
350351
DiffEqBase.anyeltypedual((; x = foo, y = prob.f))
352+
353+
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(3))) == Any
354+
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(3)))) == Any
355+
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) == eltype(ones(ForwardDiff.Dual, 3))
356+
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) == eltype(ones(ForwardDiff.Dual, 3))
357+
358+
struct FakeParameterObject{T}
359+
tunables::T
360+
end
361+
362+
SciMLStructures.isscimlstructure(::FakeParameterObject) = true
363+
SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::FakeParameterObject) = f.tunables, x -> FakeParameterObject(x), true
364+
365+
@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedArray
366+
@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedReal
367+
@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedArray{<:ForwardDiff.Dual}
368+
@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedReal{<:ForwardDiff.Dual}

0 commit comments

Comments
 (0)