Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions blech_clean_slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
Info file and sorting table (csv) as kept
"""
# Import stuff!
"""
Module for cleaning and resetting the Blech dataset.

This module provides utilities to clean and reset data for analysis.
"""
import os
import shutil
import sys
Expand Down
5 changes: 5 additions & 0 deletions blech_common_avg_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

# Import stuff!
import tables
"""
Module for common average referencing in the Blech dataset.

This module includes functions for applying common average reference techniques.
"""
import numpy as np
import os
import easygui
Expand Down
5 changes: 5 additions & 0 deletions blech_exp_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
X Misc Notes
"""

"""
Module for handling experimental information in the Blech dataset.

This module includes functions to manage and retrieve experimental metadata.
"""
import json
import numpy as np
import os
Expand Down
5 changes: 5 additions & 0 deletions blech_make_arrays.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Import stuff!
"""
Module for creating arrays in the Blech dataset.

This module includes functions to generate and manipulate data arrays.
"""
import numpy as np
import tables
import sys
Expand Down
5 changes: 5 additions & 0 deletions blech_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
import numpy as np
import pylab as plt
from sklearn.mixture import GaussianMixture
"""
Module for post-processing Blech dataset.

This module provides functions for post-processing and analysis of data.
"""
import pandas as pd
import matplotlib
from glob import glob
Expand Down
5 changes: 5 additions & 0 deletions blech_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
import pylab as plt
import json
import sys
"""
Module for processing Blech dataset.

This module includes functions for data processing and manipulation.
"""
import numpy as np
import warnings

Expand Down
5 changes: 5 additions & 0 deletions blech_units_characteristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
- Dynamic population (ANOVA over time on PCA/other latents)
"""

"""
Module for analyzing unit characteristics in the Blech dataset.

This module provides functions to compute and analyze characteristics of units.
"""
import numpy as np
import tables
import easygui
Expand Down
5 changes: 5 additions & 0 deletions blech_units_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import easygui
import sys
import os
"""
Module for plotting units in the Blech dataset.

This module provides functions to visualize and analyze unit data.
"""
import matplotlib.pyplot as plt
import shutil
from tqdm import tqdm, trange
Expand Down
208 changes: 109 additions & 99 deletions utils/infer_rnn_rates.py
Original file line number Diff line number Diff line change
@@ -1,113 +1,136 @@
"""
Use Auto-regressive RNN to infer firing rates from a given data set.
This module uses an Auto-regressive RNN to infer firing rates from a given data set.

