@@ -747,8 +747,9 @@ def estimate_parameters(self,
747
747
raise ValueError (f"frq_estimation_mode=`band-correlation`, but this requires phase guesses,"
748
748
f"and no phase guesses were provided" )
749
749
750
- mempool = cp .get_default_memory_pool ()
751
- memory_start = mempool .used_bytes ()
750
+ if self .use_gpu :
751
+ mempool = cp .get_default_memory_pool ()
752
+ memory_start = mempool .used_bytes ()
752
753
753
754
self .print_log ("starting parameter estimation..." )
754
755
@@ -913,10 +914,11 @@ def estimate_parameters(self,
913
914
phases = np .array (phases )
914
915
amps = np .array (amps )
915
916
916
- # find this is necessary, else mempool gets too big for 8GB GPU's
917
- mempool .free_all_blocks ()
918
- # self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
919
- # self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
917
+ if self .use_gpu :
918
+ # find this is necessary, else mempool gets too big for 8GB GPU's
919
+ mempool .free_all_blocks ()
920
+ # self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
921
+ # self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
920
922
921
923
elif self ._recon_settings ["phase_estimation_mode" ] == "real-space" :
922
924
phase_guess = self .phases_guess
@@ -941,10 +943,11 @@ def estimate_parameters(self,
941
943
phases = np .array (results ).reshape ((self .nangles , self .nphases ))
942
944
amps = np .ones ((self .nangles , self .nphases ))
943
945
944
- # find this is necessary, else mempool gets too big for 8GB GPU's
945
- mempool .free_all_blocks ()
946
- # self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
947
- # self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
946
+ if self .use_gpu :
947
+ # find this is necessary, else mempool gets too big for 8GB GPU's
948
+ mempool .free_all_blocks ()
949
+ # self.print_log(f"after phase estimation used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
950
+ # self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
948
951
else :
949
952
raise ValueError (f"phase_estimation_mode must be one of { self .allowed_phase_estimation_modes } "
950
953
f" but was '{ self ._recon_settings ['phase_estimation_mode' ]:s} '" )
@@ -997,9 +1000,10 @@ def estimate_parameters(self,
997
1000
(self .dy , self .dx ),
998
1001
self .upsample_fact )
999
1002
1000
- # self.print_log(f"after upsampling and band shifting used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
1001
- # self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
1002
- mempool .free_all_blocks ()
1003
+ if self .use_gpu :
1004
+ # self.print_log(f"after upsampling and band shifting used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
1005
+ # self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
1006
+ mempool .free_all_blocks ()
1003
1007
1004
1008
# upsample and shift OTFs
1005
1009
otf_us = resample_bandlimited_ft (self .otf ,
@@ -1058,9 +1062,10 @@ def estimate_parameters(self,
1058
1062
1059
1063
self .print_log (f"estimated global phases and modulation depths in { time .perf_counter () - tstart_mod_depth :.2f} s" )
1060
1064
1061
- # self.print_log(f"after phase correction used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
1062
- # self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
1063
- mempool .free_all_blocks ()
1065
+ if self .use_gpu :
1066
+ # self.print_log(f"after phase correction used GPU memory = {(mempool.used_bytes() - memory_start) / 1e9:.3f}GB")
1067
+ # self.print_log(f"GPU memory pool = {(mempool.total_bytes()) / 1e9:.3f}GB")
1068
+ mempool .free_all_blocks ()
1064
1069
1065
1070
del mask
1066
1071
del otf_shifted
0 commit comments