Skip to content

add fb algorithm, pnp-prior functionality, abstract classes, and kera… #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions notebooks/custom_operators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
},
"outputs": [],
"source": [
"class custom_phi:\n",
"class custom_phi(linear_operators.LinearOperator):\n",
" \"\"\"A custom linear operator e.g. a custom measurement operator\"\"\"\n",
"\n",
" def __init__(self, dim, masking):\n",
Expand Down Expand Up @@ -173,8 +173,7 @@
},
"outputs": [],
"source": [
"g = grad_operators.l2_norm(sigma, y, phi)\n",
"g.beta = 1.0 / sigma ** 2"
"g = grad_operators.l2_norm(sigma, y, phi)"
]
},
{
Expand Down Expand Up @@ -222,7 +221,6 @@
"outputs": [],
"source": [
"h = prox_operators.l1_norm(np.max(np.abs(psi.dir_op(phi.adj_op(y)))) * reg_param, psi)\n",
"h.beta = 1.0\n",
"f = prox_operators.real_prox()"
]
},
Expand All @@ -242,8 +240,7 @@
},
"outputs": [],
"source": [
"# Note that phi_adj_op(y) is a dirty first estimate. In practice one may wish \n",
"# to begin the optimisation from a better first guess!\n",
"# Note that phi_adj_op(y) is a dirty first estimate. In practice one may wish to begin the optimisation from a better first guess!\n",
"best_estimate, diagnostics = primal_dual.FBPD(phi.adj_op(y), options, g, f, h)"
]
},
Expand Down
298 changes: 298 additions & 0 deletions notebooks/fb-proximal.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions optimusprimal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from . import linear_operators
from . import map_uncertainty
from . import primal_dual
from . import forward_backward
from . import prox_operators
from . import ai_operators
from . import Empty

# create logger
Expand Down
144 changes: 144 additions & 0 deletions optimusprimal/ai_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import abc
import numpy as np
import tensorflow as tf


class LearntPrior(metaclass=abc.ABCMeta):
"""Base abstract class for learnt prior models"""

@abc.abstractmethod
def __init__(self, tf_model):
"""Constructor setting the hyper-parameters and domains of the model.

Must be implemented by derived class (currently abstract).

Args:
tf_model (KerasTensor): network pretrained as a prior model
"""
self.model = tf_model

@abc.abstractmethod
def prox(self, x, gamma):
"""Evaluates the l2-ball prox of x

Args:

x (np.ndarray): Array to evaluate proximal gradient
gamma (float): weighting of proximal gradient
"""
return self.model(x)

@classmethod
def fun(self, x):
"""Placeholder for loss of functional term

Args:

x (np.ndarray): Array to evaluate model loss of

"""
return 0

@classmethod
def dir_op(self, x):
"""Evaluates the forward sensing operator

Args:

x (np.ndarray): Array to transform

Returns:

Forward sensing operator applied to x
"""
return x

@classmethod
def adj_op(self, x):
"""Evaluates the forward adjoint sensing operator

Args:

x (np.ndarray): Array to adjoint transform

Returns:

Forward adjoint sensing operator applied to x
"""
return x


class PnpDenoiser(LearntPrior):
"""This class integrates machine learning operators to PNP algorithms"""

def __init__(self, tf_model, sigma):
"""Initialises a pre-trained tensorflow model

Args:

tf_model (KerasTensor): network trained as a denoising prior
sigma (float): noise std of observed data
"""

self.model = tf_model

# Normalisation specific parameters
self.maxtmp = 0
self.mintmp = 0
self.scale_range = 1.0 + sigma / 2.0
self.scale_shift = (1 - self.scale_range) / 2.0

def prox(self, x, gamma=1):
"""Applies a keras model as a backward projection step

Args:

x (np.ndarray): Array to execute learnt backward denoising step

Returns:

Denoising plug & play model applied to input
"""
out = x.numpy()
out = self.__normalise(out)
out = self.__sigma_correction(out)
out = self.model(out)
out = self.__invert_sigma_correction(out)
return self.__invert_normalise(out)

