Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b298df0
feat: Parallelize electrode processing with multiprocessing
abuzarmahmood Nov 21, 2024
505d583
Fix metadata getting logic for saving units
abuzarmahmood Nov 21, 2024
ff5343e
feat: Add script to toggle waveform classifier parameter in params file
abuzarmahmood Nov 22, 2024
6aeac80
refactor: Simplify spike_emg_test by extracting spike test logic into…
abuzarmahmood Nov 22, 2024
777beb9
feat: Add jetstream bash runs with and without waveform classifier
abuzarmahmood Nov 22, 2024
343c225
feat: Add --delete-log flag to force deletion of results.log
abuzarmahmood Nov 25, 2024
c33982c
Add --delete-log flag to force log deletion in blech_run_process
abuzarmahmood Nov 25, 2024
d7cc48f
Write out sorting table during pipeline testing
abuzarmahmood Nov 25, 2024
8d6d2d9
Update workflow to have concurrency groups
abuzarmahmood Nov 25, 2024
bcce20f
Update workflow to have concurrency groups
abuzarmahmood Nov 25, 2024
6bd798c
Merge branch '236-parallelize-autosort-processing' of https://github.…
abuzarmahmood Nov 25, 2024
ee82e09
Fix bug in blech_process if not using classifier
abuzarmahmood Nov 25, 2024
593e1ce
Also add concurency group to preamble
abuzarmahmood Nov 25, 2024
d75c548
feat: Add electrode processing log with start/end timestamps and stat…
abuzarmahmood Nov 25, 2024
dd4ee8e
Check arguments into blech_process.py
abuzarmahmood Nov 26, 2024
fd2ebb3
feat: Add log completion status check to blech_run_process.sh
abuzarmahmood Nov 26, 2024
bc71186
feat: Add script to toggle auto clustering and post-processing parame…
abuzarmahmood Nov 26, 2024
e9e8879
refactor: Update toggle_auto_params to use dynamic params file path
abuzarmahmood Nov 26, 2024
1c17e6c
feat: Add toggle_auto_params.py for pipeline testing configuration
abuzarmahmood Nov 26, 2024
8e15626
feat: Add command-line arguments to set auto clustering parameters ex…
abuzarmahmood Nov 26, 2024
37a0ed5
Working version of multiprocess autosort with 1) autosort testing, 2)…
abuzarmahmood Nov 26, 2024
0504a1a
Add max_parallel_cpu to parallel autosort
abuzarmahmood Nov 26, 2024
80ff07d
Add file to change auto-sorting + auto-post-processing params
abuzarmahmood Nov 26, 2024
60408b2
Fix bugs in testing auto-sorting
abuzarmahmood Nov 26, 2024
4a05e62
Update change_auto_params.py to use data dir
abuzarmahmood Nov 26, 2024
aa918b8
Update prefect_pipeline.py to direct auto params changer to data dir
abuzarmahmood Nov 26, 2024
b92a996
Fix bug in
abuzarmahmood Nov 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/python_workflow_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ on: [pull_request]
jobs:
Preamble:
runs-on: self-hosted
concurrency:
group: ${{ github.workflow }}-${{ github.ref }} preamble
cancel-in-progress: true
steps:
- run: pwd
- run: which python
Expand All @@ -26,6 +29,9 @@ jobs:
Spike-Only:
runs-on: self-hosted
needs: Preamble
concurrency:
group: ${{ github.workflow }}-${{ github.ref }} spike
cancel-in-progress: true
steps:
- name: Prefect SPIKE only test
shell: bash
Expand All @@ -37,6 +43,9 @@ jobs:
EMG-Only:
runs-on: self-hosted
needs: Preamble
concurrency:
group: ${{ github.workflow }}-${{ github.ref }} emg
cancel-in-progress: true
steps:
- name: Prefect EMG only test
shell: bash
Expand All @@ -48,6 +57,9 @@ jobs:
Spike-EMG:
runs-on: self-hosted
needs: Preamble
concurrency:
group: ${{ github.workflow }}-${{ github.ref }} spike+emg
cancel-in-progress: true
steps:
- name: Prefect SPIKE then EMG test
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion blech_clust_pre.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ python blech_clust.py $DIR &&
echo Running Common Average Reference
python blech_common_avg_reference.py $DIR &&
echo Running Jetstream Bash
for x in $(seq 10);do bash blech_run_process.sh $DIR;done
bash blech_run_process.sh $DIR
243 changes: 58 additions & 185 deletions blech_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,13 @@
description = 'Spike extraction and sorting script')
parser.add_argument('dir_name',
help = 'Directory containing data files')
parser.add_argument('--show-plot', '-p',
help = 'Show waveforms while iterating (True/False)', default = 'True')
parser.add_argument('--sort-file', '-f', help = 'CSV with sorted units',
default = None)
parser.add_argument('--show-plot',
help = 'Show waveforms while iterating',
action = 'store_true')
parser.add_argument('--keep-raw', help = 'Keep raw data in hdf5 file',
action = 'store_true')
args = parser.parse_args()

