Skip to content

Commit 4946555

Browse files
committed
Type fixes
1 parent 1eed1fe commit 4946555

23 files changed

+355
-355
lines changed

src/pybammsolvers/idaklu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ casadi::Function generate_casadi_function(const std::string &data)
2828
namespace py = pybind11;
2929

3030
PYBIND11_MAKE_OPAQUE(std::vector<np_array>);
31-
PYBIND11_MAKE_OPAQUE(std::vector<np_array_realtype>);
31+
PYBIND11_MAKE_OPAQUE(std::vector<np_array_sunrealtype>);
3232
PYBIND11_MAKE_OPAQUE(std::vector<Solution>);
3333

3434
PYBIND11_MODULE(idaklu, m)
3535
{
3636
m.doc() = "sundials solvers"; // optional module docstring
3737

3838
py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");
39-
py::bind_vector<std::vector<np_array_realtype>>(m, "VectorRealtypeNdArray");
39+
py::bind_vector<std::vector<np_array_sunrealtype>>(m, "VectorsunrealtypeNdArray");
4040
py::bind_vector<std::vector<Solution>>(m, "VectorSolution");
4141

4242
py::class_<IDAKLUSolverGroup>(m, "IDAKLUSolverGroup")

src/pybammsolvers/idaklu_source/Expressions/Base/Expression.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class Expression {
2121
* @brief Evaluation operator (supplying data references)
2222
*/
2323
virtual void operator()(
24-
const std::vector<realtype*>& inputs,
25-
const std::vector<realtype*>& results) = 0;
24+
const std::vector<sunrealtype*>& inputs,
25+
const std::vector<sunrealtype*>& results) = 0;
2626

2727
/**
2828
* @brief The maximum number of elements returned by the k'th output

src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionSet.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,16 @@ class ExpressionSet
6969

7070
std::vector<int64_t> jac_times_cjmass_rowvals; // cppcheck-suppress unusedStructMember
7171
std::vector<int64_t> jac_times_cjmass_colptrs; // cppcheck-suppress unusedStructMember
72-
std::vector<realtype> inputs; // cppcheck-suppress unusedStructMember
72+
std::vector<sunrealtype> inputs; // cppcheck-suppress unusedStructMember
7373

7474
SetupOptions setup_opts;
7575

76-
virtual realtype *get_tmp_state_vector() = 0;
77-
virtual realtype *get_tmp_sparse_jacobian_data() = 0;
76+
virtual sunrealtype *get_tmp_state_vector() = 0;
77+
virtual sunrealtype *get_tmp_sparse_jacobian_data() = 0;
7878

7979
protected:
80-
std::vector<realtype> tmp_state_vector;
81-
std::vector<realtype> tmp_sparse_jacobian_data;
80+
std::vector<sunrealtype> tmp_state_vector;
81+
std::vector<sunrealtype> tmp_sparse_jacobian_data;
8282
};
8383

8484
#endif // PYBAMM_IDAKLU_EXPRESSION_SET_HPP

src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ const std::vector<expr_int>& CasadiFunction::get_col() {
6161
return m_cols;
6262
}
6363

64-
void CasadiFunction::operator()(const std::vector<realtype*>& inputs,
65-
const std::vector<realtype*>& results)
64+
void CasadiFunction::operator()(const std::vector<sunrealtype*>& inputs,
65+
const std::vector<sunrealtype*>& results)
6666
{
6767
DEBUG("CasadiFunction operator() with inputs and results: " << m_func.name());
6868

src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class CasadiFunction : public Expression
2222

2323
// Method overrides
2424
void operator()() override;
25-
void operator()(const std::vector<realtype*>& inputs,
26-
const std::vector<realtype*>& results) override;
25+
void operator()(const std::vector<sunrealtype*>& inputs,
26+
const std::vector<sunrealtype*>& results) override;
2727
expr_int out_shape(int k) override;
2828
expr_int nnz() override;
2929
expr_int nnz_out() override;
@@ -144,10 +144,10 @@ class CasadiFunctions : public ExpressionSet<CasadiFunction>
144144
std::vector<CasadiFunction> dvar_dy_fcns_casadi;
145145
std::vector<CasadiFunction> dvar_dp_fcns_casadi;
146146

147-
realtype* get_tmp_state_vector() override {
147+
sunrealtype* get_tmp_state_vector() override {
148148
return tmp_state_vector.data();
149149
}
150-
realtype* get_tmp_sparse_jacobian_data() override {
150+
sunrealtype* get_tmp_sparse_jacobian_data() override {
151151
return tmp_sparse_jacobian_data.data();
152152
}
153153
};

src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class IREEFunction : public Expression
2222

2323
// Method overrides
2424
void operator()() override;
25-
void operator()(const std::vector<realtype*>& inputs,
26-
const std::vector<realtype*>& results) override;
25+
void operator()(const std::vector<sunrealtype*>& inputs,
26+
const std::vector<sunrealtype*>& results) override;
2727
expr_int out_shape(int k) override;
2828
expr_int nnz() override;
2929
expr_int nnz_out() override;

src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ void IREEFunction::evaluate(int n_outputs) {
177177
// Copy results to output array
178178
for(size_t k=0; k<n_outputs; k++) {
179179
for(size_t j=0; j<result[k].size(); j++) {
180-
m_res[k][j] = static_cast<realtype>(result[k][j]);
180+
m_res[k][j] = static_cast<sunrealtype>(result[k][j]);
181181
}
182182
}
183183

@@ -213,8 +213,8 @@ const std::vector<expr_int>& IREEFunction::get_col() {
213213
return m_func.col;
214214
}
215215

216-
void IREEFunction::operator()(const std::vector<realtype*>& inputs,
217-
const std::vector<realtype*>& results)
216+
void IREEFunction::operator()(const std::vector<sunrealtype*>& inputs,
217+
const std::vector<sunrealtype*>& results)
218218
{
219219
DEBUG("IreeFunction operator() with inputs and results");
220220
// Set-up input arguments, provide result vector, then execute function

src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ class IREEFunctions : public ExpressionSet<IREEFunction>
130130
std::vector<IREEFunction> dvar_dy_fcns_iree;
131131
std::vector<IREEFunction> dvar_dp_fcns_iree;
132132

133-
realtype* get_tmp_state_vector() override {
133+
sunrealtype* get_tmp_state_vector() override {
134134
return tmp_state_vector.data();
135135
}
136-
realtype* get_tmp_sparse_jacobian_data() override {
136+
sunrealtype* get_tmp_sparse_jacobian_data() override {
137137
return tmp_sparse_jacobian_data.data();
138138
}
139139

src/pybammsolvers/idaklu_source/IDAKLUSolver.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ class IDAKLUSolver
2828
* @brief Abstract solver method that executes the solver
2929
*/
3030
virtual SolutionData solve(
31-
const std::vector<realtype> &t_eval,
32-
const std::vector<realtype> &t_interp,
33-
const realtype *y0,
34-
const realtype *yp0,
35-
const realtype *inputs,
31+
const std::vector<sunrealtype> &t_eval,
32+
const std::vector<sunrealtype> &t_interp,
33+
const sunrealtype *y0,
34+
const sunrealtype *yp0,
35+
const sunrealtype *inputs,
3636
bool save_adaptive_steps,
3737
bool save_interp_steps
3838
) = 0;

src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ std::vector<Solution> IDAKLUSolverGroup::solve(
1313
// If t_interp is empty, save all adaptive steps
1414
bool save_adaptive_steps = t_interp_np.size() == 0;
1515

16-
const realtype* t_eval_begin = t_eval_np.data();
17-
const realtype* t_eval_end = t_eval_begin + t_eval_np.size();
18-
const realtype* t_interp_begin = t_interp_np.data();
19-
const realtype* t_interp_end = t_interp_begin + t_interp_np.size();
16+
const sunrealtype* t_eval_begin = t_eval_np.data();
17+
const sunrealtype* t_eval_end = t_eval_begin + t_eval_np.size();
18+
const sunrealtype* t_interp_begin = t_interp_np.data();
19+
const sunrealtype* t_interp_end = t_interp_begin + t_interp_np.size();
2020

2121
// Process the time inputs
2222
// 1. Get the sorted and unique t_eval vector
@@ -97,9 +97,9 @@ std::vector<Solution> IDAKLUSolverGroup::solve(
9797
const std::size_t solves_per_thread = number_of_groups / m_solvers.size();
9898
const std::size_t remainder_solves = number_of_groups % m_solvers.size();
9999

100-
const realtype *y0 = y0_np.data();
101-
const realtype *yp0 = yp0_np.data();
102-
const realtype *inputs_data = inputs.data();
100+
const sunrealtype *y0 = y0_np.data();
101+
const sunrealtype *yp0 = yp0_np.data();
102+
const sunrealtype *inputs_data = inputs.data();
103103

104104
std::vector<SolutionData> results(number_of_groups);
105105

@@ -111,9 +111,9 @@ std::vector<Solution> IDAKLUSolverGroup::solve(
111111
try {
112112
for (int j = 0; j < solves_per_thread; j++) {
113113
const std::size_t index = i * solves_per_thread + j;
114-
const realtype *y = y0 + index * y0_np.shape(1);
115-
const realtype *yp = yp0 + index * yp0_np.shape(1);
116-
const realtype *input = inputs_data + index * inputs.shape(1);
114+
const sunrealtype *y = y0 + index * y0_np.shape(1);
115+
const sunrealtype *yp = yp0 + index * yp0_np.shape(1);
116+
const sunrealtype *input = inputs_data + index * inputs.shape(1);
117117
results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps);
118118
}
119119
} catch (std::exception &e) {
@@ -132,9 +132,9 @@ std::vector<Solution> IDAKLUSolverGroup::solve(
132132

133133
for (int i = 0; i < remainder_solves; i++) {
134134
const std::size_t index = number_of_groups - remainder_solves + i;
135-
const realtype *y = y0 + index * y0_np.shape(1);
136-
const realtype *yp = yp0 + index * yp0_np.shape(1);
137-
const realtype *input = inputs_data + index * inputs.shape(1);
135+
const sunrealtype *y = y0 + index * y0_np.shape(1);
136+
const sunrealtype *yp = yp0 + index * yp0_np.shape(1);
137+
const sunrealtype *input = inputs_data + index * inputs.shape(1);
138138
results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps);
139139
}
140140

0 commit comments

Comments
 (0)