Skip to content

Commit 62f89f9

Browse files
committed
corrected more GPU logic
1 parent 7a601a1 commit 62f89f9

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

mcsim/analysis/sim_reconstruction.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,9 @@ def estimate_parameters(self,
747747
raise ValueError(f"frq_estimation_mode=`band-correlation`, but this requires phase guesses,"
748748
f"and no phase guesses were provided")
749749

750-
mempool = cp.get_default_memory_pool()
751-
memory_start = mempool.used_bytes()
750+
if self.use_gpu:
751+
mempool = cp.get_default_memory_pool()
752+
memory_start = mempool.used_bytes()
752753

753754
self.print_log("starting parameter estimation...")
754755

@@ -913,10 +914,11 @@ def estimate_parameters(self,
913914
phases = np.array(phases)
914915
amps = np.array(amps)
915916

916-
# find this is necessary, else mempool gets too big for 8GB GPU's
917-
mempool.free_all_blocks()
918-
# self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
919-
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
917+
if self.use_gpu:
918+
# find this is necessary, else mempool gets too big for 8GB GPU's
919+
mempool.free_all_blocks()
920+
# self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
921+
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
920922

921923
elif self._recon_settings["phase_estimation_mode"] == "real-space":
922924
phase_guess = self.phases_guess
@@ -941,10 +943,11 @@ def estimate_parameters(self,
941943
phases = np.array(results).reshape((self.nangles, self.nphases))
942944
amps = np.ones((self.nangles, self.nphases))
943945

944-
# find this is necessary, else mempool gets too big for 8GB GPU's
945-
mempool.free_all_blocks()
946-
# self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
947-
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
946+
if self.use_gpu:
947+
# find this is necessary, else mempool gets too big for 8GB GPU's
948+
mempool.free_all_blocks()
949+
# self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
950+
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
948951
else:
949952
raise ValueError(f"phase_estimation_mode must be one of {self.allowed_phase_estimation_modes}"
950953
f" but was '{self._recon_settings['phase_estimation_mode']:s}'")
@@ -997,9 +1000,10 @@ def estimate_parameters(self,
9971000
(self.dy, self.dx),
9981001
self.upsample_fact)
9991002

1000-
# self.print_log(f"after upsampling and band shifting used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
1001-
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
1002-
mempool.free_all_blocks()
1003+
if self.use_gpu:
1004+
# self.print_log(f"after upsampling and band shifting used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
1005+
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
1006+
mempool.free_all_blocks()
10031007

10041008
# upsample and shift OTFs
10051009
otf_us = resample_bandlimited_ft(self.otf,
@@ -1058,9 +1062,10 @@ def estimate_parameters(self,
10581062

10591063
self.print_log(f"estimated global phases and modulation depths in {time.perf_counter() - tstart_mod_depth:.2f}s")
10601064

1061-
# self.print_log(f"after phase correction used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
1062-
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
1063-
mempool.free_all_blocks()
1065+
if self.use_gpu:
1066+
# self.print_log(f"after phase correction used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
1067+
# self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
1068+
mempool.free_all_blocks()
10641069

10651070
del mask
10661071
del otf_shifted

0 commit comments

Comments
 (0)