Skip to content

Operator Norms #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/benchmarks/algorithms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ BENCHMARK_DEFINE_F(AlgoFixture, Padmm)(benchmark::State &state) {
m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma,
m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3, 1e-2, 50,
1.0, 1.0);
1.0);

while (state.KeepRunning()) {
auto start = std::chrono::high_resolution_clock::now();
Expand All @@ -92,7 +92,7 @@ BENCHMARK_DEFINE_F(AlgoFixture, ForwardBackward)(benchmark::State &state) {
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma,
beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3,
1e-2, 50, 1.0);
1e-2, 50);

while (state.KeepRunning()) {
auto start = std::chrono::high_resolution_clock::now();
Expand Down
8 changes: 4 additions & 4 deletions cpp/benchmarks/algorithms_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, PadmmDistributeImage)(benchmark::State &state
m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
factory::algo_distribution::mpi_distributed, m_measurements_distribute_image, wavelets,
m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
false, 1e-3, 1e-2, 50, 1.0, 1.0);
false, 1e-3, 1e-2, 50, 1.0);

// Benchmark the application of the algorithm
while (state.KeepRunning()) {
Expand All @@ -111,7 +111,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, PadmmDistributeGrid)(benchmark::State &state)
m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
factory::algo_distribution::mpi_distributed, m_measurements_distribute_grid, wavelets,
m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
false, 1e-3, 1e-2, 50, 1.0, 1.0);
false, 1e-3, 1e-2, 50, 1.0);

// Benchmark the application of the algorithm
while (state.KeepRunning()) {
Expand All @@ -135,7 +135,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbDistributeImage)(benchmark::State &state) {
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::mpi_serial, m_measurements_distribute_image, wavelets, m_uv_data,
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
false, 1e-3, 1e-2, 50, 1.0);
false, 1e-3, 1e-2, 50);

// Benchmark the application of the algorithm
while (state.KeepRunning()) {
Expand All @@ -159,7 +159,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbDistributeGrid)(benchmark::State &state) {
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::mpi_serial, m_measurements_distribute_grid, wavelets, m_uv_data,
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
false, 1e-3, 1e-2, 50, 1.0);
false, 1e-3, 1e-2, 50);

