@@ -205,41 +205,42 @@ TEST_CASE("fb_factory_stochastic") {
205
205
// This functor would be defined in Purify
206
206
std::mt19937 rng (0 );
207
207
const size_t N = 1000 ;
208
- std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater = [&input_data_path, imsizex, imsizey, &rng, &N]() {
209
- utilities::vis_params uv_data = utilities::read_visibility (input_data_path, false );
210
- uv_data.units = utilities::vis_units::radians;
211
-
212
- // Get random subset
213
- std::vector<size_t > indices (uv_data.size ());
214
- size_t i = 0 ;
215
- for (auto &x : indices)
216
- {
217
- x = i++;
218
- }
219
-
220
- std::shuffle (indices.begin (), indices.end (), rng);
221
- Vector<t_real> u_fragment (N);
222
- Vector<t_real> v_fragment (N);
223
- Vector<t_real> w_fragment (N);
224
- Vector<t_complex> vis_fragment (N);
225
- Vector<t_complex> weights_fragment (N);
226
- for (i = 0 ; i < N; i++)
227
- {
228
- size_t j = indices[i];
229
- u_fragment[i] = uv_data.u [j];
230
- v_fragment[i] = uv_data.v [j];
231
- w_fragment[i] = uv_data.w [j];
232
- vis_fragment[i] = uv_data.vis [j];
233
- weights_fragment[i] = uv_data.weights [j];
234
- }
235
- utilities::vis_params uv_data_fragment (u_fragment, v_fragment, w_fragment, vis_fragment, weights_fragment, uv_data.units , uv_data.ra , uv_data.dec , uv_data.average_frequency );
236
-
237
- auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
238
- factory::distributed_measurement_operator::serial, uv_data_fragment, imsizey, imsizex, 1 , 1 , 2 ,
239
- kernels::kernel_from_string.at (" kb" ), 4 , 4 );
240
-
241
- return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment.vis , phi);
242
- };
208
+ std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
209
+ [&input_data_path, imsizex, imsizey, &rng, &N]() {
210
+ utilities::vis_params uv_data = utilities::read_visibility (input_data_path, false );
211
+ uv_data.units = utilities::vis_units::radians;
212
+
213
+ // Get random subset
214
+ std::vector<size_t > indices (uv_data.size ());
215
+ size_t i = 0 ;
216
+ for (auto &x : indices) {
217
+ x = i++;
218
+ }
219
+
220
+ std::shuffle (indices.begin (), indices.end (), rng);
221
+ Vector<t_real> u_fragment (N);
222
+ Vector<t_real> v_fragment (N);
223
+ Vector<t_real> w_fragment (N);
224
+ Vector<t_complex> vis_fragment (N);
225
+ Vector<t_complex> weights_fragment (N);
226
+ for (i = 0 ; i < N; i++) {
227
+ size_t j = indices[i];
228
+ u_fragment[i] = uv_data.u [j];
229
+ v_fragment[i] = uv_data.v [j];
230
+ w_fragment[i] = uv_data.w [j];
231
+ vis_fragment[i] = uv_data.vis [j];
232
+ weights_fragment[i] = uv_data.weights [j];
233
+ }
234
+ utilities::vis_params uv_data_fragment (u_fragment, v_fragment, w_fragment, vis_fragment,
235
+ weights_fragment, uv_data.units , uv_data.ra ,
236
+ uv_data.dec , uv_data.average_frequency );
237
+
238
+ auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
239
+ factory::distributed_measurement_operator::serial, uv_data_fragment, imsizey, imsizex,
240
+ 1 , 1 , 2 , kernels::kernel_from_string.at (" kb" ), 4 , 4 );
241
+
242
+ return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment.vis , phi);
243
+ };
243
244
244
245
Vector<t_complex> const init = Vector<t_complex>::Ones (imsizex * imsizey);
245
246
auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
@@ -257,38 +258,47 @@ TEST_CASE("fb_factory_stochastic") {
257
258
258
259
const auto solution = pfitsio::read2d (expected_solution_path);
259
260
260
- // wavelets
261
+ // wavelets
261
262
std::vector<std::tuple<std::string, t_uint>> const sara{
262
263
std::make_tuple (" Dirac" , 3u ), std::make_tuple (" DB1" , 3u ), std::make_tuple (" DB2" , 3u ),
263
264
std::make_tuple (" DB3" , 3u ), std::make_tuple (" DB4" , 3u ), std::make_tuple (" DB5" , 3u ),
264
265
std::make_tuple (" DB6" , 3u ), std::make_tuple (" DB7" , 3u ), std::make_tuple (" DB8" , 3u )};
265
266
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
266
267
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
267
-
268
- // algorithm
268
+
269
+ // algorithm
269
270
t_real const sigma = 0.016820222945913496 * std::sqrt (2 ); // see test_parameters file
270
271
t_real const beta = sigma * sigma;
271
272
t_real const gamma = 0.0001 ;
272
273
273
274
sopt::algorithm::ImagingForwardBackward<t_complex> fb (random_updater);
274
- fb.itermax (1000 ).step_size (beta*sqrt (2 )).sigma (sigma*sqrt (2 )).regulariser_strength (gamma).relative_variation (1e-3 ).residual_tolerance (0 ).tight_frame (true ).sq_op_norm (op_norm*op_norm);
275
+ fb.itermax (1000 )
276
+ .step_size (beta * sqrt (2 ))
277
+ .sigma (sigma * sqrt (2 ))
278
+ .regulariser_strength (gamma)
279
+ .relative_variation (1e-3 )
280
+ .residual_tolerance (0 )
281
+ .tight_frame (true )
282
+ .sq_op_norm (op_norm * op_norm);
275
283
276
284
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false );
277
- gp->l1_proximal_tolerance (1e-4 ).l1_proximal_nu (1 ).l1_proximal_itermax (50 ).l1_proximal_positivity_constraint (true ).l1_proximal_real_constraint (true ).Psi (*wavelets);
285
+ gp->l1_proximal_tolerance (1e-4 )
286
+ .l1_proximal_nu (1 )
287
+ .l1_proximal_itermax (50 )
288
+ .l1_proximal_positivity_constraint (true )
289
+ .l1_proximal_real_constraint (true )
290
+ .Psi (*wavelets);
278
291
fb.g_function (gp);
279
292
280
293
auto const diagnostic = fb ();
281
294
const Image<t_complex> image = Image<t_complex>::Map (diagnostic.x .data (), imsizey, imsizex);
282
- // pfitsio::write2d(image.real(), result_path);
283
- // pfitsio::write2d(residual_image.real(), expected_residual_path);
295
+ // pfitsio::write2d(image.real(), result_path);
296
+ // pfitsio::write2d(residual_image.real(), expected_residual_path);
284
297
285
298
auto soln_flat = Vector<t_complex>::Map (solution.data (), solution.size ());
286
299
double average_intensity = soln_flat.real ().sum () / soln_flat.size ();
287
300
SOPT_HIGH_LOG (" Average intensity = {}" , average_intensity);
288
- double mse = (soln_flat - diagnostic.x )
289
- .real ()
290
- .squaredNorm () /
291
- solution.size ();
301
+ double mse = (soln_flat - diagnostic.x ).real ().squaredNorm () / solution.size ();
292
302
SOPT_HIGH_LOG (" MSE = {}" , mse);
293
303
CHECK (mse <= average_intensity * 1e-3 );
294
304
}
0 commit comments