It includes functions for parsing arguments, loading configurations, setting up paths,
and checking the existence of necessary directories. The main functionality involves
processing neural spike data, training an RNN model, and visualizing the results.
"""

import argparse
parser = argparse.ArgumentParser(description = 'Infer firing rates using RNN')
parser.add_argument('data_dir', help = 'Path to data directory')
parser.add_argument('--override_config', action = 'store_true',
help = 'Override config file and use provided arguments'+\
'(default: %(default)s)')
parser.add_argument('--train_steps', type = int, default = 15000,
help = 'Number of training steps (default: %(default)s)')
# Hidden size of 8 was tested to be optimal across multiple datasets
parser.add_argument('--hidden_size', type = int, default = 8,
help = 'Hidden size of RNN (default: %(default)s)')
parser.add_argument('--bin_size', type = int, default = 25,
help = 'Bin size for binning spikes (default: %(default)s)')
parser.add_argument('--train_test_split', type = float, default = 0.75,
help = 'Fraction of data to use for training (default: %(default)s)')
parser.add_argument('--no_pca', action = 'store_true',
help = 'Do not use PCA for preprocessing (default: %(default)s)')
parser.add_argument('--retrain', action = 'store_true',
help = 'Force retraining of model. Will overwrite existing model'+\
' (default: %(default)s)')
parser.add_argument('--time_lims', type = int, nargs = 2, default = [1500, 4500],
help = 'Time limits inferred firing rates (default: %(default)s)')

import argparse
import json
from pprint import pprint
import os
args = parser.parse_args()
data_dir = args.data_dir
script_path = os.path.abspath(__file__)
blech_clust_path = os.path.dirname(os.path.dirname(script_path))

if args.override_config:
print('Overriding config file\nUsing provided arguments\n')
train_steps = args.train_steps
hidden_size = args.hidden_size
bin_size = args.bin_size
train_test_split = args.train_test_split
use_pca = not args.no_pca
time_lims = args.time_lims
else:
config_path = os.path.join(blech_clust_path, 'params', 'blechrnn_params.json')
if not os.path.exists(config_path):
raise FileNotFoundError(f'BlechRNN Config file not found @ {config_path}')
with open(config_path, 'r') as f:
config = json.load(f)
print('Using config file\n')
train_steps = config['train_steps']
hidden_size = config['hidden_size']
bin_size = config['bin_size']
train_test_split = config['train_test_split']
use_pca = config['use_pca']
time_lims = config['time_lims']

params_dict = dict(
train_steps = train_steps,
hidden_size = hidden_size,
bin_size = bin_size,
train_test_split = train_test_split,
use_pca = use_pca,
time_lims = time_lims,
)
pprint(params_dict)

##############################

# Check that blechRNN is on the Desktop, if so, add to path
import sys

import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy.stats import zscore
import tables
from pprint import pprint

blechRNN_path = os.path.join(os.path.expanduser('~'), 'Desktop', 'blechRNN')
if os.path.exists(blechRNN_path):
sys.path.append(blechRNN_path)
else:
raise FileNotFoundError('blechRNN not found on Desktop')
from src.model import autoencoderRNN
from src.train import train_model

# script_path = '/home/abuzarmahmood/Desktop/blech_clust/utils/infer_rnn_rates.py'
sys.path.append(blech_clust_path)
from utils.ephys_data import ephys_data
from utils.ephys_data import visualize as vz

# mse loss performs better than poisson loss
loss_name = 'mse'
def parse_arguments():
"""
Parse command-line arguments.

Returns:
argparse.Namespace: Parsed command-line arguments.
"""
parser = argparse.ArgumentParser(description='Infer firing rates using RNN')
parser.add_argument('data_dir', help='Path to data directory')
parser.add_argument('--override_config', action='store_true', help='Override config file and use provided arguments (default: %(default)s)')
parser.add_argument('--train_steps', type=int, default=15000, help='Number of training steps (default: %(default)s)')
parser.add_argument('--hidden_size', type=int, default=8, help='Hidden size of RNN (default: %(default)s)')
parser.add_argument('--bin_size', type=int, default=25, help='Bin size for binning spikes (default: %(default)s)')
parser.add_argument('--train_test_split', type=float, default=0.75, help='Fraction of data to use for training (default: %(default)s)')
parser.add_argument('--no_pca', action='store_true', help='Do not use PCA for preprocessing (default: %(default)s)')
parser.add_argument('--retrain', action='store_true', help='Force retraining of model. Will overwrite existing model (default: %(default)s)')
parser.add_argument('--time_lims', type=int, nargs=2, default=[1500, 4500], help='Time limits inferred firing rates (default: %(default)s)')
return parser.parse_args()

args = parse_arguments()
data_dir = args.data_dir
script_path = os.path.abspath(__file__)
blech_clust_path = os.path.dirname(os.path.dirname(script_path))

output_path = os.path.join(data_dir, 'rnn_output')
artifacts_dir = os.path.join(output_path, 'artifacts')
plots_dir = os.path.join(output_path, 'plots')
def load_config(args, blech_clust_path):
"""
Load configuration parameters from a JSON file or use command-line arguments.

Args:
args (argparse.Namespace): Parsed command-line arguments.
blech_clust_path (str): Path to the BlechRNN configuration directory.

