Skip to content

Commit 5d64ef4

Browse files
author
Michael McLeod
committed
MPI hdf5 and stochastic test
1 parent b190f74 commit 5d64ef4

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

cpp/tests/mpi_algo_factory.cc

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
#include <sopt/power_method.h>
1717
#include <sopt/wavelets.h>
1818

19+
#ifdef PURIFY_H5
20+
#include "purify/h5reader.h"
21+
#endif
22+
1923
#include "purify/algorithm_factory.h"
2024
#include "purify/measurement_operator_factory.h"
2125
#include "purify/wavelet_operator_factory.h"
@@ -311,6 +315,7 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
311315

312316
const std::string &test_dir = "expected/fb/";
313317
const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
318+
const std::string &result_path = data_filename(test_dir + "mpi_fb_result.fits");
314319

315320
auto uv_data = dirty_visibilities({input_data_path}, world);
316321
uv_data.units = utilities::vis_units::radians;
@@ -344,6 +349,12 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
344349
beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm);
345350

346351
auto const diagnostic = (*fb)();
352+
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
353+
if (world.is_root())
354+
{
355+
pfitsio::write2d(image.real(), result_path);
356+
//pfitsio::write2d(residual_image.real(), expected_residual_path);
357+
}
347358

348359
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
349360
const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
@@ -360,3 +371,149 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
360371
SOPT_HIGH_LOG("MSE = {}", mse);
361372
CHECK(mse <= average_intensity * 1e-3);
362373
}
374+
375+
TEST_CASE("MPI_fb_factory_hdf5") {
376+
auto const world = sopt::mpi::Communicator::World();
377+
const size_t N = 13107;
378+
379+
const std::string &test_dir = "expected/fb/";
380+
const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
381+
const std::string &result_path = data_filename(test_dir + "mpi_fb_result.fits");
382+
H5::H5Handler h5file(input_data_path, world);
383+
384+
auto uv_data = H5::stochread_visibility(h5file, 6000, false);
385+
uv_data.units = utilities::vis_units::radians;
386+
if (world.is_root()) {
387+
CAPTURE(uv_data.vis.head(5));
388+
}
389+
//REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
390+
391+
t_uint const imsizey = 128;
392+
t_uint const imsizex = 128;
393+
394+
auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
395+
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, imsizey, imsizex, 1,
396+
1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
397+
auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
398+
*measurements_transform, 1000, 1e-5,
399+
world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
400+
const t_real op_norm = std::get<0>(power_method_stuff);
401+
std::vector<std::tuple<std::string, t_uint>> const sara{
402+
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
403+
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
404+
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
405+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
406+
factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
407+
t_real const sigma =
408+
world.broadcast(0.016820222945913496) * std::sqrt(2); // see test_parameters file
409+
t_real const beta = sigma * sigma;
410+
t_real const gamma = 0.0001;
411+
auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
412+
factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma,
413+
beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm);
414+
415+
auto const diagnostic = (*fb)();
416+
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
417+
if (world.is_root())
418+
{
419+
pfitsio::write2d(image.real(), result_path);
420+
//pfitsio::write2d(residual_image.real(), expected_residual_path);
421+
}
422+
423+
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
424+
const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
425+
426+
const auto solution = pfitsio::read2d(expected_solution_path);
427+
const auto residual = pfitsio::read2d(expected_residual_path);
428+
429+
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
430+
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
431+
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
432+
.real()
433+
.squaredNorm() /
434+
solution.size();
435+
SOPT_HIGH_LOG("MSE = {}", mse);
436+
CHECK(mse <= average_intensity * 1e-3);
437+
}
438+
439+
#ifdef PURIFY_H5
440+
TEST_CASE("fb_factory_stochastic") {
441+
const std::string &test_dir = "expected/fb/";
442+
const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
443+
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
444+
const std::string &result_path = data_filename(test_dir + "fb_stochastic_result_mpi.fits");
445+
446+
// HDF5
447+
auto const comm = sopt::mpi::Communicator::World();
448+
const size_t N = 2000;
449+
H5::H5Handler h5file(input_data_path, comm); // length 13107
450+
using t_complexVec = Vector<t_complex>;
451+
452+
// This functor would be defined in Purify
453+
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater = [&f = h5file, &N]() {
454+
utilities::vis_params uv_data = H5::stochread_visibility(f, N, false); // no w-term in this data-set
455+
uv_data.units = utilities::vis_units::radians;
456+
auto phi = factory::measurement_operator_factory<t_complexVec>(
457+
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128, 128, 1, 1, 2,
458+
kernels::kernel_from_string.at("kb"), 4, 4);
459+
460+
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
461+
};
462+
463+
auto IS = random_updater();
464+
auto Phi = IS->Phi();
465+
auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
466+
Phi, 1000, 1e-5,
467+
comm.broadcast(Vector<t_complex>::Ones(128 * 128).eval()));
468+
const t_real op_norm = std::get<0>(power_method_stuff);
469+
470+
const auto solution = pfitsio::read2d(expected_solution_path);
471+
472+
t_uint const imsizey = 128;
473+
t_uint const imsizex = 128;
474+
475+
//wavelets
476+
std::vector<std::tuple<std::string, t_uint>> const sara{
477+
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
478+
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
479+
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
480+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
481+
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
482+
483+
//algorithm
484+
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
485+
t_real const beta = sigma * sigma;
486+
t_real const gamma = 0.0001;
487+
488+
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
489+
fb.itermax(1000).step_size(beta*sqrt(2)).sigma(sigma*sqrt(2)).regulariser_strength(gamma).relative_variation(1e-3).residual_tolerance(0).tight_frame(true).sq_op_norm(op_norm*op_norm).obj_comm(comm);
490+
491+
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
492+
gp->l1_proximal_tolerance(1e-4).l1_proximal_nu(1).l1_proximal_itermax(50).l1_proximal_positivity_constraint(true).l1_proximal_real_constraint(true).Psi(*wavelets);
493+
fb.g_function(gp);
494+
495+
auto const diagnostic = fb();
496+
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
497+
SOPT_HIGH_LOG("God help me.");
498+
if (comm.is_root())
499+
{
500+
SOPT_HIGH_LOG("Root write file");
501+
pfitsio::write2d(image.real(), result_path);
502+
//pfitsio::write2d(residual_image.real(), expected_residual_path);
503+
}
504+
else
505+
{
506+
SOPT_HIGH_LOG("Worker has nowt to do.");
507+
}
508+
509+
auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
510+
double average_intensity = soln_flat.real().sum() / soln_flat.size();
511+
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
512+
double mse = (soln_flat - diagnostic.x)
513+
.real()
514+
.squaredNorm() /
515+
solution.size();
516+
SOPT_HIGH_LOG("MSE = {}", mse);
517+
CHECK(mse <= average_intensity * 1e-3);
518+
}
519+
#endif

0 commit comments

Comments
 (0)