Skip to content

Commit 10c7fab

Browse files
Merge pull request #435 from katzlabbrandeis/434-interlocking-of-classes-in-blech_processpy-and-blech_process_utilspy-is-getting-too-complicated
refactor: Improve modularity and reduce class interdependencies in spike processing pipeline
2 parents 11e2908 + 7f770dd commit 10c7fab

File tree

5 files changed

+235
-225
lines changed

5 files changed

+235
-225
lines changed

blech_autosort.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ else
7171
fi
7272

7373
echo === Post Process ===
74-
python blech_post_process.py $DIR &&
74+
if [ $FORCE -eq 1 ]; then
75+
python blech_post_process.py $DIR --delete-existing &&
76+
else
77+
python blech_post_process.py $DIR &&
78+
fi
7579

7680
bash blech_clust_post.sh $DIR

blech_process.py

Lines changed: 54 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,32 @@
2828
# Imports
2929
############################################################
3030
import argparse # noqa
31-
parser = argparse.ArgumentParser(
32-
description='Process single electrode waveforms')
33-
parser.add_argument('data_dir', type=str, help='Path to data directory')
34-
parser.add_argument('electrode_num', type=int,
35-
help='Electrode number to process')
36-
args = parser.parse_args()
31+
import os # noqa
32+
from utils.blech_utils import imp_metadata, pipeline_graph_check # noqa
33+
34+
test_bool = False
35+
if test_bool:
36+
args = argparse.Namespace(
37+
data_dir='/media/storage/abu_resorted/gc_only/AM34_4Tastes_201216_105150/',
38+
electrode_num=0
39+
)
40+
else:
41+
parser = argparse.ArgumentParser(
42+
description='Process single electrode waveforms')
43+
parser.add_argument('data_dir', type=str, help='Path to data directory')
44+
parser.add_argument('electrode_num', type=int,
45+
help='Electrode number to process')
46+
args = parser.parse_args()
47+
48+
# Perform pipeline graph check
49+
script_path = os.path.realpath(__file__)
50+
this_pipeline_check = pipeline_graph_check(args.data_dir)
51+
this_pipeline_check.check_previous(script_path)
52+
this_pipeline_check.write_to_log(script_path, 'attempted')
53+
3754

3855
# Set environment variables to limit the number of threads used by various libraries
3956
# Do it at the start of the script to ensure it applies to all imported libraries
40-
import os # noqa
4157
os.environ['OMP_NUM_THREADS'] = '1' # noqa
4258
os.environ['MKL_NUM_THREADS'] = '1' # noqa
4359
os.environ['OPENBLAS_NUM_THREADS'] = '1' # noqa
@@ -50,7 +66,7 @@
5066
import json # noqa
5167
import pylab as plt # noqa
5268
import utils.blech_process_utils as bpu # noqa
53-
from utils.blech_utils import imp_metadata, pipeline_graph_check # noqa
69+
from itertools import product # noqa
5470

5571
# Confirm sys.argv[1] is a path that exists
5672
if not os.path.exists(args.data_dir):
@@ -74,12 +90,6 @@
7490
blech_clust_dir = path_handler.blech_clust_dir
7591
data_dir_name = args.data_dir
7692

77-
# Perform pipeline graph check
78-
script_path = os.path.realpath(__file__)
79-
this_pipeline_check = pipeline_graph_check(data_dir_name)
80-
this_pipeline_check.check_previous(script_path)
81-
this_pipeline_check.write_to_log(script_path, 'attempted')
82-
8393
metadata_handler = imp_metadata([[], data_dir_name])
8494
os.chdir(metadata_handler.dir_name)
8595

@@ -126,39 +136,30 @@
126136
electrode_num,
127137
params_dict)
128138

129-
electrode.filter_electrode()
130-
131-
# Calculate the 3 voltage parameters
132-
electrode.cut_to_int_seconds()
133-
electrode.calc_recording_cutoff()
134-
135-
# Dump a plot showing where the recording was cut off at
136-
electrode.make_cutoff_plot()
137-
138-
# Then cut the recording accordingly
139-
electrode.cutoff_electrode()
139+
# Run complete preprocessing pipeline
140+
filtered_data = electrode.preprocess_electrode()
140141

141142
#############################################################
142143
# Process Spikes
143144
#############################################################
144145