// Benchmark the application of the algorithm
while (state.KeepRunning()) {
Expand Down
1 change: 0 additions & 1 deletion cpp/example/padmm_mpi_random_coverage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> padmm_factory(
.l1_proximal_real_constraint(true)
.residual_tolerance(epsilon)
.lagrange_update_scale(0.9)
.sq_op_norm(1e0)
.Psi(Psi)
.Phi(*measurements);
sopt::ScalarRelativeVariation<t_complex> conv(padmm->relative_variation(),
Expand Down
1 change: 0 additions & 1 deletion cpp/example/padmm_mpi_real_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> padmm_factory(
.l1_proximal_real_constraint(true)
.residual_tolerance(epsilon)
.lagrange_update_scale(0.9)
.sq_op_norm(1e0)
.Psi(Psi)
.Phi(*measurements);
sopt::ScalarRelativeVariation<t_complex> conv(padmm->relative_variation(),
Expand Down
1 change: 0 additions & 1 deletion cpp/example/padmm_random_coverage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ void padmm(const std::string &name, const Image<t_complex> &M31, const std::stri
#ifdef PURIFY_CImg
.is_converged(show_image)
#endif
.sq_op_norm(1e0)
.Psi(Psi)
.Phi(*measurements_transform);

Expand Down
1 change: 0 additions & 1 deletion cpp/example/padmm_real_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ void padmm(const std::string &name, const t_uint &imsizex, const t_uint &imsizey
.l1_proximal_real_constraint(true)
.residual_convergence(epsilon)
.lagrange_update_scale(0.9)
.sq_op_norm(1e0)
.Psi(Psi)
.Phi(*measurements_transform);

Expand Down
1 change: 0 additions & 1 deletion cpp/example/padmm_reweighted_simulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ int main(int nargs, char const **args) {
.l1_proximal_real_constraint(true)
.residual_convergence(epsilon * 1.001)
.lagrange_update_scale(0.9)
.sq_op_norm(1e0)
.Psi(Psi)
.Phi(measurements_transform);
// Timing reconstruction
Expand Down
1 change: 0 additions & 1 deletion cpp/example/padmm_simulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ int main(int nargs, char const **args) {
.l1_proximal_real_constraint(true)
.residual_convergence(epsilon * 1.001)
.lagrange_update_scale(0.9)
.sq_op_norm(1e0)
.Psi(Psi)
.itermax(100)
.is_converged(convergence_function)
Expand Down
1 change: 0 additions & 1 deletion cpp/example/sara_padmm_random_coverage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ int main(int, char **) {
.l1_proximal_real_constraint(true)
.residual_convergence(epsilon * 1.001)
.lagrange_update_scale(0.9)
.sq_op_norm(1e0)
.Psi(Psi)
.Phi(measurements_transform);

Expand Down
13 changes: 6 additions & 7 deletions cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ int main(int argc, const char **argv) {
getInputData(params, mop_algo, wop_algo, using_mpi);

// create measurement operator
auto [measurements_transform, operator_norm] =
auto measurements_transform =
createMeasurementOperator(params, mop_algo, wop_algo, using_mpi, image_index, w_stacks,
uv_data, measurement_op_eigen_vector);

// create wavelet operator
const waveletInfo wavelets = createWaveletOperator(params, wop_algo);

PURIFY_LOW_LOG("Value of operator norm is {}", operator_norm);
PURIFY_LOW_LOG("Value of operator norm is {}", measurements_transform->norm());
t_real const flux_scale = 1.;
uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale;

Expand All @@ -95,8 +95,7 @@ int main(int argc, const char **argv) {
beam_units = uv_data.size() / flux_scale / flux_scale;
}

savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, operator_norm,
beam_units);
savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, beam_units);

// the dirty image
saveDirtyImage(params, def_header, measurements_transform, uv_data, beam_units);
Expand All @@ -114,7 +113,7 @@ int main(int argc, const char **argv) {
(params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and
(not params.positiveValueConstraint()),
params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50,
params.epsilonConvergenceScaling(), operator_norm);
params.epsilonConvergenceScaling());
if (params.algorithm() == "fb") {
std::shared_ptr<DifferentiableFunc<t_complex>> f;
if (params.diffFuncType() == diff_func_type::L2Norm_with_CRR) {
Expand All @@ -135,7 +134,7 @@ int main(int argc, const char **argv) {
params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(),
(params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and
(not params.positiveValueConstraint()),
params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50, operator_norm,
params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50,
params.model_path(), params.nondiffFuncType(), f);
}
if (params.algorithm() == "primaldual")
Expand All @@ -144,7 +143,7 @@ int main(int argc, const char **argv) {
sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(),
wavelets.sara_size, params.iterations(), params.realValueConstraint(),
params.positiveValueConstraint(), params.relVarianceConvergence(),
params.epsilonConvergenceScaling(), operator_norm);
params.epsilonConvergenceScaling());
// Add primal dual preconditioning
if (params.algorithm() == "primaldual" and params.precondition_iters() > 0) {
PURIFY_HIGH_LOG(
Expand Down
11 changes: 4 additions & 7 deletions cpp/purify/algorithm_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ padmm_factory(const algo_distribution dist,
const bool tight_frame = false, const t_real relative_variation = 1e-3,
const t_real l1_proximal_tolerance = 1e-2,
const t_uint maximum_proximal_iterations = 50,
const t_real residual_tolerance_scaling = 1, const t_real op_norm = 1) {
const t_real residual_tolerance_scaling = 1) {
typedef typename Algorithm::Scalar t_scalar;
if (sara_size > 1 and tight_frame)
throw std::runtime_error(
Expand All @@ -78,7 +78,6 @@ padmm_factory(const algo_distribution dist,
.l1_proximal_positivity_constraint(positive_constraint)
.l1_proximal_real_constraint(real_constraint)
.lagrange_update_scale(0.9)
.sq_op_norm(op_norm * op_norm)
.Psi(*wavelets)
.Phi(*measurements);
#ifdef PURIFY_MPI
Expand Down Expand Up @@ -162,7 +161,7 @@ fb_factory(const algo_distribution dist,
const bool real_constraint = true, const bool positive_constraint = true,
const bool tight_frame = false, const t_real relative_variation = 1e-3,
const t_real l1_proximal_tolerance = 1e-2, const t_uint maximum_proximal_iterations = 50,
const t_real op_norm = 1, const std::string model_path = "",
const std::string model_path = "",
const nondiff_func_type g_proximal = nondiff_func_type::L1Norm,
std::shared_ptr<DifferentiableFunc<typename Algorithm::Scalar>> f_function = nullptr) {
typedef typename Algorithm::Scalar t_scalar;
Expand All @@ -178,7 +177,6 @@ fb_factory(const algo_distribution dist,
.step_size(step_size * std::sqrt(2))
.relative_variation(relative_variation)
.tight_frame(tight_frame)
.sq_op_norm(op_norm * op_norm)
.Phi(*measurements);

if (f_function) fb->f_function(f_function); // only override f_function default if non-null
Expand Down Expand Up @@ -262,8 +260,7 @@ primaldual_factory(
const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey,
const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations = 500,
const bool real_constraint = true, const bool positive_constraint = true,
const t_real relative_variation = 1e-3, const t_real residual_tolerance_scaling = 1,
const t_real op_norm = 1) {
const t_real relative_variation = 1e-3, const t_real residual_tolerance_scaling = 1) {
typedef typename Algorithm::Scalar t_scalar;
PURIFY_INFO("Constructing Primal Dual algorithm");
auto epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size())) * sigma;
Expand All @@ -274,7 +271,7 @@ primaldual_factory(
.positivity_constraint(positive_constraint)
.Psi(*wavelets)
.Phi(*measurements)
.tau(0.5 / (op_norm * op_norm + 1))
.tau(0.5 / (measurements->sq_norm() + 1))
.xi(1.)
.sigma(1.);
#ifdef PURIFY_MPI
Expand Down
12 changes: 7 additions & 5 deletions cpp/purify/setup_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ inputData getInputData(const YamlParser &params,
return {uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks};
}

measurementOpInfo createMeasurementOperator(
std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> createMeasurementOperator(
const YamlParser &params, const factory::distributed_measurement_operator mop_algo,
const factory::distributed_wavelet_operator wop_algo, const bool using_mpi,
const std::vector<t_int> &image_index, const std::vector<t_real> &w_stacks,
Expand Down Expand Up @@ -294,6 +294,7 @@ measurementOpInfo createMeasurementOperator(
params.powMethod_tolerance(), comm.broadcast(measurement_op_eigen_vector).eval());
measurement_op_eigen_vector = std::get<1>(power_method_result);
operator_norm = std::get<0>(power_method_result);
measurements_transform->set_norm(operator_norm);
} else
#endif
{
Expand All @@ -302,9 +303,10 @@ measurementOpInfo createMeasurementOperator(
measurement_op_eigen_vector);
measurement_op_eigen_vector = std::get<1>(power_method_result);
operator_norm = std::get<0>(power_method_result);
measurements_transform->set_norm(operator_norm);
}

return {measurements_transform, operator_norm};
return measurements_transform;
}

void setupCostFunctions(const YamlParser &params, std::unique_ptr<DifferentiableFunc<t_complex>> &f,
Expand Down Expand Up @@ -402,7 +404,7 @@ void savePSF(
const YamlParser &params, const pfitsio::header_params &def_header,
const std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> &measurements_transform,
const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma,
const t_real operator_norm, const t_real beam_units) {
const t_real beam_units) {
pfitsio::header_params psf_header = def_header;
psf_header.fits_name = params.output_path() + "/psf.fits";
psf_header.pix_units = "Jy/Pixel";
Expand All @@ -417,7 +419,7 @@ void savePSF(
auto const world = sopt::mpi::Communicator::World();
PURIFY_LOW_LOG(
"Expected image domain residual RMS is {} jy/beam",
sigma * params.epsilonScaling() * operator_norm /
sigma * params.epsilonScaling() * measurements_transform->norm() /
(std::sqrt(params.width() * params.height()) * world.all_sum_all(uv_data.size())));
if (world.is_root())
#else
Expand All @@ -426,7 +428,7 @@ void savePSF(
pfitsio::write2d(psf_image, psf_header, true);
} else {
PURIFY_LOW_LOG("Expected image domain residual RMS is {} jy/beam",
sigma * params.epsilonScaling() * operator_norm /
sigma * params.epsilonScaling() * measurements_transform->norm() /
(std::sqrt(params.width() * params.height()) * uv_data.size()));
pfitsio::write2d(psf_image, psf_header, true);
}
Expand Down
9 changes: 2 additions & 7 deletions cpp/purify/setup_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ inputData getInputData(const YamlParser &params,
const factory::distributed_measurement_operator mop_algo,
const factory::distributed_wavelet_operator wop_algo, const bool using_mpi);

struct measurementOpInfo {
std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> measurement_transform;
t_real operator_norm;
};
Comment on lines -44 to -47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the existence of this class was a good reason for this change :)


measurementOpInfo createMeasurementOperator(
std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> createMeasurementOperator(
const YamlParser &params, const factory::distributed_measurement_operator mop_algo,
const factory::distributed_wavelet_operator wop_algo, const bool using_mpi,
const std::vector<t_int> &image_index, const std::vector<t_real> &w_stacks,
Expand All @@ -73,7 +68,7 @@ void savePSF(
const YamlParser &params, const pfitsio::header_params &def_header,
const std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> &measurements_transform,
const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma,
const t_real operator_norm, const t_real beam_units);
const t_real beam_units);

void saveDirtyImage(
const YamlParser &params, const pfitsio::header_params &def_header,
Expand Down
Loading