Skip to content

Fixing an error occurring when Heavisides are used and a model is initialised from a previous state #5075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from

Conversation

isaacbasil
Copy link
Contributor

Description

  • When models that include a Heaviside function of time are initialised from a previous state, they fail at the very beginning of the integration.
  • This is not a problem when solved from time t=0, because there is a feature in base_solver.py which removes an unwanted discontinuity at time 0.
  • However, when initialised from a previous state the integration does not start from t=0, and thus this discontinuity is not removed. The solver returns an error since only 1 time value is passed for the integration.
  • This happens for both IDAKLU and Casadi solvers.
  • I have made a small modification which removes the discontinuity at the beginning of the t_eval array, rather than explicitly at t=0.
  • This solves the problem, and I can't see any reason why it would create other problems.
  • I have provided a MWE below. You can verify that with the previous base_solver.py script, the solver will fail, and with the modified version it solves successfully.

Notes

  • I created another issue on the github page with a similar MWE - these are not related.
  • Sorry, I didn't make an issue about this first. Since it was just a single line of the code which changed I thought I would just try directly making a PR so we could discuss it here.

Fixes # (issue)

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #)

Important checks:

Please confirm the following before marking the PR as ready for review:

  • No style issues: nox -s pre-commit
  • All tests pass: nox -s tests
  • The documentation builds: nox -s doctests
  • Code is commented for hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

MWE

  • This solves Yang's model (which involves a Heaviside) and the DFN for a half-cell discharge followed by a relaxation.
  • Yang's model is built from the Basic DFN script
  • The relaxation period is initialised from the results of the discharge.
  • The previous base_solver.py script fails at the switch from discharge to relaxation.
#
# Basic Yang Half Cell Model
#
import pybamm
import numpy as np
import matplotlib.pyplot as plt
from pybamm.models.full_battery_models.lithium_ion.base_lithium_ion_model import BaseModel


