Skip to content

Add options to terminate early when there's little progress #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using std::vector;

#include "Options.hpp"
#include "NoProgressGuard.hpp"
#include "Solution.hpp"

/**
Expand Down
16 changes: 14 additions & 2 deletions src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ SolutionData IDAKLUSolverOpenMP<ExprSet>::solve(

sunrealtype t_val = t0;
sunrealtype t_prev = t0;
sunrealtype dt;
int i_eval = 0;

sunrealtype t_interp_next;
Expand Down Expand Up @@ -448,6 +449,12 @@ SolutionData IDAKLUSolverOpenMP<ExprSet>::solve(
// Progress one step. This must be done before the while loop to ensure
// that we can run IDAGetDky at t0 for dky = 1
int retval = IDASolve(ida_mem, tf, &t_val, yy, yyp, IDA_ONE_STEP);
dt = t_val - t_prev;

// Optional method to fail the simulation if the solver is not making progress.
NoProgressGuard no_progression(solver_opts.num_steps_no_progress, solver_opts.t_no_progress);
no_progression.Initialize();
no_progression.AddDt(dt);

// Store consistent initialization
CheckErrors(IDAGetDky(ida_mem, t0, 0, yy));
Expand All @@ -466,7 +473,7 @@ SolutionData IDAKLUSolverOpenMP<ExprSet>::solve(
if (retval < 0) {
// failed
break;
} else if (t_prev == t_val) {
} else if (t_prev == t_val || no_progression.Violated()) {
// IDA sometimes returns an identical time point twice
// instead of erroring. Assign a retval and break
retval = IDA_ERR_FAIL;
Expand Down Expand Up @@ -518,16 +525,21 @@ SolutionData IDAKLUSolverOpenMP<ExprSet>::solve(
i_eval++;
t_eval_next = t_eval[i_eval];
CheckErrors(IDASetStopTime(ida_mem, t_eval_next));

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

// Reinitialize the solver to deal with the discontinuity at t = t_val.
ReinitializeIntegrator(t_val);
ConsistentInitialization(t_val, t_eval_next, IDA_YA_YDP_INIT);
// Reset the no-progress guard
no_progression.Initialize();
}

t_prev = t_val;

// Progress one step
retval = IDASolve(ida_mem, tf, &t_val, yy, yyp, IDA_ONE_STEP);

dt = t_val - t_prev;
no_progression.AddDt(dt);
}

int const length_of_final_sv_slice = save_outputs_only ? number_of_states : 0;
Expand Down
70 changes: 70 additions & 0 deletions src/pybammsolvers/idaklu_source/NoProgressGuard.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#ifndef PYBAMM_IDAKLU_NOPROGRESS_GUARD_HPP
#define PYBAMM_IDAKLU_NOPROGRESS_GUARD_HPP

#include "common.hpp"
#include <vector>
#include <numeric>
#include <algorithm>

/**
* @brief Utility for checking lack-of-progress over a fixed-size sliding window
*/
class NoProgressGuard {
public:
NoProgressGuard(size_t window_size, sunrealtype threshold_sec)
: window_size_(window_size), threshold_sec_(threshold_sec), idx_(0) {
if (!Disabled()) {
dt_window_.assign(window_size_, threshold_sec_);
}
}

inline bool Disabled() const {
return window_size_ == 0 || threshold_sec_ == SUN_RCONST(0.0);
}

// initialize with a full window of threshold values to avoid immediate triggering
inline void Initialize() {
if (Disabled()) {
return;
}
idx_ = 0;
dt_window_.assign(window_size_, threshold_sec_);
}

// insert a new dt into the circular buffer
inline void AddDt(sunrealtype dt) {
if (Disabled()) {
return;
}

dt_window_[idx_] = dt;
idx_ = (idx_ + 1) % window_size_;
}

// violation if the running sum across the window remains below the threshold
// early exit: as soon as we reach/exceed threshold, we are not violated
inline bool Violated() const {
if (Disabled()) {
return false;
}

sunrealtype sum_dt = SUN_RCONST(0.0);
for (const auto &dt : dt_window_) {
sum_dt += dt;
if (sum_dt >= threshold_sec_) {
return false;
}
}
return true;
}

private:
const size_t window_size_;
const sunrealtype threshold_sec_;
std::vector<sunrealtype> dt_window_;
size_t idx_;
};

#endif // PYBAMM_IDAKLU_NOPROGRESS_GUARD_HPP


13 changes: 12 additions & 1 deletion src/pybammsolvers/idaklu_source/Options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,15 @@ SolverOptions::SolverOptions(py::dict &py_opts)
linear_solution_scaling(py_opts["linear_solution_scaling"].cast<sunbooleantype>()),
epsilon_linear_tolerance(SUN_RCONST(py_opts["epsilon_linear_tolerance"].cast<double>())),
increment_factor(SUN_RCONST(py_opts["increment_factor"].cast<double>()))
{}
{
// Early termination. Key checks enable backward compatibility with previous versions
// of pybamm.
num_steps_no_progress = 0;
t_no_progress = SUN_RCONST(0.0);
if (py_opts.contains("num_steps_no_progress")) {
num_steps_no_progress = py_opts["num_steps_no_progress"].cast<size_t>();
}
if (py_opts.contains("t_no_progress")) {
t_no_progress = SUN_RCONST(py_opts["t_no_progress"].cast<sunrealtype>());
}
}
3 changes: 3 additions & 0 deletions src/pybammsolvers/idaklu_source/Options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ struct SolverOptions {
sunbooleantype linear_solution_scaling;
double epsilon_linear_tolerance;
double increment_factor;
// Early termination
size_t num_steps_no_progress;
sunrealtype t_no_progress;
explicit SolverOptions(py::dict &py_opts);
};

Expand Down
Loading