Skip to content

Commit 53182d3

Browse files
Merge pull request #256 from katzlabbrandeis/236-parallelize-autosort-processing
feat: Parallelize electrode processing with multiprocessing
2 parents 6ceb470 + b92a996 commit 53182d3

12 files changed

+519
-267
lines changed

.github/workflows/python_workflow_test.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ on: [pull_request]
77
jobs:
88
Preamble:
99
runs-on: self-hosted
10+
concurrency:
11+
group: ${{ github.workflow }}-${{ github.ref }} preamble
12+
cancel-in-progress: true
1013
steps:
1114
- run: pwd
1215
- run: which python
@@ -26,6 +29,9 @@ jobs:
2629
Spike-Only:
2730
runs-on: self-hosted
2831
needs: Preamble
32+
concurrency:
33+
group: ${{ github.workflow }}-${{ github.ref }} spike
34+
cancel-in-progress: true
2935
steps:
3036
- name: Prefect SPIKE only test
3137
shell: bash
@@ -37,6 +43,9 @@ jobs:
3743
EMG-Only:
3844
runs-on: self-hosted
3945
needs: Preamble
46+
concurrency:
47+
group: ${{ github.workflow }}-${{ github.ref }} emg
48+
cancel-in-progress: true
4049
steps:
4150
- name: Prefect EMG only test
4251
shell: bash
@@ -48,6 +57,9 @@ jobs:
4857
Spike-EMG:
4958
runs-on: self-hosted
5059
needs: Preamble
60+
concurrency:
61+
group: ${{ github.workflow }}-${{ github.ref }} spike+emg
62+
cancel-in-progress: true
5163
steps:
5264
- name: Prefect SPIKE then EMG test
5365
shell: bash

blech_clust_pre.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ python blech_clust.py $DIR &&
44
echo Running Common Average Reference
55
python blech_common_avg_reference.py $DIR &&
66
echo Running Jetstream Bash
7-
for x in $(seq 10);do bash blech_run_process.sh $DIR;done
7+
bash blech_run_process.sh $DIR

blech_post_process.py

Lines changed: 58 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,13 @@
6565
description = 'Spike extraction and sorting script')
6666
parser.add_argument('dir_name',
6767
help = 'Directory containing data files')
68-
parser.add_argument('--show-plot', '-p',
69-
help = 'Show waveforms while iterating (True/False)', default = 'True')
7068
parser.add_argument('--sort-file', '-f', help = 'CSV with sorted units',
7169
default = None)
70+
parser.add_argument('--show-plot',
71+
help = 'Show waveforms while iterating',
72+
action = 'store_true')
73+
parser.add_argument('--keep-raw', help = 'Keep raw data in hdf5 file',
74+
action = 'store_true')
7275
args = parser.parse_args()
7376

7477
############################################################
@@ -83,6 +86,8 @@
8386
import matplotlib
8487
from glob import glob
8588
import re
89+
from functools import partial
90+
from multiprocessing import Pool, cpu_count
8691

8792
matplotlib.rcParams['font.size'] = 6
8893

@@ -104,7 +109,6 @@
104109
else:
105110
metadata_handler = imp_metadata([])
106111

107-
108112
# Extract parameters for automatic processing
109113
params_dict = metadata_handler.params_dict
110114
sampling_rate = params_dict['sampling_rate']
@@ -131,7 +135,11 @@
131135

132136

133137
# Delete the raw node, if it exists in the hdf5 file, to cut down on file size
134-
repacked_bool = post_utils.delete_raw_recordings(hdf5_name)
138+
if args.keep_raw == 'False':
139+
repacked_bool = post_utils.delete_raw_recordings(hdf5_name)
140+
else:
141+
repacked_bool = False
142+
print('=== Keeping raw data in hdf5 file ===')
135143

136144
# Open the hdf5 file
137145
if repacked_bool:
@@ -167,7 +175,7 @@
167175
############################################################
168176

169177
print()
170-
print('======================================')
178+
print('==== Manual Post-Processing ====\n')
171179
print()
172180

