|
12 | 12 | #include "purify/algorithm_factory.h"
|
13 | 13 | #include "purify/measurement_operator_factory.h"
|
14 | 14 | #include "purify/wavelet_operator_factory.h"
|
| 15 | + |
| 16 | +#ifdef PURIFY_ONNXRT |
| 17 | +#include <sopt/onnx_differentiable_func.h> |
| 18 | +#endif |
| 19 | + |
15 | 20 | #include <sopt/power_method.h>
|
16 | 21 |
|
17 | 22 | #include "purify/test_data.h"
|
@@ -136,6 +141,7 @@ TEST_CASE("fb_factory") {
|
136 | 141 | notinstalled::data_filename(test_dir + "solution.fits");
|
137 | 142 | const std::string &expected_residual_path =
|
138 | 143 | notinstalled::data_filename(test_dir + "residual.fits");
|
| 144 | + const std::string &result_path = notinstalled::data_filename(test_dir + "fb_result.fits"); |
139 | 145 |
|
140 | 146 | const auto solution = pfitsio::read2d(expected_solution_path);
|
141 | 147 | const auto residual = pfitsio::read2d(expected_residual_path);
|
@@ -170,20 +176,144 @@ TEST_CASE("fb_factory") {
|
170 | 176 |
|
171 | 177 | auto const diagnostic = (*fb)();
|
172 | 178 | const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
|
173 |
| - // pfitsio::write2d(image.real(), expected_solution_path); |
174 |
| - CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10)); |
175 |
| - CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10)); |
176 |
| - CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10)); |
177 |
| - CHECK(image.isApprox(solution, 1e-4)); |
| 179 | + // pfitsio::write2d(image.real(), result_path); |
| 180 | + // pfitsio::write2d(residual_image.real(), expected_residual_path); |
178 | 181 |
|
179 |
| - const Vector<t_complex> residuals = measurements_transform->adjoint() * |
180 |
| - (uv_data.vis - ((*measurements_transform) * diagnostic.x)); |
181 |
| - const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex); |
| 182 | + double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size(); |
| 183 | + SOPT_HIGH_LOG("Average intensity = {}", average_intensity); |
| 184 | + double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x) |
| 185 | + .real() |
| 186 | + .squaredNorm() / |
| 187 | + solution.size(); |
| 188 | + SOPT_HIGH_LOG("MSE = {}", mse); |
| 189 | + CHECK(mse <= average_intensity * 1e-3); |
| 190 | +} |
| 191 | + |
| 192 | +#ifdef PURIFY_ONNXRT |
| 193 | +TEST_CASE("tf_fb_factory") { |
| 194 | + const std::string &test_dir = "expected/fb/"; |
| 195 | + const std::string &input_data_path = notinstalled::data_filename(test_dir + "input_data.vis"); |
| 196 | + const std::string &expected_solution_path = |
| 197 | + notinstalled::data_filename(test_dir + "solution.fits"); |
| 198 | + const std::string &expected_residual_path = |
| 199 | + notinstalled::data_filename(test_dir + "residual.fits"); |
| 200 | + const std::string &result_path = notinstalled::data_filename(test_dir + "tf_result.fits"); |
| 201 | + |
| 202 | + const auto solution = pfitsio::read2d(expected_solution_path); |
| 203 | + const auto residual = pfitsio::read2d(expected_residual_path); |
| 204 | + |
| 205 | + auto uv_data = utilities::read_visibility(input_data_path, false); |
| 206 | + uv_data.units = utilities::vis_units::radians; |
| 207 | + CAPTURE(uv_data.vis.head(5)); |
| 208 | + REQUIRE(uv_data.size() == 13107); |
| 209 | + |
| 210 | + t_uint const imsizey = 128; |
| 211 | + t_uint const imsizex = 128; |
| 212 | + |
| 213 | + Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey); |
| 214 | + auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>( |
| 215 | + factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2, |
| 216 | + kernels::kernel_from_string.at("kb"), 4, 4); |
| 217 | + auto const power_method_stuff = |
| 218 | + sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init); |
| 219 | + const t_real op_norm = std::get<0>(power_method_stuff); |
| 220 | + std::vector<std::tuple<std::string, t_uint>> const sara{ |
| 221 | + std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), |
| 222 | + std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), |
| 223 | + std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)}; |
| 224 | + auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>( |
| 225 | + factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex); |
| 226 | + t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file |
| 227 | + t_real const beta = sigma * sigma; |
| 228 | + t_real const gamma = 0.0001; |
| 229 | + |
| 230 | + std::string tf_model_path = |
| 231 | + purify::notinstalled::data_directory() + "/models/snr_15_model_dynamic.onnx"; |
| 232 | + |
| 233 | + auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>( |
| 234 | + factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, |
| 235 | + gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, |
| 236 | + tf_model_path, factory::g_proximal_type::TFGProximal); |
| 237 | + |
| 238 | + auto const diagnostic = (*fb)(); |
| 239 | + const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex); |
| 240 | + // pfitsio::write2d(image.real(), result_path); |
182 | 241 | // pfitsio::write2d(residual_image.real(), expected_residual_path);
|
183 |
| - CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10)); |
184 |
| - CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10)); |
185 |
| - CHECK(residual_image.real().isApprox(residual.real(), 1e-4)); |
| 242 | + |
| 243 | + double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size(); |
| 244 | + SOPT_HIGH_LOG("Average intensity = {}", average_intensity); |
| 245 | + double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x) |
| 246 | + .real() |
| 247 | + .squaredNorm() / |
| 248 | + solution.size(); |
| 249 | + SOPT_HIGH_LOG("MSE = {}", mse); |
| 250 | + CHECK(mse <= average_intensity * 1e-3); |
| 251 | +} |
| 252 | + |
| 253 | +TEST_CASE("onnx_fb_factory") { |
| 254 | + const std::string &test_dir = "expected/fb/"; |
| 255 | + const std::string &input_data_path = notinstalled::data_filename(test_dir + "input_data.vis"); |
| 256 | + const std::string &expected_solution_path = |
| 257 | + notinstalled::data_filename(test_dir + "solution.fits"); |
| 258 | + const std::string &expected_residual_path = |
| 259 | + notinstalled::data_filename(test_dir + "residual.fits"); |
| 260 | + const std::string &result_path = notinstalled::data_filename(test_dir + "onnx_result.fits"); |
| 261 | + const auto solution = pfitsio::read2d(expected_solution_path); |
| 262 | + const auto residual = pfitsio::read2d(expected_residual_path); |
| 263 | + |
| 264 | + auto uv_data = utilities::read_visibility(input_data_path, false); |
| 265 | + uv_data.units = utilities::vis_units::radians; |
| 266 | + CAPTURE(uv_data.vis.head(5)); |
| 267 | + REQUIRE(uv_data.size() == 13107); |
| 268 | + |
| 269 | + t_uint const imsizey = 128; |
| 270 | + t_uint const imsizex = 128; |
| 271 | + |
| 272 | + Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey); |
| 273 | + auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>( |
| 274 | + factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2, |
| 275 | + kernels::kernel_from_string.at("kb"), 4, 4); |
| 276 | + auto const power_method_stuff = |
| 277 | + sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init); |
| 278 | + const t_real op_norm = std::get<0>(power_method_stuff); |
| 279 | + std::vector<std::tuple<std::string, t_uint>> const sara{ |
| 280 | + std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), |
| 281 | + std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), |
| 282 | + std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)}; |
| 283 | + auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>( |
| 284 | + factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex); |
| 285 | + t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file |
| 286 | + t_real const beta = sigma * sigma; |
| 287 | + t_real const gamma = 0.0001; |
| 288 | + |
| 289 | + std::string const prior_path = |
| 290 | + purify::notinstalled::data_directory() + "/models/example_cost_dynamic_CRR_sigma_5_t_5.onnx"; |
| 291 | + std::string const prior_gradient_path = |
| 292 | + purify::notinstalled::data_directory() + "/models/example_grad_dynamic_CRR_sigma_5_t_5.onnx"; |
| 293 | + std::shared_ptr<sopt::ONNXDifferentiableFunc<t_complex>> diff_function = |
| 294 | + std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>( |
| 295 | + prior_path, prior_gradient_path, sigma, 20, 5e4, *measurements_transform); |
| 296 | + |
| 297 | + auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>( |
| 298 | + factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, |
| 299 | + gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, "", |
| 300 | + factory::g_proximal_type::Indicator, diff_function); |
| 301 | + |
| 302 | + auto const diagnostic = (*fb)(); |
| 303 | + const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex); |
| 304 | + // pfitsio::write2d(image.real(), result_path); |
| 305 | + // pfitsio::write2d(residual_image.real(), expected_residual_path); |
| 306 | + |
| 307 | + double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size(); |
| 308 | + SOPT_HIGH_LOG("Average intensity = {}", average_intensity); |
| 309 | + double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x) |
| 310 | + .real() |
| 311 | + .squaredNorm() / |
| 312 | + solution.size(); |
| 313 | + SOPT_HIGH_LOG("MSE = {}", mse); |
| 314 | + CHECK(mse <= average_intensity * 1e-3); |
186 | 315 | }
|
| 316 | +#endif |
187 | 317 |
|
188 | 318 | TEST_CASE("joint_map_factory") {
|
189 | 319 | const std::string &test_dir = "expected/joint_map/";
|
|
0 commit comments