Skip to content

return nothing & helper functions for custom time integrators #2517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 10 additions & 31 deletions src/time_integration/methods_2N.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,7 @@ function init(ode::ODEProblem, alg::SimpleAlgorithm2N;
SimpleIntegratorOptions(callback, ode.tspan;
kwargs...), false)

# initialize callbacks
if callback isa CallbackSet
foreach(callback.continuous_callbacks) do cb
throw(ArgumentError("Continuous callbacks are unsupported with the 2N storage time integration methods."))
end
foreach(callback.discrete_callbacks) do cb
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
end
initialize_callbacks!(callback, integrator)

return integrator
end
Expand All @@ -145,12 +137,7 @@ function step!(integrator::SimpleIntegrator2N)
error("time step size `dt` is NaN")
end

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
limit_dt!(integrator, t_end)

# one time step
integrator.u_tmp .= 0
Expand All @@ -171,23 +158,11 @@ function step!(integrator::SimpleIntegrator2N)
integrator.iter += 1
integrator.t += integrator.dt

@trixi_timeit timer() "Step-Callbacks" begin
# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
end
end
@trixi_timeit timer() "Step-Callbacks" handle_callbacks!(callbacks, integrator)

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
check_max_iter!(integrator)

return nothing
end

# get a cache where the RHS can be stored
Expand All @@ -200,12 +175,16 @@ u_modified!(integrator::SimpleIntegrator2N, ::Bool) = false
function terminate!(integrator::SimpleIntegrator2N)
integrator.finalstep = true
empty!(integrator.opts.tstops)

return nothing
end

# used for AMR
function Base.resize!(integrator::SimpleIntegrator2N, new_size)
resize!(integrator.u, new_size)
resize!(integrator.du, new_size)
resize!(integrator.u_tmp, new_size)

return nothing
end
end # @muladd
43 changes: 11 additions & 32 deletions src/time_integration/methods_3Sstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,7 @@ function init(ode::ODEProblem, alg::SimpleAlgorithm3Sstar;
ode.tspan;
kwargs...), false)

# initialize callbacks
if callback isa CallbackSet
foreach(callback.continuous_callbacks) do cb
throw(ArgumentError("Continuous callbacks are unsupported with the 3 star time integration methods."))
end
foreach(callback.discrete_callbacks) do cb
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
end
initialize_callbacks!(callback, integrator)

return integrator
end
Expand All @@ -187,12 +179,7 @@ function step!(integrator::SimpleIntegrator3Sstar)
error("time step size `dt` is NaN")
end

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
limit_dt!(integrator, t_end)

# one time step
integrator.u_tmp1 .= zero(eltype(integrator.u_tmp1))
Expand All @@ -219,28 +206,16 @@ function step!(integrator::SimpleIntegrator3Sstar)
integrator.iter += 1
integrator.t += integrator.dt

@trixi_timeit timer() "Step-Callbacks" begin
# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
end
end
@trixi_timeit timer() "Step-Callbacks" handle_callbacks!(callbacks, integrator)

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
check_max_iter!(integrator)

return nothing
end

# get a cache where the RHS can be stored
function get_tmp_cache(integrator::SimpleIntegrator3Sstar)
(integrator.u_tmp1, integrator.u_tmp2)
return (integrator.u_tmp1, integrator.u_tmp2)
end

# some algorithms from DiffEq like FSAL-ones need to be informed when a callback has modified u
Expand All @@ -250,6 +225,8 @@ u_modified!(integrator::SimpleIntegrator3Sstar, ::Bool) = false
function terminate!(integrator::SimpleIntegrator3Sstar)
integrator.finalstep = true
empty!(integrator.opts.tstops)

return nothing
end

# used for AMR
Expand All @@ -258,5 +235,7 @@ function Base.resize!(integrator::SimpleIntegrator3Sstar, new_size)
resize!(integrator.du, new_size)
resize!(integrator.u_tmp1, new_size)
resize!(integrator.u_tmp2, new_size)