class BasicYangModel(BaseModel):
    """
    Rebuilt from basic DFN
    Parameters
    ----------
    options : dict
        A dictionary of options to be passed to the model. For the half cell it should
        include which is the working electrode.
    name : str, optional
        The name of the model.

    """

    def __init__(self, I_cell, options=None, name="Yang Model"):
        options = {"working electrode": "positive"}
        super().__init__(options, name)
        pybamm.citations.register("Marquis2019")
        # `param` is a class containing all the relevant parameters and functions for
        # this model. These are purely symbolic at this stage, and will be set by the
        # `ParameterValues` class when the model is processed.

        ######################
        # Variables
        ######################
        # Variables that depend on time only are created without a domain
        Q = pybamm.Variable("Discharge capacity [A.h]")

        # Variables that vary spatially are created with a domain.
        c_e_s = pybamm.Variable(
            "Separator electrolyte concentration [mol.m-3]", domain="separator"
        )
        c_e_w = pybamm.Variable(
            "Positive electrolyte concentration [mol.m-3]", domain="positive electrode"
        )

        c_s_surf_w = pybamm.Variable("Positive particle surface concentration [mol.m-3]", domain="positive electrode")

        c_e = pybamm.concatenation(c_e_s, c_e_w)

        c_s_w = pybamm.Variable(
            "Positive particle concentration [mol.m-3]", domain="positive electrode",
        )
        phi_s_w = pybamm.Variable(
            "Positive electrode potential [V]", domain="positive electrode"
        )
        phi_e_s = pybamm.Variable(
            "Separator electrolyte potential [V]", domain="separator"
        )
        phi_e_w = pybamm.Variable(
            "Positive electrolyte potential [V]", domain="positive electrode"
        )
        phi_e = pybamm.concatenation(phi_e_s, phi_e_w)

        # Constant temperature
        T = self.param.T_init

        ######################
        # Other set-up
        ######################

        # Porosity and Transport_efficiency
        eps_s = pybamm.PrimaryBroadcast(
            pybamm.Parameter("Separator porosity"), "separator"
        )
        eps_w = pybamm.PrimaryBroadcast(
            pybamm.Parameter("Positive electrode porosity"), "positive electrode"
        )
        b_e_s = self.param.s.b_e
        b_e_w = self.param.p.b_e

        # Interfacial reactions
        j0_w = self.param.p.prim.j0(c_e_w, c_s_surf_w, T)
        U_w = self.param.p.prim.U
        ne_w = self.param.p.prim.ne

        # Particle diffusion parameters
        D_w = self.param.p.prim.D
        c_w_init = pybamm.surf(self.param.p.prim.c_init)

        # Electrode equation parameters
        eps_s_w = pybamm.Parameter("Positive electrode active material volume fraction")
        b_s_w = self.param.p.b_s
        sigma_w = self.param.p.sigma

        # Other parameters (for outputs)
        c_w_max = self.param.p.prim.c_max
        L_w = self.param.p.L

        i_cell = I_cell / self.param.p.main_param.A_cc

        eps = pybamm.concatenation(eps_s, eps_w)
        tor = pybamm.concatenation(eps_s**b_e_s, eps_w**b_e_w)

        F_RT = self.param.F / (self.param.R * T)
        RT_F = self.param.R * T / self.param.F
        sto_surf_w = c_s_surf_w / c_w_max
        j_w = (
            2
            * j0_w
            * pybamm.sinh(ne_w / 2 * F_RT * (phi_s_w - phi_e_w - U_w(sto_surf_w, T)))
        )
        R_w = self.param.p.prim.R_typ
        a_w = 3 * eps_s_w / R_w
        a_j_w = a_w * j_w
        a_j_s = pybamm.PrimaryBroadcast(0, "separator")
        a_j = pybamm.concatenation(a_j_s, a_j_w)

        l_s_p = R_w
        
        D_s = pybamm.surf(D_w(c_s_w, T))
        
        a_d_p = 1.0 # a fitting parameter from Yang's model, usually 1
        small_perturbation = 1e-200
        t_dif_p = l_s_p ** 2 / D_s
        t_cut_p = t_dif_p / (a_d_p ** 2)
        regime_p = (pybamm.t < t_cut_p)

        l_dif_p = (a_d_p * (D_s * (pybamm.t + small_perturbation))**0.5) * regime_p + l_s_p * (1 - regime_p)


        yang_corr_term = - (l_dif_p / (2 * l_s_p) - l_dif_p**2 / (6 * l_s_p**2)) * (j_w * l_s_p / (D_s * self.param.F)) 

        self.algebraic[c_s_surf_w] = c_s_surf_w - (c_s_w + yang_corr_term)

        ######################
        # State of Charge
        ######################
        I = self.param.current_with_time
        # The `rhs` dictionary contains differential equations, with the key being the
        # variable in the d/dt
        self.rhs[Q] = I / 3600
        # Initial conditions must be provided for the ODEs
        self.initial_conditions[Q] = pybamm.Scalar(0)

        ######################
        # Particles
        ######################
        # The div and grad operators will be converted to the appropriate matrix
        # multiplication at the discretisation stage

        self.rhs[c_s_w] = - a_j_w / (self.param.F * eps_s_w)

        # Boundary conditions must be provided for equations with spatial
        # derivatives
        self.boundary_conditions[c_s_w] = {
            "left": (pybamm.Scalar(0), "Neumann"),
            #"right": (-j_w / pybamm.surf(D_w(c_s_w, T)) / self.param.F, "Neumann"),
            "right": (pybamm.Scalar(0), "Neumann"),
        }
        self.initial_conditions[c_s_w] = c_w_init

        self.boundary_conditions[c_s_surf_w] = {
            "left": (pybamm.Scalar(0), "Neumann"),
            "right": (pybamm.Scalar(0), "Neumann"),
        }

        self.initial_conditions[c_s_surf_w] = c_w_init

        # Events specify points at which a solution should terminate
        self.events += [
            pybamm.Event(
                "Minimum positive particle surface concentration",
                pybamm.min(sto_surf_w) - 0.01,
            ),
            pybamm.Event(
                "Maximum positive particle surface concentration",
                (1 - 0.01) - pybamm.max(sto_surf_w),
            ),
        ]

        ######################
        # Current in the solid
        ######################
        sigma_eff_w = sigma_w(T) * eps_s_w**b_s_w
        i_s_w = -sigma_eff_w * pybamm.grad(phi_s_w)
        self.boundary_conditions[phi_s_w] = {
            "left": (pybamm.Scalar(0), "Neumann"),
            "right": (
                i_cell / pybamm.boundary_value(-sigma_eff_w, "right"),
                "Neumann",
            ),
        }
        # multiply by Lx**2 to improve conditioning
        self.algebraic[phi_s_w] = self.param.L_x**2 * (pybamm.div(i_s_w) + a_j_w)

        self.initial_conditions[phi_s_w] = self.param.p.prim.U_init

        ######################
        # Electrolyte concentration
        ######################
        N_e = -tor * self.param.D_e(c_e, T) * pybamm.grad(c_e)
        self.rhs[c_e] = (1 / eps) * (
            -pybamm.div(N_e) + (1 - self.param.t_plus(c_e, T)) * a_j / self.param.F
        )
        dce_dx = (
            -(1 - self.param.t_plus(c_e, T))
            * i_cell
            / (tor * self.param.F * self.param.D_e(c_e, T))
        )

        self.boundary_conditions[c_e] = {
            "left": (pybamm.boundary_value(dce_dx, "left"), "Neumann"),
            "right": (pybamm.Scalar(0), "Neumann"),
        }

        self.initial_conditions[c_e] = self.param.c_e_init
        self.events.append(
            pybamm.Event(
                "Zero electrolyte concentration cut-off", pybamm.min(c_e) - 0.002
            )
        )

        ######################
        # Current in the electrolyte
        ######################
        i_e = (self.param.kappa_e(c_e, T) * tor) * (
            self.param.chiRT_over_Fc(c_e, T) * pybamm.grad(c_e) - pybamm.grad(phi_e)
        )
        # multiply by Lx**2 to improve conditioning
        self.algebraic[phi_e] = self.param.L_x**2 * (pybamm.div(i_e) - a_j)

        # reference potential
        L_Li = self.param.n.L
        sigma_Li = self.param.n.sigma
        j_Li = self.param.j0_Li_metal(pybamm.boundary_value(c_e, "left"), c_w_max, T)
        eta_Li = 2 * RT_F * pybamm.arcsinh(i_cell / (2 * j_Li))

        phi_s_cn = 0
        delta_phi = eta_Li
        delta_phis_Li = L_Li * i_cell / sigma_Li(T)
        ref_potential = phi_s_cn - delta_phis_Li - delta_phi

        self.boundary_conditions[phi_e] = {
            "left": (ref_potential, "Dirichlet"),
            "right": (pybamm.Scalar(0), "Neumann"),
        }

        self.initial_conditions[phi_e] = -self.param.n.prim.U_init

        ######################
        # (Some) variables
        ######################
        vdrop_cell = pybamm.boundary_value(phi_s_w, "right") - ref_potential
        vdrop_Li = -eta_Li - delta_phis_Li
        voltage = vdrop_cell + vdrop_Li
        num_cells = pybamm.Parameter(
            "Number of cells connected in series to make a battery"
        )

        c_e_total = pybamm.x_average(eps * c_e)
        c_s_surf_w_av = pybamm.x_average(c_s_surf_w)

        c_s_rav = pybamm.r_average(c_s_w)
        c_s_vol_av = pybamm.x_average(eps_s_w * c_s_rav)

        # Cut-off voltage
        self.events.append(
            pybamm.Event("Minimum voltage [V]", voltage - self.param.voltage_low_cut)
        )
        self.events.append(
            pybamm.Event("Maximum voltage [V]", self.param.voltage_high_cut - voltage)
        )

        # Cut-off open-circuit voltage (for event switch with casadi 'fast with events'
        # mode)
        tol = 0.1
        self.events.append(
            pybamm.Event(
                "Minimum voltage switch",
                voltage - (self.param.voltage_low_cut - tol),
                pybamm.EventType.SWITCH,
            )
        )
        self.events.append(
            pybamm.Event(
                "Maximum voltage switch",
                voltage - (self.param.voltage_high_cut + tol),
                pybamm.EventType.SWITCH,
            )
        )

        self.variables = {
            "Time [s]": pybamm.t,
            "Discharge capacity [A.h]": Q,
            "Positive particle surface concentration [mol.m-3]": c_s_surf_w,
            "X-averaged positive particle surface concentration "
            "[mol.m-3]": c_s_surf_w_av,
            "Positive particle concentration [mol.m-3]": c_s_w,
            "Total lithium in positive electrode [mol]": c_s_vol_av
            * L_w
            * self.param.A_cc,
            "Electrolyte concentration [mol.m-3]": c_e,
            "Separator electrolyte concentration [mol.m-3]": c_e_s,
            "Positive electrolyte concentration [mol.m-3]": c_e_w,
            "Total lithium in electrolyte [mol]": c_e_total
            * self.param.L_x
            * self.param.A_cc,
            "Current [A]": I,
            "Current variable [A]": I,  # for compatibility with pybamm.Experiment
            "Current density [A.m-2]": i_cell,
            "Positive electrode potential [V]": phi_s_w,
            "Positive electrode open-circuit potential [V]": U_w(sto_surf_w, T),
            "Electrolyte potential [V]": phi_e,
            "Separator electrolyte potential [V]": phi_e_s,
            "Positive electrolyte potential [V]": phi_e_w,
            "Voltage drop in the cell [V]": vdrop_cell,
            "Negative electrode exchange current density [A.m-2]": j_Li,
            "Negative electrode reaction overpotential [V]": eta_Li,
            "Negative electrode potential drop [V]": delta_phis_Li,
            "Voltage [V]": voltage,
            "Battery voltage [V]": voltage * num_cells,
            "Instantaneous power [W.m-2]": i_cell * voltage,
        }