173181
# If sort_file given, iterate through that, otherwise ask user
@@ -192,12 +200,12 @@
192200
energy,
193201
amplitudes,
194202
predictions,
195-
) = post_utils.load_data_from_disk(electrode_num, num_clusters)
203+
) = post_utils.load_data_from_disk(dir_name, electrode_num, num_clusters)
196204

197205
# Re-show images of neurons so dumb people like Abu can make sure they
198206
# picked the right ones
199207
#if ast.literal_eval(args.show_plot):
200-
if args.show_plot == 'True':
208+
if args.show_plot:
201209
post_utils.gen_select_cluster_plot(electrode_num, num_clusters, clusters)
202210

203211
############################################################
@@ -423,191 +431,56 @@
423431
for this_electrode in electrode_list]
424432
electrode_num_list.sort()
425433

426-
for electrode_num in electrode_num_list:
427-
############################################################
428-
# Get unit details and load data
429-
############################################################
430-
431-
print()
432-
print('======================================')
433-
print()
434-
435-
# Iterate over electrodes and pull out spikes
436-
# Get classifier probabilities for each spike and use only
437-
# "good" spikes
438-
439-
# Print out selections
440-
print(f'=== Processing Electrode {electrode_num:02} ===')
441-
442-
# Load data from the chosen electrode
443-
# We can pick any soluation, but need to know what
444-
# solutions are present
445-
446-
(
447-
spike_waveforms,
448-
spike_times,
449-
pca_slices,
450-
energy,
451-
amplitudes,
452-
split_predictions,
453-
) = post_utils.load_data_from_disk(electrode_num, max_autosort_clusters)
454-
455-
clf_data_paths = [
456-
f'./spike_waveforms/electrode{electrode_num:02}/clf_prob.npy',
457-
f'./spike_waveforms/electrode{electrode_num:02}/clf_pred.npy',
458-
]
459-
clf_prob, clf_pred = [np.load(this_path) for this_path in clf_data_paths]
460-
461-
# If auto-clustering was done, data has already been trimmed
462-
# Only clf_pred needs to be trimmed
463-
clf_prob = clf_prob[clf_pred]
464-
clf_pred = clf_pred[clf_pred]
465-
466-
##############################
467-
# Calculate whether the cluster is a wanted_unit
468-
# This will be useful for merging, so we only merge units
469-
470-
##############################
471-
# Merge clusters using mahalanobis distance
472-
# If min( mahal a->b, mahal b->a ) < threshold, merge
473-
# Unless ISI violations are > threshold
474-
475-
mahal_thresh = auto_params['mahalanobis_merge_thresh']
476-
isi_threshs = auto_params['ISI_violations_thresholds']
477-
478-
mahal_mat_path = os.path.join(
479-
'.',
480-
'clustering_results',
481-
f'electrode{electrode_num:02}',
482-
f'clusters{max_autosort_clusters:02}',
483-
'mahalanobis_distances.npy',
484-
)
485-
mahal_mat = np.load(mahal_mat_path)
486-
487-
unique_clusters = np.unique(split_predictions)
488-
assert len(unique_clusters) == len(mahal_mat), \
489-
'Mahalanobis matrix does not match number of clusters'
490-
491-
(
492-
final_merge_sets,
493-
new_clust_names,
494-
)= post_utils.calculate_merge_sets(
495-
mahal_mat,
496-
mahal_thresh,
497-
isi_threshs,
498-
split_predictions,
499-
spike_waveforms,
500-
spike_times,
501-
clf_prob,
502-
chi_square_alpha,
503-
count_threshold,
504-
sampling_rate,
505-
)
506-
507-
508-
if len(final_merge_sets) > 0:
509-
# Create names for merged clusters
510-
# Rename both to max_clusters
511-
512-
# Print out merge sets
513-
print(f'=== Merging {len(final_merge_sets)} Clusters ===')
514-
for this_merge_set, new_name in zip(final_merge_sets, new_clust_names):
515-
print(f'==== {this_merge_set} => {new_name} ====')
516-
517-
fig, ax = post_utils.gen_plot_auto_merged_clusters(
518-
spike_waveforms,
519-
spike_times,
520-
split_predictions,
521-
sampling_rate,
522-
final_merge_sets,
523-
new_clust_names,
524-
)
525-
526-
# In case first unit is merged, we need to create the autosort_output_dir
527-
if not os.path.exists(autosort_output_dir):
528-
os.makedirs(autosort_output_dir)
529-
530-
fig.savefig(
531-
os.path.join(
532-
autosort_output_dir,
533-
f'{electrode_num:02}_merged_units.png',
534-
),
535-
bbox_inches = 'tight',
536-
)
537-
plt.close(fig)
538-
539-
# Update split_predictions
540-
for this_set, this_name in zip(final_merge_sets, new_clust_names):
541-
for this_cluster in this_set:
542-
split_predictions[split_predictions == this_cluster] = this_name
543-
544-
##############################
545-
546-
# Take everything
547-
data = post_utils.prepare_data(
548-
np.arange(len(spike_waveforms)),
549-
pca_slices,
550-
energy,
551-
amplitudes,
552-
)
434+
# Create processing parameters tuple
435+
process_params = (
436+
max_autosort_clusters,
437+
auto_params,
438+
chi_square_alpha,
439+
count_threshold,
440+
sampling_rate,
441+
metadata_handler.dir_name,
442+
)
553443

