Skip to content

Commit e361f1b

Browse files
author
Michael McLeod
committed
linting
1 parent e244d16 commit e361f1b

File tree

3 files changed

+106
-84
lines changed

3 files changed

+106
-84
lines changed

cpp/purify/setup_utils.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
#ifdef PURIFY_ONNXRT
1515
#include <sopt/tf_non_diff_function.h>
16-
#endif
16+
#endif
1717

1818
using namespace purify;
1919

@@ -315,13 +315,14 @@ void setupCostFunctions(const YamlParser &params, std::unique_ptr<Differentiable
315315
f = std::make_unique<sopt::L2DifferentiableFunc<t_complex>>(sigma, Phi);
316316
break;
317317
case purify::diff_func_type::L2Norm_with_CRR:
318-
#ifdef PURIFY_ONNXRT
318+
#ifdef PURIFY_ONNXRT
319319
f = std::make_unique<sopt::ONNXDifferentiableFunc<t_complex>>(
320320
params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma, params.CRR_mu(),
321321
params.CRR_lambda(), Phi);
322-
#else
323-
throw std::runtime_error("To use the CRR you must compile with ONNX runtime turned on. (-Donnxrt=on)");
324-
#endif
322+
#else
323+
throw std::runtime_error(
324+
"To use the CRR you must compile with ONNX runtime turned on. (-Donnxrt=on)");
325+
#endif
325326
break;
326327
}
327328

@@ -330,12 +331,13 @@ void setupCostFunctions(const YamlParser &params, std::unique_ptr<Differentiable
330331
g = std::make_unique<sopt::algorithm::L1GProximal<t_complex>>();
331332
break;
332333
case purify::nondiff_func_type::Denoiser:
333-
#ifdef PURIFY_ONNXRT
334+
#ifdef PURIFY_ONNXRT
334335
g = std::make_unique<sopt::algorithm::TFGProximal<t_complex>>(params.model_path());
335336
break;
336-
#else
337-
throw std::runtime_error("To use the Denoiser you must compile with ONNX runtime turned on. (-Donnxrt=on)");
338-
#endif
337+
#else
338+
throw std::runtime_error(
339+
"To use the Denoiser you must compile with ONNX runtime turned on. (-Donnxrt=on)");
340+
#endif
339341

340342
case purify::nondiff_func_type::RealIndicator:
341343
g = std::make_unique<sopt::algorithm::RealIndicator<t_complex>>();

cpp/tests/algo_factory.cc

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -205,41 +205,42 @@ TEST_CASE("fb_factory_stochastic") {
205205
// This functor would be defined in Purify
206206
std::mt19937 rng(0);
207207
const size_t N = 1000;
208-
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater = [&input_data_path, imsizex, imsizey, &rng, &N]() {
209-
utilities::vis_params uv_data = utilities::read_visibility(input_data_path, false);
210-
uv_data.units = utilities::vis_units::radians;
211-
212-
// Get random subset
213-
std::vector<size_t> indices(uv_data.size());
214-
size_t i = 0;
215-
for(auto &x : indices)
216-
{
217-
x = i++;
218-
}
219-
220-
std::shuffle(indices.begin(), indices.end(), rng);
221-
Vector<t_real> u_fragment(N);
222-
Vector<t_real> v_fragment(N);
223-
Vector<t_real> w_fragment(N);
224-
Vector<t_complex> vis_fragment(N);
225-
Vector<t_complex> weights_fragment(N);
226-
for(i = 0; i < N; i++)
227-
{
228-
size_t j = indices[i];
229-
u_fragment[i] = uv_data.u[j];
230-
v_fragment[i] = uv_data.v[j];
231-
w_fragment[i] = uv_data.w[j];
232-
vis_fragment[i] = uv_data.vis[j];
233-
weights_fragment[i] = uv_data.weights[j];
234-
}
235-
utilities::vis_params uv_data_fragment(u_fragment, v_fragment, w_fragment, vis_fragment, weights_fragment, uv_data.units, uv_data.ra, uv_data.dec, uv_data.average_frequency);
236-
237-
auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
238-
factory::distributed_measurement_operator::serial, uv_data_fragment, imsizey, imsizex, 1, 1, 2,
239-
kernels::kernel_from_string.at("kb"), 4, 4);
240-
241-
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment.vis, phi);
242-
};
208+
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
209+
[&input_data_path, imsizex, imsizey, &rng, &N]() {
210+
utilities::vis_params uv_data = utilities::read_visibility(input_data_path, false);
211+
uv_data.units = utilities::vis_units::radians;
212+
213+
// Get random subset
214+
std::vector<size_t> indices(uv_data.size());
215+
size_t i = 0;
216+
for (auto &x : indices) {
217+
x = i++;
218+
}
219+
220+
std::shuffle(indices.begin(), indices.end(), rng);
221+
Vector<t_real> u_fragment(N);
222+
Vector<t_real> v_fragment(N);
223+
Vector<t_real> w_fragment(N);
224+
Vector<t_complex> vis_fragment(N);
225+
Vector<t_complex> weights_fragment(N);
226+
for (i = 0; i < N; i++) {
227+
size_t j = indices[i];
228+
u_fragment[i] = uv_data.u[j];
229+
v_fragment[i] = uv_data.v[j];
230+
w_fragment[i] = uv_data.w[j];
231+
vis_fragment[i] = uv_data.vis[j];
232+
weights_fragment[i] = uv_data.weights[j];
233+
}
234+
utilities::vis_params uv_data_fragment(u_fragment, v_fragment, w_fragment, vis_fragment,
235+
weights_fragment, uv_data.units, uv_data.ra,
236+
uv_data.dec, uv_data.average_frequency);
237+
238+
auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
239+
factory::distributed_measurement_operator::serial, uv_data_fragment, imsizey, imsizex,
240+
1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
241+
242+
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment.vis, phi);
243+
};
243244

