Skip to content

Commit 7bbd786

Browse files
authored
Merge pull request #380 from astro-informatics/mm/stochastic_testing
Stochastic test cases
2 parents 2e748a9 + a76ec4a commit 7bbd786

File tree

4 files changed

+293
-6
lines changed

4 files changed

+293
-6
lines changed

cpp/purify/setup_utils.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33
#include <sopt/l1_non_diff_function.h>
44
#include <sopt/l2_differentiable_func.h>
55
#include <sopt/non_differentiable_func.h>
6+
7+
#ifdef PURIFY_ONNXRT
68
#include <sopt/onnx_differentiable_func.h>
9+
#endif
10+
711
#include <sopt/power_method.h>
812
#include <sopt/real_indicator.h>
13+
14+
#ifdef PURIFY_ONNXRT
915
#include <sopt/tf_non_diff_function.h>
16+
#endif
1017

1118
using namespace purify;
1219

@@ -308,9 +315,14 @@ void setupCostFunctions(const YamlParser &params, std::unique_ptr<Differentiable
308315
f = std::make_unique<sopt::L2DifferentiableFunc<t_complex>>(sigma, Phi);
309316
break;
310317
case purify::diff_func_type::L2Norm_with_CRR:
318+
#ifdef PURIFY_ONNXRT
311319
f = std::make_unique<sopt::ONNXDifferentiableFunc<t_complex>>(
312320
params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma, params.CRR_mu(),
313321
params.CRR_lambda(), Phi);
322+
#else
323+
throw std::runtime_error(
324+
"To use the CRR you must compile with ONNX runtime turned on. (-Donnxrt=on)");
325+
#endif
314326
break;
315327
}
316328

