Skip to content

Commit 7acf21c

Browse files
Merge pull request #294 from katzlabbrandeis/42-small-datasets-for-automatic-testing
42 small datasets for automatic testing
2 parents 09aa677 + b858e31 commit 7acf21c

17 files changed

+1256
-669
lines changed

blech_clust.py

Lines changed: 37 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import pandas as pd
5959
import shutil
6060
import pylab as plt
61+
from ast import literal_eval
6162

6263
# Necessary blech_clust modules
6364
from utils import read_file
@@ -123,36 +124,6 @@ def initialize_groups(self):
123124
self.hf5.close()
124125
return continue_bool, reload_data_str
125126

126-
def get_digital_inputs(self, sampling_rate):
127-
"""Get digital input data from HDF5 file
128-
129-
Args:
130-
sampling_rate: Sampling rate of the data
131-
132-
Returns:
133-
numpy array of digital input data
134-
"""
135-
with tables.open_file(self.hdf5_name, 'r') as hf5:
136-
dig_in_list = [self._process_digital_input(x[:], sampling_rate)
137-
for x in hf5.root.digital_in]
138-
return np.stack(dig_in_list)
139-
140-
@staticmethod
141-
def _process_digital_input(data, sampling_rate):
142-
"""Process a single digital input channel
143-
144-
Args:
145-
data: Raw digital input data
146-
sampling_rate: Sampling rate
147-
148-
Returns:
149-
Processed digital input data
150-
"""
151-
len_dig_in = len(data)
152-
truncated = data[:(len_dig_in//sampling_rate)*sampling_rate]
153-
return np.reshape(truncated, (-1, sampling_rate)).sum(axis=-1)
154-
155-
156127
def generate_processing_scripts(dir_name, blech_clust_dir, electrode_layout_frame,
157128
all_electrodes, all_params_dict):
158129
"""Generate bash scripts for running single and parallel processing
@@ -243,15 +214,7 @@ def generate_processing_scripts(dir_name, blech_clust_dir, electrode_layout_fram
243214
info_dict = metadata_handler.info_dict
244215
file_list = metadata_handler.file_list
245216

246-
247-
# Get the type of data files (.rhd or .dat)
248-
if 'auxiliary.dat' in file_list:
249-
file_type = ['one file per signal type']
250-
elif sum(['rhd' in x for x in file_list]) > 1: # multiple .rhd files
251-
file_type = ['traditional']
252-
else:
253-
file_type = ['one file per channel']
254-
217+
file_type = info_dict['file_type']
255218

256219
# Create HDF5 handler and initialize groups
257220
hdf5_handler = HDF5Handler(dir_name, force_run)
@@ -286,31 +249,31 @@ def generate_processing_scripts(dir_name, blech_clust_dir, electrode_layout_fram
286249
file_lists = {
287250
'one file per signal type': {
288251
'electrodes': ['amplifier.dat'],
289-
'dig_in': ['digitalin.dat']
290252
},
291253
'one file per channel': {
292254
'electrodes': sorted([name for name in file_list if name.startswith('amp-')]),
293-
'dig_in': sorted([name for name in file_list if name.startswith('board-DI')])
294255
},
295256
'traditional': {
296257
'rhd': sorted([name for name in file_list if name.endswith('.rhd')])
297258
}
298259
}
299260

300-
if file_type[0] != 'traditional':
301-
electrodes_list = file_lists[file_type[0]]['electrodes']
302-
dig_in_file_list = file_lists[file_type[0]]['dig_in']
261+
# Get digin and laser info
262+
print('Getting trial markers from digital inputs')
263+
# dig_in_array = hdf5_handler.get_digital_inputs(sampling_rate)
264+
this_dig_handler = read_file.DigInHandler(dir_name, file_type)
265+
this_dig_handler.load_dig_in_frame()
266+
267+
print('DigIn data loaded')
268+
print(this_dig_handler.dig_in_frame.drop(columns='pulse_times'))
269+
270+
if file_type != 'traditional':
271+
electrodes_list = file_lists[file_type]['electrodes']
303272

304-
if file_type == ['one file per channel']:
273+
if file_type == 'one file per channel':
305274
print("\tOne file per CHANNEL Detected")
306-
# Read dig-in data
307-
# Pull out the digital input channels used,
308-
# and convert them to integers
309-
dig_in_int = [x.split('-')[-1].split('.')[0] for x in dig_in_file_list]
310-
dig_in_int = sorted([(x) for x in dig_in_int])
311-
elif file_type == ['one file per signal type']:
275+
elif file_type == 'one file per signal type':
312276
print("\tOne file per SIGNAL Detected")
313-
dig_in_int = np.arange(info_dict['dig_ins']['count'])
314277

315278
# Use info file for port list calculation
316279
info_file = np.fromfile(dir_name + '/info.rhd', dtype=np.dtype('float32'))
@@ -324,25 +287,22 @@ def generate_processing_scripts(dir_name, blech_clust_dir, electrode_layout_fram
324287
ports = info_dict['ports']
325288

326289
check_str = f'Amplifier files: {electrodes_list} \nSampling rate: {sampling_rate} Hz'\
327-
f'\nDigital input files: {dig_in_file_list} \n Ports : {ports} \n---------- \n \n'
290+
+ '\n Ports : {ports} \n---------- \n \n'
328291
print(check_str)
329292

330-
if file_type[0] == 'traditional':
293+
if file_type == 'traditional':
331294
print('Tranditional INTAN file format detected')
332-
rhd_file_list = file_lists[file_type[0]]['rhd']
295+
rhd_file_list = file_lists[file_type]['rhd']
333296
with open(rhd_file_list[0], 'rb') as f:
334297
header = read_header(f)
335298
# temp_file, data_present = importrhdutilities.load_file(file_list[0])
336299
amp_channel_ports = [x['port_prefix'] for x in header['amplifier_channels']]
337300
amp_channel_names = [x['native_channel_name'] for x in header['amplifier_channels']]
338-
dig_in_channels = [x['native_channel_name'] for x in header['board_dig_in_channels']]
339-
dig_in_int = sorted([x.split('-')[-1].split('.')[0] for x in dig_in_channels])
340301
sampling_rate = int(header['sample_rate'])
341302
ports = np.unique(amp_channel_ports)
342303

343304
check_str = f"""
344305
== Amplifier channels: \n{amp_channel_names}\n
345-
== Digital input channels: \n{dig_in_channels}\n
346306
== Sampling rate: {sampling_rate} Hz\n
347307
== Ports: {ports}\n
348308
"""
@@ -368,22 +328,21 @@ def generate_processing_scripts(dir_name, blech_clust_dir, electrode_layout_fram
368328

369329
# Read data files, and append to electrode arrays
370330
if reload_data_str in ['y', 'yes']:
371-
if file_type == ['one file per channel']:
372-
read_file.read_digins(hdf5_name, dig_in_int, dig_in_file_list)
331+
if file_type == 'one file per channel':
332+
# read_file.read_digins(hdf5_name, dig_in_int, dig_in_file_list)
373333
read_file.read_electrode_channels(hdf5_name, electrode_layout_frame)
374334
if len(emg_channels) > 0:
375335
read_file.read_emg_channels(hdf5_name, electrode_layout_frame)
376-
elif file_type == ['one file per signal type']:
377-
read_file.read_digins_single_file(hdf5_name, dig_in_int, dig_in_file_list)
336+
elif file_type == 'one file per signal type':
337+
# read_file.read_digins_single_file(hdf5_name, dig_in_int, dig_in_file_list)
378338
# This next line takes care of both electrodes and emgs
379339
read_file.read_electrode_emg_channels_single_file(
380340
hdf5_name, electrode_layout_frame, electrodes_list, num_recorded_samples, emg_channels)
381-
elif file_type == ['traditional']:
341+
elif file_type == 'traditional':
382342
read_file.read_traditional_intan(
383343
hdf5_name,
384344
rhd_file_list,
385345
electrode_layout_frame,
386-
dig_in_int,
387346
)
388347
else:
389348
print('Data already present...Not reloading data')
@@ -428,20 +387,18 @@ def generate_processing_scripts(dir_name, blech_clust_dir, electrode_layout_fram
428387
##############################
429388
# Also output a plot with digin and laser info
430389

431-
# Get digin and laser info
432-
print('Getting trial markers from digital inputs')
433-
dig_in_array = hdf5_handler.get_digital_inputs(sampling_rate)
434390
# Downsample to 10 seconds
435-
# dig_in_array = dig_in_array[:, :(dig_in_array.shape[1]//sampling_rate)*sampling_rate]
436-
# dig_in_array = np.reshape(dig_in_array, (len(dig_in_array), -1, sampling_rate)).sum(axis=2)
437-
dig_in_markers = np.where(dig_in_array > 0)
438-
del dig_in_array
391+
dig_in_pulses = this_dig_handler.dig_in_frame.pulse_times.values
392+
dig_in_pulses = [literal_eval(x) for x in dig_in_pulses]
393+
# Take starts of pulses
394+
dig_in_pulses = [[x[0] for x in this_dig] for this_dig in dig_in_pulses]
395+
dig_in_markers = [np.array(x) / sampling_rate for x in dig_in_pulses]
439396

440397
# Check if laser is present
441-
laser_dig_in = info_dict['laser_params']['dig_in']
398+
laser_dig_in = info_dict['laser_params']['dig_in_nums']
442399

443400
dig_in_map = {}
444-
for num, name in zip(info_dict['taste_params']['dig_ins'], info_dict['taste_params']['tastes']):
401+
for num, name in zip(info_dict['taste_params']['dig_in_nums'], info_dict['taste_params']['tastes']):
445402
dig_in_map[num] = name
446403
for num in laser_dig_in:
447404
dig_in_map[num] = 'laser'
@@ -450,12 +407,16 @@ def generate_processing_scripts(dir_name, blech_clust_dir, electrode_layout_fram
450407
dig_in_map = {num:dig_in_map[num] for num in sorted(list(dig_in_map.keys()))}
451408
dig_in_str = [f'{num}: {dig_in_map[num]}' for num in dig_in_map.keys()]
452409

453-
plt.scatter(dig_in_markers[1], dig_in_markers[0], s=50, marker='|', c='k')
410+
for i, vals in enumerate(dig_in_markers):
411+
plt.scatter(vals,
412+
np.ones_like(vals)*i,
413+
s=50, marker='|', c='k')
454414
# If there is a laser_dig_in, mark laser trials with axvline
455415
if len(laser_dig_in) > 0:
456-
laser_markers = np.where(dig_in_markers[0] == laser_dig_in)[0]
416+
# laser_markers = np.where(dig_in_markers[0] == laser_dig_in)[0]
417+
laser_markers = dig_in_markers[laser_dig_in[0]]
457418
for marker in laser_markers:
458-
plt.axvline(dig_in_markers[1][marker], c='yellow', lw=2, alpha = 0.5,
419+
plt.axvline(marker, c='yellow', lw=2, alpha = 0.5,
459420
zorder = -1)
460421
plt.yticks(np.array(list(dig_in_map.keys())), dig_in_str)
461422
plt.title('Digital Inputs')

blech_clust_post.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
DIR=$1
2-
BLECH_DIR=$HOME/Desktop/blech_clust
2+
SCRIPT_DIR=$0
3+
BLECH_DIR=$(dirname $SCRIPT_DIR)
34
echo === Make Arrays ===
45
python $BLECH_DIR/blech_make_arrays.py $DIR &&
56
echo === Quality Assurance ===

0 commit comments

Comments
 (0)