Skip to content
25 changes: 16 additions & 9 deletions blech_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,16 @@


# Make the sorted_units group in the hdf5 file if it doesn't already exist
if not '/sorted_units' in hf5:
if '/sorted_units' in hf5:
overwrite_hf5 = input('Saved units detected; remove them? (y/[n]): ') or 'n'
if overwrite_hf5.lower() == 'y':
hf5.remove_node('/sorted_units', recursive=True)
hf5.create_group('/', 'sorted_units')
print('==== Cleared saved units. ====\n')
else:
hf5.create_group('/', 'sorted_units')



############################################################
# Main Processing Loop
############################################################
Expand Down Expand Up @@ -158,7 +165,7 @@
##############################
# Get clustering parameters from user
continue_bool, n_clusters, n_iter, thresh, n_restarts = \
post_utils.get_clustering_params()
post_utils.get_clustering_params(this_sort_file_handler)
if not continue_bool: continue

# Make data array to be put through the GMM - 5 components:
Expand Down Expand Up @@ -332,8 +339,8 @@
hf5.flush()


print('==== {} Complete ===\n'.format(unit_name))
print('==== Iteration Ended ===\n')
print('==== {} Complete ====\n'.format(unit_name))
print('==== Iteration Ended ====\n')

# Run auto-processing only if clustering was ALSO automatic
# As currently, this does not have functionality to determine
Expand Down Expand Up @@ -376,7 +383,7 @@
# "good" spikes

# Print out selections
print(f'=== Processing Electrode {electrode_num:02} ===')
print(f'==== Processing Electrode {electrode_num:02} ====\n')

# Load data from the chosen electrode
# We can pick any soluation, but need to know what
Expand Down Expand Up @@ -449,9 +456,9 @@
# Rename both to max_clusters

# Print out merge sets
print(f'=== Merging {len(final_merge_sets)} Clusters ===')
print(f'==== Merging {len(final_merge_sets)} Clusters ====\n')
for this_merge_set, new_name in zip(final_merge_sets, new_clust_names):
print(f'==== {this_merge_set} => {new_name} ====')
print(f'==== {this_merge_set} => {new_name} ====\n')

fig, ax = post_utils.gen_plot_auto_merged_clusters(
spike_waveforms,
Expand Down Expand Up @@ -570,6 +577,6 @@


print()
print('== Post-processing exiting ==')
print('==== Post-processing exiting ====\n')
# Close the hdf5 file
hf5.close()
18 changes: 13 additions & 5 deletions utils/blech_post_process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(self, sort_file_path):
sort_table.sort_values(
['len_cluster','Split'],
ascending=False, inplace=True)
if 'level_0' in sort_table.columns:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

level_0 is added by the .reset_index() function: I guess when pandas resets the index, it creates that column as a record of the new index.

The issue is, that it then permanently writes that column to the output csv file. If you're using the csv as the input for cell sorting, you create it manually, yes, but then on the first run through post_process, it writes in that column. If you run the same CSV through post-process AGAIN, it tries to write level_0 in again, and can't write it overtop of the existing level_0, which throws an error.

So this is a little bit of a niche problem; you need to be running post_process with a csv input, and then re-run the same csv to create the error. I mainly ran into the issue as a result of testing my code as I familiarize myself with the pipeline. That said, I could imagine scenarios where you accidentally added the wrong cell to the spreadsheet, or you're not happy with a split/merge outcome, or something similar, and want to run post_process again, and if you're using the csv input (which I like a lot, being very pro-automation), then this saves you from needing to manually go in and delete level_0 from the spreadsheet.

sort_table.drop(columns=['level_0'], inplace=True)
sort_table.reset_index(inplace=True)
sort_table['unit_saved'] = False
self.sort_table = sort_table
Expand Down Expand Up @@ -160,7 +162,7 @@ def gen_select_cluster_plot(electrode_num, num_clusters, clusters):
ax[cluster_num,0].axis('off')
ax[cluster_num, 1].imshow(waveform_plot,aspect='auto');
ax[cluster_num,1].axis('off')
fig.suptitle('Are these the neurons you want to select?')
fig.suptitle('Are these the neurons you want to select? Press q to exit plot')
fig.tight_layout()
plt.show()

Expand Down Expand Up @@ -216,12 +218,18 @@ def generate_cluster_plots(
plt.tight_layout()
plt.show()

def get_clustering_params():
def get_clustering_params(this_sort_file_handler):
"""
Ask user for clustering parameters
"""
# Get clustering parameters from user
n_clusters = int(input('Number of clusters (default=5): ') or "5")
if (this_sort_file_handler.sort_table is not None):
dat_row = this_sort_file_handler.current_row
split_val = int(re.findall('[0-9]+', str(dat_row.Split))[0])
n_clusters = int(input(f'Number of clusters (sort file={split_val})') or split_val)
else:
n_clusters = int(input('Number of clusters (default=5): ') or "5")

fields = [
'Max iterations',
'Convergence criterion',
Expand Down Expand Up @@ -982,15 +990,15 @@ def ask_split(self):
check_func = lambda x: x in ['y','n'],
fail_response = 'Please enter (y/n)')
if continue_bool:
if msg == 'y':
if msg == 'y':
self.split = True
elif msg == 'n':
self.split = False

def check_split_sort_file(self):
if self.this_sort_file_handler.sort_table is not None:
dat_row = self.this_sort_file_handler.current_row
if len(dat_row.Split) > 0:
if not (dat_row.Split == ''):
self.split=True
else:
self.split=False
Expand Down