554-
(
555-
subcluster_inds,
556-
subcluster_waveforms,
557-
subcluster_prob,
558-
subcluster_times,
559-
mean_waveforms,
560-
std_waveforms,
561-
chi_out,
562-
fin_bool,
563-
fin_bool_dict,
564-
) = \
565-
post_utils.get_cluster_props(
566-
split_predictions,
567-
spike_waveforms,
568-
clf_prob,
569-
spike_times,
570-
chi_square_alpha,
571-
count_threshold,
444+
# Use multiprocessing to process electrodes in parallel
445+
n_cores = np.min(
446+
(
447+
len(electrode_num_list),
448+
cpu_count() - 1,
449+
params_dict['max_parallel_cpu']
572450
)
573-
574-
575-
##############################
576-
# Generate plots for each subcluster
577-
##############################
578-
579-
post_utils.gen_autosort_plot(
580-
subcluster_prob,
581-
subcluster_waveforms,
582-
chi_out,
583-
mean_waveforms,
584-
std_waveforms,
585-
subcluster_times,
586-
fin_bool,
587-
np.unique(split_predictions),
588-
electrode_num,
589-
sampling_rate,
590-
autosort_output_dir,
591-
n_max_plot=5000,
451+
) # Leave one core free
452+
print(f"Processing {len(electrode_num_list)} electrodes using {n_cores} cores")
453+
454+
# Create partial function
455+
auto_process_partial = partial(
456+
post_utils.auto_process_electrode,
457+
process_params = process_params
458+
)
459+
460+
print(f'== Saving to {autosort_output_dir} ==')
461+
with Pool(n_cores) as pool:
462+
result = pool.starmap(
463+
auto_process_partial,
464+
[(electrode_num,) for electrode_num in electrode_num_list]
592465
)
593466

594-
############################################################
595-
# Finally, save the unit to the HDF5 file
596-
############################################################
597-
598-
############################################################
599-
# Subsetting this set of waveforms to include only the chosen split
600-
467+
# This last part cannot be incorporated in auto_process_electrode as it
468+
# needs passing of classes (descriptor_handler, and sort_file_handler) to
469+
# the processes.
470+
# Get pickling errors when they are included
471+
# It is also a quick process so it doesn't need to be parallelized
472+
print('Writing sorted units to file...')
473+
for subcluster_waveforms, subcluster_times, fin_bool, electrode_num in result:
601474
for this_sub in range(len(subcluster_waveforms)):
602475
if fin_bool[this_sub]:
603476
continue_bool, unit_name = this_descriptor_handler.save_unit(
604-
subcluster_waveforms[this_sub],
605-
subcluster_times[this_sub],
606-
electrode_num,
607-
this_sort_file_handler,
608-
split_or_merge = None,
609-
override_ask = True,
610-
)
477+
subcluster_waveforms[this_sub],
478+
subcluster_times[this_sub],
479+
electrode_num,
480+
this_sort_file_handler,
481+
split_or_merge=None,
482+
override_ask=True,
483+
)
611484
else:
612485
continue_bool = True
613486

0 commit comments

Comments
 (0)