diff --git a/blech_clean_slate.py b/blech_clean_slate.py index ea2ab9bc..a7a88e2f 100644 --- a/blech_clean_slate.py +++ b/blech_clean_slate.py @@ -4,6 +4,11 @@ Info file and sorting table (csv) as kept """ # Import stuff! +""" +Module for cleaning and resetting the Blech dataset. + +This module provides utilities to clean and reset data for analysis. +""" import os import shutil import sys diff --git a/blech_common_avg_reference.py b/blech_common_avg_reference.py index 4a0d6a97..62d67a5b 100644 --- a/blech_common_avg_reference.py +++ b/blech_common_avg_reference.py @@ -4,6 +4,11 @@ # Import stuff! import tables +""" +Module for common average referencing in the Blech dataset. + +This module includes functions for applying common average reference techniques. +""" import numpy as np import os import easygui diff --git a/blech_exp_info.py b/blech_exp_info.py index 53564383..2d2b1cc8 100644 --- a/blech_exp_info.py +++ b/blech_exp_info.py @@ -17,6 +17,11 @@ X Misc Notes """ +""" +Module for handling experimental information in the Blech dataset. + +This module includes functions to manage and retrieve experimental metadata. +""" import json import numpy as np import os diff --git a/blech_make_arrays.py b/blech_make_arrays.py index b5dd711b..1c8b5f66 100644 --- a/blech_make_arrays.py +++ b/blech_make_arrays.py @@ -1,4 +1,9 @@ # Import stuff! +""" +Module for creating arrays in the Blech dataset. + +This module includes functions to generate and manipulate data arrays. +""" import numpy as np import tables import sys diff --git a/blech_post_process.py b/blech_post_process.py index 690a1995..2fd21af0 100644 --- a/blech_post_process.py +++ b/blech_post_process.py @@ -28,6 +28,11 @@ import numpy as np import pylab as plt from sklearn.mixture import GaussianMixture +""" +Module for post-processing Blech dataset. + +This module provides functions for post-processing and analysis of data. +""" import pandas as pd import matplotlib from glob import glob diff --git a/blech_process.py b/blech_process.py index 1e6cc073..a14248b5 100644 --- a/blech_process.py +++ b/blech_process.py @@ -35,6 +35,11 @@ import pylab as plt import json import sys +""" +Module for processing Blech dataset. + +This module includes functions for data processing and manipulation. +""" import numpy as np import warnings diff --git a/blech_units_characteristics.py b/blech_units_characteristics.py index 51239ccd..ca87c2fa 100644 --- a/blech_units_characteristics.py +++ b/blech_units_characteristics.py @@ -11,6 +11,11 @@ - Dynamic population (ANOVA over time on PCA/other latents) """ +""" +Module for analyzing unit characteristics in the Blech dataset. + +This module provides functions to compute and analyze characteristics of units. +""" import numpy as np import tables import easygui diff --git a/blech_units_plot.py b/blech_units_plot.py index 741398dd..87436748 100644 --- a/blech_units_plot.py +++ b/blech_units_plot.py @@ -4,6 +4,11 @@ import easygui import sys import os +""" +Module for plotting units in the Blech dataset. + +This module provides functions to visualize and analyze unit data. +""" import matplotlib.pyplot as plt import shutil from tqdm import tqdm, trange diff --git a/utils/infer_rnn_rates.py b/utils/infer_rnn_rates.py index 67728508..92922242 100644 --- a/utils/infer_rnn_rates.py +++ b/utils/infer_rnn_rates.py @@ -1,113 +1,136 @@ """ -Use Auto-regressive RNN to infer firing rates from a given data set. +This module uses an Auto-regressive RNN to infer firing rates from a given data set. + +It includes functions for parsing arguments, loading configurations, setting up paths, +and checking the existence of necessary directories. The main functionality involves +processing neural spike data, training an RNN model, and visualizing the results. """ import argparse -parser = argparse.ArgumentParser(description = 'Infer firing rates using RNN') -parser.add_argument('data_dir', help = 'Path to data directory') -parser.add_argument('--override_config', action = 'store_true', - help = 'Override config file and use provided arguments'+\ - '(default: %(default)s)') -parser.add_argument('--train_steps', type = int, default = 15000, - help = 'Number of training steps (default: %(default)s)') -# Hidden size of 8 was tested to be optimal across multiple datasets -parser.add_argument('--hidden_size', type = int, default = 8, - help = 'Hidden size of RNN (default: %(default)s)') -parser.add_argument('--bin_size', type = int, default = 25, - help = 'Bin size for binning spikes (default: %(default)s)') -parser.add_argument('--train_test_split', type = float, default = 0.75, - help = 'Fraction of data to use for training (default: %(default)s)') -parser.add_argument('--no_pca', action = 'store_true', - help = 'Do not use PCA for preprocessing (default: %(default)s)') -parser.add_argument('--retrain', action = 'store_true', - help = 'Force retraining of model. Will overwrite existing model'+\ - ' (default: %(default)s)') -parser.add_argument('--time_lims', type = int, nargs = 2, default = [1500, 4500], - help = 'Time limits inferred firing rates (default: %(default)s)') - +import argparse import json -from pprint import pprint import os -args = parser.parse_args() -data_dir = args.data_dir -script_path = os.path.abspath(__file__) -blech_clust_path = os.path.dirname(os.path.dirname(script_path)) - -if args.override_config: - print('Overriding config file\nUsing provided arguments\n') - train_steps = args.train_steps - hidden_size = args.hidden_size - bin_size = args.bin_size - train_test_split = args.train_test_split - use_pca = not args.no_pca - time_lims = args.time_lims -else: - config_path = os.path.join(blech_clust_path, 'params', 'blechrnn_params.json') - if not os.path.exists(config_path): - raise FileNotFoundError(f'BlechRNN Config file not found @ {config_path}') - with open(config_path, 'r') as f: - config = json.load(f) - print('Using config file\n') - train_steps = config['train_steps'] - hidden_size = config['hidden_size'] - bin_size = config['bin_size'] - train_test_split = config['train_test_split'] - use_pca = config['use_pca'] - time_lims = config['time_lims'] - -params_dict = dict( - train_steps = train_steps, - hidden_size = hidden_size, - bin_size = bin_size, - train_test_split = train_test_split, - use_pca = use_pca, - time_lims = time_lims, - ) -pprint(params_dict) - -############################## - -# Check that blechRNN is on the Desktop, if so, add to path import sys + import numpy as np -from sklearn.preprocessing import StandardScaler -from sklearn.decomposition import PCA import torch import matplotlib.pyplot as plt +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler from scipy.stats import zscore import tables +from pprint import pprint -blechRNN_path = os.path.join(os.path.expanduser('~'), 'Desktop', 'blechRNN') -if os.path.exists(blechRNN_path): - sys.path.append(blechRNN_path) -else: - raise FileNotFoundError('blechRNN not found on Desktop') from src.model import autoencoderRNN from src.train import train_model - -# script_path = '/home/abuzarmahmood/Desktop/blech_clust/utils/infer_rnn_rates.py' -sys.path.append(blech_clust_path) from utils.ephys_data import ephys_data from utils.ephys_data import visualize as vz -# mse loss performs better than poisson loss -loss_name = 'mse' +def parse_arguments(): + """ + Parse command-line arguments. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ + parser = argparse.ArgumentParser(description='Infer firing rates using RNN') + parser.add_argument('data_dir', help='Path to data directory') + parser.add_argument('--override_config', action='store_true', help='Override config file and use provided arguments (default: %(default)s)') + parser.add_argument('--train_steps', type=int, default=15000, help='Number of training steps (default: %(default)s)') + parser.add_argument('--hidden_size', type=int, default=8, help='Hidden size of RNN (default: %(default)s)') + parser.add_argument('--bin_size', type=int, default=25, help='Bin size for binning spikes (default: %(default)s)') + parser.add_argument('--train_test_split', type=float, default=0.75, help='Fraction of data to use for training (default: %(default)s)') + parser.add_argument('--no_pca', action='store_true', help='Do not use PCA for preprocessing (default: %(default)s)') + parser.add_argument('--retrain', action='store_true', help='Force retraining of model. Will overwrite existing model (default: %(default)s)') + parser.add_argument('--time_lims', type=int, nargs=2, default=[1500, 4500], help='Time limits inferred firing rates (default: %(default)s)') + return parser.parse_args() + +args = parse_arguments() +data_dir = args.data_dir +script_path = os.path.abspath(__file__) +blech_clust_path = os.path.dirname(os.path.dirname(script_path)) -output_path = os.path.join(data_dir, 'rnn_output') -artifacts_dir = os.path.join(output_path, 'artifacts') -plots_dir = os.path.join(output_path, 'plots') +def load_config(args, blech_clust_path): + """ + Load configuration parameters from a JSON file or use command-line arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + blech_clust_path (str): Path to the BlechRNN configuration directory. + + Returns: + dict: Configuration parameters. + """ + if args.override_config: + print('Overriding config file\nUsing provided arguments\n') + return { + 'train_steps': args.train_steps, + 'hidden_size': args.hidden_size, + 'bin_size': args.bin_size, + 'train_test_split': args.train_test_split, + 'use_pca': not args.no_pca, + 'time_lims': args.time_lims + } + else: + config_path = os.path.join(blech_clust_path, 'params', 'blechrnn_params.json') + if not os.path.exists(config_path): + raise FileNotFoundError(f'BlechRNN Config file not found @ {config_path}') + with open(config_path, 'r') as f: + config = json.load(f) + print('Using config file\n') + return config + +params_dict = load_config(args, blech_clust_path) +pprint(params_dict) -if not os.path.exists(output_path): - os.mkdir(output_path) -if not os.path.exists(artifacts_dir): - os.mkdir(artifacts_dir) -if not os.path.exists(plots_dir): - os.mkdir(plots_dir) +# Extract parameters from the config +hidden_size = params_dict['hidden_size'] +time_lims = params_dict['time_lims'] +bin_size = params_dict['bin_size'] +use_pca = params_dict['use_pca'] +train_test_split = params_dict['train_test_split'] +train_steps = params_dict['train_steps'] + +def setup_paths(data_dir): + """ + Set up directories for output, artifacts, and plots. + + Args: + data_dir (str): Path to the data directory. + + Returns: + tuple: Paths for output, artifacts, and plots directories. + """ + output_path = os.path.join(data_dir, 'rnn_output') + artifacts_dir = os.path.join(output_path, 'artifacts') + plots_dir = os.path.join(output_path, 'plots') + + for path in [output_path, artifacts_dir, plots_dir]: + if not os.path.exists(path): + os.mkdir(path) + + return output_path, artifacts_dir, plots_dir + +def check_blechRNN_path(): + """ + Check if the blechRNN directory exists on the Desktop and append it to sys.path. + + Raises: + FileNotFoundError: If the blechRNN directory is not found. + """ + blechRNN_path = os.path.join(os.path.expanduser('~'), 'Desktop', 'blechRNN') + if os.path.exists(blechRNN_path): + sys.path.append(blechRNN_path) + else: + raise FileNotFoundError('blechRNN not found on Desktop') +check_blechRNN_path() +sys.path.append(blech_clust_path) +loss_name = 'mse' +output_path, artifacts_dir, plots_dir = setup_paths(data_dir) print(f'Processing data from {data_dir}') - data = ephys_data.ephys_data(data_dir) data.get_spikes() @@ -128,7 +151,6 @@ model_name = f'taste_{taste_ind}_hidden_{hidden_size}_loss_{loss_name}' model_save_path = os.path.join(artifacts_dir, f'{model_name}.pt') - # taste_spikes = np.concatenate(spike_array) # Cut taste_spikes to time limits # Shape: (trials, neurons, time) taste_spikes = taste_spikes[..., time_lims[0]:time_lims[1]] @@ -161,7 +183,6 @@ # Perform standard scaling scaler = StandardScaler() - # scaler = MinMaxScaler() inputs_long = scaler.fit_transform(inputs_long) if use_pca: @@ -171,9 +192,6 @@ inputs_pca = pca_obj.fit_transform(inputs_long) n_components = inputs_pca.shape[-1] - # # Scale the PCA outputs - # pca_scaler = StandardScaler() - # inputs_pca = pca_scaler.fit_transform(inputs_pca) inputs_trial_pca = inputs_pca.reshape(inputs.shape[0], -1, n_components) @@ -301,8 +319,6 @@ ) # Save artifacts and plots torch.save(net, model_save_path) - # np.save(loss_path, loss) - # np.save(cross_val_loss_path, cross_val_loss) with open(loss_path, 'w') as f: json.dump(loss, f) with open(cross_val_loss_path, 'w') as f: @@ -349,12 +365,6 @@ # If pca was performed, first reverse PCA, then reverse pca standard scaling if use_pca: - # # Reverse NMF scaling - # pred_firing_long = nmf_scaler.inverse_transform(pred_firing_long) - # pred_firing_long = pca_scaler.inverse_transform(pred_firing_long) - - # Reverse NMF transform - # pred_firing_long = nmf_obj.inverse_transform(pred_firing_long) pred_firing_long = pca_obj.inverse_transform(pred_firing_long) # Reverse standard scaling