244245
Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
245246
auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
@@ -257,38 +258,47 @@ TEST_CASE("fb_factory_stochastic") {
257258

258259
const auto solution = pfitsio::read2d(expected_solution_path);
259260

260-
//wavelets
261+
// wavelets
261262
std::vector<std::tuple<std::string, t_uint>> const sara{
262263
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
263264
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
264265
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
265266
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
266267
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
267-
268-
//algorithm
268+
269+
// algorithm
269270
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
270271
t_real const beta = sigma * sigma;
271272
t_real const gamma = 0.0001;
272273

273274
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
274-
fb.itermax(1000).step_size(beta*sqrt(2)).sigma(sigma*sqrt(2)).regulariser_strength(gamma).relative_variation(1e-3).residual_tolerance(0).tight_frame(true).sq_op_norm(op_norm*op_norm);
275+
fb.itermax(1000)
276+
.step_size(beta * sqrt(2))
277+
.sigma(sigma * sqrt(2))
278+
.regulariser_strength(gamma)
279+
.relative_variation(1e-3)
280+
.residual_tolerance(0)
281+
.tight_frame(true)
282+
.sq_op_norm(op_norm * op_norm);
275283

276284
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
277-
gp->l1_proximal_tolerance(1e-4).l1_proximal_nu(1).l1_proximal_itermax(50).l1_proximal_positivity_constraint(true).l1_proximal_real_constraint(true).Psi(*wavelets);
285+
gp->l1_proximal_tolerance(1e-4)
286+
.l1_proximal_nu(1)
287+
.l1_proximal_itermax(50)
288+
.l1_proximal_positivity_constraint(true)
289+
.l1_proximal_real_constraint(true)
290+
.Psi(*wavelets);
278291
fb.g_function(gp);
279292

280293
auto const diagnostic = fb();
281294
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
282-
//pfitsio::write2d(image.real(), result_path);
283-
//pfitsio::write2d(residual_image.real(), expected_residual_path);
295+
// pfitsio::write2d(image.real(), result_path);
296+
// pfitsio::write2d(residual_image.real(), expected_residual_path);
284297

285298
auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
286299
double average_intensity = soln_flat.real().sum() / soln_flat.size();
287300
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
288-
double mse = (soln_flat - diagnostic.x)
289-
.real()
290-
.squaredNorm() /
291-
solution.size();
301+
double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
292302
SOPT_HIGH_LOG("MSE = {}", mse);
293303
CHECK(mse <= average_intensity * 1e-3);
294304
}

cpp/tests/mpi_algo_factory.cc

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,9 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
350350

351351
auto const diagnostic = (*fb)();
352352
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
353-
if (world.is_root())
354-
{
353+
if (world.is_root()) {
355354
pfitsio::write2d(image.real(), result_path);
356-
//pfitsio::write2d(residual_image.real(), expected_residual_path);
355+
// pfitsio::write2d(residual_image.real(), expected_residual_path);
357356
}
358357

359358
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
@@ -376,7 +375,7 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
376375
TEST_CASE("MPI_fb_factory_hdf5") {
377376
auto const world = sopt::mpi::Communicator::World();
378377
const size_t N = 13107;
379-
378+
380379
const std::string &test_dir = "expected/fb/";
381380
const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
382381
const std::string &result_path = data_filename(test_dir + "mpi_fb_result_hdf5.fits");
@@ -387,7 +386,7 @@ TEST_CASE("MPI_fb_factory_hdf5") {
387386
if (world.is_root()) {
388387
CAPTURE(uv_data.vis.head(5));
389388
}
390-
//REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
389+
// REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
391390

392391
t_uint const imsizey = 128;
393392
t_uint const imsizex = 128;
@@ -415,7 +414,7 @@ TEST_CASE("MPI_fb_factory_hdf5") {
415414

416415
auto const diagnostic = (*fb)();
417416
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
418-
//if (world.is_root())
417+
// if (world.is_root())
419418
//{
420419
// pfitsio::write2d(image.real(), result_path);
421420
//}
@@ -441,70 +440,81 @@ TEST_CASE("fb_factory_stochastic") {
441440
const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
442441
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
443442
const std::string &result_path = data_filename(test_dir + "fb_stochastic_result_mpi.fits");
444-
443+
445444
// HDF5
446445
auto const comm = sopt::mpi::Communicator::World();
447446
const size_t N = 2000;
448447
H5::H5Handler h5file(input_data_path, comm); // length 13107
449448
using t_complexVec = Vector<t_complex>;
450449

451450
// This functor would be defined in Purify
452-
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater = [&f = h5file, &N]() {
453-
utilities::vis_params uv_data = H5::stochread_visibility(f, N, false); // no w-term in this data-set
454-
uv_data.units = utilities::vis_units::radians;
455-
auto phi = factory::measurement_operator_factory<t_complexVec>(
456-
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128, 128, 1, 1, 2,
457-
kernels::kernel_from_string.at("kb"), 4, 4);
458-
459-
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
460-
};
451+
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
452+
[&f = h5file, &N]() {
453+
utilities::vis_params uv_data =
454+
H5::stochread_visibility(f, N, false); // no w-term in this data-set
455+
uv_data.units = utilities::vis_units::radians;
456+
auto phi = factory::measurement_operator_factory<t_complexVec>(
457+
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128, 128, 1,
458+
1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
459+
460+
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
461+
};
461462

462463
auto IS = random_updater();
463464
auto Phi = IS->Phi();
464465
auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
465-
Phi, 1000, 1e-5,
466-
comm.broadcast(Vector<t_complex>::Ones(128 * 128).eval()));
466+
Phi, 1000, 1e-5, comm.broadcast(Vector<t_complex>::Ones(128 * 128).eval()));
467467
const t_real op_norm = std::get<0>(power_method_stuff);
468468

469469
const auto solution = pfitsio::read2d(expected_solution_path);
470470

471471
t_uint const imsizey = 128;
472472
t_uint const imsizex = 128;
473-
474-
//wavelets
473+
474+
// wavelets
475475
std::vector<std::tuple<std::string, t_uint>> const sara{
476476
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
477477
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
478478
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
479479
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
480480
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
481-
482-
//algorithm
481+
482+
// algorithm
483483
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
484484
t_real const beta = sigma * sigma;
485485
t_real const gamma = 0.0001;
486486

487487
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
488-
fb.itermax(1000).step_size(beta*sqrt(2)).sigma(sigma*sqrt(2)).regulariser_strength(gamma).relative_variation(1e-3).residual_tolerance(0).tight_frame(true).sq_op_norm(op_norm*op_norm).obj_comm(comm);
488+
fb.itermax(1000)
489+
.step_size(beta * sqrt(2))
490+
.sigma(sigma * sqrt(2))
491+
.regulariser_strength(gamma)
492+
.relative_variation(1e-3)
493+
.residual_tolerance(0)
494+
.tight_frame(true)
495+
.sq_op_norm(op_norm * op_norm)
496+
.obj_comm(comm);
489497

490498
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
491-
gp->l1_proximal_tolerance(1e-4).l1_proximal_nu(1).l1_proximal_itermax(50).l1_proximal_positivity_constraint(true).l1_proximal_real_constraint(true).Psi(*wavelets);
499+
gp->l1_proximal_tolerance(1e-4)
500+
.l1_proximal_nu(1)
501+
.l1_proximal_itermax(50)
502+
.l1_proximal_positivity_constraint(true)
503+
.l1_proximal_real_constraint(true)
504+
.Psi(*wavelets);
492505
fb.g_function(gp);
493506

494507
auto const diagnostic = fb();
495508
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
496-
//if (comm.is_root())
509+
// if (comm.is_root())
497510
//{
498511
// //pfitsio::write2d(image.real(), result_path);
499512
//}
500513

501514
auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
502515
double average_intensity = soln_flat.real().sum() / soln_flat.size();
503516
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
504-
double mse = (soln_flat - diagnostic.x)
505-
.real()
506-
.squaredNorm() /
507-
solution.size();
517+
double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
508518
SOPT_HIGH_LOG("MSE = {}", mse);
509519
CHECK(mse <= average_intensity * 1e-3);
510520
}

0 commit comments

Comments
 (0)