############################################################
Expand All @@ -83,6 +86,8 @@
import matplotlib
from glob import glob
import re
from functools import partial
from multiprocessing import Pool, cpu_count

matplotlib.rcParams['font.size'] = 6

Expand All @@ -104,7 +109,6 @@
else:
metadata_handler = imp_metadata([])


# Extract parameters for automatic processing
params_dict = metadata_handler.params_dict
sampling_rate = params_dict['sampling_rate']
Expand All @@ -131,7 +135,11 @@


# Delete the raw node, if it exists in the hdf5 file, to cut down on file size
repacked_bool = post_utils.delete_raw_recordings(hdf5_name)
if args.keep_raw == 'False':
repacked_bool = post_utils.delete_raw_recordings(hdf5_name)
else:
repacked_bool = False
print('=== Keeping raw data in hdf5 file ===')

# Open the hdf5 file
if repacked_bool:
Expand Down Expand Up @@ -167,7 +175,7 @@
############################################################

print()
print('======================================')
print('==== Manual Post-Processing ====\n')
print()

# If sort_file given, iterate through that, otherwise ask user
Expand All @@ -192,12 +200,12 @@
energy,
amplitudes,
predictions,
) = post_utils.load_data_from_disk(electrode_num, num_clusters)
) = post_utils.load_data_from_disk(dir_name, electrode_num, num_clusters)

# Re-show images of neurons so dumb people like Abu can make sure they
# picked the right ones
#if ast.literal_eval(args.show_plot):
if args.show_plot == 'True':
if args.show_plot:
post_utils.gen_select_cluster_plot(electrode_num, num_clusters, clusters)

############################################################
Expand Down Expand Up @@ -423,191 +431,56 @@
for this_electrode in electrode_list]
electrode_num_list.sort()

for electrode_num in electrode_num_list:
############################################################
# Get unit details and load data
############################################################

print()
print('======================================')
print()

# Iterate over electrodes and pull out spikes
# Get classifier probabilities for each spike and use only
# "good" spikes

# Print out selections
print(f'=== Processing Electrode {electrode_num:02} ===')

# Load data from the chosen electrode
# We can pick any soluation, but need to know what
# solutions are present

(
spike_waveforms,
spike_times,
pca_slices,
energy,
amplitudes,
split_predictions,
) = post_utils.load_data_from_disk(electrode_num, max_autosort_clusters)

clf_data_paths = [
f'./spike_waveforms/electrode{electrode_num:02}/clf_prob.npy',
f'./spike_waveforms/electrode{electrode_num:02}/clf_pred.npy',
]
clf_prob, clf_pred = [np.load(this_path) for this_path in clf_data_paths]

# If auto-clustering was done, data has already been trimmed
# Only clf_pred needs to be trimmed
clf_prob = clf_prob[clf_pred]
clf_pred = clf_pred[clf_pred]

##############################
# Calculate whether the cluster is a wanted_unit
# This will be useful for merging, so we only merge units

##############################
# Merge clusters using mahalanobis distance
# If min( mahal a->b, mahal b->a ) < threshold, merge
# Unless ISI violations are > threshold

mahal_thresh = auto_params['mahalanobis_merge_thresh']
isi_threshs = auto_params['ISI_violations_thresholds']

mahal_mat_path = os.path.join(
'.',
'clustering_results',
f'electrode{electrode_num:02}',
f'clusters{max_autosort_clusters:02}',
'mahalanobis_distances.npy',
)
mahal_mat = np.load(mahal_mat_path)

