16
16
#include < sopt/power_method.h>
17
17
#include < sopt/wavelets.h>
18
18
19
+ #ifdef PURIFY_H5
20
+ #include " purify/h5reader.h"
21
+ #endif
22
+
19
23
#include " purify/algorithm_factory.h"
20
24
#include " purify/measurement_operator_factory.h"
21
25
#include " purify/wavelet_operator_factory.h"
@@ -311,6 +315,7 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
311
315
312
316
const std::string &test_dir = " expected/fb/" ;
313
317
const std::string &input_data_path = data_filename (test_dir + " input_data.vis" );
318
+ const std::string &result_path = data_filename (test_dir + " mpi_fb_result.fits" );
314
319
315
320
auto uv_data = dirty_visibilities ({input_data_path}, world);
316
321
uv_data.units = utilities::vis_units::radians;
@@ -344,6 +349,75 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
344
349
beta, gamma, imsizey, imsizex, sara.size (), 1000 , true , true , false , 1e-2 , 1e-3 , 50 , op_norm);
345
350
346
351
auto const diagnostic = (*fb)();
352
+ const Image<t_complex> image = Image<t_complex>::Map (diagnostic.x .data (), imsizey, imsizex);
353
+ if (world.is_root ()) {
354
+ pfitsio::write2d (image.real (), result_path);
355
+ // pfitsio::write2d(residual_image.real(), expected_residual_path);
356
+ }
357
+
358
+ const std::string &expected_solution_path = data_filename (test_dir + " solution.fits" );
359
+ const std::string &expected_residual_path = data_filename (test_dir + " residual.fits" );
360
+
361
+ const auto solution = pfitsio::read2d (expected_solution_path);
362
+ const auto residual = pfitsio::read2d (expected_residual_path);
363
+
364
+ double average_intensity = diagnostic.x .real ().sum () / diagnostic.x .size ();
365
+ SOPT_HIGH_LOG (" Average intensity = {}" , average_intensity);
366
+ double mse = (Vector<t_complex>::Map (solution.data (), solution.size ()) - diagnostic.x )
367
+ .real ()
368
+ .squaredNorm () /
369
+ solution.size ();
370
+ SOPT_HIGH_LOG (" MSE = {}" , mse);
371
+ CHECK (mse <= average_intensity * 1e-3 );
372
+ }
373
+
374
+ #ifdef PURIFY_H5
375
+ TEST_CASE (" MPI_fb_factory_hdf5" ) {
376
+ auto const world = sopt::mpi::Communicator::World ();
377
+ const size_t N = 13107 ;
378
+
379
+ const std::string &test_dir = " expected/fb/" ;
380
+ const std::string &input_data_path = data_filename (test_dir + " input_data.h5" );
381
+ const std::string &result_path = data_filename (test_dir + " mpi_fb_result_hdf5.fits" );
382
+ H5::H5Handler h5file (input_data_path, world);
383
+
384
+ auto uv_data = H5::stochread_visibility (h5file, 6000 , false );
385
+ uv_data.units = utilities::vis_units::radians;
386
+ if (world.is_root ()) {
387
+ CAPTURE (uv_data.vis .head (5 ));
388
+ }
389
+ // REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
390
+
391
+ t_uint const imsizey = 128 ;
392
+ t_uint const imsizex = 128 ;
393
+
394
+ auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
395
+ factory::distributed_measurement_operator::mpi_distribute_image, uv_data, imsizey, imsizex, 1 ,
396
+ 1 , 2 , kernels::kernel_from_string.at (" kb" ), 4 , 4 );
397
+ auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
398
+ *measurements_transform, 1000 , 1e-5 ,
399
+ world.broadcast (Vector<t_complex>::Ones (imsizex * imsizey).eval ()));
400
+ const t_real op_norm = std::get<0 >(power_method_stuff);
401
+ std::vector<std::tuple<std::string, t_uint>> const sara{
402
+ std::make_tuple (" Dirac" , 3u ), std::make_tuple (" DB1" , 3u ), std::make_tuple (" DB2" , 3u ),
403
+ std::make_tuple (" DB3" , 3u ), std::make_tuple (" DB4" , 3u ), std::make_tuple (" DB5" , 3u ),
404
+ std::make_tuple (" DB6" , 3u ), std::make_tuple (" DB7" , 3u ), std::make_tuple (" DB8" , 3u )};
405
+ auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
406
+ factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
407
+ t_real const sigma =
408
+ world.broadcast (0.016820222945913496 ) * std::sqrt (2 ); // see test_parameters file
409
+ t_real const beta = sigma * sigma;
410
+ t_real const gamma = 0.0001 ;
411
+ auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
412
+ factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma,
413
+ beta, gamma, imsizey, imsizex, sara.size (), 1000 , true , true , false , 1e-2 , 1e-3 , 50 , op_norm);
414
+
415
+ auto const diagnostic = (*fb)();
416
+ const Image<t_complex> image = Image<t_complex>::Map (diagnostic.x .data (), imsizey, imsizex);
417
+ // if (world.is_root())
418
+ // {
419
+ // pfitsio::write2d(image.real(), result_path);
420
+ // }
347
421
348
422
const std::string &expected_solution_path = data_filename (test_dir + " solution.fits" );
349
423
const std::string &expected_residual_path = data_filename (test_dir + " residual.fits" );
@@ -360,3 +434,88 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
360
434
SOPT_HIGH_LOG (" MSE = {}" , mse);
361
435
CHECK (mse <= average_intensity * 1e-3 );
362
436
}
437
+
438
+ TEST_CASE (" fb_factory_stochastic" ) {
439
+ const std::string &test_dir = " expected/fb/" ;
440
+ const std::string &input_data_path = data_filename (test_dir + " input_data.h5" );
441
+ const std::string &expected_solution_path = data_filename (test_dir + " solution.fits" );
442
+ const std::string &result_path = data_filename (test_dir + " fb_stochastic_result_mpi.fits" );
443
+
444
+ // HDF5
445
+ auto const comm = sopt::mpi::Communicator::World ();
446
+ const size_t N = 2000 ;
447
+ H5::H5Handler h5file (input_data_path, comm); // length 13107
448
+ using t_complexVec = Vector<t_complex>;
449
+
450
+ // This functor would be defined in Purify
451
+ std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
452
+ [&f = h5file, &N]() {
453
+ utilities::vis_params uv_data =
454
+ H5::stochread_visibility (f, N, false ); // no w-term in this data-set
455
+ uv_data.units = utilities::vis_units::radians;
456
+ auto phi = factory::measurement_operator_factory<t_complexVec>(
457
+ factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128 , 128 , 1 ,
458
+ 1 , 2 , kernels::kernel_from_string.at (" kb" ), 4 , 4 );
459
+
460
+ return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis , phi);
461
+ };
462
+
463
+ auto IS = random_updater ();
464
+ auto Phi = IS->Phi ();
465
+ auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
466
+ Phi, 1000 , 1e-5 , comm.broadcast (Vector<t_complex>::Ones (128 * 128 ).eval ()));
467
+ const t_real op_norm = std::get<0 >(power_method_stuff);
468
+
469
+ const auto solution = pfitsio::read2d (expected_solution_path);
470
+
471
+ t_uint const imsizey = 128 ;
472
+ t_uint const imsizex = 128 ;
473
+
474
+ // wavelets
475
+ std::vector<std::tuple<std::string, t_uint>> const sara{
476
+ std::make_tuple (" Dirac" , 3u ), std::make_tuple (" DB1" , 3u ), std::make_tuple (" DB2" , 3u ),
477
+ std::make_tuple (" DB3" , 3u ), std::make_tuple (" DB4" , 3u ), std::make_tuple (" DB5" , 3u ),
478
+ std::make_tuple (" DB6" , 3u ), std::make_tuple (" DB7" , 3u ), std::make_tuple (" DB8" , 3u )};
479
+ auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
480
+ factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
481
+
482
+ // algorithm
483
+ t_real const sigma = 0.016820222945913496 * std::sqrt (2 ); // see test_parameters file
484
+ t_real const beta = sigma * sigma;
485
+ t_real const gamma = 0.0001 ;
486
+
487
+ sopt::algorithm::ImagingForwardBackward<t_complex> fb (random_updater);
488
+ fb.itermax (1000 )
489
+ .step_size (beta * sqrt (2 ))
490
+ .sigma (sigma * sqrt (2 ))
491
+ .regulariser_strength (gamma)
492
+ .relative_variation (1e-3 )
493
+ .residual_tolerance (0 )
494
+ .tight_frame (true )
495
+ .sq_op_norm (op_norm * op_norm)
496
+ .obj_comm (comm);
497
+
498
+ auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false );
499
+ gp->l1_proximal_tolerance (1e-4 )
500
+ .l1_proximal_nu (1 )
501
+ .l1_proximal_itermax (50 )
502
+ .l1_proximal_positivity_constraint (true )
503
+ .l1_proximal_real_constraint (true )
504
+ .Psi (*wavelets);
505
+ fb.g_function (gp);
506
+
507
+ auto const diagnostic = fb ();
508
+ const Image<t_complex> image = Image<t_complex>::Map (diagnostic.x .data (), imsizey, imsizex);
509
+ // if (comm.is_root())
510
+ // {
511
+ // //pfitsio::write2d(image.real(), result_path);
512
+ // }
513
+
514
+ auto soln_flat = Vector<t_complex>::Map (solution.data (), solution.size ());
515
+ double average_intensity = soln_flat.real ().sum () / soln_flat.size ();
516
+ SOPT_HIGH_LOG (" Average intensity = {}" , average_intensity);
517
+ double mse = (soln_flat - diagnostic.x ).real ().squaredNorm () / solution.size ();
518
+ SOPT_HIGH_LOG (" MSE = {}" , mse);
519
+ CHECK (mse <= average_intensity * 1e-3 );
520
+ }
521
+ #endif
0 commit comments