Skip to content

Commit 75c8197

Browse files
committed
update backends to incorporate node weight directly
1 parent 4ce9034 commit 75c8197

File tree

15 files changed

+567
-571
lines changed

15 files changed

+567
-571
lines changed

include/operon/core/dispatch.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ auto Fill(Backend::View<T, S> view, int idx, T value) {
5353
// detect missing specializations for functions
5454
template<typename T, Operon::NodeType N = Operon::NodeTypes::NoType, bool C = false, std::size_t S = Backend::BatchSize<T>>
5555
struct Func {
56-
auto operator()(Backend::View<T, S> /*primal*/, std::integral auto /*node index*/, std::integral auto... /*child indices*/) {
56+
auto operator()(std::vector<Operon::Node> const& /*nodes*/, Backend::View<T, S> /*primal*/, std::integral auto /*node index*/, std::integral auto... /*child indices*/) {
5757
throw std::runtime_error(fmt::format("backend error: missing specialization for function: {}\n", Operon::Node{N}.Name()));
5858
}
5959
};
@@ -94,8 +94,8 @@ static void NaryOp(Operon::Vector<Node> const& nodes, Backend::View<T, S> data,
9494
bool continued = false;
9595

9696
auto const call = [&](bool continued, int result, auto... args) {
97-
if (continued) { Func<T, Type, true , S>{}(data, result, args...); }
98-
else { Func<T, Type, false, S>{}(data, result, args...); }
97+
if (continued) { Func<T, Type, true , S>{}(nodes, data, result, args...); }
98+
else { Func<T, Type, false, S>{}(nodes, data, result, args...); }
9999
};
100100

101101
int arity = nodes[parentIndex].Arity;
@@ -139,14 +139,14 @@ static void BinaryOp(Operon::Vector<Node> const& nodes, Backend::View<T, S> m, s
139139
{
140140
auto j = i - 1;
141141
auto k = j - nodes[j].Length - 1;
142-
Func<T, Type, false>{}(m, i, j, k);
142+
Func<T, Type, false>{}(nodes, m, i, j, k);
143143
}
144144

145145
template<NodeType Type, typename T, std::size_t S>
146146
requires Node::IsUnary<Type>
147-
static void UnaryOp(Operon::Vector<Node> const& /*unused*/, Backend::View<T, S> m, size_t i, Operon::Range /*unused*/)
147+
static void UnaryOp(Operon::Vector<Node> const& nodes, Backend::View<T, S> m, size_t i, Operon::Range /*unused*/)
148148
{
149-
Func<T, Type, false>{}(m, i, i-1);
149+
Func<T, Type, false>{}(nodes, m, i, i-1);
150150
}
151151

152152
struct Noop {

include/operon/interpreter/backend/arma/functions.hpp

Lines changed: 64 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#define ARMA_DONT_USE_WRAPPER
88
#include <armadillo>
9-
#include "operon/interpreter/backend/backend.hpp"
9+
#include "operon/core/dispatch.hpp"
1010

1111
namespace Operon::Backend {
1212
template<typename T, std::size_t S>
@@ -32,173 +32,173 @@ namespace Operon::Backend {
3232

3333
// utility
3434
template<typename T, std::size_t S>
35-
auto Fill(T* res, T value) {
35+
auto Fill(T* res, T weight, T value) {
3636
Map<T, S>(res).fill(value);
3737
}
3838

3939
// n-ary functions
4040
template<typename T, std::size_t S>
41-
auto Add(T* res, auto const*... args) {
42-
Map<T, S>(res) = (Map<T const, S>(args) + ...);
41+
auto Add(T* res, T weight, auto const*... args) {
42+
Map<T, S>(res) = weight * (Map<T const, S>(args) + ...);
4343
}
4444

4545
template<typename T, std::size_t S>
46-
auto Mul(T* res, auto const*... args) {
47-
Map<T, S>(res) = (Map<T const, S>(args) % ...);
46+
auto Mul(T* res, T weight, auto const*... args) {
47+
Map<T, S>(res) = weight * (Map<T const, S>(args) % ...);
4848
}
4949

5050
template<typename T, std::size_t S>
51-
auto Sub(T* res, auto* first, auto const*... rest) {
51+
auto Sub(T* res, T weight, auto* first, auto const*... rest) {
5252
static_assert(sizeof...(rest) > 0);
53-
Map<T, S>(res) = Map<T const, S>(first) - (Map<T const, S>(rest) + ...);
53+
Map<T, S>(res) = weight * Map<T const, S>(first) - (Map<T const, S>(rest) + ...);
5454
}
5555

5656
template<typename T, std::size_t S>
57-
auto Div(T* res, auto const* first, auto const*... rest) {
57+
auto Div(T* res, T weight, auto const* first, auto const*... rest) {
5858
static_assert(sizeof...(rest) > 0);
59-
Map<T, S>(res) = Map<T const, S>(first) / (Map<T const, S>(rest) * ...);
59+
Map<T, S>(res) = weight * Map<T const, S>(first) / (Map<T const, S>(rest) * ...);
6060
}
6161

6262
template<typename T, std::size_t S>
63-
auto Min(T* res, auto const* first, auto const*... args) {
63+
auto Min(T* res, T weight, auto const* first, auto const*... args) {
6464
static_assert(sizeof...(args) > 0);
6565
for (auto i = 0UL; i < S; ++i) {
66-
res[i] = std::min({first[i], args[i]...});
66+
res[i] = weight * std::min({first[i], args[i]...});
6767
}
6868
}
6969

7070
template<typename T, std::size_t S>
71-
auto Max(T* res, auto* first, auto const*... args) {
71+
auto Max(T* res, T weight, auto* first, auto const*... args) {
7272
for (auto i = 0UL; i < S; ++i) {
73-
res[i] = std::max({first[i], args[i]...});
73+
res[i] = weight * std::max({first[i], args[i]...});
7474
}
7575
}
7676

7777
// binary functions
7878
template<typename T, std::size_t S>
79-
auto Aq(T* res, T const* a, T const* b) {
80-
Map<T, S>(res) = Map<T const, S>(a) / arma::sqrt((T{1} + arma::square(Map<T const, S>(b))));
79+
auto Aq(T* res, T weight, T const* a, T const* b) {
80+
Map<T, S>(res) = weight * Map<T const, S>(a) / arma::sqrt((T{1} + arma::square(Map<T const, S>(b))));
8181
}
8282

8383
template<typename T, std::size_t S>
84-
auto Pow(T* res, T const* a, T const* b) {
85-
Map<T, S>(res) = arma::pow(Map<T const, S>(a), Map<T const, S>(b));
84+
auto Pow(T* res, T weight, T const* a, T const* b) {
85+
Map<T, S>(res) = weight * arma::pow(Map<T const, S>(a), Map<T const, S>(b));
8686
}
8787

8888
// unary functions
8989
template<typename T, std::size_t S>
90-
auto Cpy(T* res, T const* arg) {
91-
Map<T, S>(res) = Map<T const, S>(arg);
90+
auto Cpy(T* res, T weight, T const* arg) {
91+
Map<T, S>(res) = weight * Map<T const, S>(arg);
9292
}
9393

9494
template<typename T, std::size_t S>
95-
auto Neg(T* res, T const* arg) {
96-
Map<T, S>(res) = -Map<T const, S>(arg);
95+
auto Neg(T* res, T weight, T const* arg) {
96+
Map<T, S>(res) = weight * -Map<T const, S>(arg);
9797
}
9898

9999
template<typename T, std::size_t S>
100-
auto Inv(T* res, T const* arg) {
101-
Map<T, S>(res) = T{1} / Map<T const, S>(arg);
100+
auto Inv(T* res, T weight, T const* arg) {
101+
Map<T, S>(res) = weight / Map<T const, S>(arg);
102102
}
103103

104104
template<typename T, std::size_t S>
105-
auto Abs(T* res, T const* arg) {
106-
Map<T, S>(res) = arma::abs(Map<T const, S>(arg));
105+
auto Abs(T* res, T weight, T const* arg) {
106+
Map<T, S>(res) = weight * arma::abs(Map<T const, S>(arg));
107107
}
108108

109109
template<typename T, std::size_t S>
110-
auto Ceil(T* res, T const* arg) {
111-
Map<T, S>(res) = arma::ceil(Map<T const, S>(arg));
110+
auto Ceil(T* res, T weight, T const* arg) {
111+
Map<T, S>(res) = weight * arma::ceil(Map<T const, S>(arg));
112112
}
113113

114114
template<typename T, std::size_t S>
115-
auto Floor(T* res, T const* arg) {
116-
Map<T, S>(res) = arma::floor(Map<T const, S>(arg));
115+
auto Floor(T* res, T weight, T const* arg) {
116+
Map<T, S>(res) = weight * arma::floor(Map<T const, S>(arg));
117117
}
118118

119119
template<typename T, std::size_t S>
120-
auto Square(T* res, T const* arg) {
121-
Map<T, S>(res) = arma::square(Map<T const, S>(arg));
120+
auto Square(T* res, T weight, T const* arg) {
121+
Map<T, S>(res) = weight * arma::square(Map<T const, S>(arg));
122122
}
123123

124124
template<typename T, std::size_t S>
125-
auto Exp(T* res, T const* arg) {
126-
Map<T, S>(res) = arma::exp(Map<T const, S>(arg));
125+
auto Exp(T* res, T weight, T const* arg) {
126+
Map<T, S>(res) = weight * arma::exp(Map<T const, S>(arg));
127127
}
128128

129129
template<typename T, std::size_t S>
130-
auto Log(T* res, T const* arg) {
131-
Map<T, S>(res) = arma::log(Map<T const, S>(arg));
130+
auto Log(T* res, T weight, T const* arg) {
131+
Map<T, S>(res) = weight * arma::log(Map<T const, S>(arg));
132132
}
133133

134134
template<typename T, std::size_t S>
135-
auto Log1p(T* res, T const* arg) {
136-
Map<T, S>(res) = arma::log1p(Map<T const, S>(arg));
135+
auto Log1p(T* res, T weight, T const* arg) {
136+
Map<T, S>(res) = weight * arma::log1p(Map<T const, S>(arg));
137137
}
138138

139139
template<typename T, std::size_t S>
140-
auto Logabs(T* res, T const* arg) {
141-
Map<T, S>(res) = arma::log(arma::abs(Map<T const, S>(arg)));
140+
auto Logabs(T* res, T weight, T const* arg) {
141+
Map<T, S>(res) = weight * arma::log(arma::abs(Map<T const, S>(arg)));
142142
}
143143

144144
template<typename T, std::size_t S>
145-
auto Sin(T* res, T const* arg) {
146-
Map<T, S>(res) = arma::sin(Map<T const, S>(arg));
145+
auto Sin(T* res, T weight, T const* arg) {
146+
Map<T, S>(res) = weight * arma::sin(Map<T const, S>(arg));
147147
}
148148

149149
template<typename T, std::size_t S>
150-
auto Cos(T* res, T const* arg) {
151-
Map<T, S>(res) = arma::cos(Map<T const, S>(arg));
150+
auto Cos(T* res, T weight, T const* arg) {
151+
Map<T, S>(res) = weight * arma::cos(Map<T const, S>(arg));
152152
}
153153

154154
template<typename T, std::size_t S>
155-
auto Tan(T* res, T const* arg) {
156-
Map<T, S>(res) = arma::tan(Map<T const, S>(arg));
155+
auto Tan(T* res, T weight, T const* arg) {
156+
Map<T, S>(res) = weight * arma::tan(Map<T const, S>(arg));
157157
}
158158

159159
template<typename T, std::size_t S>
160-
auto Asin(T* res, T const* arg) {
161-
Map<T, S>(res) = arma::asin(Map<T const, S>(arg));
160+
auto Asin(T* res, T weight, T const* arg) {
161+
Map<T, S>(res) = weight * arma::asin(Map<T const, S>(arg));
162162
}
163163

164164
template<typename T, std::size_t S>
165-
auto Acos(T* res, T const* arg) {
166-
Map<T, S>(res) = arma::acos(Map<T const, S>(arg));
165+
auto Acos(T* res, T weight, T const* arg) {
166+
Map<T, S>(res) = weight * arma::acos(Map<T const, S>(arg));
167167
}
168168

169169
template<typename T, std::size_t S>
170-
auto Atan(T* res, T const* arg) {
171-
Map<T, S>(res) = arma::atan(Map<T const, S>(arg));
170+
auto Atan(T* res, T weight, T const* arg) {
171+
Map<T, S>(res) = weight * arma::atan(Map<T const, S>(arg));
172172
}
173173

174174
template<typename T, std::size_t S>
175-
auto Sinh(T* res, T const* arg) {
176-
Map<T, S>(res) = arma::sinh(Map<T const, S>(arg));
175+
auto Sinh(T* res, T weight, T const* arg) {
176+
Map<T, S>(res) = weight * arma::sinh(Map<T const, S>(arg));
177177
}
178178

179179
template<typename T, std::size_t S>
180-
auto Cosh(T* res, T const* arg) {
181-
Map<T, S>(res) = arma::cosh(Map<T const, S>(arg));
180+
auto Cosh(T* res, T weight, T const* arg) {
181+
Map<T, S>(res) = weight * arma::cosh(Map<T const, S>(arg));
182182
}
183183

184184
template<typename T, std::size_t S>
185-
auto Tanh(T* res, T const* arg) {
186-
Map<T, S>(res) = arma::tanh(Map<T const, S>(arg));
185+
auto Tanh(T* res, T weight, T const* arg) {
186+
Map<T, S>(res) = weight * arma::tanh(Map<T const, S>(arg));
187187
}
188188

189189
template<typename T, std::size_t S>
190-
auto Sqrt(T* res, T const* arg) {
191-
Map<T, S>(res) = arma::sqrt(Map<T const, S>(arg));
190+
auto Sqrt(T* res, T weight, T const* arg) {
191+
Map<T, S>(res) = weight * arma::sqrt(Map<T const, S>(arg));
192192
}
193193

194194
template<typename T, std::size_t S>
195-
auto Sqrtabs(T* res, T const* arg) {
196-
Map<T, S>(res) = arma::sqrt(arma::abs(Map<T const, S>(arg)));
195+
auto Sqrtabs(T* res, T weight, T const* arg) {
196+
Map<T, S>(res) = weight * arma::sqrt(arma::abs(Map<T const, S>(arg)));
197197
}
198198

199199
template<typename T, std::size_t S>
200-
auto Cbrt(T* res, T const* arg) {
201-
Map<T, S>(res) = arma::cbrt(Map<T const, S>(arg));
200+
auto Cbrt(T* res, T weight, T const* arg) {
201+
Map<T, S>(res) = weight * arma::cbrt(Map<T const, S>(arg));
202202
}
203203
} // namespace Operon::Backend
204204
#endif

0 commit comments

Comments
 (0)