Skip to content

Commit e1eb551

Browse files
authored
Merge pull request #844 from int-brain-lab/aggregate_training
Aggregate training
2 parents d26b34e + f095bd7 commit e1eb551

File tree

4 files changed

+51
-41
lines changed

4 files changed

+51
-41
lines changed

ibllib/ephys/ephysqc.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -205,57 +205,65 @@ def run(self, update: bool = False, overwrite: bool = True, stream: bool = None,
205205
return qc_files
206206

207207

208-
def rmsmap(sglx):
208+
def rmsmap(sglx, spectra=True, nmod=1):
209209
"""
210210
Computes RMS map in time domain and spectra for each channel of Neuropixel probe
211211
212212
:param sglx: Open spikeglx reader
213+
:param spectra: Whether to compute the spectra
214+
:param nmod: take every nmod windows, in cases where we don't want to compute over the whole signal
213215
:return: a dictionary with amplitudes in channeltime space, channelfrequency space, time
214216
and frequency scales
215217
"""
216218
rms_win_length_samples = 2 ** np.ceil(np.log2(sglx.fs * RMS_WIN_LENGTH_SECS))
217219
# the window generator will generates window indices
218220
wingen = utils.WindowGenerator(ns=sglx.ns, nswin=rms_win_length_samples, overlap=0)
221+
nwin = np.ceil(wingen.nwin / nmod).astype(int)
219222
# pre-allocate output dictionary of numpy arrays
220-
win = {'TRMS': np.zeros((wingen.nwin, sglx.nc)),
221-
'nsamples': np.zeros((wingen.nwin,)),
223+
win = {'TRMS': np.zeros((nwin, sglx.nc)),
224+
'nsamples': np.zeros((nwin,)),
222225
'fscale': fourier.fscale(WELCH_WIN_LENGTH_SAMPLES, 1 / sglx.fs, one_sided=True),
223-
'tscale': wingen.tscale(fs=sglx.fs)}
226+
'tscale': wingen.tscale(fs=sglx.fs)[::nmod]}
224227
win['spectral_density'] = np.zeros((len(win['fscale']), sglx.nc))
225228
# loop through the whole session
226229
with tqdm(total=wingen.nwin) as pbar:
227-
for first, last in wingen.firstlast:
230+
for iwindow, (first, last) in enumerate(wingen.firstlast):
231+
if np.mod(iwindow, nmod) != 0:
232+
continue
233+
228234
D = sglx.read_samples(first_sample=first, last_sample=last)[0].transpose()
229235
# remove low frequency noise below 1 Hz
230236
D = fourier.hp(D, 1 / sglx.fs, [0, 1])
231-
iw = wingen.iw
237+
iw = np.floor(wingen.iw / nmod).astype(int)
232238
win['TRMS'][iw, :] = utils.rms(D)
233239
win['nsamples'][iw] = D.shape[1]
234-
# the last window may be smaller than what is needed for welch
235-
if last - first < WELCH_WIN_LENGTH_SAMPLES:
236-
continue
237-
# compute a smoothed spectrum using welch method
238-
_, w = signal.welch(
239-
D, fs=sglx.fs, window='hann', nperseg=WELCH_WIN_LENGTH_SAMPLES,
240-
detrend='constant', return_onesided=True, scaling='density', axis=-1
241-
)
242-
win['spectral_density'] += w.T
240+
if spectra:
241+
# the last window may be smaller than what is needed for welch
242+
if last - first < WELCH_WIN_LENGTH_SAMPLES:
243+
continue
244+
# compute a smoothed spectrum using welch method
245+
_, w = signal.welch(
246+
D, fs=sglx.fs, window='hann', nperseg=WELCH_WIN_LENGTH_SAMPLES,
247+
detrend='constant', return_onesided=True, scaling='density', axis=-1
248+
)
249+
win['spectral_density'] += w.T
243250
# print at least every 20 windows
244251
if (iw % min(20, max(int(np.floor(wingen.nwin / 75)), 1))) == 0:
245252
pbar.update(iw)
246253
sglx.close()
247254
return win
248255

249256

250-
def extract_rmsmap(sglx, out_folder=None, overwrite=False):
257+
def extract_rmsmap(sglx, out_folder=None, overwrite=False, spectra=True, nmod=1):
251258
"""
252259
Wrapper for rmsmap that outputs _ibl_ephysRmsMap and _ibl_ephysSpectra ALF files
253260
254261
:param sglx: Open spikeglx Reader with data for which to compute rmsmap
255262
:param out_folder: folder in which to store output ALF files. Default uses the folder in which
256263
the `fbin` file lives.
257264
:param overwrite: do not re-extract if all ALF files already exist
258-
:param label: string or list of strings that will be appended to the filename before extension
265+
:param spectra: Whether to compute the spectral density across the signal
266+
:param nmod: take every nmod windows, in cases where we don't want to compute over the whole signal
259267
:return: None
260268
"""
261269
if out_folder is None:
@@ -271,18 +279,19 @@ def extract_rmsmap(sglx, out_folder=None, overwrite=False):
271279
_logger.warning(f'RMS map already exists for .{sglx.type} data in {out_folder}, skipping. Use overwrite option.')
272280
return files_time + files_freq
273281
# crunch numbers
274-
rms = rmsmap(sglx)
282+
rms = rmsmap(sglx, spectra=spectra, nmod=nmod)
275283
# output ALF files, single precision with the optional label as suffix before extension
276284
if not out_folder.exists():
277285
out_folder.mkdir()
278286
tdict = {'rms': rms['TRMS'].astype(np.single), 'timestamps': rms['tscale'].astype(np.single)}
279-
fdict = {'power': rms['spectral_density'].astype(np.single),
280-
'freqs': rms['fscale'].astype(np.single)}
281287
out_time = alfio.save_object_npy(
282288
out_folder, object=alf_object_time, dico=tdict, namespace='iblqc')
283-
out_freq = alfio.save_object_npy(
284-
out_folder, object=alf_object_freq, dico=fdict, namespace='iblqc')
285-
return out_time + out_freq
289+
if spectra:
290+
fdict = {'power': rms['spectral_density'].astype(np.single),
291+
'freqs': rms['fscale'].astype(np.single)}
292+
out_freq = alfio.save_object_npy(
293+
out_folder, object=alf_object_freq, dico=fdict, namespace='iblqc')
294+
return out_time + out_freq if spectra else out_time
286295

287296

288297
def raw_qc_session(session_path, overwrite=False):

ibllib/io/extractors/ephys_fpga.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def get_audio_event_times(self, sync, chmap, audio_event_ttls=None, display=Fals
10831083
if audio_event_ttls is None:
10841084
# For training/biased/ephys protocols, the ready tone should be below 110 ms. The error
10851085
# tone should be between 400ms and 1200ms
1086-
audio_event_ttls = {'ready_tone': (0, 0.11), 'error_tone': (0.4, 1.2)}
1086+
audio_event_ttls = {'ready_tone': (0, 0.1101), 'error_tone': (0.4, 1.2)}
10871087
audio_event_intervals = self._assign_events(audio['times'], audio['polarities'], audio_event_ttls, display=display)
10881088

10891089
return audio, audio_event_intervals

ibllib/pipes/dynamic_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def make_pipeline(session_path, **pkwargs):
532532
# The PostDLC plots require a trials object for QC
533533
# Find the first task that outputs a trials.table dataset
534534
trials_task = (
535-
t for t in tasks.values() if any('trials.table' in f for f in t.signature.get('output_files', []))
535+
t for t in tasks.values() if any('trials.table' in f[0] for f in t.signature.get('output_files', []))
536536
)
537537
if trials_task := next(trials_task, None):
538538
parents = [tasks['DLC'], tasks[f'VideoSyncQC_{sync}'], trials_task]

ibllib/pipes/training_status.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def load_combined_trials(sess_paths, one, force=True):
215215
return training.concatenate_trials(trials_dict)
216216

217217

218-
def get_latest_training_information(sess_path, one):
218+
def get_latest_training_information(sess_path, one, save=True):
219219
"""
220220
Extracts the latest training status.
221221
@@ -262,7 +262,8 @@ def get_latest_training_information(sess_path, one):
262262
df = df.sort_values('date')
263263
df = df.reset_index(drop=True)
264264
# Save our dataframe
265-
save_dataframe(df, subj_path)
265+
if save:
266+
save_dataframe(df, subj_path)
266267

267268
# Now go through the backlog and compute the training status for sessions. If for example one was missing as it is cumulative
268269
# we need to go through and compute all the backlog
@@ -288,10 +289,10 @@ def get_latest_training_information(sess_path, one):
288289
if 'ready4ephysrig' not in tr_st:
289290
sess = un_df.iloc[39].session_path
290291
df.loc[df['session_path'] == sess, 'training_status'] = 'unbiasable'
292+
if save:
293+
save_dataframe(df, subj_path)
291294

292-
save_dataframe(df, subj_path)
293-
294-
if one.mode != 'local':
295+
if one.mode != 'local' and save:
295296
upload_training_table_to_aws(lab, sub)
296297

297298
return df
@@ -519,11 +520,11 @@ def get_sess_dict(session_path, one, protocol, alf_collections=None, raw_collect
519520
sess_dict['n_delay'] = np.nan
520521
sess_dict['location'] = np.nan
521522
sess_dict['training_status'] = 'habituation'
522-
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
523+
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapselow_50'], sess_dict['lapsehigh_50'] = \
523524
(np.nan, np.nan, np.nan, np.nan)
524-
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
525+
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapselow_20'], sess_dict['lapsehigh_20'] = \
525526
(np.nan, np.nan, np.nan, np.nan)
526-
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
527+
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapselow_80'], sess_dict['lapsehigh_80'] = \
527528
(np.nan, np.nan, np.nan, np.nan)
528529

529530
else:
@@ -534,18 +535,18 @@ def get_sess_dict(session_path, one, protocol, alf_collections=None, raw_collect
534535

535536
sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True)
536537
if sess_dict['task_protocol'] == 'training':
537-
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
538+
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapselow_50'], sess_dict['lapsehigh_50'] = \
538539
training.compute_psychometric(trials)
539-
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
540+
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapselow_20'], sess_dict['lapsehigh_20'] = \
540541
(np.nan, np.nan, np.nan, np.nan)
541-
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
542+
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapselow_80'], sess_dict['lapsehigh_80'] = \
542543
(np.nan, np.nan, np.nan, np.nan)
543544
else:
544-
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
545+
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapselow_50'], sess_dict['lapsehigh_50'] = \
545546
training.compute_psychometric(trials, block=0.5)
546-
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
547+
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapselow_20'], sess_dict['lapsehigh_20'] = \
547548
training.compute_psychometric(trials, block=0.2)
548-
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
549+
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapselow_80'], sess_dict['lapsehigh_80'] = \
549550
training.compute_psychometric(trials, block=0.8)
550551

551552
sess_dict['performance_easy'] = training.compute_performance_easy(trials)
@@ -646,8 +647,8 @@ def get_training_info_for_session(session_paths, one, force=True):
646647
for bias in [50, 20, 80]:
647648
sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0]
648649
sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1]
649-
sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][2]
650-
sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][3]
650+
sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][2]
651+
sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][3]
651652

652653
# Case where two sessions on same day with different number of contrasts! Oh boy
653654
if sess_dict['combined_performance'].size != sess_dict['performance'].size:

0 commit comments

Comments
 (0)