Skip to content

Commit 27820a9

Browse files
mmcleod89Michael McLeod20DM
authored
ONNX Testing (#343)
* add onnx to fb_factory * Fix Psi compatibility issue * add onnx test * Access wavelet operator through algo directly * make sure ORT lib location is known to cmake * Fix ANN flags and add tf test * Use proper data directory from sopt * Update ANN model path * Don't write output images by default in tests * Linting! * Copy models to purify so independent from sopt tests * linting * Remove sopt directory dependency * Add onnxrt guards for tests * Replace strict regression test with mse check on previous solution * Linting that makes things harder to read * Bring MPI test in line with serial * Remove test on exact number of iterations --------- Co-authored-by: Michael McLeod <michaelmcleod@ucl.ac.uk> Co-authored-by: Christian Gutschow <chris.g@cern.ch>
1 parent 4ad17e3 commit 27820a9

8 files changed

+171
-34
lines changed

cpp/purify/algorithm_factory.h

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,27 @@
1515
#include <sopt/mpi/communicator.h>
1616
#endif
1717

18+
#include <sopt/differentiable_func.h>
1819
#include <sopt/imaging_forward_backward.h>
1920
#include <sopt/imaging_padmm.h>
2021
#include <sopt/imaging_primal_dual.h>
2122
#include <sopt/joint_map.h>
22-
#include <sopt/l1_g_proximal.h>
23+
#include <sopt/l1_non_diff_function.h>
24+
#include <sopt/non_differentiable_func.h>
25+
#include <sopt/real_indicator.h>
2326
#include <sopt/relative_variation.h>
2427
#include <sopt/utilities.h>
2528
#include <sopt/wavelets.h>
2629
#include <sopt/wavelets/sara.h>
2730
#ifdef PURIFY_ONNXRT
28-
#include <sopt/tf_g_proximal.h>
31+
#include <sopt/tf_non_diff_function.h>
2932
#endif
3033

3134
namespace purify {
3235
namespace factory {
3336
enum class algorithm { padmm, primal_dual, sdmm, forward_backward };
3437
enum class algo_distribution { serial, mpi_serial, mpi_distributed, mpi_random_updates };
35-
enum class g_proximal_type { L1GProximal, TFGProximal };
38+
enum class g_proximal_type { L1GProximal, TFGProximal, Indicator };
3639
const std::map<std::string, algo_distribution> algo_distribution_string = {
3740
{"none", algo_distribution::serial},
3841
{"serial-equivalent", algo_distribution::mpi_serial},
@@ -161,7 +164,8 @@ fb_factory(const algo_distribution dist,
161164
const bool tight_frame = false, const t_real relative_variation = 1e-3,
162165
const t_real l1_proximal_tolerance = 1e-2, const t_uint maximum_proximal_iterations = 50,
163166
const t_real op_norm = 1, const std::string model_path = "",
164-
const g_proximal_type g_proximal = g_proximal_type::L1GProximal) {
167+
const g_proximal_type g_proximal = g_proximal_type::L1GProximal,
168+
std::shared_ptr<DifferentiableFunc<typename Algorithm::Scalar>> f_function = nullptr) {
165169
typedef typename Algorithm::Scalar t_scalar;
166170
if (sara_size > 1 and tight_frame)
167171
throw std::runtime_error(
@@ -177,7 +181,8 @@ fb_factory(const algo_distribution dist,
177181
.nu(op_norm * op_norm)
178182
.Phi(*measurements);
179183

180-
std::shared_ptr<GProximal<t_scalar>> gp;
184+
if (f_function) fb->f_function(f_function); // only override f_function default if non-null
185+
std::shared_ptr<NonDifferentiableFunc<t_scalar>> g;
181186

182187
switch (g_proximal) {
183188
case (g_proximal_type::L1GProximal): {
@@ -197,25 +202,29 @@ fb_factory(const algo_distribution dist,
197202
l1_gp->l1_proximal_direct_space_comm(comm);
198203
}
199204
#endif
200-
gp = l1_gp;
205+
g = l1_gp;
201206
break;
202207
}
203208
case (g_proximal_type::TFGProximal): {
204209
#ifdef PURIFY_ONNXRT
205210
// Create a shared pointer to an instance of the TFGProximal class
206-
gp = std::make_shared<sopt::algorithm::TFGProximal<t_scalar>>(model_path);
211+
g = std::make_shared<sopt::algorithm::TFGProximal<t_scalar>>(model_path);
207212
break;
208213
#else
209214
throw std::runtime_error(
210215
"Type TFGProximal not recognized because purify was built with onnxrt=off");
211216
#endif
212217
}
218+
case (g_proximal_type::Indicator): {
219+
g = std::make_shared<RealIndicator<t_scalar>>();
220+
break;
221+
}
213222
default: {
214223
throw std::runtime_error("Type of g_proximal operator not recognised.");
215224
}
216225
}
217226

218-
fb->g_proximal(gp);
227+
fb->g_function(g);
219228

220229
switch (dist) {
221230
case (algo_distribution::serial): {

cpp/purify/config.in.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
//! Whether PURIFY is running with casacore
2626
#cmakedefine PURIFY_CASACORE
2727

28+
//! Whether PURIFY is using (and SOPT was built with) onnxrt support
29+
#cmakedefine PURIFY_ONNXRT
30+
2831
#include <string>
2932
#include <tuple>
3033
#include <cstdint>

cpp/purify/update_factory.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void add_updater(std::weak_ptr<Algo> const algo_weak, const t_real step_size_sca
4545
auto algo = algo_weak.lock();
4646
if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ {}", algo->gamma());
4747
if (algo->gamma() > 0) {
48-
Vector<t_complex> const alpha = algo->g_proximal()->Psi().adjoint() * x;
48+
Vector<t_complex> const alpha = algo->Psi().adjoint() * x;
4949
const t_real new_gamma =
5050
comm.all_reduce((sara_size > 0) ? alpha.real().cwiseAbs().maxCoeff() : 0., MPI_MAX) *
5151
step_size_scale;
@@ -88,7 +88,7 @@ void add_updater(std::weak_ptr<Algo> const algo_weak, const t_real step_size_sca
8888
auto algo = algo_weak.lock();
8989
if (algo->gamma() > 0) {
9090
PURIFY_MEDIUM_LOG("Step size γ {}", algo->gamma());
91-
Vector<T> const alpha = algo->g_proximal()->Psi().adjoint() * x;
91+
Vector<T> const alpha = algo->Psi().adjoint() * x;
9292
const t_real new_gamma = alpha.real().cwiseAbs().maxCoeff() * step_size_scale;
9393
PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma);
9494
// updating parameter

cpp/tests/algo_factory.cc

Lines changed: 141 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
#include "purify/algorithm_factory.h"
1313
#include "purify/measurement_operator_factory.h"
1414
#include "purify/wavelet_operator_factory.h"
15+
16+
#ifdef PURIFY_ONNXRT
17+
#include <sopt/onnx_differentiable_func.h>
18+
#endif
19+
1520
#include <sopt/power_method.h>
1621

1722
#include "purify/test_data.h"
@@ -136,6 +141,7 @@ TEST_CASE("fb_factory") {
136141
notinstalled::data_filename(test_dir + "solution.fits");
137142
const std::string &expected_residual_path =
138143
notinstalled::data_filename(test_dir + "residual.fits");
144+
const std::string &result_path = notinstalled::data_filename(test_dir + "fb_result.fits");
139145

140146
const auto solution = pfitsio::read2d(expected_solution_path);
141147
const auto residual = pfitsio::read2d(expected_residual_path);
@@ -170,20 +176,144 @@ TEST_CASE("fb_factory") {
170176

171177
auto const diagnostic = (*fb)();
172178
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
173-
// pfitsio::write2d(image.real(), expected_solution_path);
174-
CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
175-
CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
176-
CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
177-
CHECK(image.isApprox(solution, 1e-4));
179+
// pfitsio::write2d(image.real(), result_path);
180+
// pfitsio::write2d(residual_image.real(), expected_residual_path);
178181

179-
const Vector<t_complex> residuals = measurements_transform->adjoint() *
180-
(uv_data.vis - ((*measurements_transform) * diagnostic.x));
181-
const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
182+
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
183+
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
184+
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
185+
.real()
186+
.squaredNorm() /
187+
solution.size();
188+
SOPT_HIGH_LOG("MSE = {}", mse);
189+
CHECK(mse <= average_intensity * 1e-3);
190+
}
191+
192+
#ifdef PURIFY_ONNXRT
193+
TEST_CASE("tf_fb_factory") {
194+
const std::string &test_dir = "expected/fb/";
195+
const std::string &input_data_path = notinstalled::data_filename(test_dir + "input_data.vis");
196+
const std::string &expected_solution_path =
197+
notinstalled::data_filename(test_dir + "solution.fits");
198+
const std::string &expected_residual_path =
199+
notinstalled::data_filename(test_dir + "residual.fits");
200+
const std::string &result_path = notinstalled::data_filename(test_dir + "tf_result.fits");
201+
202+
const auto solution = pfitsio::read2d(expected_solution_path);
203+
const auto residual = pfitsio::read2d(expected_residual_path);
204+
205+
auto uv_data = utilities::read_visibility(input_data_path, false);
206+
uv_data.units = utilities::vis_units::radians;
207+
CAPTURE(uv_data.vis.head(5));
208+
REQUIRE(uv_data.size() == 13107);
209+
210+
t_uint const imsizey = 128;
211+
t_uint const imsizex = 128;
212+
213+
Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
214+
auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
215+
factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
216+
kernels::kernel_from_string.at("kb"), 4, 4);
217+
auto const power_method_stuff =
218+
sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
219+
const t_real op_norm = std::get<0>(power_method_stuff);
220+
std::vector<std::tuple<std::string, t_uint>> const sara{
221+
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
222+
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
223+
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
224+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
225+
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
226+
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
227+
t_real const beta = sigma * sigma;
228+
t_real const gamma = 0.0001;
229+
230+
std::string tf_model_path =
231+
purify::notinstalled::data_directory() + "/models/snr_15_model_dynamic.onnx";
232+
233+
auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
234+
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
235+
gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm,
236+
tf_model_path, factory::g_proximal_type::TFGProximal);
237+
238+
auto const diagnostic = (*fb)();
239+
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
240+
// pfitsio::write2d(image.real(), result_path);
182241
// pfitsio::write2d(residual_image.real(), expected_residual_path);
183-
CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
184-
CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
185-
CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
242+
243+
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
244+
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
245+
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
246+
.real()
247+
.squaredNorm() /
248+
solution.size();
249+
SOPT_HIGH_LOG("MSE = {}", mse);
250+
CHECK(mse <= average_intensity * 1e-3);
251+
}
252+
253+
TEST_CASE("onnx_fb_factory") {
254+
const std::string &test_dir = "expected/fb/";
255+
const std::string &input_data_path = notinstalled::data_filename(test_dir + "input_data.vis");
256+
const std::string &expected_solution_path =
257+
notinstalled::data_filename(test_dir + "solution.fits");
258+
const std::string &expected_residual_path =
259+
notinstalled::data_filename(test_dir + "residual.fits");
260+
const std::string &result_path = notinstalled::data_filename(test_dir + "onnx_result.fits");
261+
const auto solution = pfitsio::read2d(expected_solution_path);
262+
const auto residual = pfitsio::read2d(expected_residual_path);
263+
264+
auto uv_data = utilities::read_visibility(input_data_path, false);
265+
uv_data.units = utilities::vis_units::radians;
266+
CAPTURE(uv_data.vis.head(5));
267+
REQUIRE(uv_data.size() == 13107);
268+
269+
t_uint const imsizey = 128;
270+
t_uint const imsizex = 128;
271+
272+
Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
273+
auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
274+
factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
275+
kernels::kernel_from_string.at("kb"), 4, 4);
276+
auto const power_method_stuff =
277+
sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
278+
const t_real op_norm = std::get<0>(power_method_stuff);
279+
std::vector<std::tuple<std::string, t_uint>> const sara{
280+
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
281+
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
282+
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
283+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
284+
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
285+
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
286+
t_real const beta = sigma * sigma;
287+
t_real const gamma = 0.0001;
288+
289+
std::string const prior_path =
290+
purify::notinstalled::data_directory() + "/models/example_cost_dynamic_CRR_sigma_5_t_5.onnx";
291+
std::string const prior_gradient_path =
292+
purify::notinstalled::data_directory() + "/models/example_grad_dynamic_CRR_sigma_5_t_5.onnx";
293+
std::shared_ptr<sopt::ONNXDifferentiableFunc<t_complex>> diff_function =
294+
std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(
295+
prior_path, prior_gradient_path, sigma, 20, 5e4, *measurements_transform);
296+
297+
auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
298+
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
299+
gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, "",
300+
factory::g_proximal_type::Indicator, diff_function);
301+
302+
auto const diagnostic = (*fb)();
303+
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
304+
// pfitsio::write2d(image.real(), result_path);
305+
// pfitsio::write2d(residual_image.real(), expected_residual_path);
306+
307+
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
308+
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
309+
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
310+
.real()
311+
.squaredNorm() /
312+
solution.size();
313+
SOPT_HIGH_LOG("MSE = {}", mse);
314+
CHECK(mse <= average_intensity * 1e-3);
186315
}
316+
#endif
187317

188318
TEST_CASE("joint_map_factory") {
189319
const std::string &test_dir = "expected/joint_map/";

cpp/tests/mpi_algo_factory.cc

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,6 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
348348
beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm);
349349

350350
auto const diagnostic = (*fb)();
351-
CHECK(diagnostic.niters == 11);
352351

353352
const std::string &expected_solution_path =
354353
notinstalled::data_filename(test_dir + "solution.fits");
@@ -358,16 +357,12 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
358357
const auto solution = pfitsio::read2d(expected_solution_path);
359358
const auto residual = pfitsio::read2d(expected_residual_path);
360359

361-
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
362-
CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
363-
CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
364-
CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
365-
CHECK(image.isApprox(solution, 1e-4));
366-
367-
const Vector<t_complex> residuals = measurements_transform->adjoint() *
368-
(uv_data.vis - ((*measurements_transform) * diagnostic.x));
369-
const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
370-
CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
371-
CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
372-
CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
360+
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
361+
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
362+
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
363+
.real()
364+
.squaredNorm() /
365+
solution.size();
366+
SOPT_HIGH_LOG("MSE = {}", mse);
367+
CHECK(mse <= average_intensity * 1e-3);
373368
}
Binary file not shown.
Binary file not shown.

data/models/snr_15_model_dynamic.onnx

2.13 MB
Binary file not shown.

0 commit comments

Comments
 (0)