Skip to content

Commit 4e60a87

Browse files
test: add tests for ReverseDiff dual detection and promotion
1 parent eb42bc2 commit 4e60a87

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
133133
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
134134

135135
[targets]
136-
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
136+
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]

test/forwarddiff_dual_detection.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DiffEqBase, ForwardDiff, Test, InteractiveUtils
2+
using ReverseDiff, SciMLStructures
23
using Plots
34

45
u0 = 2.0
@@ -348,3 +349,20 @@ foo = SciMLBase.build_solution(
348349
prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0])
349350
DiffEqBase.anyeltypedual((; x = foo))
350351
DiffEqBase.anyeltypedual((; x = foo, y = prob.f))
352+
353+
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(3))) == Any
354+
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(3)))) == Any
355+
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) == eltype(ones(ForwardDiff.Dual, 3))
356+
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) == eltype(ones(ForwardDiff.Dual, 3))
357+
358+
struct FakeParameterObject{T}
359+
tunables::T
360+
end
361+
362+
SciMLStructures.isscimlstructure(::FakeParameterObject) = true
363+
SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::FakeParameterObject) = f.tunables, x -> FakeParameterObject(x), true
364+
365+
@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedArray
366+
@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedReal
367+
@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedArray{<:ForwardDiff.Dual}
368+
@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedReal{<:ForwardDiff.Dual}

0 commit comments

Comments
 (0)