Skip to content

Commit

Permalink
ONNX Testing (#343)
Browse files Browse the repository at this point in the history
* add onnx to fb_factory

* Fix Psi compatibility issue

* add onnx test

* Access wavelet operator through algo directly

* make sure ORT lib location is known to cmake

* Fix ANN flags and add tf test

* Use proper data directory from sopt

* Update ANN model path

* Don't write output images by default in tests

* Linting!

* Copy models to purify so independent from sopt tests

* linting

* Remove sopt directory dependency

* Add onnxrt guards for tests

* Replace strict regression test with mse check on previous solution

* Linting that makes things harder to read

* Bring MPI test in line with serial

* Remove test on exact number of iterations

---------

Co-authored-by: Michael McLeod <[email protected]>
Co-authored-by: Christian Gutschow <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 4ad17e3 commit 27820a9
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 34 deletions.
25 changes: 17 additions & 8 deletions cpp/purify/algorithm_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,27 @@
#include <sopt/mpi/communicator.h>
#endif

#include <sopt/differentiable_func.h>
#include <sopt/imaging_forward_backward.h>
#include <sopt/imaging_padmm.h>
#include <sopt/imaging_primal_dual.h>
#include <sopt/joint_map.h>
#include <sopt/l1_g_proximal.h>
#include <sopt/l1_non_diff_function.h>
#include <sopt/non_differentiable_func.h>
#include <sopt/real_indicator.h>
#include <sopt/relative_variation.h>
#include <sopt/utilities.h>
#include <sopt/wavelets.h>
#include <sopt/wavelets/sara.h>
#ifdef PURIFY_ONNXRT
#include <sopt/tf_g_proximal.h>
#include <sopt/tf_non_diff_function.h>
#endif

namespace purify {
namespace factory {
enum class algorithm { padmm, primal_dual, sdmm, forward_backward };
enum class algo_distribution { serial, mpi_serial, mpi_distributed, mpi_random_updates };
enum class g_proximal_type { L1GProximal, TFGProximal };
enum class g_proximal_type { L1GProximal, TFGProximal, Indicator };
const std::map<std::string, algo_distribution> algo_distribution_string = {
{"none", algo_distribution::serial},
{"serial-equivalent", algo_distribution::mpi_serial},
Expand Down Expand Up @@ -161,7 +164,8 @@ fb_factory(const algo_distribution dist,
const bool tight_frame = false, const t_real relative_variation = 1e-3,
const t_real l1_proximal_tolerance = 1e-2, const t_uint maximum_proximal_iterations = 50,
const t_real op_norm = 1, const std::string model_path = "",
const g_proximal_type g_proximal = g_proximal_type::L1GProximal) {
const g_proximal_type g_proximal = g_proximal_type::L1GProximal,
std::shared_ptr<DifferentiableFunc<typename Algorithm::Scalar>> f_function = nullptr) {
typedef typename Algorithm::Scalar t_scalar;
if (sara_size > 1 and tight_frame)
throw std::runtime_error(
Expand All @@ -177,7 +181,8 @@ fb_factory(const algo_distribution dist,
.nu(op_norm * op_norm)
.Phi(*measurements);

std::shared_ptr<GProximal<t_scalar>> gp;
if (f_function) fb->f_function(f_function); // only override f_function default if non-null
std::shared_ptr<NonDifferentiableFunc<t_scalar>> g;

switch (g_proximal) {
case (g_proximal_type::L1GProximal): {
Expand All @@ -197,25 +202,29 @@ fb_factory(const algo_distribution dist,
l1_gp->l1_proximal_direct_space_comm(comm);
}
#endif
gp = l1_gp;
g = l1_gp;
break;
}
case (g_proximal_type::TFGProximal): {
#ifdef PURIFY_ONNXRT
// Create a shared pointer to an instance of the TFGProximal class
gp = std::make_shared<sopt::algorithm::TFGProximal<t_scalar>>(model_path);
g = std::make_shared<sopt::algorithm::TFGProximal<t_scalar>>(model_path);
break;
#else
throw std::runtime_error(
"Type TFGProximal not recognized because purify was built with onnxrt=off");
#endif
}
case (g_proximal_type::Indicator): {
g = std::make_shared<RealIndicator<t_scalar>>();
break;
}
default: {
throw std::runtime_error("Type of g_proximal operator not recognised.");
}
}

fb->g_proximal(gp);
fb->g_function(g);

switch (dist) {
case (algo_distribution::serial): {
Expand Down
3 changes: 3 additions & 0 deletions cpp/purify/config.in.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
//! Whether PURIFY is running with casacore
#cmakedefine PURIFY_CASACORE

//! Whether PURIFY is using (and SOPT was built with) onnxrt support
#cmakedefine PURIFY_ONNXRT

#include <string>
#include <tuple>
#include <cstdint>
Expand Down
4 changes: 2 additions & 2 deletions cpp/purify/update_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void add_updater(std::weak_ptr<Algo> const algo_weak, const t_real step_size_sca
auto algo = algo_weak.lock();
if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ {}", algo->gamma());
if (algo->gamma() > 0) {
Vector<t_complex> const alpha = algo->g_proximal()->Psi().adjoint() * x;
Vector<t_complex> const alpha = algo->Psi().adjoint() * x;
const t_real new_gamma =
comm.all_reduce((sara_size > 0) ? alpha.real().cwiseAbs().maxCoeff() : 0., MPI_MAX) *
step_size_scale;
Expand Down Expand Up @@ -88,7 +88,7 @@ void add_updater(std::weak_ptr<Algo> const algo_weak, const t_real step_size_sca
auto algo = algo_weak.lock();
if (algo->gamma() > 0) {
PURIFY_MEDIUM_LOG("Step size γ {}", algo->gamma());
Vector<T> const alpha = algo->g_proximal()->Psi().adjoint() * x;
Vector<T> const alpha = algo->Psi().adjoint() * x;
const t_real new_gamma = alpha.real().cwiseAbs().maxCoeff() * step_size_scale;
PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma);
// updating parameter
Expand Down
152 changes: 141 additions & 11 deletions cpp/tests/algo_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
#include "purify/algorithm_factory.h"
#include "purify/measurement_operator_factory.h"
#include "purify/wavelet_operator_factory.h"

#ifdef PURIFY_ONNXRT
#include <sopt/onnx_differentiable_func.h>
#endif

#include <sopt/power_method.h>

#include "purify/test_data.h"
Expand Down Expand Up @@ -136,6 +141,7 @@ TEST_CASE("fb_factory") {
notinstalled::data_filename(test_dir + "solution.fits");
const std::string &expected_residual_path =
notinstalled::data_filename(test_dir + "residual.fits");
const std::string &result_path = notinstalled::data_filename(test_dir + "fb_result.fits");

const auto solution = pfitsio::read2d(expected_solution_path);
const auto residual = pfitsio::read2d(expected_residual_path);
Expand Down Expand Up @@ -170,20 +176,144 @@ TEST_CASE("fb_factory") {

auto const diagnostic = (*fb)();
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
// pfitsio::write2d(image.real(), expected_solution_path);
CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
CHECK(image.isApprox(solution, 1e-4));
// pfitsio::write2d(image.real(), result_path);
// pfitsio::write2d(residual_image.real(), expected_residual_path);

const Vector<t_complex> residuals = measurements_transform->adjoint() *
(uv_data.vis - ((*measurements_transform) * diagnostic.x));
const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
.real()
.squaredNorm() /
solution.size();
SOPT_HIGH_LOG("MSE = {}", mse);
CHECK(mse <= average_intensity * 1e-3);
}

#ifdef PURIFY_ONNXRT
TEST_CASE("tf_fb_factory") {
const std::string &test_dir = "expected/fb/";
const std::string &input_data_path = notinstalled::data_filename(test_dir + "input_data.vis");
const std::string &expected_solution_path =
notinstalled::data_filename(test_dir + "solution.fits");
const std::string &expected_residual_path =
notinstalled::data_filename(test_dir + "residual.fits");
const std::string &result_path = notinstalled::data_filename(test_dir + "tf_result.fits");

const auto solution = pfitsio::read2d(expected_solution_path);
const auto residual = pfitsio::read2d(expected_residual_path);

auto uv_data = utilities::read_visibility(input_data_path, false);
uv_data.units = utilities::vis_units::radians;
CAPTURE(uv_data.vis.head(5));
REQUIRE(uv_data.size() == 13107);

t_uint const imsizey = 128;
t_uint const imsizex = 128;

Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
kernels::kernel_from_string.at("kb"), 4, 4);
auto const power_method_stuff =
sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
const t_real op_norm = std::get<0>(power_method_stuff);
std::vector<std::tuple<std::string, t_uint>> const sara{
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
t_real const beta = sigma * sigma;
t_real const gamma = 0.0001;

std::string tf_model_path =
purify::notinstalled::data_directory() + "/models/snr_15_model_dynamic.onnx";

auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm,
tf_model_path, factory::g_proximal_type::TFGProximal);

auto const diagnostic = (*fb)();
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
// pfitsio::write2d(image.real(), result_path);
// pfitsio::write2d(residual_image.real(), expected_residual_path);
CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
CHECK(residual_image.real().isApprox(residual.real(), 1e-4));

double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
.real()
.squaredNorm() /
solution.size();
SOPT_HIGH_LOG("MSE = {}", mse);
CHECK(mse <= average_intensity * 1e-3);
}

TEST_CASE("onnx_fb_factory") {
const std::string &test_dir = "expected/fb/";
const std::string &input_data_path = notinstalled::data_filename(test_dir + "input_data.vis");
const std::string &expected_solution_path =
notinstalled::data_filename(test_dir + "solution.fits");
const std::string &expected_residual_path =
notinstalled::data_filename(test_dir + "residual.fits");
const std::string &result_path = notinstalled::data_filename(test_dir + "onnx_result.fits");
const auto solution = pfitsio::read2d(expected_solution_path);
const auto residual = pfitsio::read2d(expected_residual_path);

auto uv_data = utilities::read_visibility(input_data_path, false);
uv_data.units = utilities::vis_units::radians;
CAPTURE(uv_data.vis.head(5));
REQUIRE(uv_data.size() == 13107);

t_uint const imsizey = 128;
t_uint const imsizex = 128;

Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
kernels::kernel_from_string.at("kb"), 4, 4);
auto const power_method_stuff =
sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
const t_real op_norm = std::get<0>(power_method_stuff);
std::vector<std::tuple<std::string, t_uint>> const sara{
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
t_real const beta = sigma * sigma;
t_real const gamma = 0.0001;

std::string const prior_path =
purify::notinstalled::data_directory() + "/models/example_cost_dynamic_CRR_sigma_5_t_5.onnx";
std::string const prior_gradient_path =
purify::notinstalled::data_directory() + "/models/example_grad_dynamic_CRR_sigma_5_t_5.onnx";
std::shared_ptr<sopt::ONNXDifferentiableFunc<t_complex>> diff_function =
std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(
prior_path, prior_gradient_path, sigma, 20, 5e4, *measurements_transform);

auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, "",
factory::g_proximal_type::Indicator, diff_function);

auto const diagnostic = (*fb)();
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
// pfitsio::write2d(image.real(), result_path);
// pfitsio::write2d(residual_image.real(), expected_residual_path);

double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
.real()
.squaredNorm() /
solution.size();
SOPT_HIGH_LOG("MSE = {}", mse);
CHECK(mse <= average_intensity * 1e-3);
}
#endif

TEST_CASE("joint_map_factory") {
const std::string &test_dir = "expected/joint_map/";
Expand Down
21 changes: 8 additions & 13 deletions cpp/tests/mpi_algo_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm);

auto const diagnostic = (*fb)();
CHECK(diagnostic.niters == 11);

const std::string &expected_solution_path =
notinstalled::data_filename(test_dir + "solution.fits");
Expand All @@ -358,16 +357,12 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
const auto solution = pfitsio::read2d(expected_solution_path);
const auto residual = pfitsio::read2d(expected_residual_path);

const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
CHECK(image.isApprox(solution, 1e-4));

const Vector<t_complex> residuals = measurements_transform->adjoint() *
(uv_data.vis - ((*measurements_transform) * diagnostic.x));
const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
.real()
.squaredNorm() /
solution.size();
SOPT_HIGH_LOG("MSE = {}", mse);
CHECK(mse <= average_intensity * 1e-3);
}
Binary file not shown.
Binary file not shown.
Binary file added data/models/snr_15_model_dynamic.onnx
Binary file not shown.

0 comments on commit 27820a9

Please sign in to comment.