diff --git a/cpp/tests/algo_factory.cc b/cpp/tests/algo_factory.cc index 4ba724ef..4c559c60 100644 --- a/cpp/tests/algo_factory.cc +++ b/cpp/tests/algo_factory.cc @@ -81,13 +81,12 @@ TEST_CASE("padmm_factory") { CHECK(residual_image.real().isApprox(residual.real(), 1e-4)); } -// This test does not converge and is therefore set to shouldfail. -// See https://github.com/astro-informatics/purify/issues/317 for details. -TEST_CASE("primal_dual_factory", "[!shouldfail]") { +TEST_CASE("primal_dual_factory") { const std::string &test_dir = "expected/primal_dual/"; const std::string &input_data_path = data_filename(test_dir + "input_data.vis"); const std::string &expected_solution_path = data_filename(test_dir + "solution.fits"); const std::string &expected_residual_path = data_filename(test_dir + "residual.fits"); + const std::string &result_path = data_filename(test_dir + "pd_result.fits"); const auto solution = pfitsio::read2d(expected_solution_path); const auto residual = pfitsio::read2d(expected_residual_path); @@ -119,23 +118,20 @@ TEST_CASE("primal_dual_factory", "[!shouldfail]") { auto const primaldual = factory::primaldual_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, - imsizey, imsizex, sara.size(), 20, true, true, 1e-2, 1); + imsizey, imsizex, sara.size(), 1000, true, true, 1e-3, 1); auto const diagnostic = (*primaldual)(); + const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); - // pfitsio::write2d(image.real(), expected_solution_path); - CAPTURE(Vector::Map(solution.data(), solution.size()).real().head(10)); - CAPTURE(Vector::Map(image.data(), image.size()).real().head(10)); - CAPTURE(Vector::Map((image / solution).eval().data(), image.size()).real().head(10)); - CHECK(image.isApprox(solution, 1e-4)); + // pfitsio::write2d(image.real(), result_path); - const Vector residuals = measurements_transform->adjoint() * - (uv_data.vis - ((*measurements_transform) * diagnostic.x)); - const Image residual_image = Image::Map(residuals.data(), imsizey, imsizex); - // pfitsio::write2d(residual_image.real(), expected_residual_path); - CAPTURE(Vector::Map(residual.data(), residual.size()).real().head(10)); - CAPTURE(Vector::Map(residuals.data(), residuals.size()).real().head(10)); - CHECK(residual_image.real().isApprox(residual.real(), 1e-4)); + double brightness = solution.real().cwiseAbs().maxCoeff(); + double mse = (Vector::Map(solution.data(), solution.size()) - diagnostic.x) + .real() + .squaredNorm() / + solution.size(); + double rms = sqrt(mse); + CHECK(rms <= brightness * 5e-2); } TEST_CASE("fb_factory") { @@ -171,9 +167,11 @@ TEST_CASE("fb_factory") { std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)}; auto const wavelets = factory::wavelet_operator_factory>( factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex); + t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file t_real const beta = sigma * sigma; t_real const gamma = 0.0001; + auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50); @@ -183,14 +181,13 @@ TEST_CASE("fb_factory") { pfitsio::write2d(image.real(), result_path); // pfitsio::write2d(residual_image.real(), expected_residual_path); - double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size(); - SOPT_HIGH_LOG("Average intensity = {}", average_intensity); + double brightness = solution.real().cwiseAbs().maxCoeff(); double mse = (Vector::Map(solution.data(), solution.size()) - diagnostic.x) .real() .squaredNorm() / solution.size(); - SOPT_HIGH_LOG("MSE = {}", mse); - CHECK(mse <= average_intensity * 1e-3); + double rms = sqrt(mse); + CHECK(rms <= brightness * 5e-2); } #ifdef PURIFY_H5 @@ -294,11 +291,10 @@ TEST_CASE("fb_factory_stochastic") { // pfitsio::write2d(residual_image.real(), expected_residual_path); auto soln_flat = Vector::Map(solution.data(), solution.size()); - double average_intensity = soln_flat.real().sum() / soln_flat.size(); - SOPT_HIGH_LOG("Average intensity = {}", average_intensity); + double brightness = soln_flat.real().cwiseAbs().maxCoeff(); double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size(); SOPT_HIGH_LOG("MSE = {}", mse); - CHECK(mse <= average_intensity * 1e-3); + CHECK(mse <= brightness * 5e-2); } #endif @@ -352,14 +348,13 @@ TEST_CASE("tf_fb_factory") { // pfitsio::write2d(image.real(), result_path); // pfitsio::write2d(residual_image.real(), expected_residual_path); - double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size(); - SOPT_HIGH_LOG("Average intensity = {}", average_intensity); + double brightness = solution.real().cwiseAbs().maxCoeff(); double mse = (Vector::Map(solution.data(), solution.size()) - diagnostic.x) .real() .squaredNorm() / solution.size(); - SOPT_HIGH_LOG("MSE = {}", mse); - CHECK(mse <= average_intensity * 1e-3); + double rms = sqrt(mse); + CHECK(rms <= brightness * 5e-2); } TEST_CASE("onnx_fb_factory") { @@ -408,7 +403,7 @@ TEST_CASE("onnx_fb_factory") { auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, - gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, "", + gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-3, 1e-3, 50, "", nondiff_func_type::RealIndicator, diff_function); auto const diagnostic = (*fb)(); @@ -416,14 +411,13 @@ TEST_CASE("onnx_fb_factory") { // pfitsio::write2d(image.real(), result_path); // pfitsio::write2d(residual_image.real(), expected_residual_path); - double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size(); - SOPT_HIGH_LOG("Average intensity = {}", average_intensity); + double brightness = solution.real().cwiseAbs().maxCoeff(); double mse = (Vector::Map(solution.data(), solution.size()) - diagnostic.x) .real() .squaredNorm() / solution.size(); - SOPT_HIGH_LOG("MSE = {}", mse); - CHECK(mse <= average_intensity * 1e-3); + double rms = sqrt(mse); + CHECK(rms <= brightness * 5e-2); } #endif