Skip to content

Commit 75c02c3

Browse files
test: add tests for ReverseDiff dual detection and promotion
1 parent 565a4b9 commit 75c02c3

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

test/forwarddiff_dual_detection.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DiffEqBase, ForwardDiff, Test, InteractiveUtils
2+
using ReverseDiff, SciMLStructures
23
using Plots
34

45
u0 = 2.0
@@ -348,3 +349,20 @@ foo = SciMLBase.build_solution(
348349
prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0])
349350
DiffEqBase.anyeltypedual((; x = foo))
350351
DiffEqBase.anyeltypedual((; x = foo, y = prob.f))
352+
353+
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(3))) == Any
354+
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(3)))) == Any
355+
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) == eltype(ones(ForwardDiff.Dual, 3))
356+
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) == eltype(ones(ForwardDiff.Dual, 3))
357+
358+
struct Foo{T}
359+
tunables::T
360+
end
361+
362+
SciMLStructures.isscimlstructure(::Foo) = true
363+
SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::Foo) = f.tunables, x -> Foo(x), true
364+
365+
@test DiffEqBase.promote_u0(ones(3), Foo(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedArray
366+
@test DiffEqBase.promote_u0(1.0, Foo(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedReal
367+
@test DiffEqBase.promote_u0(ones(3), Foo(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedArray{<:ForwardDiff.Dual}
368+
@test DiffEqBase.promote_u0(1.0, Foo(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedReal{<:ForwardDiff.Dual}

0 commit comments

Comments
 (0)