16
16
#include < sopt/power_method.h>
17
17
#include < sopt/wavelets.h>
18
18
19
+ #ifdef PURIFY_H5
20
+ #include " purify/h5reader.h"
21
+ #endif
22
+
19
23
#include " purify/algorithm_factory.h"
20
24
#include " purify/measurement_operator_factory.h"
21
25
#include " purify/wavelet_operator_factory.h"
@@ -311,6 +315,7 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
311
315
312
316
const std::string &test_dir = " expected/fb/" ;
313
317
const std::string &input_data_path = data_filename (test_dir + " input_data.vis" );
318
+ const std::string &result_path = data_filename (test_dir + " mpi_fb_result.fits" );
314
319
315
320
auto uv_data = dirty_visibilities ({input_data_path}, world);
316
321
uv_data.units = utilities::vis_units::radians;
@@ -344,6 +349,12 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
344
349
beta, gamma, imsizey, imsizex, sara.size (), 1000 , true , true , false , 1e-2 , 1e-3 , 50 , op_norm);
345
350
346
351
auto const diagnostic = (*fb)();
352
+ const Image<t_complex> image = Image<t_complex>::Map (diagnostic.x .data (), imsizey, imsizex);
353
+ if (world.is_root ())
354
+ {
355
+ pfitsio::write2d (image.real (), result_path);
356
+ // pfitsio::write2d(residual_image.real(), expected_residual_path);
357
+ }
347
358
348
359
const std::string &expected_solution_path = data_filename (test_dir + " solution.fits" );
349
360
const std::string &expected_residual_path = data_filename (test_dir + " residual.fits" );
@@ -360,3 +371,149 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
360
371
SOPT_HIGH_LOG (" MSE = {}" , mse);
361
372
CHECK (mse <= average_intensity * 1e-3 );
362
373
}
374
+
375
+ TEST_CASE (" MPI_fb_factory_hdf5" ) {
376
+ auto const world = sopt::mpi::Communicator::World ();
377
+ const size_t N = 13107 ;
378
+
379
+ const std::string &test_dir = " expected/fb/" ;
380
+ const std::string &input_data_path = data_filename (test_dir + " input_data.h5" );
381
+ const std::string &result_path = data_filename (test_dir + " mpi_fb_result.fits" );
382
+ H5::H5Handler h5file (input_data_path, world);
383
+
384
+ auto uv_data = H5::stochread_visibility (h5file, 6000 , false );
385
+ uv_data.units = utilities::vis_units::radians;
386
+ if (world.is_root ()) {
387
+ CAPTURE (uv_data.vis .head (5 ));
388
+ }
389
+ // REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
390
+
391
+ t_uint const imsizey = 128 ;
392
+ t_uint const imsizex = 128 ;
393
+
394
+ auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
395
+ factory::distributed_measurement_operator::mpi_distribute_image, uv_data, imsizey, imsizex, 1 ,
396
+ 1 , 2 , kernels::kernel_from_string.at (" kb" ), 4 , 4 );
397
+ auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
398
+ *measurements_transform, 1000 , 1e-5 ,
399
+ world.broadcast (Vector<t_complex>::Ones (imsizex * imsizey).eval ()));
400
+ const t_real op_norm = std::get<0 >(power_method_stuff);
401
+ std::vector<std::tuple<std::string, t_uint>> const sara{
402
+ std::make_tuple (" Dirac" , 3u ), std::make_tuple (" DB1" , 3u ), std::make_tuple (" DB2" , 3u ),
403
+ std::make_tuple (" DB3" , 3u ), std::make_tuple (" DB4" , 3u ), std::make_tuple (" DB5" , 3u ),
404
+ std::make_tuple (" DB6" , 3u ), std::make_tuple (" DB7" , 3u ), std::make_tuple (" DB8" , 3u )};
405
+ auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
406
+ factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
407
+ t_real const sigma =
408
+ world.broadcast (0.016820222945913496 ) * std::sqrt (2 ); // see test_parameters file
409
+ t_real const beta = sigma * sigma;
410
+ t_real const gamma = 0.0001 ;
411
+ auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
412
+ factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma,
413
+ beta, gamma, imsizey, imsizex, sara.size (), 1000 , true , true , false , 1e-2 , 1e-3 , 50 , op_norm);
414
+
415
+ auto const diagnostic = (*fb)();
416
+ const Image<t_complex> image = Image<t_complex>::Map (diagnostic.x .data (), imsizey, imsizex);
417
+ if (world.is_root ())
418
+ {
419
+ pfitsio::write2d (image.real (), result_path);
420
+ // pfitsio::write2d(residual_image.real(), expected_residual_path);
421
+ }
422
+
423
+ const std::string &expected_solution_path = data_filename (test_dir + " solution.fits" );
424
+ const std::string &expected_residual_path = data_filename (test_dir + " residual.fits" );
425
+
426
+ const auto solution = pfitsio::read2d (expected_solution_path);
427
+ const auto residual = pfitsio::read2d (expected_residual_path);
428
+
429
+ double average_intensity = diagnostic.x .real ().sum () / diagnostic.x .size ();
430
+ SOPT_HIGH_LOG (" Average intensity = {}" , average_intensity);
431
+ double mse = (Vector<t_complex>::Map (solution.data (), solution.size ()) - diagnostic.x )
432
+ .real ()
433
+ .squaredNorm () /
434
+ solution.size ();
435
+ SOPT_HIGH_LOG (" MSE = {}" , mse);
436
+ CHECK (mse <= average_intensity * 1e-3 );
437
+ }
438
+
439
+ #ifdef PURIFY_H5
440
+ TEST_CASE (" fb_factory_stochastic" ) {
441
+ const std::string &test_dir = " expected/fb/" ;
442
+ const std::string &input_data_path = data_filename (test_dir + " input_data.h5" );
443
+ const std::string &expected_solution_path = data_filename (test_dir + " solution.fits" );
444
+ const std::string &result_path = data_filename (test_dir + " fb_stochastic_result_mpi.fits" );
445
+
446
+ // HDF5
447
+ auto const comm = sopt::mpi::Communicator::World ();
448
+ const size_t N = 2000 ;
449
+ H5::H5Handler h5file (input_data_path, comm); // length 13107
450
+ using t_complexVec = Vector<t_complex>;
451
+
452
+ // This functor would be defined in Purify
453
+ std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater = [&f = h5file, &N]() {
454
+ utilities::vis_params uv_data = H5::stochread_visibility (f, N, false ); // no w-term in this data-set
455
+ uv_data.units = utilities::vis_units::radians;
456
+ auto phi = factory::measurement_operator_factory<t_complexVec>(
457
+ factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128 , 128 , 1 , 1 , 2 ,
458
+ kernels::kernel_from_string.at (" kb" ), 4 , 4 );
459
+
460
+ return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis , phi);
461
+ };
462
+
463
+ auto IS = random_updater ();
464
+ auto Phi = IS->Phi ();
465
+ auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
466
+ Phi, 1000 , 1e-5 ,
467
+ comm.broadcast (Vector<t_complex>::Ones (128 * 128 ).eval ()));
468
+ const t_real op_norm = std::get<0 >(power_method_stuff);
469
+
470
+ const auto solution = pfitsio::read2d (expected_solution_path);
471
+
472
+ t_uint const imsizey = 128 ;
473
+ t_uint const imsizex = 128 ;
474
+
475
+ // wavelets
476
+ std::vector<std::tuple<std::string, t_uint>> const sara{
477
+ std::make_tuple (" Dirac" , 3u ), std::make_tuple (" DB1" , 3u ), std::make_tuple (" DB2" , 3u ),
478
+ std::make_tuple (" DB3" , 3u ), std::make_tuple (" DB4" , 3u ), std::make_tuple (" DB5" , 3u ),
479
+ std::make_tuple (" DB6" , 3u ), std::make_tuple (" DB7" , 3u ), std::make_tuple (" DB8" , 3u )};
480
+ auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
481
+ factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
482
+
483
+ // algorithm
484
+ t_real const sigma = 0.016820222945913496 * std::sqrt (2 ); // see test_parameters file
485
+ t_real const beta = sigma * sigma;
486
+ t_real const gamma = 0.0001 ;
487
+
488
+ sopt::algorithm::ImagingForwardBackward<t_complex> fb (random_updater);
489
+ fb.itermax (1000 ).step_size (beta*sqrt (2 )).sigma (sigma*sqrt (2 )).regulariser_strength (gamma).relative_variation (1e-3 ).residual_tolerance (0 ).tight_frame (true ).sq_op_norm (op_norm*op_norm).obj_comm (comm);
490
+
491
+ auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false );
492
+ gp->l1_proximal_tolerance (1e-4 ).l1_proximal_nu (1 ).l1_proximal_itermax (50 ).l1_proximal_positivity_constraint (true ).l1_proximal_real_constraint (true ).Psi (*wavelets);
493
+ fb.g_function (gp);
494
+
495
+ auto const diagnostic = fb ();
496
+ const Image<t_complex> image = Image<t_complex>::Map (diagnostic.x .data (), imsizey, imsizex);
497
+ SOPT_HIGH_LOG (" God help me." );
498
+ if (comm.is_root ())
499
+ {
500
+ SOPT_HIGH_LOG (" Root write file" );
501
+ pfitsio::write2d (image.real (), result_path);
502
+ // pfitsio::write2d(residual_image.real(), expected_residual_path);
503
+ }
504
+ else
505
+ {
506
+ SOPT_HIGH_LOG (" Worker has nowt to do." );
507
+ }
508
+
509
+ auto soln_flat = Vector<t_complex>::Map (solution.data (), solution.size ());
510
+ double average_intensity = soln_flat.real ().sum () / soln_flat.size ();
511
+ SOPT_HIGH_LOG (" Average intensity = {}" , average_intensity);
512
+ double mse = (soln_flat - diagnostic.x )
513
+ .real ()
514
+ .squaredNorm () /
515
+ solution.size ();
516
+ SOPT_HIGH_LOG (" MSE = {}" , mse);
517
+ CHECK (mse <= average_intensity * 1e-3 );
518
+ }
519
+ #endif
0 commit comments