Skip to content

Commit b7121e9

Browse files
author
Michael McLeod
committed
Unreadble linting
1 parent baad633 commit b7121e9

File tree

1 file changed

+31
-40
lines changed

1 file changed

+31
-40
lines changed

cpp/main.cc

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
#include <array>
55
#include <cstddef>
66
#include <ctime>
7+
#include <memory>
78
#include <random>
89
#include "purify/algorithm_factory.h"
910
#include "purify/cimg.h"
1011
#include "purify/logging.h"
1112
#include "purify/measurement_operator_factory.h"
1213
#include "purify/pfitsio.h"
1314
#include "purify/read_measurements.h"
15+
#include "purify/setup_utils.h"
1416
#include "purify/update_factory.h"
1517
#include "purify/wavelet_operator_factory.h"
1618
#include "purify/wide_field_utilities.h"
@@ -20,16 +22,13 @@
2022
#include <sopt/power_method.h>
2123
#include <sopt/relative_variation.h>
2224
#include <sopt/reweighted.h>
23-
#include "purify/setup_utils.h"
24-
#include <memory>
2525

2626
#ifdef PURIFY_ONNXRT
2727
#include <sopt/onnx_differentiable_func.h>
2828
#endif
2929

3030
using namespace purify;
3131

32-
3332
int main(int argc, const char **argv) {
3433
std::srand(static_cast<t_uint>(std::time(0)));
3534
std::mt19937 mersnne(std::time(0));
@@ -58,20 +57,13 @@ int main(int argc, const char **argv) {
5857
purify::logging::set_level(params.logging());
5958

6059
// Read or generate input data
61-
auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = getInputData(params,
62-
mop_algo,
63-
wop_algo,
64-
using_mpi);
60+
auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] =
61+
getInputData(params, mop_algo, wop_algo, using_mpi);
6562

6663
// create measurement operator
67-
auto [measurements_transform, operator_norm] = createMeasurementOperator(params,
68-
mop_algo,
69-
wop_algo,
70-
using_mpi,
71-
image_index,
72-
w_stacks,
73-
uv_data,
74-
measurement_op_eigen_vector);
64+
auto [measurements_transform, operator_norm] =
65+
createMeasurementOperator(params, mop_algo, wop_algo, using_mpi, image_index, w_stacks,
66+
uv_data, measurement_op_eigen_vector);
7567

7668
// create wavelet operator
7769
const waveletInfo wavelets = createWaveletOperator(params, wop_algo);
@@ -86,10 +78,10 @@ int main(int argc, const char **argv) {
8678

8779
// Creating header for saving output images during iterations
8880
const auto [update_header_sol, update_header_res, def_header] = genHeaders(params, uv_data);
89-
81+
9082
// the eigenvector
9183
saveMeasurementEigenVector(params, measurement_op_eigen_vector);
92-
84+
9385
// the psf
9486
t_real beam_units = 1.0;
9587
if (params.mpiAlgorithm() != factory::algo_distribution::serial) {
@@ -103,40 +95,36 @@ int main(int argc, const char **argv) {
10395
beam_units = uv_data.size() / flux_scale / flux_scale;
10496
}
10597

106-
savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, operator_norm, beam_units);
98+
savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, operator_norm,
99+
beam_units);
107100

108101
// the dirty image
109102
saveDirtyImage(params, def_header, measurements_transform, uv_data, beam_units);
110103

111-
112104
// Create algorithm
113105
std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> padmm;
114106
std::shared_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> fb;
115107
std::shared_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> primaldual;
116108
if (params.algorithm() == "padmm")
117109
padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
118110
params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data,
119-
sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), wavelets.sara_size,
120-
params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(),
111+
sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(),
112+
wavelets.sara_size, params.iterations(), params.realValueConstraint(),
113+
params.positiveValueConstraint(),
121114
(params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and
122115
(not params.positiveValueConstraint()),
123116
params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50,
124117
params.epsilonConvergenceScaling(), operator_norm);
125-
if (params.algorithm() == "fb")
126-
{
118+
if (params.algorithm() == "fb") {
127119
std::shared_ptr<DifferentiableFunc<t_complex>> f;
128-
if(params.diffFuncType() == diff_func_type::L2Norm_with_CRR)
129-
{
130-
#ifdef PURIFY_ONNXRT
131-
f = std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(params.CRR_function_model_path(),
132-
params.CRR_gradient_model_path(),
133-
sigma,
134-
params.CRR_mu(),
135-
params.CRR_lambda(),
136-
*measurements_transform);
137-
#else
138-
throw std::runtime_error("CRR approach cannot be used with ONNXRT off");
139-
#endif
120+
if (params.diffFuncType() == diff_func_type::L2Norm_with_CRR) {
121+
#ifdef PURIFY_ONNXRT
122+
f = std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(
123+
params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma,
124+
params.CRR_mu(), params.CRR_lambda(), *measurements_transform);
125+
#else
126+
throw std::runtime_error("CRR approach cannot be used with ONNXRT off");
127+
#endif
140128
}
141129

142130
fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
@@ -153,9 +141,10 @@ int main(int argc, const char **argv) {
153141
if (params.algorithm() == "primaldual")
154142
primaldual = factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
155143
params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data,
156-
sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), wavelets.sara_size,
157-
params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(),
158-
params.relVarianceConvergence(), params.epsilonConvergenceScaling(), operator_norm);
144+
sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(),
145+
wavelets.sara_size, params.iterations(), params.realValueConstraint(),
146+
params.positiveValueConstraint(), params.relVarianceConvergence(),
147+
params.epsilonConvergenceScaling(), operator_norm);
159148
// Add primal dual preconditioning
160149
if (params.algorithm() == "primaldual" and params.precondition_iters() > 0) {
161150
PURIFY_HIGH_LOG(
@@ -181,14 +170,16 @@ int main(int argc, const char **argv) {
181170
// Adding step size update to algorithm
182171
factory::add_updater<t_complex, sopt::algorithm::ImagingProximalADMM<t_complex>>(
183172
algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol,
184-
update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi, beam_units);
173+
update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi,
174+
beam_units);
185175
}
186176
if (params.algorithm() == "primaldual") {
187177
const std::weak_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> algo_weak(primaldual);
188178
// Adding step size update to algorithm
189179
factory::add_updater<t_complex, sopt::algorithm::ImagingPrimalDual<t_complex>>(
190180
algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol,
191-
update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi, beam_units);
181+
update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi,
182+
beam_units);
192183
}
193184
if (params.algorithm() == "fb") {
194185
const std::weak_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> algo_weak(fb);

0 commit comments

Comments
 (0)