@@ -81,13 +81,12 @@ TEST_CASE("padmm_factory") {
81
81
CHECK (residual_image.real ().isApprox (residual.real (), 1e-4 ));
82
82
}
83
83
84
- // This test does not converge and is therefore set to shouldfail.
85
- // See https://github.com/astro-informatics/purify/issues/317 for details.
86
- TEST_CASE (" primal_dual_factory" , " [!shouldfail]" ) {
84
+ TEST_CASE (" primal_dual_factory" ) {
87
85
const std::string &test_dir = " expected/primal_dual/" ;
88
86
const std::string &input_data_path = data_filename (test_dir + " input_data.vis" );
89
87
const std::string &expected_solution_path = data_filename (test_dir + " solution.fits" );
90
88
const std::string &expected_residual_path = data_filename (test_dir + " residual.fits" );
89
+ const std::string &result_path = data_filename (test_dir + " pd_result.fits" );
91
90
92
91
const auto solution = pfitsio::read2d (expected_solution_path);
93
92
const auto residual = pfitsio::read2d (expected_residual_path);
@@ -119,23 +118,20 @@ TEST_CASE("primal_dual_factory", "[!shouldfail]") {
119
118
auto const primaldual =
120
119
factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
121
120
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma,
122
- imsizey, imsizex, sara.size (), 20 , true , true , 1e-2 , 1 );
121
+ imsizey, imsizex, sara.size (), 1000 , true , true , 1e-3 , 1 );
123
122
124
123
auto const diagnostic = (*primaldual)();
124
+
125
125
const Image<t_complex> image = Image<t_complex>::Map (diagnostic.x .data (), imsizey, imsizex);
126
- // pfitsio::write2d(image.real(), expected_solution_path);
127
- CAPTURE (Vector<t_complex>::Map (solution.data (), solution.size ()).real ().head (10 ));
128
- CAPTURE (Vector<t_complex>::Map (image.data (), image.size ()).real ().head (10 ));
129
- CAPTURE (Vector<t_complex>::Map ((image / solution).eval ().data (), image.size ()).real ().head (10 ));
130
- CHECK (image.isApprox (solution, 1e-4 ));
126
+ // pfitsio::write2d(image.real(), result_path);
131
127
132
- const Vector<t_complex> residuals = measurements_transform-> adjoint () *
133
- (uv_data. vis - ((*measurements_transform) * diagnostic.x ));
134
- const Image<t_complex> residual_image = Image<t_complex>:: Map (residuals. data (), imsizey, imsizex);
135
- // pfitsio::write2d(residual_image.real(), expected_residual_path);
136
- CAPTURE (Vector<t_complex>:: Map (residual. data (), residual .size ()). real (). head ( 10 ) );
137
- CAPTURE (Vector<t_complex>:: Map (residuals. data (), residuals. size ()). real (). head ( 10 ) );
138
- CHECK (residual_image. real (). isApprox (residual. real (), 1e-4 ) );
128
+ double brightness = solution. real (). cwiseAbs (). maxCoeff ();
129
+ double mse = (Vector<t_complex>:: Map (solution. data (), solution. size ()) - diagnostic.x )
130
+ . real ()
131
+ . squaredNorm () /
132
+ solution .size ();
133
+ double rms = sqrt (mse );
134
+ CHECK (rms <= brightness * 5e-2 );
139
135
}
140
136
141
137
TEST_CASE (" fb_factory" ) {
@@ -171,9 +167,11 @@ TEST_CASE("fb_factory") {
171
167
std::make_tuple (" DB6" , 3u ), std::make_tuple (" DB7" , 3u ), std::make_tuple (" DB8" , 3u )};
172
168
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
173
169
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
170
+
174
171
t_real const sigma = 0.016820222945913496 * std::sqrt (2 ); // see test_parameters file
175
172
t_real const beta = sigma * sigma;
176
173
t_real const gamma = 0.0001 ;
174
+
177
175
auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
178
176
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
179
177
gamma, imsizey, imsizex, sara.size (), 1000 , true , true , false , 1e-2 , 1e-3 , 50 );
@@ -183,14 +181,13 @@ TEST_CASE("fb_factory") {
183
181
pfitsio::write2d (image.real (), result_path);
184
182
// pfitsio::write2d(residual_image.real(), expected_residual_path);
185
183
186
- double average_intensity = diagnostic.x .real ().sum () / diagnostic.x .size ();
187
- SOPT_HIGH_LOG (" Average intensity = {}" , average_intensity);
184
+ double brightness = solution.real ().cwiseAbs ().maxCoeff ();
188
185
double mse = (Vector<t_complex>::Map (solution.data (), solution.size ()) - diagnostic.x )
189
186
.real ()
190
187
.squaredNorm () /
191
188
solution.size ();
192
- SOPT_HIGH_LOG ( " MSE = {} " , mse);
193
- CHECK (mse <= average_intensity * 1e-3 );
189
+ double rms = sqrt ( mse);
190
+ CHECK (rms <= brightness * 5e-2 );
194
191
}
195
192
196
193
#ifdef PURIFY_H5
@@ -294,11 +291,10 @@ TEST_CASE("fb_factory_stochastic") {
294
291
// pfitsio::write2d(residual_image.real(), expected_residual_path);
295
292
296
293
auto soln_flat = Vector<t_complex>::Map (solution.data (), solution.size ());
297
- double average_intensity = soln_flat.real ().sum () / soln_flat.size ();
298
- SOPT_HIGH_LOG (" Average intensity = {}" , average_intensity);
294
+ double brightness = soln_flat.real ().cwiseAbs ().maxCoeff ();
299
295
double mse = (soln_flat - diagnostic.x ).real ().squaredNorm () / solution.size ();
300
296
SOPT_HIGH_LOG (" MSE = {}" , mse);
301
- CHECK (mse <= average_intensity * 1e-3 );
297
+ CHECK (mse <= brightness * 5e-2 );
302
298
}
303
299
#endif
304
300
@@ -352,14 +348,13 @@ TEST_CASE("tf_fb_factory") {
352
348
// pfitsio::write2d(image.real(), result_path);
353
349
// pfitsio::write2d(residual_image.real(), expected_residual_path);
354
350
355
- double average_intensity = diagnostic.x .real ().sum () / diagnostic.x .size ();
356
- SOPT_HIGH_LOG (" Average intensity = {}" , average_intensity);
351
+ double brightness = solution.real ().cwiseAbs ().maxCoeff ();
357
352
double mse = (Vector<t_complex>::Map (solution.data (), solution.size ()) - diagnostic.x )
358
353
.real ()
359
354
.squaredNorm () /
360
355
solution.size ();
361
- SOPT_HIGH_LOG ( " MSE = {} " , mse);
362
- CHECK (mse <= average_intensity * 1e-3 );
356
+ double rms = sqrt ( mse);
357
+ CHECK (rms <= brightness * 5e-2 );
363
358
}
364
359
365
360
TEST_CASE (" onnx_fb_factory" ) {
@@ -408,22 +403,21 @@ TEST_CASE("onnx_fb_factory") {
408
403
409
404
auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
410
405
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
411
- gamma, imsizey, imsizex, sara.size (), 1000 , true , true , false , 1e-2 , 1e-3 , 50 , " " ,
406
+ gamma, imsizey, imsizex, sara.size (), 1000 , true , true , false , 1e-3 , 1e-3 , 50 , " " ,
412
407
nondiff_func_type::RealIndicator, diff_function);
413
408
414
409
auto const diagnostic = (*fb)();
415
410
const Image<t_complex> image = Image<t_complex>::Map (diagnostic.x .data (), imsizey, imsizex);
416
411
// pfitsio::write2d(image.real(), result_path);
417
412
// pfitsio::write2d(residual_image.real(), expected_residual_path);
418
413
419
- double average_intensity = diagnostic.x .real ().sum () / diagnostic.x .size ();
420
- SOPT_HIGH_LOG (" Average intensity = {}" , average_intensity);
414
+ double brightness = solution.real ().cwiseAbs ().maxCoeff ();
421
415
double mse = (Vector<t_complex>::Map (solution.data (), solution.size ()) - diagnostic.x )
422
416
.real ()
423
417
.squaredNorm () /
424
418
solution.size ();
425
- SOPT_HIGH_LOG ( " MSE = {} " , mse);
426
- CHECK (mse <= average_intensity * 1e-3 );
419
+ double rms = sqrt ( mse);
420
+ CHECK (rms <= brightness * 5e-2 );
427
421
}
428
422
#endif
429
423
0 commit comments