Skip to content

Commit 08a3855

Browse files
Merge pull request #259 from katzlabbrandeis/214-improved-drift-plot1d-visualization
214 improved drift plot1d visualization
2 parents 21ce58b + b779d10 commit 08a3855

File tree

7 files changed

+71
-25
lines changed

7 files changed

+71
-25
lines changed

blech_autosort.sh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ echo === Quality Assurance ===
1313
bash blech_run_QA.sh $DIR &&
1414
echo === Units Plot ===
1515
python blech_units_plot.py $DIR &&
16-
echo === Make PSTHs ===
17-
python blech_make_psth.py $DIR &&
18-
echo === Palatability Identity Setup ===
19-
python blech_palatability_identity_setup.py $DIR &&
20-
echo === Overlay PSTHs ===
21-
python blech_overlay_psth.py $DIR &&
16+
echo === Get unit characteristics ===
17+
python blech_units_characteristics.py $DIR &&
2218
echo === Done ===

blech_clust_post.sh

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
DIR=$1
22
BLECH_DIR=$HOME/Desktop/blech_clust
3-
echo Running Units Plot
4-
python $BLECH_DIR/blech_units_plot.py $DIR;
5-
echo Running Make Arrays
6-
python $BLECH_DIR/blech_make_arrays.py $DIR;
7-
echo Running Quality Assurance
8-
bash $BLECH_DIR/blech_run_QA.sh $DIR;
9-
echo Running Make PSTHs
10-
python $BLECH_DIR/blech_make_psth.py $DIR;
11-
echo Running Palatability Identity Setup
12-
python $BLECH_DIR/blech_palatability_identity_setup.py $DIR;
13-
echo Running Overlay PSTH
14-
python $BLECH_DIR/blech_overlay_psth.py $DIR;
3+
echo === Make Arrays ===
4+
python $BLECH_DIR/blech_make_arrays.py $DIR &&
5+
echo === Quality Assurance ===
6+
bash $BLECH_DIR/blech_run_QA.sh $DIR &&
7+
echo === Units Plot ===
8+
python $BLECH_DIR/blech_units_plot.py $DIR &&
9+
echo === Get unit characteristics ===
10+
python $BLECH_DIR/blech_units_characteristics.py $DIR &&
11+
echo === Done ===

blech_post_process.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@
135135

136136

137137
# Delete the raw node, if it exists in the hdf5 file, to cut down on file size
138-
if args.keep_raw == 'False':
138+
if args.keep_raw == False:
139139
repacked_bool = post_utils.delete_raw_recordings(hdf5_name)
140140
else:
141141
repacked_bool = False
@@ -415,6 +415,9 @@
415415
'autosort_outputs'
416416
)
417417

418+
# Create output directory if needed
419+
if not os.path.exists(autosort_output_dir):
420+
os.makedirs(autosort_output_dir)
418421

419422
# Since this needs classifier output to run, check if it exists
420423
clf_list = glob('./spike_waveforms/electrode*/clf_prob.npy')

pipeline_testing/prefect_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,9 @@ def run_spike_test():
352352
select_clusters(data_dir)
353353
post_process(data_dir)
354354

355+
make_arrays(data_dir)
355356
quality_assurance(data_dir)
356357
units_plot(data_dir)
357-
make_arrays(data_dir)
358358
units_characteristics(data_dir)
359359

360360
@flow(log_prints=True)

utils/blech_post_process_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,9 +1449,6 @@ def auto_process_electrode(
14491449
new_clust_names,
14501450
)
14511451

1452-
# Create output directory if needed
1453-
if not os.path.exists(autosort_output_dir):
1454-
os.makedirs(autosort_output_dir)
14551452

