Skip to content

Commit 079b8d5

Browse files
Merge pull request #2829 from ChrisRackauckas-Claude/add-discreteproblem-functionmap-dispatches
Add FunctionMap as default algorithm for DiscreteProblem
2 parents 5ab7675 + 25d40ad commit 079b8d5

File tree

4 files changed

+106
-2
lines changed

4 files changed

+106
-2
lines changed

lib/OrdinaryDiffEqFunctionMap/src/OrdinaryDiffEqFunctionMap.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ include("interpolants.jl")
2525
include("functionmap_perform_step.jl")
2626
include("fixed_timestep_perform_step.jl")
2727

28+
# Default algorithm for DiscreteProblem
29+
function SciMLBase.__solve(prob::SciMLBase.DiscreteProblem, ::Nothing, args...;
30+
kwargs...)
31+
SciMLBase.__solve(prob, FunctionMap(), args...; kwargs...)
32+
end
33+
34+
function SciMLBase.__init(prob::SciMLBase.DiscreteProblem, ::Nothing, args...;
35+
kwargs...)
36+
SciMLBase.__init(prob, FunctionMap(), args...; kwargs...)
37+
end
38+
2839
export FunctionMap
2940

3041
end
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import OrdinaryDiffEqFunctionMap
2+
import OrdinaryDiffEqCore
3+
using Test
4+
using SciMLBase
5+
using SciMLBase: solve, init, DiscreteProblem
6+
7+
const FunctionMap = OrdinaryDiffEqFunctionMap.FunctionMap
8+
9+
# Helper functions to check algorithm properties regardless of module context
10+
is_functionmap(alg) = typeof(alg).name.name == :FunctionMap
11+
function get_scale_by_time(alg)
12+
# Access the type parameter directly since it's FunctionMap{scale_by_time}
13+
T = typeof(alg)
14+
# The parameter is stored as a type parameter
15+
return T.parameters[1]
16+
end
17+
18+
@testset "DiscreteProblem Default Algorithm" begin
19+
# Test scalar DiscreteProblem
20+
f(u, p, t) = 1.01 * u
21+
prob_scalar = DiscreteProblem(f, 0.5, (0.0, 1.0))
22+
23+
@testset "Scalar DiscreteProblem" begin
24+
# Test solve without explicit algorithm
25+
sol = solve(prob_scalar)
26+
@test is_functionmap(sol.alg)
27+
@test get_scale_by_time(sol.alg) == false
28+
@test length(sol.u) > 1
29+
30+
# Test init without explicit algorithm
31+
integrator = init(prob_scalar)
32+
@test is_functionmap(integrator.alg)
33+
@test get_scale_by_time(integrator.alg) == false
34+
end
35+
36+
# Test array DiscreteProblem
37+
function f_array!(du, u, p, t)
38+
du[1] = 1.01 * u[1]
39+
du[2] = 0.99 * u[2]
40+
end
41+
prob_array = DiscreteProblem(f_array!, [0.5, 1.0], (0.0, 1.0))
42+
43+
@testset "Array DiscreteProblem" begin
44+
# Test solve without explicit algorithm
45+
sol = solve(prob_array)
46+
@test is_functionmap(sol.alg)
47+
@test get_scale_by_time(sol.alg) == false
48+
@test length(sol.u) > 1
49+
50+
# Test init without explicit algorithm
51+
integrator = init(prob_array)
52+
@test is_functionmap(integrator.alg)
53+
@test get_scale_by_time(integrator.alg) == false
54+
end
55+
56+
# Test that explicit algorithm specification still works
57+
@testset "Explicit FunctionMap specification" begin
58+
sol1 = solve(prob_scalar, FunctionMap())
59+
@test is_functionmap(sol1.alg)
60+
@test get_scale_by_time(sol1.alg) == false
61+
62+
sol2 = solve(prob_scalar, FunctionMap(scale_by_time=true), dt=0.1)
63+
@test is_functionmap(sol2.alg)
64+
@test get_scale_by_time(sol2.alg) == true
65+
66+
integrator1 = init(prob_scalar, FunctionMap())
67+
@test is_functionmap(integrator1.alg)
68+
@test get_scale_by_time(integrator1.alg) == false
69+
70+
integrator2 = init(prob_scalar, FunctionMap(scale_by_time=true), dt=0.1)
71+
@test is_functionmap(integrator2.alg)
72+
@test get_scale_by_time(integrator2.alg) == true
73+
end
74+
75+
# Test that the default behaves correctly with different problem types
76+
@testset "DiscreteProblem with integer time" begin
77+
henon_map!(u_next, u, p, t) = begin
78+
u_next[1] = 1 + u[2] - 1.4 * u[1]^2
79+
u_next[2] = 0.3 * u[1]
80+
end
81+
82+
prob_int = DiscreteProblem(henon_map!, [0.5, 0.5], (0, 10))
83+
84+
sol = solve(prob_int)
85+
@test is_functionmap(sol.alg)
86+
@test eltype(sol.t) <: Integer
87+
88+
integrator = init(prob_int)
89+
@test is_functionmap(integrator.alg)
90+
end
91+
end

lib/OrdinaryDiffEqFunctionMap/test/qa.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Aqua
33

44
@testset "Aqua" begin
55
Aqua.test_all(
6-
OrdinaryDiffEqFunctionMap
6+
OrdinaryDiffEqFunctionMap;
7+
piracies = false # Piracy is necessary for default algorithm dispatch
78
)
89
end
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using SafeTestsets
22

33
@time @safetestset "JET Tests" include("jet.jl")
4-
@time @safetestset "Aqua" include("qa.jl")
4+
@time @safetestset "Aqua" include("qa.jl")
5+
@time @safetestset "DiscreteProblem Defaults" include("discrete_problem_defaults.jl")

0 commit comments

Comments
 (0)