1
1
#include " IdakluJax.hpp"
2
-
3
- #include < pybind11/functional.h>
4
2
#include < pybind11/numpy.h>
5
- #include < pybind11/pybind11.h>
6
3
#include < pybind11/stl.h>
7
4
#include < pybind11/stl_bind.h>
8
-
9
- #include < vector>
10
- #include < iostream>
11
5
#include < functional>
12
6
13
7
// Initialise static variable
@@ -18,7 +12,7 @@ std::map<std::int64_t, IdakluJax*> idaklu_jax_instances;
18
12
19
13
// Create a new IdakluJax object, assign identifier, add to the objects list and return as pointer
20
14
IdakluJax *create_idaklu_jax () {
21
- IdakluJax *p = new IdakluJax ();
15
+ auto *p = new IdakluJax ();
22
16
idaklu_jax_instances[p->get_index ()] = p;
23
17
return p;
24
18
}
@@ -32,21 +26,21 @@ IdakluJax::~IdakluJax() {
32
26
}
33
27
34
28
void IdakluJax::register_callback_eval (CallbackEval h) {
35
- callback_eval = h ;
29
+ callback_eval = std::move (h) ;
36
30
}
37
31
38
32
void IdakluJax::register_callback_jvp (CallbackJvp h) {
39
- callback_jvp = h ;
33
+ callback_jvp = std::move (h) ;
40
34
}
41
35
42
36
void IdakluJax::register_callback_vjp (CallbackVjp h) {
43
- callback_vjp = h ;
37
+ callback_vjp = std::move (h) ;
44
38
}
45
39
46
40
void IdakluJax::register_callbacks (CallbackEval h_eval, CallbackJvp h_jvp, CallbackVjp h_vjp) {
47
- register_callback_eval (h_eval);
48
- register_callback_jvp (h_jvp);
49
- register_callback_vjp (h_vjp);
41
+ register_callback_eval (std::move ( h_eval) );
42
+ register_callback_jvp (std::move ( h_jvp) );
43
+ register_callback_vjp (std::move ( h_vjp) );
50
44
}
51
45
52
46
void IdakluJax::cpu_idaklu_eval (void *out_tuple, const void **in) {
@@ -55,10 +49,11 @@ void IdakluJax::cpu_idaklu_eval(void *out_tuple, const void **in) {
55
49
const std::int64_t n_t = *reinterpret_cast <const std::int64_t *>(in[k++]);
56
50
const std::int64_t n_vars = *reinterpret_cast <const std::int64_t *>(in[k++]);
57
51
const std::int64_t n_inputs = *reinterpret_cast <const std::int64_t *>(in[k++]);
58
- const realtype *t = reinterpret_cast <const realtype *>(in[k++]);
59
- realtype *inputs = new realtype (n_inputs);
60
- for (int i = 0 ; i < n_inputs; i++)
52
+ const auto *t = reinterpret_cast <const realtype *>(in[k++]);
53
+ auto *inputs = new realtype (n_inputs);
54
+ for (int i = 0 ; i < n_inputs; i++) {
61
55
inputs[i] = reinterpret_cast <const realtype *>(in[k++])[0 ];
56
+ }
62
57
void *out = reinterpret_cast <realtype *>(out_tuple);
63
58
64
59
// Log
@@ -75,16 +70,16 @@ void IdakluJax::cpu_idaklu_eval(void *out_tuple, const void **in) {
75
70
PyGILState_STATE state = PyGILState_Ensure ();
76
71
77
72
// Convert time vector to an np_array
78
- py::capsule t_capsule (t, " t_capsule" );
79
- np_array t_np = np_array ({n_t }, {sizeof (realtype)}, t, t_capsule);
73
+ const py::capsule t_capsule (t, " t_capsule" );
74
+ const auto t_np = np_array ({n_t }, {sizeof (realtype)}, t, t_capsule);
80
75
81
76
// Convert inputs to an np_array
82
- py::capsule in_capsule (inputs, " in_capsule" );
83
- np_array in_np = np_array ({n_inputs}, {sizeof (realtype)}, inputs, in_capsule);
77
+ const py::capsule in_capsule (inputs, " in_capsule" );
78
+ const auto in_np = np_array ({n_inputs}, {sizeof (realtype)}, inputs, in_capsule);
84
79
85
80
// Call solve function in python to obtain an np_array
86
- np_array out_np = callback_eval (t_np, in_np);
87
- auto out_buf = out_np.request ();
81
+ const np_array out_np = callback_eval (t_np, in_np);
82
+ const auto out_buf = out_np.request ();
88
83
const realtype *out_ptr = reinterpret_cast <realtype *>(out_buf.ptr );
89
84
90
85
// Arrange into 'out' array
@@ -100,14 +95,16 @@ void IdakluJax::cpu_idaklu_jvp(void *out_tuple, const void **in) {
100
95
const std::int64_t n_t = *reinterpret_cast <const std::int64_t *>(in[k++]);
101
96
const std::int64_t n_vars = *reinterpret_cast <const std::int64_t *>(in[k++]);
102
97
const std::int64_t n_inputs = *reinterpret_cast <const std::int64_t *>(in[k++]);
103
- const realtype *primal_t = reinterpret_cast <const realtype *>(in[k++]);
104
- realtype *primal_inputs = new realtype (n_inputs);
105
- for (int i = 0 ; i < n_inputs; i++)
98
+ const auto *primal_t = reinterpret_cast <const realtype *>(in[k++]);
99
+ auto *primal_inputs = new realtype (n_inputs);
100
+ for (int i = 0 ; i < n_inputs; i++) {
106
101
primal_inputs[i] = reinterpret_cast <const realtype *>(in[k++])[0 ];
107
- const realtype *tangent_t = reinterpret_cast <const realtype *>(in[k++]);
108
- realtype *tangent_inputs = new realtype (n_inputs);
109
- for (int i = 0 ; i < n_inputs; i++)
102
+ }
103
+ const auto *tangent_t = reinterpret_cast <const realtype *>(in[k++]);
104
+ auto *tangent_inputs = new realtype (n_inputs);
105
+ for (int i = 0 ; i < n_inputs; i++) {
110
106
tangent_inputs[i] = reinterpret_cast <const realtype *>(in[k++])[0 ];
107
+ }
111
108
void *out = reinterpret_cast <realtype *>(out_tuple);
112
109
113
110
// Log
@@ -125,8 +122,8 @@ void IdakluJax::cpu_idaklu_jvp(void *out_tuple, const void **in) {
125
122
PyGILState_STATE state = PyGILState_Ensure ();
126
123
127
124
// Form primals time vector as np_array
128
- py::capsule primal_t_capsule (primal_t , " primal_t_capsule" );
129
- np_array primal_t_np = np_array (
125
+ const py::capsule primal_t_capsule (primal_t , " primal_t_capsule" );
126
+ const auto primal_t_np = np_array (
130
127
{n_t },
131
128
{sizeof (realtype)},
132
129
primal_t ,
@@ -135,25 +132,25 @@ void IdakluJax::cpu_idaklu_jvp(void *out_tuple, const void **in) {
135
132
136
133
// Pack primals as np_array
137
134
py::capsule primal_inputs_capsule (primal_inputs, " primal_inputs_capsule" );
138
- np_array primal_inputs_np = np_array (
135
+ const auto primal_inputs_np = np_array (
139
136
{n_inputs},
140
137
{sizeof (realtype)},
141
138
primal_inputs,
142
139
primal_inputs_capsule
143
140
);
144
141
145
142
// Form tangents time vector as np_array
146
- py::capsule tangent_t_capsule (tangent_t , " tangent_t_capsule" );
147
- np_array tangent_t_np = np_array (
143
+ const py::capsule tangent_t_capsule (tangent_t , " tangent_t_capsule" );
144
+ const auto tangent_t_np = np_array (
148
145
{n_t },
149
146
{sizeof (realtype)},
150
147
tangent_t ,
151
148
tangent_t_capsule
152
149
);
153
150
154
151
// Pack tangents as np_array
155
- py::capsule tangent_inputs_capsule (tangent_inputs, " tangent_inputs_capsule" );
156
- np_array tangent_inputs_np = np_array (
152
+ const py::capsule tangent_inputs_capsule (tangent_inputs, " tangent_inputs_capsule" );
153
+ const auto tangent_inputs_np = np_array (
157
154
{n_inputs},
158
155
{sizeof (realtype)},
159
156
tangent_inputs,
@@ -165,7 +162,7 @@ void IdakluJax::cpu_idaklu_jvp(void *out_tuple, const void **in) {
165
162
primal_t_np, primal_inputs_np,
166
163
tangent_t_np, tangent_inputs_np
167
164
);
168
- auto buf = y_dot.request ();
165
+ const auto buf = y_dot.request ();
169
166
const realtype *ptr = reinterpret_cast <realtype *>(buf.ptr );
170
167
171
168
// Arrange into 'out' array
@@ -182,13 +179,14 @@ void IdakluJax::cpu_idaklu_vjp(void *out_tuple, const void **in) {
182
179
const std::int64_t n_y_bar0 = *reinterpret_cast <const std::int64_t *>(in[k++]);
183
180
const std::int64_t n_y_bar1 = *reinterpret_cast <const std::int64_t *>(in[k++]);
184
181
const std::int64_t n_y_bar = (n_y_bar1 > 0 ) ? (n_y_bar0*n_y_bar1) : n_y_bar0;
185
- const realtype *y_bar = reinterpret_cast <const realtype *>(in[k++]);
186
- const std:: int64_t *invar = reinterpret_cast <const std::int64_t *>(in[k++]);
187
- const realtype *t = reinterpret_cast <const realtype *>(in[k++]);
188
- realtype *inputs = new realtype (n_inputs);
189
- for (int i = 0 ; i < n_inputs; i++)
182
+ const auto *y_bar = reinterpret_cast <const realtype *>(in[k++]);
183
+ const auto *invar = reinterpret_cast <const std::int64_t *>(in[k++]);
184
+ const auto *t = reinterpret_cast <const realtype *>(in[k++]);
185
+ auto *inputs = new realtype (n_inputs);
186
+ for (int i = 0 ; i < n_inputs; i++) {
190
187
inputs[i] = reinterpret_cast <const realtype *>(in[k++])[0 ];
191
- realtype *out = reinterpret_cast <realtype *>(out_tuple);
188
+ }
189
+ auto *out = reinterpret_cast <realtype *>(out_tuple);
192
190
193
191
// Log
194
192
DEBUG (" cpu_idaklu_vjp" );
0 commit comments