|
65 | 65 | description = 'Spike extraction and sorting script')
|
66 | 66 | parser.add_argument('dir_name',
|
67 | 67 | help = 'Directory containing data files')
|
68 |
| -parser.add_argument('--show-plot', '-p', |
69 |
| - help = 'Show waveforms while iterating (True/False)', default = 'True') |
70 | 68 | parser.add_argument('--sort-file', '-f', help = 'CSV with sorted units',
|
71 | 69 | default = None)
|
| 70 | +parser.add_argument('--show-plot', |
| 71 | + help = 'Show waveforms while iterating', |
| 72 | + action = 'store_true') |
| 73 | +parser.add_argument('--keep-raw', help = 'Keep raw data in hdf5 file', |
| 74 | + action = 'store_true') |
72 | 75 | args = parser.parse_args()
|
73 | 76 |
|
74 | 77 | ############################################################
|
|
83 | 86 | import matplotlib
|
84 | 87 | from glob import glob
|
85 | 88 | import re
|
| 89 | +from functools import partial |
| 90 | +from multiprocessing import Pool, cpu_count |
86 | 91 |
|
87 | 92 | matplotlib.rcParams['font.size'] = 6
|
88 | 93 |
|
|
104 | 109 | else:
|
105 | 110 | metadata_handler = imp_metadata([])
|
106 | 111 |
|
107 |
| - |
108 | 112 | # Extract parameters for automatic processing
|
109 | 113 | params_dict = metadata_handler.params_dict
|
110 | 114 | sampling_rate = params_dict['sampling_rate']
|
|
131 | 135 |
|
132 | 136 |
|
133 | 137 | # Delete the raw node, if it exists in the hdf5 file, to cut down on file size
|
134 |
| -repacked_bool = post_utils.delete_raw_recordings(hdf5_name) |
| 138 | +if args.keep_raw == 'False': |
| 139 | + repacked_bool = post_utils.delete_raw_recordings(hdf5_name) |
| 140 | +else: |
| 141 | + repacked_bool = False |
| 142 | + print('=== Keeping raw data in hdf5 file ===') |
135 | 143 |
|
136 | 144 | # Open the hdf5 file
|
137 | 145 | if repacked_bool:
|
|
167 | 175 | ############################################################
|
168 | 176 |
|
169 | 177 | print()
|
170 |
| - print('======================================') |
| 178 | + print('==== Manual Post-Processing ====\n') |
171 | 179 | print()
|
172 | 180 |
|
173 | 181 | # If sort_file given, iterate through that, otherwise ask user
|
|
192 | 200 | energy,
|
193 | 201 | amplitudes,
|
194 | 202 | predictions,
|
195 |
| - ) = post_utils.load_data_from_disk(electrode_num, num_clusters) |
| 203 | + ) = post_utils.load_data_from_disk(dir_name, electrode_num, num_clusters) |
196 | 204 |
|
197 | 205 | # Re-show images of neurons so dumb people like Abu can make sure they
|
198 | 206 | # picked the right ones
|
199 | 207 | #if ast.literal_eval(args.show_plot):
|
200 |
| - if args.show_plot == 'True': |
| 208 | + if args.show_plot: |
201 | 209 | post_utils.gen_select_cluster_plot(electrode_num, num_clusters, clusters)
|
202 | 210 |
|
203 | 211 | ############################################################
|
|
423 | 431 | for this_electrode in electrode_list]
|
424 | 432 | electrode_num_list.sort()
|
425 | 433 |
|
426 |
| - for electrode_num in electrode_num_list: |
427 |
| - ############################################################ |
428 |
| - # Get unit details and load data |
429 |
| - ############################################################ |
430 |
| - |
431 |
| - print() |
432 |
| - print('======================================') |
433 |
| - print() |
434 |
| - |
435 |
| - # Iterate over electrodes and pull out spikes |
436 |
| - # Get classifier probabilities for each spike and use only |
437 |
| - # "good" spikes |
438 |
| - |
439 |
| - # Print out selections |
440 |
| - print(f'=== Processing Electrode {electrode_num:02} ===') |
441 |
| - |
442 |
| - # Load data from the chosen electrode |
443 |
| - # We can pick any soluation, but need to know what |
444 |
| - # solutions are present |
445 |
| - |
446 |
| - ( |
447 |
| - spike_waveforms, |
448 |
| - spike_times, |
449 |
| - pca_slices, |
450 |
| - energy, |
451 |
| - amplitudes, |
452 |
| - split_predictions, |
453 |
| - ) = post_utils.load_data_from_disk(electrode_num, max_autosort_clusters) |
454 |
| - |
455 |
| - clf_data_paths = [ |
456 |
| - f'./spike_waveforms/electrode{electrode_num:02}/clf_prob.npy', |
457 |
| - f'./spike_waveforms/electrode{electrode_num:02}/clf_pred.npy', |
458 |
| - ] |
459 |
| - clf_prob, clf_pred = [np.load(this_path) for this_path in clf_data_paths] |
460 |
| - |
461 |
| - # If auto-clustering was done, data has already been trimmed |
462 |
| - # Only clf_pred needs to be trimmed |
463 |
| - clf_prob = clf_prob[clf_pred] |
464 |
| - clf_pred = clf_pred[clf_pred] |
465 |
| - |
466 |
| - ############################## |
467 |
| - # Calculate whether the cluster is a wanted_unit |
468 |
| - # This will be useful for merging, so we only merge units |
469 |
| - |
470 |
| - ############################## |
471 |
| - # Merge clusters using mahalanobis distance |
472 |
| - # If min( mahal a->b, mahal b->a ) < threshold, merge |
473 |
| - # Unless ISI violations are > threshold |
474 |
| - |
475 |
| - mahal_thresh = auto_params['mahalanobis_merge_thresh'] |
476 |
| - isi_threshs = auto_params['ISI_violations_thresholds'] |
477 |
| - |
478 |
| - mahal_mat_path = os.path.join( |
479 |
| - '.', |
480 |
| - 'clustering_results', |
481 |
| - f'electrode{electrode_num:02}', |
482 |
| - f'clusters{max_autosort_clusters:02}', |
483 |
| - 'mahalanobis_distances.npy', |
484 |
| - ) |
485 |
| - mahal_mat = np.load(mahal_mat_path) |
486 |
| - |
487 |
| - unique_clusters = np.unique(split_predictions) |
488 |
| - assert len(unique_clusters) == len(mahal_mat), \ |
489 |
| - 'Mahalanobis matrix does not match number of clusters' |
490 |
| - |
491 |
| - ( |
492 |
| - final_merge_sets, |
493 |
| - new_clust_names, |
494 |
| - )= post_utils.calculate_merge_sets( |
495 |
| - mahal_mat, |
496 |
| - mahal_thresh, |
497 |
| - isi_threshs, |
498 |
| - split_predictions, |
499 |
| - spike_waveforms, |
500 |
| - spike_times, |
501 |
| - clf_prob, |
502 |
| - chi_square_alpha, |
503 |
| - count_threshold, |
504 |
| - sampling_rate, |
505 |
| - ) |
506 |
| - |
507 |
| - |
508 |
| - if len(final_merge_sets) > 0: |
509 |
| - # Create names for merged clusters |
510 |
| - # Rename both to max_clusters |
511 |
| - |
512 |
| - # Print out merge sets |
513 |
| - print(f'=== Merging {len(final_merge_sets)} Clusters ===') |
514 |
| - for this_merge_set, new_name in zip(final_merge_sets, new_clust_names): |
515 |
| - print(f'==== {this_merge_set} => {new_name} ====') |
516 |
| - |
517 |
| - fig, ax = post_utils.gen_plot_auto_merged_clusters( |
518 |
| - spike_waveforms, |
519 |
| - spike_times, |
520 |
| - split_predictions, |
521 |
| - sampling_rate, |
522 |
| - final_merge_sets, |
523 |
| - new_clust_names, |
524 |
| - ) |
525 |
| - |
526 |
| - # In case first unit is merged, we need to create the autosort_output_dir |
527 |
| - if not os.path.exists(autosort_output_dir): |
528 |
| - os.makedirs(autosort_output_dir) |
529 |
| - |
530 |
| - fig.savefig( |
531 |
| - os.path.join( |
532 |
| - autosort_output_dir, |
533 |
| - f'{electrode_num:02}_merged_units.png', |
534 |
| - ), |
535 |
| - bbox_inches = 'tight', |
536 |
| - ) |
537 |
| - plt.close(fig) |
538 |
| - |
539 |
| - # Update split_predictions |
540 |
| - for this_set, this_name in zip(final_merge_sets, new_clust_names): |
541 |
| - for this_cluster in this_set: |
542 |
| - split_predictions[split_predictions == this_cluster] = this_name |
543 |
| - |
544 |
| - ############################## |
545 |
| - |
546 |
| - # Take everything |
547 |
| - data = post_utils.prepare_data( |
548 |
| - np.arange(len(spike_waveforms)), |
549 |
| - pca_slices, |
550 |
| - energy, |
551 |
| - amplitudes, |
552 |
| - ) |
| 434 | + # Create processing parameters tuple |
| 435 | + process_params = ( |
| 436 | + max_autosort_clusters, |
| 437 | + auto_params, |
| 438 | + chi_square_alpha, |
| 439 | + count_threshold, |
| 440 | + sampling_rate, |
| 441 | + metadata_handler.dir_name, |
| 442 | + ) |
553 | 443 |
|
554 |
| - ( |
555 |
| - subcluster_inds, |
556 |
| - subcluster_waveforms, |
557 |
| - subcluster_prob, |
558 |
| - subcluster_times, |
559 |
| - mean_waveforms, |
560 |
| - std_waveforms, |
561 |
| - chi_out, |
562 |
| - fin_bool, |
563 |
| - fin_bool_dict, |
564 |
| - ) = \ |
565 |
| - post_utils.get_cluster_props( |
566 |
| - split_predictions, |
567 |
| - spike_waveforms, |
568 |
| - clf_prob, |
569 |
| - spike_times, |
570 |
| - chi_square_alpha, |
571 |
| - count_threshold, |
| 444 | + # Use multiprocessing to process electrodes in parallel |
| 445 | + n_cores = np.min( |
| 446 | + ( |
| 447 | + len(electrode_num_list), |
| 448 | + cpu_count() - 1, |
| 449 | + params_dict['max_parallel_cpu'] |
572 | 450 | )
|
573 |
| - |
574 |
| - |
575 |
| - ############################## |
576 |
| - # Generate plots for each subcluster |
577 |
| - ############################## |
578 |
| - |
579 |
| - post_utils.gen_autosort_plot( |
580 |
| - subcluster_prob, |
581 |
| - subcluster_waveforms, |
582 |
| - chi_out, |
583 |
| - mean_waveforms, |
584 |
| - std_waveforms, |
585 |
| - subcluster_times, |
586 |
| - fin_bool, |
587 |
| - np.unique(split_predictions), |
588 |
| - electrode_num, |
589 |
| - sampling_rate, |
590 |
| - autosort_output_dir, |
591 |
| - n_max_plot=5000, |
| 451 | + ) # Leave one core free |
| 452 | + print(f"Processing {len(electrode_num_list)} electrodes using {n_cores} cores") |
| 453 | + |
| 454 | + # Create partial function |
| 455 | + auto_process_partial = partial( |
| 456 | + post_utils.auto_process_electrode, |
| 457 | + process_params = process_params |
| 458 | + ) |
| 459 | + |
| 460 | + print(f'== Saving to {autosort_output_dir} ==') |
| 461 | + with Pool(n_cores) as pool: |
| 462 | + result = pool.starmap( |
| 463 | + auto_process_partial, |
| 464 | + [(electrode_num,) for electrode_num in electrode_num_list] |
592 | 465 | )
|
593 | 466 |
|
594 |
| - ############################################################ |
595 |
| - # Finally, save the unit to the HDF5 file |
596 |
| - ############################################################ |
597 |
| - |
598 |
| - ############################################################ |
599 |
| - # Subsetting this set of waveforms to include only the chosen split |
600 |
| - |
| 467 | + # This last part cannot be incorporated in auto_process_electrode as it |
| 468 | + # needs passing of classes (descriptor_handler, and sort_file_handler) to |
| 469 | + # the processes. |
| 470 | + # Get pickling errors when they are included |
| 471 | + # It is also a quick process so it doesn't need to be parallelized |
| 472 | + print('Writing sorted units to file...') |
| 473 | + for subcluster_waveforms, subcluster_times, fin_bool, electrode_num in result: |
601 | 474 | for this_sub in range(len(subcluster_waveforms)):
|
602 | 475 | if fin_bool[this_sub]:
|
603 | 476 | continue_bool, unit_name = this_descriptor_handler.save_unit(
|
604 |
| - subcluster_waveforms[this_sub], |
605 |
| - subcluster_times[this_sub], |
606 |
| - electrode_num, |
607 |
| - this_sort_file_handler, |
608 |
| - split_or_merge = None, |
609 |
| - override_ask = True, |
610 |
| - ) |
| 477 | + subcluster_waveforms[this_sub], |
| 478 | + subcluster_times[this_sub], |
| 479 | + electrode_num, |
| 480 | + this_sort_file_handler, |
| 481 | + split_or_merge=None, |
| 482 | + override_ask=True, |
| 483 | + ) |
611 | 484 | else:
|
612 | 485 | continue_bool = True
|
613 | 486 |
|
|
0 commit comments