Skip to content

Commit d6de438

Browse files
committed
(1) now using DCT unwrap implementation even without weights, which allows full Rytov calculation on the GPU and is much faster. (2) renamed _to_cpu() to to_cpu(). (3) other minor improvements
1 parent 71e1d9d commit d6de438

File tree

3 files changed

+108
-58
lines changed

3 files changed

+108
-58
lines changed

mcsim/analysis/optimize.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Tools for solving inverse problems using accelerated proximal gradient methods
33
"""
44

5+
import warnings
56
import numpy as np
67
import time
78
import random
@@ -25,7 +26,7 @@
2526
array = Union[np.ndarray, cp.ndarray]
2627

2728

28-
def _to_cpu(m):
29+
def to_cpu(m):
2930
"""
3031
Ensure array is CPU/NumPy
3132
:param m:
@@ -204,7 +205,7 @@ def run(self,
204205
"niterations": max_iterations,
205206
"use_fista": use_fista,
206207
"use_gpu": use_gpu,
207-
"x_init": _to_cpu(xp.array(x_start, copy=True)),
208+
"x_init": to_cpu(xp.array(x_start, copy=True)),
208209
"prox_parameters": self.prox_parameters,
209210
"stop_condition": "ok"
210211
}
@@ -252,9 +253,9 @@ def run(self,
252253

253254
if compute_cost:
254255
if compute_all_costs:
255-
costs[ii] = _to_cpu(self.cost(x))
256+
costs[ii] = to_cpu(self.cost(x))
256257
else:
257-
costs[ii, inds] = _to_cpu(self.cost(x, inds=inds))
258+
costs[ii, inds] = to_cpu(self.cost(x, inds=inds))
258259

259260
timing["cost"] = np.concatenate((timing["cost"], np.array([time.perf_counter() - tstart_err])))
260261

@@ -282,11 +283,11 @@ def run(self,
282283

283284
if compute_all_costs:
284285
c_all = self.cost(x)
285-
costs[ii] = _to_cpu(c_all)
286+
costs[ii] = to_cpu(c_all)
286287
cx = xp.mean(c_all[inds], axis=0)
287288
else:
288289
c_now = self.cost(x, inds=inds)
289-
costs[ii, inds] = _to_cpu(c_now)
290+
costs[ii, inds] = to_cpu(c_now)
290291
cx = xp.mean(c_now, axis=0)
291292

292293
timing["cost"] = np.concatenate((timing["cost"], np.array([time.perf_counter() - tstart_err])))
@@ -347,15 +348,19 @@ def lipschitz_condition_violated(y, cx, gx):
347348

348349
# print information
349350
if verbose:
350-
status = f"iteration {ii + 1:d}/{max_iterations:d}," \
351-
f" cost={np.nanmean(costs[ii]):.3g}," \
352-
f" step={steps[ii]:.3g}," \
353-
f" line search iters={line_search_iters[ii]:d}," \
354-
f" grad={timing['grad'][ii]:.3f}s," \
355-
f" prox={timing['prox'][ii]:.3f}s," \
356-
f" cost={timing['cost'][ii]:.3f}s," \
357-
f" iter={timing['iteration'][ii]:.3f}s," \
358-
f" total={time.perf_counter() - tstart:.3f}s"
351+
with warnings.catch_warnings():
352+
warnings.simplefilter("ignore", category=RuntimeWarning)
353+
354+
status = f"iteration {ii + 1:d}/{max_iterations:d}," \
355+
f" cost={np.nanmean(costs[ii]):.3g}," \
356+
f" step={steps[ii]:.3g}," \
357+
f" line search iters={line_search_iters[ii]:d}," \
358+
f" grad={timing['grad'][ii]:.3f}s," \
359+
f" prox={timing['prox'][ii]:.3f}s," \
360+
f" cost={timing['cost'][ii]:.3f}s," \
361+
f" iter={timing['iteration'][ii]:.3f}s," \
362+
f" total={time.perf_counter() - tstart:.3f}s"
363+
359364
if use_gpu:
360365
status += f", GPU={mempool.used_bytes()/1e9:.3}GB"
361366

@@ -369,9 +374,9 @@ def lipschitz_condition_violated(y, cx, gx):
369374
# compute final cost
370375
if compute_cost:
371376
if compute_all_costs:
372-
costs[ii + 1] = _to_cpu(self.cost(x))
377+
costs[ii + 1] = to_cpu(self.cost(x))
373378
else:
374-
costs[ii + 1, inds] = _to_cpu(self.cost(x, inds=inds))
379+
costs[ii + 1, inds] = to_cpu(self.cost(x, inds=inds))
375380

376381
# store results
377382
results.update({"timing": timing,

mcsim/analysis/sim_reconstruction.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from matplotlib.patches import Circle, Rectangle
4747
# code from our projects
4848
import mcsim.analysis.analysis_tools as tools
49-
from mcsim.analysis.optimize import Optimizer, soft_threshold, tv_prox, _to_cpu
49+
from mcsim.analysis.optimize import Optimizer, soft_threshold, tv_prox, to_cpu
5050
from mcsim.analysis.fft import ft2, ift2
5151
from localize_psf.rois import get_centered_rois, cut_roi
5252
from localize_psf.fit_psf import circ_aperture_otf, blur_img_psf, oversample_voxel
@@ -1332,11 +1332,11 @@ def proc_bands(bands, phases, amps, frqs, dy, dx, upsample_fact, use_gpu):
13321332

13331333
# if cupy array, move off GPU
13341334
if isinstance(attr, cp.ndarray):
1335-
setattr(self, attr_name, _to_cpu(attr))
1335+
setattr(self, attr_name, to_cpu(attr))
13361336

13371337
# if dask array, move off GPU delayed
13381338
if isinstance(attr, da.core.Array):
1339-
on_cpu = da.map_blocks(_to_cpu, attr, dtype=attr.dtype)
1339+
on_cpu = da.map_blocks(to_cpu, attr, dtype=attr.dtype)
13401340
setattr(self, attr_name, on_cpu)
13411341

13421342
self.print_log(f"reconstruction took {time.perf_counter() - tstart_recon:.2f}s")
@@ -2370,13 +2370,28 @@ def show_sim_napari(fname_zarr: str,
23702370

23712371
dxy = imgz.attrs["dx"]
23722372
dxy_sim = dxy / imgz.attrs["upsample_factor"]
2373+
2374+
# translate to put FFT zero coordinates at origin
23732375
translate_wf = (-(wf.shape[-2] // 2) * dxy, -(wf.shape[-1] // 2) * dxy)
23742376
translate_sim = (-((2 * wf.shape[-2]) // 2) * dxy_sim, -((2 * wf.shape[-1]) // 2) * dxy_sim)
2377+
translate_pattern_2x = [a - 0.25 * dxy for a in translate_wf]
23752378

23762379
if viewer is None:
23772380
viewer = napari.Viewer()
23782381

2379-
# translate to put FFT zero coordinates at origin
2382+
if hasattr(imgz, "patterns"):
2383+
viewer.add_image(imgz.patterns,
2384+
scale=(dxy, dxy),
2385+
translate=translate_wf,
2386+
name="patterns")
2387+
2388+
if hasattr(imgz, "patterns_2x"):
2389+
viewer.add_image(imgz.patterns_2x,
2390+
scale=(dxy_sim, dxy_sim),
2391+
# translate=translate_sim,
2392+
translate=translate_pattern_2x,
2393+
name="patterns upsampled")
2394+
23802395
if hasattr(imgz, "sim_os"):
23812396
sim_os = np.expand_dims(imgz.sim_os, axis=-3)
23822397

@@ -2392,6 +2407,18 @@ def show_sim_napari(fname_zarr: str,
23922407
name="wf deconvolved",
23932408
visible=False)
23942409

2410+
if hasattr(imgz, "sim_fista_forward_model"):
2411+
viewer.add_image(imgz.sim_fista_forward_model,
2412+
scale=(dxy, dxy),
2413+
translate=translate_wf,
2414+
name="FISTA forward model")
2415+
2416+
if hasattr(imgz, "sim_sr_fista"):
2417+
viewer.add_image(np.expand_dims(imgz.sim_sr_fista, axis=-3),
2418+
scale=(dxy_sim, dxy_sim),
2419+
translate=translate_pattern_2x,
2420+
name="SIM-SR FISTA")
2421+
23952422
if hasattr(imgz, "sim_sr"):
23962423
viewer.add_image(np.expand_dims(imgz.sim_sr, axis=-3),
23972424
scale=(dxy_sim, dxy_sim),
@@ -2420,19 +2447,6 @@ def show_sim_napari(fname_zarr: str,
24202447
translate=translate_wf,
24212448
name="raw images")
24222449

2423-
if hasattr(imgz, "patterns"):
2424-
viewer.add_image(imgz.patterns,
2425-
scale=(dxy, dxy),
2426-
translate=translate_wf,
2427-
name="patterns")
2428-
2429-
if hasattr(imgz, "patterns_2x"):
2430-
viewer.add_image(imgz.patterns_2x,
2431-
scale=(dxy_sim, dxy_sim),
2432-
# translate=translate_sim,
2433-
translate=[a - 0.25 * dxy for a in translate_wf],
2434-
name="patterns upsampled")
2435-
24362450
viewer.show(block=block)
24372451

24382452
return viewer
@@ -2582,10 +2596,10 @@ def fit_modulation_frq(mft1: np.ndarray,
25822596
raise ValueError("must have ft1.shape = ft2.shape")
25832597

25842598
# must be on CPU for this function to work
2585-
mft1 = _to_cpu(mft1)
2586-
mft2 = _to_cpu(mft2)
2599+
mft1 = to_cpu(mft1)
2600+
mft2 = to_cpu(mft2)
25872601
if otf is not None:
2588-
otf = _to_cpu(otf)
2602+
otf = to_cpu(otf)
25892603

25902604
# mask
25912605
if mask is None:

0 commit comments

Comments
 (0)