Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 73 additions & 98 deletions utils/infer_rnn_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,111 +3,99 @@
"""

import argparse
parser = argparse.ArgumentParser(description = 'Infer firing rates using RNN')
parser.add_argument('data_dir', help = 'Path to data directory')
parser.add_argument('--override_config', action = 'store_true',
help = 'Override config file and use provided arguments'+\
'(default: %(default)s)')
parser.add_argument('--train_steps', type = int, default = 15000,
help = 'Number of training steps (default: %(default)s)')
# Hidden size of 8 was tested to be optimal across multiple datasets
parser.add_argument('--hidden_size', type = int, default = 8,
help = 'Hidden size of RNN (default: %(default)s)')
parser.add_argument('--bin_size', type = int, default = 25,
help = 'Bin size for binning spikes (default: %(default)s)')
parser.add_argument('--train_test_split', type = float, default = 0.75,
help = 'Fraction of data to use for training (default: %(default)s)')
parser.add_argument('--no_pca', action = 'store_true',
help = 'Do not use PCA for preprocessing (default: %(default)s)')
parser.add_argument('--retrain', action = 'store_true',
help = 'Force retraining of model. Will overwrite existing model'+\
' (default: %(default)s)')
parser.add_argument('--time_lims', type = int, nargs = 2, default = [1500, 4500],
help = 'Time limits inferred firing rates (default: %(default)s)')

import argparse
import json
from pprint import pprint
import os
args = parser.parse_args()
data_dir = args.data_dir
script_path = os.path.abspath(__file__)
blech_clust_path = os.path.dirname(os.path.dirname(script_path))

if args.override_config:
print('Overriding config file\nUsing provided arguments\n')
train_steps = args.train_steps
hidden_size = args.hidden_size
bin_size = args.bin_size
train_test_split = args.train_test_split
use_pca = not args.no_pca
time_lims = args.time_lims
else:
config_path = os.path.join(blech_clust_path, 'params', 'blechrnn_params.json')
if not os.path.exists(config_path):
raise FileNotFoundError(f'BlechRNN Config file not found @ {config_path}')
with open(config_path, 'r') as f:
config = json.load(f)
print('Using config file\n')
train_steps = config['train_steps']
hidden_size = config['hidden_size']
bin_size = config['bin_size']
train_test_split = config['train_test_split']
use_pca = config['use_pca']
time_lims = config['time_lims']

params_dict = dict(
train_steps = train_steps,
hidden_size = hidden_size,
bin_size = bin_size,
train_test_split = train_test_split,
use_pca = use_pca,
time_lims = time_lims,
)
pprint(params_dict)

##############################

# Check that blechRNN is on the Desktop, if so, add to path
import sys

import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy.stats import zscore
import tables
from pprint import pprint

blechRNN_path = os.path.join(os.path.expanduser('~'), 'Desktop', 'blechRNN')
if os.path.exists(blechRNN_path):
sys.path.append(blechRNN_path)
else:
raise FileNotFoundError('blechRNN not found on Desktop')
from src.model import autoencoderRNN
from src.train import train_model

# script_path = '/home/abuzarmahmood/Desktop/blech_clust/utils/infer_rnn_rates.py'
sys.path.append(blech_clust_path)
from utils.ephys_data import ephys_data
from utils.ephys_data import visualize as vz

# mse loss performs better than poisson loss
loss_name = 'mse'
def parse_arguments():
parser = argparse.ArgumentParser(description='Infer firing rates using RNN')
parser.add_argument('data_dir', help='Path to data directory')
parser.add_argument('--override_config', action='store_true', help='Override config file and use provided arguments (default: %(default)s)')
parser.add_argument('--train_steps', type=int, default=15000, help='Number of training steps (default: %(default)s)')
parser.add_argument('--hidden_size', type=int, default=8, help='Hidden size of RNN (default: %(default)s)')
parser.add_argument('--bin_size', type=int, default=25, help='Bin size for binning spikes (default: %(default)s)')
parser.add_argument('--train_test_split', type=float, default=0.75, help='Fraction of data to use for training (default: %(default)s)')
parser.add_argument('--no_pca', action='store_true', help='Do not use PCA for preprocessing (default: %(default)s)')
parser.add_argument('--retrain', action='store_true', help='Force retraining of model. Will overwrite existing model (default: %(default)s)')
parser.add_argument('--time_lims', type=int, nargs=2, default=[1500, 4500], help='Time limits inferred firing rates (default: %(default)s)')
return parser.parse_args()

args = parse_arguments()
data_dir = args.data_dir
script_path = os.path.abspath(__file__)
blech_clust_path = os.path.dirname(os.path.dirname(script_path))

output_path = os.path.join(data_dir, 'rnn_output')
artifacts_dir = os.path.join(output_path, 'artifacts')
plots_dir = os.path.join(output_path, 'plots')
def load_config(args, blech_clust_path):
if args.override_config:
print('Overriding config file\nUsing provided arguments\n')
return {
'train_steps': args.train_steps,
'hidden_size': args.hidden_size,
'bin_size': args.bin_size,
'train_test_split': args.train_test_split,
'use_pca': not args.no_pca,
'time_lims': args.time_lims
}
else:
config_path = os.path.join(blech_clust_path, 'params', 'blechrnn_params.json')
if not os.path.exists(config_path):
raise FileNotFoundError(f'BlechRNN Config file not found @ {config_path}')
with open(config_path, 'r') as f:
config = json.load(f)
print('Using config file\n')
return config

params_dict = load_config(args, blech_clust_path)
pprint(params_dict)

if not os.path.exists(output_path):
os.mkdir(output_path)
if not os.path.exists(artifacts_dir):
os.mkdir(artifacts_dir)
if not os.path.exists(plots_dir):
os.mkdir(plots_dir)
# Extract parameters from the config
hidden_size = params_dict['hidden_size']
time_lims = params_dict['time_lims']
bin_size = params_dict['bin_size']
use_pca = params_dict['use_pca']
train_test_split = params_dict['train_test_split']
train_steps = params_dict['train_steps']

def setup_paths(data_dir):
output_path = os.path.join(data_dir, 'rnn_output')
artifacts_dir = os.path.join(output_path, 'artifacts')
plots_dir = os.path.join(output_path, 'plots')

for path in [output_path, artifacts_dir, plots_dir]:
if not os.path.exists(path):
os.mkdir(path)

return output_path, artifacts_dir, plots_dir

def check_blechRNN_path():
blechRNN_path = os.path.join(os.path.expanduser('~'), 'Desktop', 'blechRNN')
if os.path.exists(blechRNN_path):
sys.path.append(blechRNN_path)
else:
raise FileNotFoundError('blechRNN not found on Desktop')

check_blechRNN_path()
sys.path.append(blech_clust_path)

loss_name = 'mse'
output_path, artifacts_dir, plots_dir = setup_paths(data_dir)

print(f'Processing data from {data_dir}')

data = ephys_data.ephys_data(data_dir)
data.get_spikes()

Expand All @@ -128,7 +116,6 @@
model_name = f'taste_{taste_ind}_hidden_{hidden_size}_loss_{loss_name}'
model_save_path = os.path.join(artifacts_dir, f'{model_name}.pt')

# taste_spikes = np.concatenate(spike_array)
# Cut taste_spikes to time limits
# Shape: (trials, neurons, time)
taste_spikes = taste_spikes[..., time_lims[0]:time_lims[1]]
Expand Down Expand Up @@ -161,7 +148,6 @@

# Perform standard scaling
scaler = StandardScaler()
# scaler = MinMaxScaler()
inputs_long = scaler.fit_transform(inputs_long)

if use_pca:
Expand All @@ -171,9 +157,6 @@
inputs_pca = pca_obj.fit_transform(inputs_long)
n_components = inputs_pca.shape[-1]

# # Scale the PCA outputs
# pca_scaler = StandardScaler()
# inputs_pca = pca_scaler.fit_transform(inputs_pca)

inputs_trial_pca = inputs_pca.reshape(inputs.shape[0], -1, n_components)

Expand Down Expand Up @@ -301,8 +284,6 @@
)
# Save artifacts and plots
torch.save(net, model_save_path)
# np.save(loss_path, loss)
# np.save(cross_val_loss_path, cross_val_loss)
with open(loss_path, 'w') as f:
json.dump(loss, f)
with open(cross_val_loss_path, 'w') as f:
Expand Down Expand Up @@ -349,12 +330,6 @@

# If pca was performed, first reverse PCA, then reverse pca standard scaling
if use_pca:
# # Reverse NMF scaling
# pred_firing_long = nmf_scaler.inverse_transform(pred_firing_long)
# pred_firing_long = pca_scaler.inverse_transform(pred_firing_long)

# Reverse NMF transform
# pred_firing_long = nmf_obj.inverse_transform(pred_firing_long)
pred_firing_long = pca_obj.inverse_transform(pred_firing_long)

# Reverse standard scaling
Expand Down
Loading