46
46
from matplotlib .patches import Circle , Rectangle
47
47
# code from our projects
48
48
import mcsim .analysis .analysis_tools as tools
49
- from mcsim .analysis .optimize import Optimizer , soft_threshold , tv_prox , _to_cpu
49
+ from mcsim .analysis .optimize import Optimizer , soft_threshold , tv_prox , to_cpu
50
50
from mcsim .analysis .fft import ft2 , ift2
51
51
from localize_psf .rois import get_centered_rois , cut_roi
52
52
from localize_psf .fit_psf import circ_aperture_otf , blur_img_psf , oversample_voxel
@@ -1332,11 +1332,11 @@ def proc_bands(bands, phases, amps, frqs, dy, dx, upsample_fact, use_gpu):
1332
1332
1333
1333
# if cupy array, move off GPU
1334
1334
if isinstance (attr , cp .ndarray ):
1335
- setattr (self , attr_name , _to_cpu (attr ))
1335
+ setattr (self , attr_name , to_cpu (attr ))
1336
1336
1337
1337
# if dask array, move off GPU delayed
1338
1338
if isinstance (attr , da .core .Array ):
1339
- on_cpu = da .map_blocks (_to_cpu , attr , dtype = attr .dtype )
1339
+ on_cpu = da .map_blocks (to_cpu , attr , dtype = attr .dtype )
1340
1340
setattr (self , attr_name , on_cpu )
1341
1341
1342
1342
self .print_log (f"reconstruction took { time .perf_counter () - tstart_recon :.2f} s" )
@@ -2370,13 +2370,28 @@ def show_sim_napari(fname_zarr: str,
2370
2370
2371
2371
dxy = imgz .attrs ["dx" ]
2372
2372
dxy_sim = dxy / imgz .attrs ["upsample_factor" ]
2373
+
2374
+ # translate to put FFT zero coordinates at origin
2373
2375
translate_wf = (- (wf .shape [- 2 ] // 2 ) * dxy , - (wf .shape [- 1 ] // 2 ) * dxy )
2374
2376
translate_sim = (- ((2 * wf .shape [- 2 ]) // 2 ) * dxy_sim , - ((2 * wf .shape [- 1 ]) // 2 ) * dxy_sim )
2377
+ translate_pattern_2x = [a - 0.25 * dxy for a in translate_wf ]
2375
2378
2376
2379
if viewer is None :
2377
2380
viewer = napari .Viewer ()
2378
2381
2379
- # translate to put FFT zero coordinates at origin
2382
+ if hasattr (imgz , "patterns" ):
2383
+ viewer .add_image (imgz .patterns ,
2384
+ scale = (dxy , dxy ),
2385
+ translate = translate_wf ,
2386
+ name = "patterns" )
2387
+
2388
+ if hasattr (imgz , "patterns_2x" ):
2389
+ viewer .add_image (imgz .patterns_2x ,
2390
+ scale = (dxy_sim , dxy_sim ),
2391
+ # translate=translate_sim,
2392
+ translate = translate_pattern_2x ,
2393
+ name = "patterns upsampled" )
2394
+
2380
2395
if hasattr (imgz , "sim_os" ):
2381
2396
sim_os = np .expand_dims (imgz .sim_os , axis = - 3 )
2382
2397
@@ -2392,6 +2407,18 @@ def show_sim_napari(fname_zarr: str,
2392
2407
name = "wf deconvolved" ,
2393
2408
visible = False )
2394
2409
2410
+ if hasattr (imgz , "sim_fista_forward_model" ):
2411
+ viewer .add_image (imgz .sim_fista_forward_model ,
2412
+ scale = (dxy , dxy ),
2413
+ translate = translate_wf ,
2414
+ name = "FISTA forward model" )
2415
+
2416
+ if hasattr (imgz , "sim_sr_fista" ):
2417
+ viewer .add_image (np .expand_dims (imgz .sim_sr_fista , axis = - 3 ),
2418
+ scale = (dxy_sim , dxy_sim ),
2419
+ translate = translate_pattern_2x ,
2420
+ name = "SIM-SR FISTA" )
2421
+
2395
2422
if hasattr (imgz , "sim_sr" ):
2396
2423
viewer .add_image (np .expand_dims (imgz .sim_sr , axis = - 3 ),
2397
2424
scale = (dxy_sim , dxy_sim ),
@@ -2420,19 +2447,6 @@ def show_sim_napari(fname_zarr: str,
2420
2447
translate = translate_wf ,
2421
2448
name = "raw images" )
2422
2449
2423
- if hasattr (imgz , "patterns" ):
2424
- viewer .add_image (imgz .patterns ,
2425
- scale = (dxy , dxy ),
2426
- translate = translate_wf ,
2427
- name = "patterns" )
2428
-
2429
- if hasattr (imgz , "patterns_2x" ):
2430
- viewer .add_image (imgz .patterns_2x ,
2431
- scale = (dxy_sim , dxy_sim ),
2432
- # translate=translate_sim,
2433
- translate = [a - 0.25 * dxy for a in translate_wf ],
2434
- name = "patterns upsampled" )
2435
-
2436
2450
viewer .show (block = block )
2437
2451
2438
2452
return viewer
@@ -2582,10 +2596,10 @@ def fit_modulation_frq(mft1: np.ndarray,
2582
2596
raise ValueError ("must have ft1.shape = ft2.shape" )
2583
2597
2584
2598
# must be on CPU for this function to work
2585
- mft1 = _to_cpu (mft1 )
2586
- mft2 = _to_cpu (mft2 )
2599
+ mft1 = to_cpu (mft1 )
2600
+ mft2 = to_cpu (mft2 )
2587
2601
if otf is not None :
2588
- otf = _to_cpu (otf )
2602
+ otf = to_cpu (otf )
2589
2603
2590
2604
# mask
2591
2605
if mask is None :
0 commit comments