Skip to content

Add FunctionMap as default algorithm for DiscreteProblem #2829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lib/OrdinaryDiffEqFunctionMap/src/OrdinaryDiffEqFunctionMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ include("interpolants.jl")
include("functionmap_perform_step.jl")
include("fixed_timestep_perform_step.jl")

# Default algorithm for DiscreteProblem
function SciMLBase.__solve(prob::SciMLBase.DiscreteProblem, ::Nothing, args...;
kwargs...)
SciMLBase.__solve(prob, FunctionMap(), args...; kwargs...)
end

function SciMLBase.__init(prob::SciMLBase.DiscreteProblem, ::Nothing, args...;
kwargs...)
SciMLBase.__init(prob, FunctionMap(), args...; kwargs...)
end

export FunctionMap

end
91 changes: 91 additions & 0 deletions lib/OrdinaryDiffEqFunctionMap/test/discrete_problem_defaults.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import OrdinaryDiffEqFunctionMap
import OrdinaryDiffEqCore
using Test
using SciMLBase
using SciMLBase: solve, init, DiscreteProblem

const FunctionMap = OrdinaryDiffEqFunctionMap.FunctionMap

# Helper functions to check algorithm properties regardless of module context
is_functionmap(alg) = typeof(alg).name.name == :FunctionMap
function get_scale_by_time(alg)
# Access the type parameter directly since it's FunctionMap{scale_by_time}
T = typeof(alg)
# The parameter is stored as a type parameter
return T.parameters[1]
end

@testset "DiscreteProblem Default Algorithm" begin
# Test scalar DiscreteProblem
f(u, p, t) = 1.01 * u
prob_scalar = DiscreteProblem(f, 0.5, (0.0, 1.0))

@testset "Scalar DiscreteProblem" begin
# Test solve without explicit algorithm
sol = solve(prob_scalar)
@test is_functionmap(sol.alg)
@test get_scale_by_time(sol.alg) == false
@test length(sol.u) > 1

# Test init without explicit algorithm
integrator = init(prob_scalar)
@test is_functionmap(integrator.alg)
@test get_scale_by_time(integrator.alg) == false
end

# Test array DiscreteProblem
function f_array!(du, u, p, t)
du[1] = 1.01 * u[1]
du[2] = 0.99 * u[2]
end
prob_array = DiscreteProblem(f_array!, [0.5, 1.0], (0.0, 1.0))

@testset "Array DiscreteProblem" begin
# Test solve without explicit algorithm
sol = solve(prob_array)
@test is_functionmap(sol.alg)
@test get_scale_by_time(sol.alg) == false
@test length(sol.u) > 1

# Test init without explicit algorithm
integrator = init(prob_array)
@test is_functionmap(integrator.alg)
@test get_scale_by_time(integrator.alg) == false
end

# Test that explicit algorithm specification still works
@testset "Explicit FunctionMap specification" begin
sol1 = solve(prob_scalar, FunctionMap())
@test is_functionmap(sol1.alg)
@test get_scale_by_time(sol1.alg) == false

sol2 = solve(prob_scalar, FunctionMap(scale_by_time=true), dt=0.1)
@test is_functionmap(sol2.alg)
@test get_scale_by_time(sol2.alg) == true

integrator1 = init(prob_scalar, FunctionMap())
@test is_functionmap(integrator1.alg)
@test get_scale_by_time(integrator1.alg) == false

integrator2 = init(prob_scalar, FunctionMap(scale_by_time=true), dt=0.1)
@test is_functionmap(integrator2.alg)
@test get_scale_by_time(integrator2.alg) == true
end

# Test that the default behaves correctly with different problem types
@testset "DiscreteProblem with integer time" begin
henon_map!(u_next, u, p, t) = begin
u_next[1] = 1 + u[2] - 1.4 * u[1]^2
u_next[2] = 0.3 * u[1]
end

prob_int = DiscreteProblem(henon_map!, [0.5, 0.5], (0, 10))

sol = solve(prob_int)
@test is_functionmap(sol.alg)
@test eltype(sol.t) <: Integer

integrator = init(prob_int)
@test is_functionmap(integrator.alg)
end
end
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqFunctionMap/test/qa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Aqua

@testset "Aqua" begin
Aqua.test_all(
OrdinaryDiffEqFunctionMap
OrdinaryDiffEqFunctionMap;
piracies = false # Piracy is necessary for default algorithm dispatch
)
end
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqFunctionMap/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using SafeTestsets

@time @safetestset "JET Tests" include("jet.jl")
@time @safetestset "Aqua" include("qa.jl")
@time @safetestset "Aqua" include("qa.jl")
@time @safetestset "DiscreteProblem Defaults" include("discrete_problem_defaults.jl")
Loading