diff --git a/cpp/benchmarks/algorithms.cc b/cpp/benchmarks/algorithms.cc index 253a5c233..8da2b6163 100644 --- a/cpp/benchmarks/algorithms.cc +++ b/cpp/benchmarks/algorithms.cc @@ -71,7 +71,7 @@ BENCHMARK_DEFINE_F(AlgoFixture, Padmm)(benchmark::State &state) { m_padmm = factory::padmm_factory>( factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3, 1e-2, 50, - 1.0, 1.0); + 1.0); while (state.KeepRunning()) { auto start = std::chrono::high_resolution_clock::now(); @@ -92,7 +92,7 @@ BENCHMARK_DEFINE_F(AlgoFixture, ForwardBackward)(benchmark::State &state) { m_fb = factory::fb_factory>( factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3, - 1e-2, 50, 1.0); + 1e-2, 50); while (state.KeepRunning()) { auto start = std::chrono::high_resolution_clock::now(); diff --git a/cpp/benchmarks/algorithms_mpi.cc b/cpp/benchmarks/algorithms_mpi.cc index 7c31422c1..ed7386560 100644 --- a/cpp/benchmarks/algorithms_mpi.cc +++ b/cpp/benchmarks/algorithms_mpi.cc @@ -90,7 +90,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, PadmmDistributeImage)(benchmark::State &state m_padmm = factory::padmm_factory>( factory::algo_distribution::mpi_distributed, m_measurements_distribute_image, wavelets, m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, - false, 1e-3, 1e-2, 50, 1.0, 1.0); + false, 1e-3, 1e-2, 50, 1.0); // Benchmark the application of the algorithm while (state.KeepRunning()) { @@ -111,7 +111,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, PadmmDistributeGrid)(benchmark::State &state) m_padmm = factory::padmm_factory>( factory::algo_distribution::mpi_distributed, m_measurements_distribute_grid, wavelets, m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, - false, 1e-3, 1e-2, 50, 1.0, 1.0); + false, 1e-3, 1e-2, 50, 1.0); // Benchmark the application of the algorithm while (state.KeepRunning()) { @@ -135,7 +135,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbDistributeImage)(benchmark::State &state) { m_fb = factory::fb_factory>( factory::algo_distribution::mpi_serial, m_measurements_distribute_image, wavelets, m_uv_data, m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, - false, 1e-3, 1e-2, 50, 1.0); + false, 1e-3, 1e-2, 50); // Benchmark the application of the algorithm while (state.KeepRunning()) { @@ -159,7 +159,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbDistributeGrid)(benchmark::State &state) { m_fb = factory::fb_factory>( factory::algo_distribution::mpi_serial, m_measurements_distribute_grid, wavelets, m_uv_data, m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, - false, 1e-3, 1e-2, 50, 1.0); + false, 1e-3, 1e-2, 50); // Benchmark the application of the algorithm while (state.KeepRunning()) { diff --git a/cpp/example/padmm_mpi_random_coverage.cc b/cpp/example/padmm_mpi_random_coverage.cc index d2d68b175..107902bdc 100644 --- a/cpp/example/padmm_mpi_random_coverage.cc +++ b/cpp/example/padmm_mpi_random_coverage.cc @@ -116,7 +116,6 @@ std::shared_ptr> padmm_factory( .l1_proximal_real_constraint(true) .residual_tolerance(epsilon) .lagrange_update_scale(0.9) - .sq_op_norm(1e0) .Psi(Psi) .Phi(*measurements); sopt::ScalarRelativeVariation conv(padmm->relative_variation(), diff --git a/cpp/example/padmm_mpi_real_data.cc b/cpp/example/padmm_mpi_real_data.cc index 876232d35..93a64b65d 100644 --- a/cpp/example/padmm_mpi_real_data.cc +++ b/cpp/example/padmm_mpi_real_data.cc @@ -89,7 +89,6 @@ std::shared_ptr> padmm_factory( .l1_proximal_real_constraint(true) .residual_tolerance(epsilon) .lagrange_update_scale(0.9) - .sq_op_norm(1e0) .Psi(Psi) .Phi(*measurements); sopt::ScalarRelativeVariation conv(padmm->relative_variation(), diff --git a/cpp/example/padmm_random_coverage.cc b/cpp/example/padmm_random_coverage.cc index 310de0384..32822027c 100644 --- a/cpp/example/padmm_random_coverage.cc +++ b/cpp/example/padmm_random_coverage.cc @@ -108,7 +108,6 @@ void padmm(const std::string &name, const Image &M31, const std::stri #ifdef PURIFY_CImg .is_converged(show_image) #endif - .sq_op_norm(1e0) .Psi(Psi) .Phi(*measurements_transform); diff --git a/cpp/example/padmm_real_data.cc b/cpp/example/padmm_real_data.cc index de6d90d61..f9c93b067 100644 --- a/cpp/example/padmm_real_data.cc +++ b/cpp/example/padmm_real_data.cc @@ -91,7 +91,6 @@ void padmm(const std::string &name, const t_uint &imsizex, const t_uint &imsizey .l1_proximal_real_constraint(true) .residual_convergence(epsilon) .lagrange_update_scale(0.9) - .sq_op_norm(1e0) .Psi(Psi) .Phi(*measurements_transform); diff --git a/cpp/example/padmm_reweighted_simulation.cc b/cpp/example/padmm_reweighted_simulation.cc index c4ee3bf5d..8008c811c 100644 --- a/cpp/example/padmm_reweighted_simulation.cc +++ b/cpp/example/padmm_reweighted_simulation.cc @@ -102,7 +102,6 @@ int main(int nargs, char const **args) { .l1_proximal_real_constraint(true) .residual_convergence(epsilon * 1.001) .lagrange_update_scale(0.9) - .sq_op_norm(1e0) .Psi(Psi) .Phi(measurements_transform); // Timing reconstruction diff --git a/cpp/example/padmm_simulation.cc b/cpp/example/padmm_simulation.cc index 49ab58b92..2dd43a137 100644 --- a/cpp/example/padmm_simulation.cc +++ b/cpp/example/padmm_simulation.cc @@ -117,7 +117,6 @@ int main(int nargs, char const **args) { .l1_proximal_real_constraint(true) .residual_convergence(epsilon * 1.001) .lagrange_update_scale(0.9) - .sq_op_norm(1e0) .Psi(Psi) .itermax(100) .is_converged(convergence_function) diff --git a/cpp/example/sara_padmm_random_coverage.cc b/cpp/example/sara_padmm_random_coverage.cc index 3de5c8472..0f9530004 100644 --- a/cpp/example/sara_padmm_random_coverage.cc +++ b/cpp/example/sara_padmm_random_coverage.cc @@ -97,7 +97,6 @@ int main(int, char **) { .l1_proximal_real_constraint(true) .residual_convergence(epsilon * 1.001) .lagrange_update_scale(0.9) - .sq_op_norm(1e0) .Psi(Psi) .Phi(measurements_transform); diff --git a/cpp/main.cc b/cpp/main.cc index 420bd5e4a..a95d87bf8 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -61,14 +61,14 @@ int main(int argc, const char **argv) { getInputData(params, mop_algo, wop_algo, using_mpi); // create measurement operator - auto [measurements_transform, operator_norm] = + auto measurements_transform = createMeasurementOperator(params, mop_algo, wop_algo, using_mpi, image_index, w_stacks, uv_data, measurement_op_eigen_vector); // create wavelet operator const waveletInfo wavelets = createWaveletOperator(params, wop_algo); - PURIFY_LOW_LOG("Value of operator norm is {}", operator_norm); + PURIFY_LOW_LOG("Value of operator norm is {}", measurements_transform->norm()); t_real const flux_scale = 1.; uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale; @@ -95,8 +95,7 @@ int main(int argc, const char **argv) { beam_units = uv_data.size() / flux_scale / flux_scale; } - savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, operator_norm, - beam_units); + savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, beam_units); // the dirty image saveDirtyImage(params, def_header, measurements_transform, uv_data, beam_units); @@ -114,7 +113,7 @@ int main(int argc, const char **argv) { (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and (not params.positiveValueConstraint()), params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50, - params.epsilonConvergenceScaling(), operator_norm); + params.epsilonConvergenceScaling()); if (params.algorithm() == "fb") { std::shared_ptr> f; if (params.diffFuncType() == diff_func_type::L2Norm_with_CRR) { @@ -135,7 +134,7 @@ int main(int argc, const char **argv) { params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(), (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and (not params.positiveValueConstraint()), - params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50, operator_norm, + params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50, params.model_path(), params.nondiffFuncType(), f); } if (params.algorithm() == "primaldual") @@ -144,7 +143,7 @@ int main(int argc, const char **argv) { sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), wavelets.sara_size, params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(), params.relVarianceConvergence(), - params.epsilonConvergenceScaling(), operator_norm); + params.epsilonConvergenceScaling()); // Add primal dual preconditioning if (params.algorithm() == "primaldual" and params.precondition_iters() > 0) { PURIFY_HIGH_LOG( diff --git a/cpp/purify/algorithm_factory.h b/cpp/purify/algorithm_factory.h index 9f90677d2..a6a245d99 100644 --- a/cpp/purify/algorithm_factory.h +++ b/cpp/purify/algorithm_factory.h @@ -60,7 +60,7 @@ padmm_factory(const algo_distribution dist, const bool tight_frame = false, const t_real relative_variation = 1e-3, const t_real l1_proximal_tolerance = 1e-2, const t_uint maximum_proximal_iterations = 50, - const t_real residual_tolerance_scaling = 1, const t_real op_norm = 1) { + const t_real residual_tolerance_scaling = 1) { typedef typename Algorithm::Scalar t_scalar; if (sara_size > 1 and tight_frame) throw std::runtime_error( @@ -78,7 +78,6 @@ padmm_factory(const algo_distribution dist, .l1_proximal_positivity_constraint(positive_constraint) .l1_proximal_real_constraint(real_constraint) .lagrange_update_scale(0.9) - .sq_op_norm(op_norm * op_norm) .Psi(*wavelets) .Phi(*measurements); #ifdef PURIFY_MPI @@ -162,7 +161,7 @@ fb_factory(const algo_distribution dist, const bool real_constraint = true, const bool positive_constraint = true, const bool tight_frame = false, const t_real relative_variation = 1e-3, const t_real l1_proximal_tolerance = 1e-2, const t_uint maximum_proximal_iterations = 50, - const t_real op_norm = 1, const std::string model_path = "", + const std::string model_path = "", const nondiff_func_type g_proximal = nondiff_func_type::L1Norm, std::shared_ptr> f_function = nullptr) { typedef typename Algorithm::Scalar t_scalar; @@ -178,7 +177,6 @@ fb_factory(const algo_distribution dist, .step_size(step_size * std::sqrt(2)) .relative_variation(relative_variation) .tight_frame(tight_frame) - .sq_op_norm(op_norm * op_norm) .Phi(*measurements); if (f_function) fb->f_function(f_function); // only override f_function default if non-null @@ -262,8 +260,7 @@ primaldual_factory( const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations = 500, const bool real_constraint = true, const bool positive_constraint = true, - const t_real relative_variation = 1e-3, const t_real residual_tolerance_scaling = 1, - const t_real op_norm = 1) { + const t_real relative_variation = 1e-3, const t_real residual_tolerance_scaling = 1) { typedef typename Algorithm::Scalar t_scalar; PURIFY_INFO("Constructing Primal Dual algorithm"); auto epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size())) * sigma; @@ -274,7 +271,7 @@ primaldual_factory( .positivity_constraint(positive_constraint) .Psi(*wavelets) .Phi(*measurements) - .tau(0.5 / (op_norm * op_norm + 1)) + .tau(0.5 / (measurements->sq_norm() + 1)) .xi(1.) .sigma(1.); #ifdef PURIFY_MPI diff --git a/cpp/purify/setup_utils.cc b/cpp/purify/setup_utils.cc index 12186e1dd..1c6a550c7 100644 --- a/cpp/purify/setup_utils.cc +++ b/cpp/purify/setup_utils.cc @@ -247,7 +247,7 @@ inputData getInputData(const YamlParser ¶ms, return {uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks}; } -measurementOpInfo createMeasurementOperator( +std::shared_ptr>> createMeasurementOperator( const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, const std::vector &image_index, const std::vector &w_stacks, @@ -294,6 +294,7 @@ measurementOpInfo createMeasurementOperator( params.powMethod_tolerance(), comm.broadcast(measurement_op_eigen_vector).eval()); measurement_op_eigen_vector = std::get<1>(power_method_result); operator_norm = std::get<0>(power_method_result); + measurements_transform->set_norm(operator_norm); } else #endif { @@ -302,9 +303,10 @@ measurementOpInfo createMeasurementOperator( measurement_op_eigen_vector); measurement_op_eigen_vector = std::get<1>(power_method_result); operator_norm = std::get<0>(power_method_result); + measurements_transform->set_norm(operator_norm); } - return {measurements_transform, operator_norm}; + return measurements_transform; } void setupCostFunctions(const YamlParser ¶ms, std::unique_ptr> &f, @@ -402,7 +404,7 @@ void savePSF( const YamlParser ¶ms, const pfitsio::header_params &def_header, const std::shared_ptr>> &measurements_transform, const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma, - const t_real operator_norm, const t_real beam_units) { + const t_real beam_units) { pfitsio::header_params psf_header = def_header; psf_header.fits_name = params.output_path() + "/psf.fits"; psf_header.pix_units = "Jy/Pixel"; @@ -417,7 +419,7 @@ void savePSF( auto const world = sopt::mpi::Communicator::World(); PURIFY_LOW_LOG( "Expected image domain residual RMS is {} jy/beam", - sigma * params.epsilonScaling() * operator_norm / + sigma * params.epsilonScaling() * measurements_transform->norm() / (std::sqrt(params.width() * params.height()) * world.all_sum_all(uv_data.size()))); if (world.is_root()) #else @@ -426,7 +428,7 @@ void savePSF( pfitsio::write2d(psf_image, psf_header, true); } else { PURIFY_LOW_LOG("Expected image domain residual RMS is {} jy/beam", - sigma * params.epsilonScaling() * operator_norm / + sigma * params.epsilonScaling() * measurements_transform->norm() / (std::sqrt(params.width() * params.height()) * uv_data.size())); pfitsio::write2d(psf_image, psf_header, true); } diff --git a/cpp/purify/setup_utils.h b/cpp/purify/setup_utils.h index fc8f40833..4198b54d4 100644 --- a/cpp/purify/setup_utils.h +++ b/cpp/purify/setup_utils.h @@ -41,12 +41,7 @@ inputData getInputData(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi); -struct measurementOpInfo { - std::shared_ptr>> measurement_transform; - t_real operator_norm; -}; - -measurementOpInfo createMeasurementOperator( +std::shared_ptr>> createMeasurementOperator( const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, const std::vector &image_index, const std::vector &w_stacks, @@ -73,7 +68,7 @@ void savePSF( const YamlParser ¶ms, const pfitsio::header_params &def_header, const std::shared_ptr>> &measurements_transform, const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma, - const t_real operator_norm, const t_real beam_units); + const t_real beam_units); void saveDirtyImage( const YamlParser ¶ms, const pfitsio::header_params &def_header, diff --git a/cpp/tests/algo_factory.cc b/cpp/tests/algo_factory.cc index dbd12e9cb..4ba724efb 100644 --- a/cpp/tests/algo_factory.cc +++ b/cpp/tests/algo_factory.cc @@ -21,6 +21,7 @@ #include "purify/h5reader.h" #endif +#include #include #include "purify/test_data.h" @@ -44,12 +45,14 @@ TEST_CASE("padmm_factory") { t_uint const imsizey = 128; t_uint const imsizex = 128; Vector const init = Vector::Ones(imsizex * imsizey); - auto const measurements_transform = factory::measurement_operator_factory>( + auto measurements_transform = factory::measurement_operator_factory>( factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); auto const power_method_stuff = sopt::algorithm::power_method>(*measurements_transform, 1000, 1e-5, init); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -59,7 +62,7 @@ TEST_CASE("padmm_factory") { t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file auto const padmm = factory::padmm_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, imsizey, - imsizex, sara.size(), 300, true, true, false, 1e-2, 1e-3, 50, 1, op_norm); + imsizex, sara.size(), 300, true, true, false, 1e-2, 1e-3, 50, 1); auto const diagnostic = (*padmm)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); @@ -98,12 +101,14 @@ TEST_CASE("primal_dual_factory", "[!shouldfail]") { t_uint const imsizex = 128; Vector const init = Vector::Ones(imsizex * imsizey); - auto const measurements_transform = factory::measurement_operator_factory>( + auto measurements_transform = factory::measurement_operator_factory>( factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); auto const power_method_stuff = sopt::algorithm::power_method>(*measurements_transform, 1000, 1e-5, init); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -114,7 +119,7 @@ TEST_CASE("primal_dual_factory", "[!shouldfail]") { auto const primaldual = factory::primaldual_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, - imsizey, imsizex, sara.size(), 20, true, true, 1e-2, 1, op_norm); + imsizey, imsizex, sara.size(), 20, true, true, 1e-2, 1); auto const diagnostic = (*primaldual)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); @@ -152,12 +157,14 @@ TEST_CASE("fb_factory") { t_uint const imsizex = 128; Vector const init = Vector::Ones(imsizex * imsizey); - auto const measurements_transform = factory::measurement_operator_factory>( + auto measurements_transform = factory::measurement_operator_factory>( factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); auto const power_method_stuff = sopt::algorithm::power_method>(*measurements_transform, 1000, 1e-5, init); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -169,7 +176,7 @@ TEST_CASE("fb_factory") { t_real const gamma = 0.0001; auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, - gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm); + gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50); auto const diagnostic = (*fb)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); @@ -239,16 +246,15 @@ TEST_CASE("fb_factory_stochastic") { factory::distributed_measurement_operator::serial, uv_data_fragment, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); + Vector const init = Vector::Ones(imsizex * imsizey); + auto const power_method_stuff = + sopt::algorithm::power_method>(*phi, 1000, 1e-5, init); + const t_real op_norm = std::get<0>(power_method_stuff); + phi->set_norm(op_norm); + return std::make_shared>>(uv_data_fragment.vis, phi); }; - Vector const init = Vector::Ones(imsizex * imsizey); - auto IS = random_updater(); - auto Phi = IS->Phi(); - auto const power_method_stuff = - sopt::algorithm::power_method>(Phi, 1000, 1e-5, init); - const t_real op_norm = std::get<0>(power_method_stuff); - const auto solution = pfitsio::read2d(expected_solution_path); // wavelets @@ -271,8 +277,7 @@ TEST_CASE("fb_factory_stochastic") { .regulariser_strength(gamma) .relative_variation(1e-3) .residual_tolerance(0) - .tight_frame(true) - .sq_op_norm(op_norm * op_norm); + .tight_frame(true); auto gp = std::make_shared>(false); gp->l1_proximal_tolerance(1e-4) @@ -317,12 +322,14 @@ TEST_CASE("tf_fb_factory") { t_uint const imsizex = 128; Vector const init = Vector::Ones(imsizex * imsizey); - auto const measurements_transform = factory::measurement_operator_factory>( + auto measurements_transform = factory::measurement_operator_factory>( factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); auto const power_method_stuff = sopt::algorithm::power_method>(*measurements_transform, 1000, 1e-5, init); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -337,8 +344,8 @@ TEST_CASE("tf_fb_factory") { auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, - gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, - tf_model_path, nondiff_func_type::Denoiser); + gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, tf_model_path, + nondiff_func_type::Denoiser); auto const diagnostic = (*fb)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); @@ -373,12 +380,14 @@ TEST_CASE("onnx_fb_factory") { t_uint const imsizex = 128; Vector const init = Vector::Ones(imsizex * imsizey); - auto const measurements_transform = factory::measurement_operator_factory>( + auto measurements_transform = factory::measurement_operator_factory>( factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); auto const power_method_stuff = sopt::algorithm::power_method>(*measurements_transform, 1000, 1e-5, init); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -399,7 +408,7 @@ TEST_CASE("onnx_fb_factory") { auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, - gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, "", + gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, "", nondiff_func_type::RealIndicator, diff_function); auto const diagnostic = (*fb)(); @@ -436,12 +445,14 @@ TEST_CASE("joint_map_factory") { t_uint const imsizex = 128; Vector const init = Vector::Ones(imsizex * imsizey); - auto const measurements_transform = factory::measurement_operator_factory>( + auto measurements_transform = factory::measurement_operator_factory>( factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); auto const power_method_stuff = sopt::algorithm::power_method>(*measurements_transform, 1000, 1e-5, init); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -453,7 +464,7 @@ TEST_CASE("joint_map_factory") { t_real const gamma = 1; auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, - gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm); + gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50); auto const l1_norm = [wavelets](const Vector &x) { auto val = sopt::l1_norm(wavelets->adjoint() * x); return val; diff --git a/cpp/tests/mpi_algo_factory.cc b/cpp/tests/mpi_algo_factory.cc index dae442df5..f54ad890c 100644 --- a/cpp/tests/mpi_algo_factory.cc +++ b/cpp/tests/mpi_algo_factory.cc @@ -57,13 +57,15 @@ TEST_CASE("Serial vs. Serial with MPI PADMM") { t_uint const imsizey = 128; t_uint const imsizex = 128; - auto const measurements_transform = factory::measurement_operator_factory>( + auto measurements_transform = factory::measurement_operator_factory>( factory::distributed_measurement_operator::mpi_distribute_image, uv_data, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); Vector const init = Vector::Ones(imsizex * imsizey).eval(); auto const power_method_stuff = sopt::algorithm::power_method>(*measurements_transform, 1000, 1e-5, init); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -75,7 +77,7 @@ TEST_CASE("Serial vs. Serial with MPI PADMM") { SECTION("global") { auto const padmm = factory::padmm_factory>( factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma, - imsizey, imsizex, sara.size(), 300, true, true, false, 1e-2, 1e-3, 50, 1, op_norm); + imsizey, imsizex, sara.size(), 300, true, true, false, 1e-2, 1e-3, 50, 1); auto const diagnostic = (*padmm)(); CHECK(diagnostic.niters == 10); @@ -103,7 +105,7 @@ TEST_CASE("Serial vs. Serial with MPI PADMM") { SECTION("local") { auto const padmm = factory::padmm_factory>( factory::algo_distribution::mpi_distributed, measurements_transform, wavelets, uv_data, - sigma, imsizey, imsizex, sara.size(), 500, true, true, false, 1e-2, 1e-3, 50, 1, op_norm); + sigma, imsizey, imsizex, sara.size(), 500, true, true, false, 1e-2, 1e-3, 50, 1); auto const diagnostic = (*padmm)(); t_real const epsilon = utilities::calculate_l2_radius(world.all_sum_all(uv_data.vis.size()), @@ -166,6 +168,8 @@ TEST_CASE("Serial vs. Serial with MPI Primal Dual", "[!shouldfail]") { *measurements_transform, 1000, 1e-5, world.broadcast(Vector::Ones(imsizex * imsizey).eval())); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -178,7 +182,7 @@ TEST_CASE("Serial vs. Serial with MPI Primal Dual", "[!shouldfail]") { auto const primaldual = factory::primaldual_factory>( factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, - sigma, imsizey, imsizex, sara.size(), 500, true, true, 1e-2, 1, op_norm); + sigma, imsizey, imsizex, sara.size(), 500, true, true, 1e-2, 1); auto const diagnostic = (*primaldual)(); CHECK(diagnostic.niters == 16); @@ -207,7 +211,7 @@ TEST_CASE("Serial vs. Serial with MPI Primal Dual", "[!shouldfail]") { auto const primaldual = factory::primaldual_factory>( factory::algo_distribution::mpi_distributed, measurements_transform, wavelets, uv_data, - sigma, imsizey, imsizex, sara.size(), 500, true, true, 1e-2, 1, op_norm); + sigma, imsizey, imsizex, sara.size(), 500, true, true, 1e-2, 1); auto const diagnostic = (*primaldual)(); t_real const epsilon = utilities::calculate_l2_radius(world.all_sum_all(uv_data.vis.size()), @@ -254,6 +258,8 @@ TEST_CASE("Serial vs. Serial with MPI Primal Dual", "[!shouldfail]") { world, *measurements_transform, 1000, 1e-5, world.broadcast(Vector::Ones(imsizex * imsizey).eval())); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + auto sara_dist = sopt::wavelets::distribute_sara(sara, world); auto const wavelets_serial = factory::wavelet_operator_factory>( factory::distributed_wavelet_operator::serial, sara_dist, imsizey, imsizex); @@ -262,7 +268,7 @@ TEST_CASE("Serial vs. Serial with MPI Primal Dual", "[!shouldfail]") { factory::primaldual_factory>( factory::algo_distribution::mpi_random_updates, measurements_transform_serial, wavelets_serial, uv_data, sigma, imsizey, imsizex, sara_dist.size(), 500, true, true, - 1e-2, 1, op_norm); + 1e-2, 1); auto const diagnostic = (*primaldual)(); t_real const epsilon = utilities::calculate_l2_radius(world.all_sum_all(uv_data.vis.size()), @@ -327,13 +333,15 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") { t_uint const imsizey = 128; t_uint const imsizex = 128; - auto const measurements_transform = factory::measurement_operator_factory>( + auto measurements_transform = factory::measurement_operator_factory>( factory::distributed_measurement_operator::mpi_distribute_image, uv_data, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); auto const power_method_stuff = sopt::algorithm::power_method>( *measurements_transform, 1000, 1e-5, world.broadcast(Vector::Ones(imsizex * imsizey).eval())); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -346,7 +354,7 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") { t_real const gamma = 0.0001; auto const fb = factory::fb_factory>( factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma, - beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm); + beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50); auto const diagnostic = (*fb)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); @@ -391,13 +399,15 @@ TEST_CASE("MPI_fb_factory_hdf5") { t_uint const imsizey = 128; t_uint const imsizex = 128; - auto const measurements_transform = factory::measurement_operator_factory>( + auto measurements_transform = factory::measurement_operator_factory>( factory::distributed_measurement_operator::mpi_distribute_image, uv_data, imsizey, imsizex, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); auto const power_method_stuff = sopt::algorithm::power_method>( *measurements_transform, 1000, 1e-5, world.broadcast(Vector::Ones(imsizex * imsizey).eval())); const t_real op_norm = std::get<0>(power_method_stuff); + measurements_transform->set_norm(op_norm); + std::vector> const sara{ std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), @@ -410,7 +420,7 @@ TEST_CASE("MPI_fb_factory_hdf5") { t_real const gamma = 0.0001; auto const fb = factory::fb_factory>( factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma, - beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm); + beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50); auto const diagnostic = (*fb)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); @@ -449,23 +459,22 @@ TEST_CASE("fb_factory_stochastic") { // This functor would be defined in Purify std::function>>()> random_updater = - [&f = h5file, &N]() { + [&h5file, &N, &comm]() { utilities::vis_params uv_data = - H5::stochread_visibility(f, N, false); // no w-term in this data-set + H5::stochread_visibility(h5file, N, false); // no w-term in this data-set uv_data.units = utilities::vis_units::radians; auto phi = factory::measurement_operator_factory( factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128, 128, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); + auto const power_method_stuff = sopt::algorithm::power_method>( + *phi, 1000, 1e-5, comm.broadcast(Vector::Ones(128 * 128).eval())); + const t_real op_norm = std::get<0>(power_method_stuff); + phi->set_norm(op_norm); + return std::make_shared>>(uv_data.vis, phi); }; - auto IS = random_updater(); - auto Phi = IS->Phi(); - auto const power_method_stuff = sopt::algorithm::power_method>( - Phi, 1000, 1e-5, comm.broadcast(Vector::Ones(128 * 128).eval())); - const t_real op_norm = std::get<0>(power_method_stuff); - const auto solution = pfitsio::read2d(expected_solution_path); t_uint const imsizey = 128; @@ -492,7 +501,6 @@ TEST_CASE("fb_factory_stochastic") { .relative_variation(1e-3) .residual_tolerance(0) .tight_frame(true) - .sq_op_norm(op_norm * op_norm) .obj_comm(comm); auto gp = std::make_shared>(false); diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc index 4ff42f92e..1bdaa2658 100644 --- a/cpp/uncertainty_quantification/uq_main.cc +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -73,7 +73,7 @@ int main(int argc, char **argv) { auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = getInputData(purify_config, mop_algo, wop_algo, using_mpi); - auto [transform, operator_norm] = + auto transform = createMeasurementOperator(purify_config, mop_algo, wop_algo, using_mpi, image_index, w_stacks, uv_data, measurement_op_eigen_vector);