parameter_values = pybamm.ParameterValues("Xu2019")

I_cell = parameter_values["Current function [A]"]  # current used in discharge

# define discharge and relaxation epochs
epochs = {"discharge": {"t_eval": [0, 4000], "Current function [A]": I_cell}, "relax": {"t_eval": [4000, 6000], "Current function [A]": 0.0}}

previous_solutions = {"Yang": None, "DFN": None} # initialise previous solution dict 

for key, epoch in epochs.items():
    
    epoch["Models"] = {}
    epoch["Sims"] = {}

    # ==== define models =======
    epoch["Models"]["Yang"] = BasicYangModel(I_cell=epoch["Current function [A]"])
    
    parameter_values["Current function [A]"] = epoch["Current function [A]"]
    epoch["Models"]["DFN"] = pybamm.lithium_ion.DFN(options={"working electrode": "positive"})
    #===========================

    # ====solve for each epoch =====
    for model_key, model in epoch["Models"].items():
        if previous_solutions[model_key] is not None:
            model.set_initial_conditions_from(previous_solutions[model_key])

        geom = model.default_geometry

        parameter_values["Current function [A]"] = epoch["Current function [A]"]
    
        parameter_values.process_model(model)
        parameter_values.process_geometry(geom)
        
        mesh = pybamm.Mesh(geom, model.default_submesh_types, model.default_var_pts)
        disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
        disc.process_model(model)
        
        solver = pybamm.CasadiSolver()
        sim = solver.solve(model, t_eval=epoch["t_eval"])
        epoch["Sims"][model_key] = sim

        previous_solutions[model_key] = sim
    # ===========================