unique_clusters = np.unique(split_predictions)
assert len(unique_clusters) == len(mahal_mat), \
'Mahalanobis matrix does not match number of clusters'

(
final_merge_sets,
new_clust_names,
)= post_utils.calculate_merge_sets(
mahal_mat,
mahal_thresh,
isi_threshs,
split_predictions,
spike_waveforms,
spike_times,
clf_prob,
chi_square_alpha,
count_threshold,
sampling_rate,
)


if len(final_merge_sets) > 0:
# Create names for merged clusters
# Rename both to max_clusters

# Print out merge sets
print(f'=== Merging {len(final_merge_sets)} Clusters ===')
for this_merge_set, new_name in zip(final_merge_sets, new_clust_names):
print(f'==== {this_merge_set} => {new_name} ====')

fig, ax = post_utils.gen_plot_auto_merged_clusters(
spike_waveforms,
spike_times,
split_predictions,
sampling_rate,
final_merge_sets,
new_clust_names,
)

# In case first unit is merged, we need to create the autosort_output_dir
if not os.path.exists(autosort_output_dir):
os.makedirs(autosort_output_dir)

fig.savefig(
os.path.join(
autosort_output_dir,
f'{electrode_num:02}_merged_units.png',
),
bbox_inches = 'tight',
)
plt.close(fig)

# Update split_predictions
for this_set, this_name in zip(final_merge_sets, new_clust_names):
for this_cluster in this_set:
split_predictions[split_predictions == this_cluster] = this_name

##############################

# Take everything
data = post_utils.prepare_data(
np.arange(len(spike_waveforms)),
pca_slices,
energy,
amplitudes,
)
# Create processing parameters tuple
process_params = (
max_autosort_clusters,
auto_params,
chi_square_alpha,
count_threshold,
sampling_rate,
metadata_handler.dir_name,
)

(
subcluster_inds,
subcluster_waveforms,
subcluster_prob,
subcluster_times,
mean_waveforms,
std_waveforms,
chi_out,
fin_bool,
fin_bool_dict,
) = \
post_utils.get_cluster_props(
split_predictions,
spike_waveforms,
clf_prob,
spike_times,
chi_square_alpha,
count_threshold,
# Use multiprocessing to process electrodes in parallel
n_cores = np.min(
(
len(electrode_num_list),
cpu_count() - 1,
params_dict['max_parallel_cpu']
)


##############################
# Generate plots for each subcluster
##############################

post_utils.gen_autosort_plot(
subcluster_prob,
subcluster_waveforms,
chi_out,
mean_waveforms,
std_waveforms,
subcluster_times,
fin_bool,
np.unique(split_predictions),
electrode_num,
sampling_rate,
autosort_output_dir,
n_max_plot=5000,
) # Leave one core free
print(f"Processing {len(electrode_num_list)} electrodes using {n_cores} cores")

# Create partial function
auto_process_partial = partial(
post_utils.auto_process_electrode,
process_params = process_params
)

print(f'== Saving to {autosort_output_dir} ==')
with Pool(n_cores) as pool:
result = pool.starmap(
auto_process_partial,
[(electrode_num,) for electrode_num in electrode_num_list]
)

############################################################
# Finally, save the unit to the HDF5 file
############################################################

############################################################
# Subsetting this set of waveforms to include only the chosen split

# This last part cannot be incorporated in auto_process_electrode as it
# needs passing of classes (descriptor_handler, and sort_file_handler) to
# the processes.
# Get pickling errors when they are included
# It is also a quick process so it doesn't need to be parallelized
print('Writing sorted units to file...')
for subcluster_waveforms, subcluster_times, fin_bool, electrode_num in result:
for this_sub in range(len(subcluster_waveforms)):
if fin_bool[this_sub]:
continue_bool, unit_name = this_descriptor_handler.save_unit(
subcluster_waveforms[this_sub],
subcluster_times[this_sub],
electrode_num,
this_sort_file_handler,
split_or_merge = None,
override_ask = True,
)
subcluster_waveforms[this_sub],
subcluster_times[this_sub],
electrode_num,
this_sort_file_handler,
split_or_merge=None,
override_ask=True,
)
else:
continue_bool = True

Expand Down
Loading