@@ -319,8 +331,14 @@ void setupCostFunctions(const YamlParser &params, std::unique_ptr<Differentiable
319331
g = std::make_unique<sopt::algorithm::L1GProximal<t_complex>>();
320332
break;
321333
case purify::nondiff_func_type::Denoiser:
334+
#ifdef PURIFY_ONNXRT
322335
g = std::make_unique<sopt::algorithm::TFGProximal<t_complex>>(params.model_path());
323336
break;
337+
#else
338+
throw std::runtime_error(
339+
"To use the Denoiser you must compile with ONNX runtime turned on. (-Donnxrt=on)");
340+
#endif
341+
324342
case purify::nondiff_func_type::RealIndicator:
325343
g = std::make_unique<sopt::algorithm::RealIndicator<t_complex>>();
326344
break;

cpp/tests/algo_factory.cc

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#include <sopt/onnx_differentiable_func.h>
1818
#endif
1919

20+
#ifdef PURIFY_H5
21+
#include "purify/h5reader.h"
22+
#endif
23+
2024
#include <sopt/power_method.h>
2125

2226
#include "purify/test_data.h"
@@ -169,7 +173,7 @@ TEST_CASE("fb_factory") {
169173

170174
auto const diagnostic = (*fb)();
171175
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
172-
// pfitsio::write2d(image.real(), result_path);
176+
pfitsio::write2d(image.real(), result_path);
173177
// pfitsio::write2d(residual_image.real(), expected_residual_path);
174178

175179
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
@@ -182,6 +186,117 @@ TEST_CASE("fb_factory") {
182186
CHECK(mse <= average_intensity * 1e-3);
183187
}
184188

189+
#ifdef PURIFY_H5
190+
TEST_CASE("fb_factory_stochastic") {
191+
const std::string &test_dir = "expected/fb/";
192+
const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
193+
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
194+
const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
195+
const std::string &result_path = data_filename(test_dir + "fb_result_stochastic.fits");
196+
197+
auto uv_data = utilities::read_visibility(input_data_path, false);
198+
uv_data.units = utilities::vis_units::radians;
199+
CAPTURE(uv_data.vis.head(5));
200+
REQUIRE(uv_data.size() == 13107);
201+
202+
t_uint const imsizey = 128;
203+
t_uint const imsizex = 128;
204+
205+
// This functor would be defined in Purify
206+
std::mt19937 rng(0);
207+
const size_t N = 1000;
208+
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
209+
[&input_data_path, imsizex, imsizey, &rng, &N]() {
210+
utilities::vis_params uv_data = utilities::read_visibility(input_data_path, false);
211+
uv_data.units = utilities::vis_units::radians;
212+
213+
// Get random subset
214+
std::vector<size_t> indices(uv_data.size());
215+
size_t i = 0;
216+
for (auto &x : indices) {
217+
x = i++;
218+
}
219+
220+
std::shuffle(indices.begin(), indices.end(), rng);
221+
Vector<t_real> u_fragment(N);
222+
Vector<t_real> v_fragment(N);
223+
Vector<t_real> w_fragment(N);
224+
Vector<t_complex> vis_fragment(N);
225+
Vector<t_complex> weights_fragment(N);
226+
for (i = 0; i < N; i++) {
227+
size_t j = indices[i];
228+
u_fragment[i] = uv_data.u[j];
229+
v_fragment[i] = uv_data.v[j];
230+
w_fragment[i] = uv_data.w[j];
231+
vis_fragment[i] = uv_data.vis[j];
232+
weights_fragment[i] = uv_data.weights[j];
233+
}
234+
utilities::vis_params uv_data_fragment(u_fragment, v_fragment, w_fragment, vis_fragment,
235+
weights_fragment, uv_data.units, uv_data.ra,
236+
uv_data.dec, uv_data.average_frequency);
237+
238+
auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
239+
factory::distributed_measurement_operator::serial, uv_data_fragment, imsizey, imsizex,
240+
1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
241+
242+
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment.vis, phi);
243+
};
244+
245+
Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
246+
auto IS = random_updater();
247+
auto Phi = IS->Phi();
248+
auto const power_method_stuff =
249+
sopt::algorithm::power_method<Vector<t_complex>>(Phi, 1000, 1e-5, init);
250+
const t_real op_norm = std::get<0>(power_method_stuff);
251+
252+
const auto solution = pfitsio::read2d(expected_solution_path);
253+
254+
// wavelets
255+
std::vector<std::tuple<std::string, t_uint>> const sara{
256+
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
257+
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
258+
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
259+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
260+
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
261+
262+
// algorithm
263+
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
264+
t_real const beta = sigma * sigma;
265+
t_real const gamma = 0.0001;
266+
267+
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
268+
fb.itermax(1000)
269+
.step_size(beta * sqrt(2))
270+
.sigma(sigma * sqrt(2))
271+
.regulariser_strength(gamma)
272+
.relative_variation(1e-3)
273+
.residual_tolerance(0)
274+
.tight_frame(true)
275+
.sq_op_norm(op_norm * op_norm);
276+
277+
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
278+
gp->l1_proximal_tolerance(1e-4)
279+
.l1_proximal_nu(1)
280+
.l1_proximal_itermax(50)
281+
.l1_proximal_positivity_constraint(true)
282+
.l1_proximal_real_constraint(true)
283+
.Psi(*wavelets);
284+
fb.g_function(gp);
285+
286+
auto const diagnostic = fb();
287+
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
288+
// pfitsio::write2d(image.real(), result_path);
289+
// pfitsio::write2d(residual_image.real(), expected_residual_path);
290+
291+
auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
292+
double average_intensity = soln_flat.real().sum() / soln_flat.size();
293+
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
294+
double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
295+
SOPT_HIGH_LOG("MSE = {}", mse);
296+
CHECK(mse <= average_intensity * 1e-3);
297+
}
298+
#endif
299+
185300
#ifdef PURIFY_ONNXRT
186301
TEST_CASE("tf_fb_factory") {
187302
const std::string &test_dir = "expected/fb/";

cpp/tests/mpi_algo_factory.cc

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
#include <sopt/power_method.h>
1717
#include <sopt/wavelets.h>
1818

19+
#ifdef PURIFY_H5
20+
#include "purify/h5reader.h"
21+
#endif
22+
1923
#include "purify/algorithm_factory.h"
2024
#include "purify/measurement_operator_factory.h"
2125
#include "purify/wavelet_operator_factory.h"
@@ -311,6 +315,7 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
311315

312316
const std::string &test_dir = "expected/fb/";
313317
const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
318+
const std::string &result_path = data_filename(test_dir + "mpi_fb_result.fits");
314319

315320
auto uv_data = dirty_visibilities({input_data_path}, world);
316321
uv_data.units = utilities::vis_units::radians;
@@ -344,6 +349,75 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
344349
beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm);
345350

346351
auto const diagnostic = (*fb)();
352+
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
353+
if (world.is_root()) {
354+
pfitsio::write2d(image.real(), result_path);
355+
// pfitsio::write2d(residual_image.real(), expected_residual_path);
356+
}
357+
358+
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
359+
const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
360+
361+
const auto solution = pfitsio::read2d(expected_solution_path);
362+
const auto residual = pfitsio::read2d(expected_residual_path);
363+
364+
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
365+
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
366+
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
367+
.real()
368+
.squaredNorm() /
369+
solution.size();
370+
SOPT_HIGH_LOG("MSE = {}", mse);
371+
CHECK(mse <= average_intensity * 1e-3);
372+
}
373+
374+
#ifdef PURIFY_H5
375+
TEST_CASE("MPI_fb_factory_hdf5") {
376+
auto const world = sopt::mpi::Communicator::World();
377+
const size_t N = 13107;
378+
379+
const std::string &test_dir = "expected/fb/";
380+
const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
381+
const std::string &result_path = data_filename(test_dir + "mpi_fb_result_hdf5.fits");
382+
H5::H5Handler h5file(input_data_path, world);
383+
384+
auto uv_data = H5::stochread_visibility(h5file, 6000, false);
385+
uv_data.units = utilities::vis_units::radians;
386+
if (world.is_root()) {
387+
CAPTURE(uv_data.vis.head(5));
388+
}
389+
// REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
390+
391+
t_uint const imsizey = 128;
392+
t_uint const imsizex = 128;
393+
394+
auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
395+
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, imsizey, imsizex, 1,
396+
1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
397+
auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
398+
*measurements_transform, 1000, 1e-5,
399+
world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
400+
const t_real op_norm = std::get<0>(power_method_stuff);
401+
std::vector<std::tuple<std::string, t_uint>> const sara{
402+
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
403+
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
404+
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
405+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
406+
factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
407+
t_real const sigma =
408+
world.broadcast(0.016820222945913496) * std::sqrt(2); // see test_parameters file
409+
t_real const beta = sigma * sigma;
410+
t_real const gamma = 0.0001;
411+
auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
412+
factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma,
413+
beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm);
414+
415+
auto const diagnostic = (*fb)();
416+
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
417+
// if (world.is_root())
418+
//{
419+
// pfitsio::write2d(image.real(), result_path);
420+
//}
347421

348422
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
349423
const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
@@ -360,3 +434,88 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
360434
SOPT_HIGH_LOG("MSE = {}", mse);
361435
CHECK(mse <= average_intensity * 1e-3);
362436
}
437+
438+
TEST_CASE("fb_factory_stochastic") {
439+
const std::string &test_dir = "expected/fb/";
440+
const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
441+
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
442+
const std::string &result_path = data_filename(test_dir + "fb_stochastic_result_mpi.fits");
443+
444+
// HDF5
445+
auto const comm = sopt::mpi::Communicator::World();
446+
const size_t N = 2000;
447+
H5::H5Handler h5file(input_data_path, comm); // length 13107
448+
using t_complexVec = Vector<t_complex>;
449+
450+
// This functor would be defined in Purify
451+
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
452+
[&f = h5file, &N]() {
453+
utilities::vis_params uv_data =
454+
H5::stochread_visibility(f, N, false); // no w-term in this data-set
455+
uv_data.units = utilities::vis_units::radians;
456+
auto phi = factory::measurement_operator_factory<t_complexVec>(
457+
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128, 128, 1,
458+
1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
459+
460+
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
461+
};
462+
463+
auto IS = random_updater();
464+
auto Phi = IS->Phi();
465+
auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
466+
Phi, 1000, 1e-5, comm.broadcast(Vector<t_complex>::Ones(128 * 128).eval()));
467+
const t_real op_norm = std::get<0>(power_method_stuff);
468+
469+
const auto solution = pfitsio::read2d(expected_solution_path);
470+
471+
t_uint const imsizey = 128;
472+
t_uint const imsizex = 128;
473+
474+
// wavelets
475+
std::vector<std::tuple<std::string, t_uint>> const sara{
476+
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
477+
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
478+
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
479+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
480+
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
481+
482+
// algorithm
483+
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
484+
t_real const beta = sigma * sigma;
485+
t_real const gamma = 0.0001;
486+
487+
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
488+
fb.itermax(1000)
489+
.step_size(beta * sqrt(2))
490+
.sigma(sigma * sqrt(2))
491+
.regulariser_strength(gamma)
492+
.relative_variation(1e-3)
493+
.residual_tolerance(0)
494+
.tight_frame(true)
495+
.sq_op_norm(op_norm * op_norm)
496+
.obj_comm(comm);
497+
498+
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
499+
gp->l1_proximal_tolerance(1e-4)
500+
.l1_proximal_nu(1)
501+
.l1_proximal_itermax(50)
502+
.l1_proximal_positivity_constraint(true)
503+
.l1_proximal_real_constraint(true)
504+
.Psi(*wavelets);
505+
fb.g_function(gp);
506+
507+
auto const diagnostic = fb();
508+
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
509+
// if (comm.is_root())
510+
//{
511+
// //pfitsio::write2d(image.real(), result_path);
512+
//}
513+
514+
auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
515+
double average_intensity = soln_flat.real().sum() / soln_flat.size();
516+
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
517+
double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
518+
SOPT_HIGH_LOG("MSE = {}", mse);
519+
CHECK(mse <= average_intensity * 1e-3);
520+
}
521+
#endif

cpp/uncertainty_quantification/uq_main.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515
#include <sopt/l1_non_diff_function.h>
1616
#include <sopt/l2_differentiable_func.h>
1717
#include <sopt/real_indicator.h>
18-
#include <sopt/tf_non_diff_function.h>
19-
20-
#ifdef PURIFY_ONNXRT
21-
#include <sopt/onnx_differentiable_func.h>
22-
#endif
2318

2419
using VectorC = sopt::Vector<std::complex<double>>;
2520

0 commit comments

Comments
 (0)