# ===========concatenate solutions to plot ===============
V_data = {}
t_data = {}

for model_key, value in epochs["discharge"]["Models"].items():
    for epoch_key, epoch in epochs.items():
        if epoch_key == 'discharge':
            V_data[model_key] = epoch["Sims"][model_key]["Voltage [V]"].entries
            t_data[model_key] = epoch["Sims"][model_key]["Time [s]"].entries

        else:
            V_data[model_key] = np.concatenate((V_data[model_key], epoch["Sims"][model_key]["Voltage [V]"].entries))
            t_data[model_key] = np.concatenate((t_data[model_key], epoch["Sims"][model_key]["Time [s]"].entries))

    plt.plot(t_data[model_key], V_data[model_key], label=model_key)

plt.legend()
plt.xlabel("Time [s]")
plt.ylabel("Voltage [V]")
plt.show()

…at it works when a model is initialised from a previous state
@isaacbasil isaacbasil requested a review from a team as a code owner June 25, 2025 08:17
Copy link

codecov bot commented Jun 25, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.12%. Comparing base (1ce4ef2) to head (0c6c292).
Report is 1 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #5075   +/-   ##
========================================
  Coverage    99.12%   99.12%           
========================================
  Files          304      305    +1     
  Lines        23576    23608   +32     