return nothing
end
end # @muladd
45 changes: 14 additions & 31 deletions src/time_integration/methods_SSP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,10 @@ function init(ode::ODEProblem, alg::SimpleAlgorithmSSP;
resize!(integrator.p, integrator.p.solver.volume_integral,
nelements(integrator.p.solver, integrator.p.cache))

# initialize callbacks
if callback isa CallbackSet
foreach(callback.continuous_callbacks) do cb
throw(ArgumentError("Continuous callbacks are unsupported with the SSP time integration methods."))
end
foreach(callback.discrete_callbacks) do cb
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
end
# Standard callbacks
initialize_callbacks!(callback, integrator)

# Addition for `SimpleAlgorithmSSP` which may have stage callbacks
for stage_callback in alg.stage_callbacks
init_callback(stage_callback, integrator.p)
end
Expand Down Expand Up @@ -187,12 +181,7 @@ function step!(integrator::SimpleIntegratorSSP)

modify_dt_for_tstops!(integrator)

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
limit_dt!(integrator, t_end)

@. integrator.u_tmp = integrator.u
for stage in eachindex(alg.c)
Expand All @@ -215,23 +204,11 @@ function step!(integrator::SimpleIntegratorSSP)
integrator.iter += 1
integrator.t += integrator.dt

@trixi_timeit timer() "Step-Callbacks" begin
# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
end
end
@trixi_timeit timer() "Step-Callbacks" handle_callbacks!(callbacks, integrator)

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
check_max_iter!(integrator)

return nothing
end

# get a cache where the RHS can be stored
Expand All @@ -243,6 +220,8 @@ u_modified!(integrator::SimpleIntegratorSSP, ::Bool) = false
# stop the time integration
function terminate!(integrator::SimpleIntegratorSSP)
integrator.finalstep = true

return nothing
end

"""
Expand All @@ -267,6 +246,8 @@ function modify_dt_for_tstops!(integrator::SimpleIntegratorSSP)
min(abs(integrator.dtcache), abs(tdir_tstop - tdir_t)) # step! to the end
end
end

return nothing
end

# used for AMR
Expand All @@ -279,5 +260,7 @@ function Base.resize!(integrator::SimpleIntegratorSSP, new_size)
# new_size = n_variables * n_nodes^n_dims * n_elements
n_elements = nelements(integrator.p.solver, integrator.p.cache)
resize!(integrator.p, integrator.p.solver.volume_integral, n_elements)

return nothing
end
end # @muladd
37 changes: 6 additions & 31 deletions src/time_integration/paired_explicit_runge_kutta/methods_PERK2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,7 @@ function init(ode::ODEProblem, alg::PairedExplicitRK2;
false, true, false,
k1)

# initialize callbacks
if callback isa CallbackSet
for cb in callback.continuous_callbacks
throw(ArgumentError("Continuous callbacks are unsupported with paired explicit Runge-Kutta methods."))
end
for cb in callback.discrete_callbacks
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
end
initialize_callbacks!(callback, integrator)

return integrator
end
Expand All @@ -260,12 +252,7 @@ function step!(integrator::PairedExplicitRK2Integrator)

modify_dt_for_tstops!(integrator)

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
limit_dt!(integrator, t_end)

@trixi_timeit timer() "Paired Explicit Runge-Kutta ODE integration step" begin
# First and second stage are identical across all single/standalone PERK methods
Expand All @@ -287,22 +274,10 @@ function step!(integrator::PairedExplicitRK2Integrator)
integrator.iter += 1
integrator.t += integrator.dt

@trixi_timeit timer() "Step-Callbacks" begin
# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
end
end
@trixi_timeit timer() "Step-Callbacks" handle_callbacks!(callbacks, integrator)

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
check_max_iter!(integrator)

return nothing
end
end # @muladd
39 changes: 8 additions & 31 deletions src/time_integration/paired_explicit_runge_kutta/methods_PERK3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,7 @@ function init(ode::ODEProblem, alg::PairedExplicitRK3;
false, true, false,
k1, kS1)

# initialize callbacks
if callback isa CallbackSet
for cb in callback.continuous_callbacks
throw(ArgumentError("Continuous callbacks are unsupported with paired explicit Runge-Kutta methods."))
end
for cb in callback.discrete_callbacks
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
end
initialize_callbacks!(callback, integrator)

return integrator
end
Expand All @@ -256,12 +248,7 @@ function step!(integrator::PairedExplicitRK3Integrator)

modify_dt_for_tstops!(integrator)

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
limit_dt!(integrator, t_end)

@trixi_timeit timer() "Paired Explicit Runge-Kutta ODE integration step" begin
# First and second stage are identical across all single/standalone PERK methods
Expand Down Expand Up @@ -292,23 +279,11 @@ function step!(integrator::PairedExplicitRK3Integrator)
integrator.iter += 1
integrator.t += integrator.dt

@trixi_timeit timer() "Step-Callbacks" begin
# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
end
end
@trixi_timeit timer() "Step-Callbacks" handle_callbacks!(callbacks, integrator)

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
check_max_iter!(integrator)

return nothing
end

function Base.resize!(integrator::PairedExplicitRK3Integrator, new_size)
Expand All @@ -318,5 +293,7 @@ function Base.resize!(integrator::PairedExplicitRK3Integrator, new_size)

resize!(integrator.k1, new_size)
resize!(integrator.kS1, new_size)

return nothing
end
end # @muladd
Loading
Loading