def __normalise(self, x):
"""Maps tensor from [a,b] to [0,1]

Args:

x (np.ndarray): Array to normalise
"""
self.maxtmp, self.mintmp = x.max(), x.min()
return (x - self.mintmp) / (self.maxtmp - self.mintmp)

def __invert_normalise(self, x):
"""Maps tensor from [0,1] to [a,b]

Args:

x (np.ndarray): Array to invert normalise
"""
return x * (self.maxtmp - self.mintmp) + self.mintmp

def __sigma_correction(self, x):
"""Corrects normalisation [a,b] onto [0,1] for noise

Args:

x (np.ndarray): Array to apply sigma shifting
"""
return x * self.scale_range + self.scale_shift

def __invert_sigma_correction(self, x):
"""Invert corrects normalisation [0,1] onto [a,b] for noise

Args:

x (np.ndarray): Array to invert sigma shifting
"""
return (x - self.scale_shift) / self.scale_range
103 changes: 103 additions & 0 deletions optimusprimal/forward_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import optimusprimal.Empty as Empty
import logging
import numpy as np
import time

logger = logging.getLogger("Optimus Primal")


def FB(x_init, options=None, g=None, f=None, h=None, alpha=1, tau=1, viewer=None):
"""Evaluates the base forward backward optimisation

Note that currently this only supports real positive semi-definite
fields.

Args:

x_init (np.ndarray): First estimate solution
options (dict): Python dictionary of optimisation configuration parameters
g (Grad Class): Unconstrained data-fidelity class
f (Prox Class): Reality constraint
h (Prox/AI Class): Proximal or Learnt regularisation constraint
alpha (float): regularisation paremeter / step-size.
tau (float): custom weighting of proximal operator
viewer (function): Plotting function for real-time viewing (must accept: x, iteration)
"""
if f is None:
f = Empty.EmptyProx()
if g is None:
g = Empty.EmptyGrad()
if h is None:
h = Empty.EmptyProx()

x = x_init

if options is None:
options = {"tol": 1e-4, "iter": 500, "update_iter": 100, "record_iters": False}

# algorithmic parameters
tol = options["tol"]
max_iter = options["iter"]
update_iter = options["update_iter"]
record_iters = options["record_iters"]

# initialization
x = np.copy(x_init)

logger.info("Running Base Forward Backward")
timing = np.zeros(max_iter)
criter = np.zeros(max_iter)

# algorithm loop
for it in range(0, max_iter):

t = time.time()
# forward step
x_old = np.copy(x)
x = x - alpha * g.grad(x)
x = f.prox(x, tau)

# backward step
u = h.dir_op(x)
x = x + h.adj_op(h.prox(u, tau) - u)

# time and criterion
if record_iters:
timing[it] = time.time() - t
criter[it] = f.fun(x) + g.fun(x) + h.fun(h.dir_op(x))

if np.allclose(x, 0):
x = x_old
logger.info("[Forward Backward] converged to 0 in %d iterations", it)
break
# stopping rule
if np.linalg.norm(x - x_old) < tol * np.linalg.norm(x_old) and it > 10:
logger.info("[Forward Backward] converged in %d iterations", it)
break
if update_iter >= 0:
if it % update_iter == 0:
logger.info(
"[Forward Backward] %d out of %d iterations, tol = %f",
it,
max_iter,
np.linalg.norm(x - x_old) / np.linalg.norm(x_old),
)
if viewer is not None:
viewer(x, it)
logger.debug(
"[Forward Backward] %d out of %d iterations, tol = %f",
it,
max_iter,
np.linalg.norm(x - x_old) / np.linalg.norm(x_old),
)

criter = criter[0 : it + 1]
timing = np.cumsum(timing[0 : it + 1])
solution = x
diagnostics = {
"max_iter": it,
"times": timing,
"Obj_vals": criter,
"x": x,
}
return solution, diagnostics
Loading