14561453
fig.savefig(
14571454
os.path.join(

utils/ephys_data/ephys_data.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,21 @@ def calc_firing_func(data):
517517
def get_firing_rates(self):
518518
"""
519519
Converts spikes to firing rates
520+
521+
Requires:
522+
- spikes
523+
- firing_rate_params
524+
525+
Generates:
526+
- firing_list : list of firing rates for each taste
527+
- each element is a 3D array of shape (n_trials, n_neurons, n_timepoints)
528+
- firing_array : 4D array of firing rates
529+
- normalized_firing : 4D array of normalized firing rates
530+
- all_firing_array : 3D array of all firing rates
531+
- all_normalized_firing : 3D array of all normalized firing rates
520532
"""
521533

522-
if self.spikes is None:
534+
if 'spikes' not in dir(self):
523535
# raise Exception('Run method "get_spikes" first')
524536
print('No spikes found, getting spikes ...')
525537
self.get_spikes()

utils/qa_utils/drift_check.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@
2121
import pingouin as pg
2222
import seaborn as sns
2323
import glob
24+
from sklearn.decomposition import PCA
25+
from umap import UMAP
2426
# Get script path
2527
script_path = os.path.realpath(__file__)
2628
script_dir_path = os.path.dirname(script_path)
2729
blech_path = os.path.dirname(os.path.dirname(script_dir_path))
2830
sys.path.append(blech_path)
2931
from utils.blech_utils import imp_metadata, pipeline_graph_check
32+
from utils.ephys_data import ephys_data
3033

3134
def get_spike_trains(hf5_path):
3235
"""
@@ -350,11 +353,49 @@ def array_to_df(array, dim_names):
350353
print('Post-stimulus limits: ' + str(stim_t) + ' to ' + str(stim_t+trial_duration) + ' ms', file=f)
351354
print('Trial Bin Count: ' + str(n_trial_bins), file=f)
352355
print('alpha: ' + str(alpha), file=f)
353-
print('\n', file=f)
356+
#print('\n', file=f)
354357
print(out_rows, file=f)
355358
print('\n', file=f)
356359
print('=== End Post-stimulus Drift Warning ===', file=f)
357360
print('\n', file=f)
358361

362+
############################################################
363+
# Perform PCA on firing rates across trials
364+
############################################################
365+
dat = ephys_data.ephys_data(dir_name)
366+
dat.get_firing_rates()
367+
# each element is a 3D array of shape (n_trials, n_neurons, n_timepoints)
368+
firing_list = dat.firing_list
369+
# Normalize for each neuron
370+
n_neurons = firing_list[0].shape[1]
371+
norm_firing_list = []
372+
for i in range(len(firing_list)):
373+
this_firing = firing_list[i]
374+
norm_firing = np.zeros_like(this_firing)
375+
for j in range(n_neurons):
376+
norm_firing[:,j,:] = zscore(this_firing[:,j,:], axis=None)
377+
norm_firing_list.append(norm_firing)
378+
379+
# shape: (n_trials, n_neurons * n_timepoints)
380+
long_firing_list = [x.reshape(x.shape[0],-1) for x in norm_firing_list]
381+
382+
# Perform PCA on long_firing_list
383+
pca_firing_list = [PCA(n_components=1, whiten=True).fit_transform(x) for x in long_firing_list]
384+
umap_firing_list = [UMAP(n_components=1).fit_transform(x) for x in long_firing_list]
385+
umap_zscore = [zscore(x, axis=None) for x in umap_firing_list]
386+
387+
# Plot PCA and UMAP results
388+
fig, ax = plt.subplots(2, 1, figsize=(5, 5), sharex=True)
389+
for i in range(len(pca_firing_list)):
390+
ax[0].plot(pca_firing_list[i], alpha=0.7)
391+
ax[1].plot(umap_zscore[i], alpha=0.7)
392+
ax[0].set_title('PCA')
393+
ax[1].set_title('UMAP')
394+
ax[-1].set_xlabel('Trial num')
395+
fig.suptitle('PCA and UMAP of Firing Rates \n' + basename)
396+
plt.tight_layout()
397+
plt.savefig(os.path.join(output_dir, 'pca_umap_firing_rates.png'))
398+
plt.close()
399+
359400
# Write successful execution to log
360401
this_pipeline_check.write_to_log(script_path, 'completed')

0 commit comments

Comments
 (0)