Skip to content

Add unstable_check for custom SSP method #2445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Trixi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ export PositivityPreservingLimiterZhangShu, EntropyBoundedLimiter
export trixi_include, examples_dir, get_examples, default_example,
default_example_unstructured, ode_default_options

export ode_norm, ode_unstable_check
export ode_norm, ode_unstable_check, unstable_check
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we export it, it becomes part of our public API (although it is not documented...). However, I do not think this is a good idea right now since it can be easily confused with ode_unstable_check. I see two options right now:

  • Do something specific for your subcell limiter ecosystem
  • Make it general so that it works nicely with the other parts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that's true.
I'm completely fine with the first option. But how would that look like in your opinion? Rename it to something like unstable_check_SSPRK, have an optional parameter unstable_check = nothing for solve() of the custom time integrators and add unstable_check = unstable_check_SSRK to every elixir, where it's needed?
Or really removing the parameter unstable_check from the solve function and hardcode it into the SSPRK method?


export convergence_test, jacobian_fd, jacobian_ad_forward, linear_structure

Expand Down
24 changes: 17 additions & 7 deletions src/time_integration/methods_SSP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,27 @@ struct SimpleSSPRK33{StageCallbacks} <: SimpleAlgorithmSSP
end

# This struct is needed to fake https://github.com/SciML/OrdinaryDiffEq.jl/blob/0c2048a502101647ac35faabd80da8a5645beac7/src/integrators/type.jl#L1
mutable struct SimpleIntegratorSSPOptions{Callback, TStops}
mutable struct SimpleIntegratorSSPOptions{Callback, UnstableCheck, TStops}
callback::Callback # callbacks; used in Trixi
unstable_check::UnstableCheck # unstable check function
adaptive::Bool # whether the algorithm is adaptive; ignored
dtmax::Float64 # ignored
maxiters::Int # maximal number of time steps
tstops::TStops # tstops from https://diffeq.sciml.ai/v6.8/basics/common_solver_opts/#Output-Control-1; ignored
end

function SimpleIntegratorSSPOptions(callback, tspan; maxiters = typemax(Int), kwargs...)
function SimpleIntegratorSSPOptions(callback, tspan; maxiters = typemax(Int),
unstable_check = unstable_check, kwargs...)
tstops_internal = BinaryHeap{eltype(tspan)}(FasterForward())
# We add last(tspan) to make sure that the time integration stops at the end time
push!(tstops_internal, last(tspan))
# We add 2 * last(tspan) because add_tstop!(integrator, t) is only called by DiffEqCallbacks.jl if tstops contains a time that is larger than t
# (https://github.com/SciML/DiffEqCallbacks.jl/blob/025dfe99029bd0f30a2e027582744528eb92cd24/src/iterative_and_periodic.jl#L92)
push!(tstops_internal, 2 * last(tspan))
SimpleIntegratorSSPOptions{typeof(callback), typeof(tstops_internal)}(callback,
false, Inf,
maxiters,
tstops_internal)
SimpleIntegratorSSPOptions{typeof(callback), typeof(unstable_check),
typeof(tstops_internal)}(callback, unstable_check,
false, Inf, maxiters,
tstops_internal)
end

# This struct is needed to fake https://github.com/SciML/OrdinaryDiffEq.jl/blob/0c2048a502101647ac35faabd80da8a5645beac7/src/integrators/type.jl#L77
Expand Down Expand Up @@ -132,7 +134,8 @@ The following structures and methods provide the infrastructure for SSP Runge-Ku
of type `SimpleAlgorithmSSP`.
"""
function solve(ode::ODEProblem, alg = SimpleSSPRK33()::SimpleAlgorithmSSP;
dt, callback::Union{CallbackSet, Nothing} = nothing, kwargs...)
dt, callback::Union{CallbackSet, Nothing} = nothing,
unstable_check = unstable_check, kwargs...)
u = copy(ode.u0)
du = similar(u)
r0 = similar(u)
Expand All @@ -142,6 +145,7 @@ function solve(ode::ODEProblem, alg = SimpleSSPRK33()::SimpleAlgorithmSSP;
integrator = SimpleIntegratorSSP(u, du, r0, t, tdir, dt, dt, iter, ode.p,
(prob = ode,), ode.f, alg,
SimpleIntegratorSSPOptions(callback, ode.tspan;
unstable_check = unstable_check,
kwargs...),
false, true, false)

Expand Down Expand Up @@ -170,6 +174,7 @@ function solve!(integrator::SimpleIntegratorSSP)
@unpack alg = integrator
t_end = last(prob.tspan)
callbacks = integrator.opts.callback
(; unstable_check) = integrator.opts

integrator.finalstep = false
@trixi_timeit timer() "main loop" while !integrator.finalstep
Expand Down Expand Up @@ -220,6 +225,11 @@ function solve!(integrator::SimpleIntegratorSSP)
end
end

if unstable_check(integrator.dt, integrator.u, integrator, integrator.t)
@warn "Instability detected. Aborting"
terminate!(integrator)
end

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
Expand Down
9 changes: 9 additions & 0 deletions src/time_integration/time_integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ function DiffEqBase.get_tstops_max(integrator::AbstractTimeIntegrator)
return maximum(get_tstops_array(integrator))
end

function unstable_check(dt, u_ode, integrator::AbstractTimeIntegrator, t)
if mpi_isparallel()
u_isfinite = MPI.Allreduce!(Ref(all(isfinite, u_ode)), Base.min, mpi_comm())[]
else
u_isfinite = all(isfinite, u_ode)
end
return !isfinite(dt) || !u_isfinite
end

function finalize_callbacks(integrator::AbstractTimeIntegrator)
callbacks = integrator.opts.callback

Expand Down
Loading