|
28 | 28 | # Imports
|
29 | 29 | ############################################################
|
30 | 30 | import argparse # noqa
|
31 |
| -parser = argparse.ArgumentParser( |
32 |
| - description='Process single electrode waveforms') |
33 |
| -parser.add_argument('data_dir', type=str, help='Path to data directory') |
34 |
| -parser.add_argument('electrode_num', type=int, |
35 |
| - help='Electrode number to process') |
36 |
| -args = parser.parse_args() |
| 31 | +import os # noqa |
| 32 | +from utils.blech_utils import imp_metadata, pipeline_graph_check # noqa |
| 33 | + |
| 34 | +test_bool = False |
| 35 | +if test_bool: |
| 36 | + args = argparse.Namespace( |
| 37 | + data_dir='/media/storage/abu_resorted/gc_only/AM34_4Tastes_201216_105150/', |
| 38 | + electrode_num=0 |
| 39 | + ) |
| 40 | +else: |
| 41 | + parser = argparse.ArgumentParser( |
| 42 | + description='Process single electrode waveforms') |
| 43 | + parser.add_argument('data_dir', type=str, help='Path to data directory') |
| 44 | + parser.add_argument('electrode_num', type=int, |
| 45 | + help='Electrode number to process') |
| 46 | + args = parser.parse_args() |
| 47 | + |
| 48 | + # Perform pipeline graph check |
| 49 | + script_path = os.path.realpath(__file__) |
| 50 | + this_pipeline_check = pipeline_graph_check(args.data_dir) |
| 51 | + this_pipeline_check.check_previous(script_path) |
| 52 | + this_pipeline_check.write_to_log(script_path, 'attempted') |
| 53 | + |
37 | 54 |
|
38 | 55 | # Set environment variables to limit the number of threads used by various libraries
|
39 | 56 | # Do it at the start of the script to ensure it applies to all imported libraries
|
40 |
| -import os # noqa |
41 | 57 | os.environ['OMP_NUM_THREADS'] = '1' # noqa
|
42 | 58 | os.environ['MKL_NUM_THREADS'] = '1' # noqa
|
43 | 59 | os.environ['OPENBLAS_NUM_THREADS'] = '1' # noqa
|
|
50 | 66 | import json # noqa
|
51 | 67 | import pylab as plt # noqa
|
52 | 68 | import utils.blech_process_utils as bpu # noqa
|
53 |
| -from utils.blech_utils import imp_metadata, pipeline_graph_check # noqa |
| 69 | +from itertools import product # noqa |
54 | 70 |
|
55 | 71 | # Confirm sys.argv[1] is a path that exists
|
56 | 72 | if not os.path.exists(args.data_dir):
|
|
74 | 90 | blech_clust_dir = path_handler.blech_clust_dir
|
75 | 91 | data_dir_name = args.data_dir
|
76 | 92 |
|
77 |
| -# Perform pipeline graph check |
78 |
| -script_path = os.path.realpath(__file__) |
79 |
| -this_pipeline_check = pipeline_graph_check(data_dir_name) |
80 |
| -this_pipeline_check.check_previous(script_path) |
81 |
| -this_pipeline_check.write_to_log(script_path, 'attempted') |
82 |
| - |
83 | 93 | metadata_handler = imp_metadata([[], data_dir_name])
|
84 | 94 | os.chdir(metadata_handler.dir_name)
|
85 | 95 |
|
|
126 | 136 | electrode_num,
|
127 | 137 | params_dict)
|
128 | 138 |
|
129 |
| -electrode.filter_electrode() |
130 |
| - |
131 |
| -# Calculate the 3 voltage parameters |
132 |
| -electrode.cut_to_int_seconds() |
133 |
| -electrode.calc_recording_cutoff() |
134 |
| - |
135 |
| -# Dump a plot showing where the recording was cut off at |
136 |
| -electrode.make_cutoff_plot() |
137 |
| - |
138 |
| -# Then cut the recording accordingly |
139 |
| -electrode.cutoff_electrode() |
| 139 | +# Run complete preprocessing pipeline |
| 140 | +filtered_data = electrode.preprocess_electrode() |
140 | 141 |
|
141 | 142 | #############################################################
|
142 | 143 | # Process Spikes
|
143 | 144 | #############################################################
|
144 | 145 |
|
145 |
| -# Extract spike times and waveforms from filtered data |
146 |
| -spike_set = bpu.spike_handler(electrode.filt_el, |
| 146 | +# Extract and process spikes from filtered data |
| 147 | +spike_set = bpu.spike_handler(filtered_data, |
147 | 148 | params_dict, data_dir_name, electrode_num)
|
148 |
| -spike_set.extract_waveforms() |
| 149 | +slices_dejittered, times_dejittered, threshold, mean_val = spike_set.process_spikes() |
149 | 150 |
|
150 | 151 | ############################################################
|
151 | 152 | # Extract windows from filt_el and plot with threshold overlayed
|
152 | 153 | window_len = 0.2 # sec
|
153 | 154 | window_count = 10
|
154 | 155 | fig = bpu.gen_window_plots(
|
155 |
| - electrode.filt_el, |
| 156 | + filtered_data, |
156 | 157 | window_len,
|
157 | 158 | window_count,
|
158 | 159 | params_dict['sampling_rate'],
|
159 |
| - spike_set.spike_times, |
160 |
| - spike_set.mean_val, |
161 |
| - spike_set.threshold, |
| 160 | + times_dejittered, |
| 161 | + mean_val, |
| 162 | + threshold, |
162 | 163 | )
|
163 | 164 | fig.savefig(f'./Plots/{electrode_num:02}/bandapass_trace_snippets.png',
|
164 | 165 | bbox_inches='tight', dpi=300)
|
|
168 | 169 | # Delete filtered electrode from memory
|
169 | 170 | del electrode
|
170 | 171 |
|
171 |
| -# Dejitter these spike waveforms, and get their maximum amplitudes |
172 |
| -# Slices are returned sorted by amplitude polaity |
173 |
| -spike_set.dejitter_spikes() |
174 |
| - |
175 | 172 | ############################################################
|
176 | 173 | # Load classifier if specificed
|
177 | 174 | classifier_params_path = \
|
|
190 | 187 | classifier_handler.load_pipelines()
|
191 | 188 |
|
192 | 189 | # If override_classifier_threshold is set, use that
|
193 |
| - if classifier_params['override_classifier_threshold'] is not False: |
194 |
| - clf_threshold = classifier_params['threshold_override'] |
| 190 | + if classifier_params['classifier_threshold_override']['override'] is not False: |
| 191 | + clf_threshold = classifier_params['classifier_threshold_override']['threshold'] |
195 | 192 | print(f' == Overriding classifier threshold with {clf_threshold} ==')
|
196 | 193 | classifier_handler.clf_threshold = clf_threshold
|
197 | 194 |
|
|
245 | 242 | print('=== Performing manual clustering ===')
|
246 | 243 | # Run GMM, from 2 to max_clusters
|
247 | 244 | max_clusters = params_dict['clustering_params']['max_clusters']
|
248 |
| - for cluster_num in range(2, max_clusters+1): |
249 |
| - cluster_handler = bpu.cluster_handler( |
250 |
| - params_dict, |
251 |
| - data_dir_name, |
252 |
| - electrode_num, |
253 |
| - cluster_num, |
254 |
| - spike_set, |
255 |
| - fit_type='manual', |
256 |
| - ) |
257 |
| - cluster_handler.perform_prediction() |
258 |
| - cluster_handler.remove_outliers(params_dict) |
259 |
| - cluster_handler.calc_mahalanobis_distance_matrix() |
260 |
| - cluster_handler.save_cluster_labels() |
261 |
| - cluster_handler.create_output_plots(params_dict) |
262 |
| - if classifier_params['use_classifier'] and \ |
263 |
| - classifier_params['use_neuRecommend']: |
264 |
| - cluster_handler.create_classifier_plots(classifier_handler) |
| 245 | + iters = product( |
| 246 | + np.arange(2, max_clusters+1), |
| 247 | + ['manual'] |
| 248 | + ) |
265 | 249 | else:
|
266 | 250 | print('=== Performing auto_clustering ===')
|
267 | 251 | max_clusters = auto_params['max_autosort_clusters']
|
| 252 | + iters = [ |
| 253 | + (max_clusters, 'auto') |
| 254 | + ] |
| 255 | + |
| 256 | +for cluster_num, fit_type in iters: |
| 257 | + # Pass specific data instead of the whole spike_set |
268 | 258 | cluster_handler = bpu.cluster_handler(
|
269 | 259 | params_dict,
|
270 | 260 | data_dir_name,
|
271 | 261 | electrode_num,
|
272 |
| - max_clusters, |
273 |
| - spike_set, |
274 |
| - fit_type='auto', |
| 262 | + cluster_num, |
| 263 | + spike_features=spike_set.spike_features, |
| 264 | + slices_dejittered=spike_set.slices_dejittered, |
| 265 | + times_dejittered=spike_set.times_dejittered, |
| 266 | + threshold=spike_set.threshold, |
| 267 | + feature_names=spike_set.feature_names, |
| 268 | + fit_type=fit_type, |
275 | 269 | )
|
276 |
| - cluster_handler.perform_prediction() |
| 270 | + # Use the new simplified clustering method |
| 271 | + cluster_handler.perform_clustering() |
277 | 272 | cluster_handler.remove_outliers(params_dict)
|
278 | 273 | cluster_handler.calc_mahalanobis_distance_matrix()
|
279 | 274 | cluster_handler.save_cluster_labels()
|
|
282 | 277 | classifier_params['use_neuRecommend']:
|
283 | 278 | cluster_handler.create_classifier_plots(classifier_handler)
|
284 | 279 |
|
285 |
| - |
286 | 280 | print(f'Electrode {electrode_num} complete.')
|
287 | 281 |
|
288 | 282 | # Update processing log with completion
|
|
0 commit comments