Skip to content

Commit 92488fb

Browse files
authored
Merge pull request #364 from astro-informatics/mm_primal_dual_test
Revert primal dual test to meaningful result
2 parents 7779be2 + 84d9e68 commit 92488fb

File tree

1 file changed

+26
-32
lines changed

1 file changed

+26
-32
lines changed

cpp/tests/algo_factory.cc

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,12 @@ TEST_CASE("padmm_factory") {
8181
CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
8282
}
8383

84-
// This test does not converge and is therefore set to shouldfail.
85-
// See https://github.com/astro-informatics/purify/issues/317 for details.
86-
TEST_CASE("primal_dual_factory", "[!shouldfail]") {
84+
TEST_CASE("primal_dual_factory") {
8785
const std::string &test_dir = "expected/primal_dual/";
8886
const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
8987
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
9088
const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
89+
const std::string &result_path = data_filename(test_dir + "pd_result.fits");
9190

9291
const auto solution = pfitsio::read2d(expected_solution_path);
9392
const auto residual = pfitsio::read2d(expected_residual_path);
@@ -119,23 +118,20 @@ TEST_CASE("primal_dual_factory", "[!shouldfail]") {
119118
auto const primaldual =
120119
factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
121120
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma,
122-
imsizey, imsizex, sara.size(), 20, true, true, 1e-2, 1);
121+
imsizey, imsizex, sara.size(), 1000, true, true, 1e-3, 1);
123122

124123
auto const diagnostic = (*primaldual)();
124+
125125
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
126-
// pfitsio::write2d(image.real(), expected_solution_path);
127-
CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
128-
CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
129-
CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
130-
CHECK(image.isApprox(solution, 1e-4));
126+
// pfitsio::write2d(image.real(), result_path);
131127

132-
const Vector<t_complex> residuals = measurements_transform->adjoint() *
133-
(uv_data.vis - ((*measurements_transform) * diagnostic.x));
134-
const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
135-
// pfitsio::write2d(residual_image.real(), expected_residual_path);
136-
CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
137-
CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
138-
CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
128+
double brightness = solution.real().cwiseAbs().maxCoeff();
129+
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
130+
.real()
131+
.squaredNorm() /
132+
solution.size();
133+
double rms = sqrt(mse);
134+
CHECK(rms <= brightness * 5e-2);
139135
}
140136

141137
TEST_CASE("fb_factory") {
@@ -171,9 +167,11 @@ TEST_CASE("fb_factory") {
171167
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
172168
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
173169
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
170+
174171
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
175172
t_real const beta = sigma * sigma;
176173
t_real const gamma = 0.0001;
174+
177175
auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
178176
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
179177
gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50);
@@ -183,14 +181,13 @@ TEST_CASE("fb_factory") {
183181
pfitsio::write2d(image.real(), result_path);
184182
// pfitsio::write2d(residual_image.real(), expected_residual_path);
185183

186-
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
187-
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
184+
double brightness = solution.real().cwiseAbs().maxCoeff();
188185
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
189186
.real()
190187
.squaredNorm() /
191188
solution.size();
192-
SOPT_HIGH_LOG("MSE = {}", mse);
193-
CHECK(mse <= average_intensity * 1e-3);
189+
double rms = sqrt(mse);
190+
CHECK(rms <= brightness * 5e-2);
194191
}
195192

196193
#ifdef PURIFY_H5
@@ -294,11 +291,10 @@ TEST_CASE("fb_factory_stochastic") {
294291
// pfitsio::write2d(residual_image.real(), expected_residual_path);
295292

296293
auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
297-
double average_intensity = soln_flat.real().sum() / soln_flat.size();
298-
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
294+
double brightness = soln_flat.real().cwiseAbs().maxCoeff();
299295
double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
300296
SOPT_HIGH_LOG("MSE = {}", mse);
301-
CHECK(mse <= average_intensity * 1e-3);
297+
CHECK(mse <= brightness * 5e-2);
302298
}
303299
#endif
304300

@@ -352,14 +348,13 @@ TEST_CASE("tf_fb_factory") {
352348
// pfitsio::write2d(image.real(), result_path);
353349
// pfitsio::write2d(residual_image.real(), expected_residual_path);
354350

355-
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
356-
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
351+
double brightness = solution.real().cwiseAbs().maxCoeff();
357352
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
358353
.real()
359354
.squaredNorm() /
360355
solution.size();
361-
SOPT_HIGH_LOG("MSE = {}", mse);
362-
CHECK(mse <= average_intensity * 1e-3);
356+
double rms = sqrt(mse);
357+
CHECK(rms <= brightness * 5e-2);
363358
}
364359

365360
TEST_CASE("onnx_fb_factory") {
@@ -408,22 +403,21 @@ TEST_CASE("onnx_fb_factory") {
408403

409404
auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
410405
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
411-
gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, "",
406+
gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-3, 1e-3, 50, "",
412407
nondiff_func_type::RealIndicator, diff_function);
413408

414409
auto const diagnostic = (*fb)();
415410
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
416411
// pfitsio::write2d(image.real(), result_path);
417412
// pfitsio::write2d(residual_image.real(), expected_residual_path);
418413

419-
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
420-
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
414+
double brightness = solution.real().cwiseAbs().maxCoeff();
421415
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
422416
.real()
423417
.squaredNorm() /
424418
solution.size();
425-
SOPT_HIGH_LOG("MSE = {}", mse);
426-
CHECK(mse <= average_intensity * 1e-3);
419+
double rms = sqrt(mse);
420+
CHECK(rms <= brightness * 5e-2);
427421
}
428422
#endif
429423

0 commit comments

Comments
 (0)