145-
# Extract spike times and waveforms from filtered data
146-
spike_set = bpu.spike_handler(electrode.filt_el,
146+
# Extract and process spikes from filtered data
147+
spike_set = bpu.spike_handler(filtered_data,
147148
params_dict, data_dir_name, electrode_num)
148-
spike_set.extract_waveforms()
149+
slices_dejittered, times_dejittered, threshold, mean_val = spike_set.process_spikes()
149150

150151
############################################################
151152
# Extract windows from filt_el and plot with threshold overlayed
152153
window_len = 0.2 # sec
153154
window_count = 10
154155
fig = bpu.gen_window_plots(
155-
electrode.filt_el,
156+
filtered_data,
156157
window_len,
157158
window_count,
158159
params_dict['sampling_rate'],
159-
spike_set.spike_times,
160-
spike_set.mean_val,
161-
spike_set.threshold,
160+
times_dejittered,
161+
mean_val,
162+
threshold,
162163
)
163164
fig.savefig(f'./Plots/{electrode_num:02}/bandapass_trace_snippets.png',
164165
bbox_inches='tight', dpi=300)
@@ -168,10 +169,6 @@
168169
# Delete filtered electrode from memory
169170
del electrode
170171

171-
# Dejitter these spike waveforms, and get their maximum amplitudes
172-
# Slices are returned sorted by amplitude polaity
173-
spike_set.dejitter_spikes()
174-
175172
############################################################
176173
# Load classifier if specificed
177174
classifier_params_path = \
@@ -190,8 +187,8 @@
190187
classifier_handler.load_pipelines()
191188

192189
# If override_classifier_threshold is set, use that
193-
if classifier_params['override_classifier_threshold'] is not False:
194-
clf_threshold = classifier_params['threshold_override']
190+
if classifier_params['classifier_threshold_override']['override'] is not False:
191+
clf_threshold = classifier_params['classifier_threshold_override']['threshold']
195192
print(f' == Overriding classifier threshold with {clf_threshold} ==')
196193
classifier_handler.clf_threshold = clf_threshold
197194

@@ -245,35 +242,33 @@
245242
print('=== Performing manual clustering ===')
246243
# Run GMM, from 2 to max_clusters
247244
max_clusters = params_dict['clustering_params']['max_clusters']
248-
for cluster_num in range(2, max_clusters+1):
249-
cluster_handler = bpu.cluster_handler(
250-
params_dict,
251-
data_dir_name,
252-
electrode_num,
253-
cluster_num,
254-
spike_set,
255-
fit_type='manual',
256-
)
257-
cluster_handler.perform_prediction()
258-
cluster_handler.remove_outliers(params_dict)
259-
cluster_handler.calc_mahalanobis_distance_matrix()
260-
cluster_handler.save_cluster_labels()
261-
cluster_handler.create_output_plots(params_dict)
262-
if classifier_params['use_classifier'] and \
263-
classifier_params['use_neuRecommend']:
264-
cluster_handler.create_classifier_plots(classifier_handler)
245+
iters = product(
246+
np.arange(2, max_clusters+1),
247+
['manual']
248+
)
265249
else:
266250
print('=== Performing auto_clustering ===')
267251
max_clusters = auto_params['max_autosort_clusters']
252+
iters = [
253+
(max_clusters, 'auto')
254+
]
255+
256+
for cluster_num, fit_type in iters:
257+
# Pass specific data instead of the whole spike_set
268258
cluster_handler = bpu.cluster_handler(
269259
params_dict,
270260
data_dir_name,
271261
electrode_num,
272-
max_clusters,
273-
spike_set,
274-
fit_type='auto',
262+
cluster_num,
263+
spike_features=spike_set.spike_features,
264+
slices_dejittered=spike_set.slices_dejittered,
265+
times_dejittered=spike_set.times_dejittered,
266+
threshold=spike_set.threshold,
267+
feature_names=spike_set.feature_names,
268+
fit_type=fit_type,
275269
)
276-
cluster_handler.perform_prediction()
270+
# Use the new simplified clustering method
271+
cluster_handler.perform_clustering()
277272
cluster_handler.remove_outliers(params_dict)
278273
cluster_handler.calc_mahalanobis_distance_matrix()
279274
cluster_handler.save_cluster_labels()
@@ -282,7 +277,6 @@
282277
classifier_params['use_neuRecommend']:
283278
cluster_handler.create_classifier_plots(classifier_handler)
284279

285-
286280
print(f'Electrode {electrode_num} complete.')
287281

288282
# Update processing log with completion

params/_templates/waveform_classifier_params.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"use_classifier": true,
44
"throw_out_noise": false,
55
"min_suggestion_count": 2000,
6-
"override_classifier_threshold": false,
7-
"threshold_override": 0.8
6+
"classifier_threshold_override": {
7+
"override": false,
8+
"threshold": 0.8
9+
}
810
}

0 commit comments

Comments
 (0)