4
4
#include < array>
5
5
#include < cstddef>
6
6
#include < ctime>
7
+ #include < memory>
7
8
#include < random>
8
9
#include " purify/algorithm_factory.h"
9
10
#include " purify/cimg.h"
10
11
#include " purify/logging.h"
11
12
#include " purify/measurement_operator_factory.h"
12
13
#include " purify/pfitsio.h"
13
14
#include " purify/read_measurements.h"
15
+ #include " purify/setup_utils.h"
14
16
#include " purify/update_factory.h"
15
17
#include " purify/wavelet_operator_factory.h"
16
18
#include " purify/wide_field_utilities.h"
20
22
#include < sopt/power_method.h>
21
23
#include < sopt/relative_variation.h>
22
24
#include < sopt/reweighted.h>
23
- #include " purify/setup_utils.h"
24
- #include < memory>
25
25
26
26
#ifdef PURIFY_ONNXRT
27
27
#include < sopt/onnx_differentiable_func.h>
28
28
#endif
29
29
30
30
using namespace purify ;
31
31
32
-
33
32
int main (int argc, const char **argv) {
34
33
std::srand (static_cast <t_uint>(std::time (0 )));
35
34
std::mt19937 mersnne (std::time (0 ));
@@ -58,20 +57,13 @@ int main(int argc, const char **argv) {
58
57
purify::logging::set_level (params.logging ());
59
58
60
59
// Read or generate input data
61
- auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = getInputData (params,
62
- mop_algo,
63
- wop_algo,
64
- using_mpi);
60
+ auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] =
61
+ getInputData (params, mop_algo, wop_algo, using_mpi);
65
62
66
63
// create measurement operator
67
- auto [measurements_transform, operator_norm] = createMeasurementOperator (params,
68
- mop_algo,
69
- wop_algo,
70
- using_mpi,
71
- image_index,
72
- w_stacks,
73
- uv_data,
74
- measurement_op_eigen_vector);
64
+ auto [measurements_transform, operator_norm] =
65
+ createMeasurementOperator (params, mop_algo, wop_algo, using_mpi, image_index, w_stacks,
66
+ uv_data, measurement_op_eigen_vector);
75
67
76
68
// create wavelet operator
77
69
const waveletInfo wavelets = createWaveletOperator (params, wop_algo);
@@ -86,10 +78,10 @@ int main(int argc, const char **argv) {
86
78
87
79
// Creating header for saving output images during iterations
88
80
const auto [update_header_sol, update_header_res, def_header] = genHeaders (params, uv_data);
89
-
81
+
90
82
// the eigenvector
91
83
saveMeasurementEigenVector (params, measurement_op_eigen_vector);
92
-
84
+
93
85
// the psf
94
86
t_real beam_units = 1.0 ;
95
87
if (params.mpiAlgorithm () != factory::algo_distribution::serial) {
@@ -103,40 +95,36 @@ int main(int argc, const char **argv) {
103
95
beam_units = uv_data.size () / flux_scale / flux_scale;
104
96
}
105
97
106
- savePSF (params, def_header, measurements_transform, uv_data, flux_scale, sigma, operator_norm, beam_units);
98
+ savePSF (params, def_header, measurements_transform, uv_data, flux_scale, sigma, operator_norm,
99
+ beam_units);
107
100
108
101
// the dirty image
109
102
saveDirtyImage (params, def_header, measurements_transform, uv_data, beam_units);
110
103
111
-
112
104
// Create algorithm
113
105
std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> padmm;
114
106
std::shared_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> fb;
115
107
std::shared_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> primaldual;
116
108
if (params.algorithm () == " padmm" )
117
109
padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
118
110
params.mpiAlgorithm (), measurements_transform, wavelets.transform , uv_data,
119
- sigma * params.epsilonScaling () / flux_scale, params.height (), params.width (), wavelets.sara_size ,
120
- params.iterations (), params.realValueConstraint (), params.positiveValueConstraint (),
111
+ sigma * params.epsilonScaling () / flux_scale, params.height (), params.width (),
112
+ wavelets.sara_size , params.iterations (), params.realValueConstraint (),
113
+ params.positiveValueConstraint (),
121
114
(params.wavelet_basis ().size () < 2 ) and (not params.realValueConstraint ()) and
122
115
(not params.positiveValueConstraint ()),
123
116
params.relVarianceConvergence (), params.dualFBVarianceConvergence (), 50 ,
124
117
params.epsilonConvergenceScaling (), operator_norm);
125
- if (params.algorithm () == " fb" )
126
- {
118
+ if (params.algorithm () == " fb" ) {
127
119
std::shared_ptr<DifferentiableFunc<t_complex>> f;
128
- if (params.diffFuncType () == diff_func_type::L2Norm_with_CRR)
129
- {
130
- #ifdef PURIFY_ONNXRT
131
- f = std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(params.CRR_function_model_path (),
132
- params.CRR_gradient_model_path (),
133
- sigma,
134
- params.CRR_mu (),
135
- params.CRR_lambda (),
136
- *measurements_transform);
137
- #else
138
- throw std::runtime_error (" CRR approach cannot be used with ONNXRT off" );
139
- #endif
120
+ if (params.diffFuncType () == diff_func_type::L2Norm_with_CRR) {
121
+ #ifdef PURIFY_ONNXRT
122
+ f = std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(
123
+ params.CRR_function_model_path (), params.CRR_gradient_model_path (), sigma,
124
+ params.CRR_mu (), params.CRR_lambda (), *measurements_transform);
125
+ #else
126
+ throw std::runtime_error (" CRR approach cannot be used with ONNXRT off" );
127
+ #endif
140
128
}
141
129
142
130
fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
@@ -153,9 +141,10 @@ int main(int argc, const char **argv) {
153
141
if (params.algorithm () == " primaldual" )
154
142
primaldual = factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
155
143
params.mpiAlgorithm (), measurements_transform, wavelets.transform , uv_data,
156
- sigma * params.epsilonScaling () / flux_scale, params.height (), params.width (), wavelets.sara_size ,
157
- params.iterations (), params.realValueConstraint (), params.positiveValueConstraint (),
158
- params.relVarianceConvergence (), params.epsilonConvergenceScaling (), operator_norm);
144
+ sigma * params.epsilonScaling () / flux_scale, params.height (), params.width (),
145
+ wavelets.sara_size , params.iterations (), params.realValueConstraint (),
146
+ params.positiveValueConstraint (), params.relVarianceConvergence (),
147
+ params.epsilonConvergenceScaling (), operator_norm);
159
148
// Add primal dual preconditioning
160
149
if (params.algorithm () == " primaldual" and params.precondition_iters () > 0 ) {
161
150
PURIFY_HIGH_LOG (
@@ -181,14 +170,16 @@ int main(int argc, const char **argv) {
181
170
// Adding step size update to algorithm
182
171
factory::add_updater<t_complex, sopt::algorithm::ImagingProximalADMM<t_complex>>(
183
172
algo_weak, 1e-3 , params.update_tolerance (), params.update_iters (), update_header_sol,
184
- update_header_res, params.height (), params.width (), wavelets.sara_size , using_mpi, beam_units);
173
+ update_header_res, params.height (), params.width (), wavelets.sara_size , using_mpi,
174
+ beam_units);
185
175
}
186
176
if (params.algorithm () == " primaldual" ) {
187
177
const std::weak_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> algo_weak (primaldual);
188
178
// Adding step size update to algorithm
189
179
factory::add_updater<t_complex, sopt::algorithm::ImagingPrimalDual<t_complex>>(
190
180
algo_weak, 1e-3 , params.update_tolerance (), params.update_iters (), update_header_sol,
191
- update_header_res, params.height (), params.width (), wavelets.sara_size , using_mpi, beam_units);
181
+ update_header_res, params.height (), params.width (), wavelets.sara_size , using_mpi,
182
+ beam_units);
192
183
}
193
184
if (params.algorithm () == " fb" ) {
194
185
const std::weak_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> algo_weak (fb);
0 commit comments