Returns:
dict: Configuration parameters.
"""
if args.override_config:
print('Overriding config file\nUsing provided arguments\n')
return {
'train_steps': args.train_steps,
'hidden_size': args.hidden_size,
'bin_size': args.bin_size,
'train_test_split': args.train_test_split,
'use_pca': not args.no_pca,
'time_lims': args.time_lims
}
else:
config_path = os.path.join(blech_clust_path, 'params', 'blechrnn_params.json')
if not os.path.exists(config_path):
raise FileNotFoundError(f'BlechRNN Config file not found @ {config_path}')
with open(config_path, 'r') as f:
config = json.load(f)
print('Using config file\n')
return config

params_dict = load_config(args, blech_clust_path)
pprint(params_dict)

if not os.path.exists(output_path):
os.mkdir(output_path)
if not os.path.exists(artifacts_dir):
os.mkdir(artifacts_dir)
if not os.path.exists(plots_dir):
os.mkdir(plots_dir)
# Extract parameters from the config
hidden_size = params_dict['hidden_size']
time_lims = params_dict['time_lims']
bin_size = params_dict['bin_size']
use_pca = params_dict['use_pca']
train_test_split = params_dict['train_test_split']
train_steps = params_dict['train_steps']

def setup_paths(data_dir):
"""
Set up directories for output, artifacts, and plots.

Args:
data_dir (str): Path to the data directory.

Returns:
tuple: Paths for output, artifacts, and plots directories.
"""
output_path = os.path.join(data_dir, 'rnn_output')
artifacts_dir = os.path.join(output_path, 'artifacts')
plots_dir = os.path.join(output_path, 'plots')

for path in [output_path, artifacts_dir, plots_dir]:
if not os.path.exists(path):
os.mkdir(path)

return output_path, artifacts_dir, plots_dir

def check_blechRNN_path():
"""
Check if the blechRNN directory exists on the Desktop and append it to sys.path.

Raises:
FileNotFoundError: If the blechRNN directory is not found.
"""
blechRNN_path = os.path.join(os.path.expanduser('~'), 'Desktop', 'blechRNN')
if os.path.exists(blechRNN_path):
sys.path.append(blechRNN_path)
else:
raise FileNotFoundError('blechRNN not found on Desktop')

check_blechRNN_path()
sys.path.append(blech_clust_path)

loss_name = 'mse'
output_path, artifacts_dir, plots_dir = setup_paths(data_dir)

print(f'Processing data from {data_dir}')

data = ephys_data.ephys_data(data_dir)
data.get_spikes()

Expand All @@ -128,7 +151,6 @@
model_name = f'taste_{taste_ind}_hidden_{hidden_size}_loss_{loss_name}'
model_save_path = os.path.join(artifacts_dir, f'{model_name}.pt')

# taste_spikes = np.concatenate(spike_array)
# Cut taste_spikes to time limits
# Shape: (trials, neurons, time)
taste_spikes = taste_spikes[..., time_lims[0]:time_lims[1]]
Expand Down Expand Up @@ -161,7 +183,6 @@

# Perform standard scaling
scaler = StandardScaler()
# scaler = MinMaxScaler()
inputs_long = scaler.fit_transform(inputs_long)

if use_pca:
Expand All @@ -171,9 +192,6 @@
inputs_pca = pca_obj.fit_transform(inputs_long)
n_components = inputs_pca.shape[-1]

# # Scale the PCA outputs
# pca_scaler = StandardScaler()
# inputs_pca = pca_scaler.fit_transform(inputs_pca)

inputs_trial_pca = inputs_pca.reshape(inputs.shape[0], -1, n_components)

Expand Down Expand Up @@ -301,8 +319,6 @@
)
# Save artifacts and plots
torch.save(net, model_save_path)
# np.save(loss_path, loss)
# np.save(cross_val_loss_path, cross_val_loss)
with open(loss_path, 'w') as f:
json.dump(loss, f)
with open(cross_val_loss_path, 'w') as f:
Expand Down Expand Up @@ -349,12 +365,6 @@

# If pca was performed, first reverse PCA, then reverse pca standard scaling
if use_pca:
# # Reverse NMF scaling
# pred_firing_long = nmf_scaler.inverse_transform(pred_firing_long)
# pred_firing_long = pca_scaler.inverse_transform(pred_firing_long)

# Reverse NMF transform
# pred_firing_long = nmf_obj.inverse_transform(pred_firing_long)
pred_firing_long = pca_obj.inverse_transform(pred_firing_long)

# Reverse standard scaling
Expand Down