========================================
+ Hits         23369    23401   +32     
  Misses         207      207           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@kratman kratman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@isaacbasil Can you also add a test to catch this issue and then add an entry in the bugs section of the changelog?

@isaacbasil
Copy link
Contributor Author

@kratman just added this, does it look ok to you?

Comment on lines +437 to +451
def test_discontinuity_removed_at_nonzero_initial_time(self):
# Test that discontinuity caused by Heaviside(t) is removed when solver called with non-zero initial time
model = pybamm.BaseModel()
u = pybamm.Variable("u")
v = pybamm.Variable("v")
model.rhs = {v: -1 * (pybamm.t < 1)}
model.algebraic = {u: v - 1 - u}
model.initial_conditions = {v: 1, u: 0}
disc = pybamm.Discretisation()
disc.process_model(model)
solver = pybamm.IDAKLUSolver()
sol1 = solver.solve(model, t_eval=[0, 1])

model.set_initial_conditions_from(sol1)
sol2 = solver.solve(model, t_eval=[1, 2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests should assert something at the end.

Options:

  • Assert that there are no errors or warnings raised by the code (if the tested code is just catching an issue)
  • Check that a value matches what you would expect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry it's the first time I've made one. I just added an assert statement to assert that there are no errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test is quite what we want to cover this issue. The problem is that it just runs a simulation and confirms nothing crashes, but it does not actually check for the error you are fixing. For this to even test anything, incorrect discontinuities would have to force the solver to return None.

A better approach is to probably extract the following code to a function:

        # make sure they are increasing in time
        discontinuities = sorted(discontinuities)

        # remove any identical discontinuities
        discontinuities = [
            v
            for i, v in enumerate(discontinuities)
            if (
                i == len(discontinuities) - 1
                or discontinuities[i] < discontinuities[i + 1]
            )
            and v > t_eval[0]
        ]

        # remove any discontinuities after end of t_eval
        discontinuities = [v for v in discontinuities if v < t_eval[-1]]

This does not really depend on the model or inputs, so you can pass in an array of discontinuities and confirm the case for t_eval[0] works as expected. What you want to do is write the test, get it working, then temporarily undo your change and confirm that the test starts to fail.

A good rule of thumb for testing is that if the section of the code you are testing is difficult to test, then you probably need to make a function that is easy to test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I understand. I'm not sure how I would test something meaningful with the function extracted like that? In other words, I don't know how I would make the test interact with the base_solver.py file.

Another idea I had was to do something like this, which tests that the _get_discontinuity_start_end_indices function correctly removes the discontinuity:

model = pybamm.BaseModel()
u = pybamm.Variable("u")
v = pybamm.Variable("v")
model.rhs = {v: -1 * (pybamm.t < 1)}
model.algebraic = {u: v - 1 - u}
model.initial_conditions = {v: 1, u: 0}
disc = pybamm.Discretisation()
disc.process_model(model)
solver = pybamm.BaseSolver()
t_eval = [1, 2]

solver.set_up(model, inputs=None, t_eval=t_eval, ics_only=False)

start_indices, end_indices, t_eval = solver._get_discontinuity_start_end_indices(model, inputs=None, t_eval=t_eval)

assert len(t_eval) == 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kratman would you be able to give me another pointer so that I can finish this off?

sol1 = solver.solve(model, t_eval=[0, 1])

model.set_initial_conditions_from(sol1)
sol2 = solver.solve(model, t_eval=[1, 2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: The pre-commit failures are from here, since you don't use sol2 for anything, it is a violation of the style/formatting guidelines. Fixing the test will also fix the pre-commit failure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants