Coverage for ibllib/io/extractors/training_trials.py: 80%
375 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-02 18:55 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-02 18:55 +0100
1import logging
2import numpy as np
3from itertools import accumulate
4from packaging import version
5from one.alf.io import AlfBunch
7import ibllib.io.raw_data_loaders as raw
8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
9from ibllib.io.extractors.training_wheel import Wheel
12_logger = logging.getLogger(__name__)
13__all__ = ['TrainingTrials']
16class FeedbackType(BaseBpodTrialsExtractor):
17 """
18 Get the feedback that was delivered to subject.
19 **Optional:** saves _ibl_trials.feedbackType.npy
21 Checks in raw datafile for error and reward state.
22 Will raise an error if more than one of the mutually exclusive states have
23 been triggered.
25 Sets feedbackType to -1 if error state was triggered (applies to no-go)
26 Sets feedbackType to +1 if reward state was triggered
27 """
28 save_names = '_ibl_trials.feedbackType.npy'
29 var_names = 'feedbackType'
31 def _extract(self):
32 feedbackType = np.zeros(len(self.bpod_trials), np.int64) 1awcjhdeqlmknoifbgp
33 for i, t in enumerate(self.bpod_trials): 1awcjhdeqlmknoifbgp
34 state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go'] 1awcjhdeqlmknoifbgp
35 outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.nan]])[0][0]) for sn in state_names} 1awcjhdeqlmknoifbgp
36 assert np.sum(list(outcome.values())) == 1 1awcjhdeqlmknoifbgp
37 outcome = next(k for k in outcome if outcome[k]) 1awcjhdeqlmknoifbgp
38 if outcome == 'correct': 1awcjhdeqlmknoifbgp
39 feedbackType[i] = 1 1awcjhdeqlmknoifbg
40 elif outcome in ['error', 'no_go']: 1awcjhdeqlmknoifbgp
41 feedbackType[i] = -1 1awcjhdeqlmknoifbgp
42 return feedbackType 1awcjhdeqlmknoifbgp
45class ContrastLR(BaseBpodTrialsExtractor):
46 """
47 Get left and right contrasts from raw datafile. Optionally, saves
48 _ibl_trials.contrastLeft.npy and _ibl_trials.contrastRight.npy to alf folder.
50 Uses signed_contrast to create left and right contrast vectors.
51 """
52 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy')
53 var_names = ('contrastLeft', 'contrastRight')
55 def _extract(self):
56 # iblrigv8 has only flat values in the trial table so we can switch to parquet table when times come
57 # and all the clutter here would fit in ~30 lines
58 if isinstance(self.bpod_trials[0]['contrast'], float): 1Fcdefb
59 contrastLeft = np.array([t['contrast'] if np.sign( 1f
60 t['position']) < 0 else np.nan for t in self.bpod_trials])
61 contrastRight = np.array([t['contrast'] if np.sign( 1f
62 t['position']) > 0 else np.nan for t in self.bpod_trials])
63 else:
64 contrastLeft = np.array([t['contrast']['value'] if np.sign( 1Fcdeb
65 t['position']) < 0 else np.nan for t in self.bpod_trials])
66 contrastRight = np.array([t['contrast']['value'] if np.sign( 1Fcdeb
67 t['position']) > 0 else np.nan for t in self.bpod_trials])
69 return contrastLeft, contrastRight 1Fcdefb
72class ProbabilityLeft(BaseBpodTrialsExtractor):
73 save_names = '_ibl_trials.probabilityLeft.npy'
74 var_names = 'probabilityLeft'
76 def _extract(self, **kwargs):
77 return np.array([t['stim_probability_left'] for t in self.bpod_trials]) 1Mcjdeqlifb
80class Choice(BaseBpodTrialsExtractor):
81 """
82 Get the subject's choice in every trial.
83 **Optional:** saves _ibl_trials.choice.npy to alf folder.
85 Uses signed_contrast and trial_correct.
86 -1 is a CCW turn (towards the left)
87 +1 is a CW turn (towards the right)
88 0 is a no_go trial
89 If a trial is correct the choice of the animal was the inverse of the sign
90 of the position.
92 >>> choice[t] = -np.sign(position[t]) if trial_correct[t]
93 """
94 save_names = '_ibl_trials.choice.npy'
95 var_names = 'choice'
97 def _extract(self):
98 sitm_side = np.array([np.sign(t['position']) for t in self.bpod_trials]) 1aBcrjhdeqlmknoifbgp
99 trial_correct = np.array([t['trial_correct'] for t in self.bpod_trials]) 1aBcrjhdeqlmknoifbgp
100 trial_nogo = np.array( 1aBcrjhdeqlmknoifbgp
101 [~np.isnan(t['behavior_data']['States timestamps']['no_go'][0][0])
102 for t in self.bpod_trials])
103 choice = sitm_side.copy() 1aBcrjhdeqlmknoifbgp
104 choice[trial_correct] = -choice[trial_correct] 1aBcrjhdeqlmknoifbgp
105 choice[trial_nogo] = 0 1aBcrjhdeqlmknoifbgp
106 choice = choice.astype(int) 1aBcrjhdeqlmknoifbgp
107 return choice 1aBcrjhdeqlmknoifbgp
110class RepNum(BaseBpodTrialsExtractor):
111 """
112 Count the consecutive repeated trials.
113 **Optional:** saves _ibl_trials.repNum.npy to alf folder.
115 Creates trial_repeated from trial['contrast']['type'] == 'RepeatContrast'
117 >>> trial_repeated = [0, 1, 1, 0, 1, 0, 1, 1, 1, 0]
118 >>> repNum = [0, 1, 2, 0, 1, 0, 1, 2, 3, 0]
119 """
120 save_names = '_ibl_trials.repNum.npy'
121 var_names = 'repNum'
123 def _extract(self):
124 def get_trial_repeat(trial): 1Ccdefb
125 if 'debias_trial' in trial: 1Ccdefb
126 return trial['debias_trial'] 1f
127 elif 'contrast' in trial and isinstance(trial['contrast'], dict): 1Ccdeb
128 return trial['contrast']['type'] == 'RepeatContrast' 1Ccdeb
129 else:
130 # For advanced choice world and its subclasses before version 8.19.0 there was no 'debias_trial' field
131 # and no debiasing protocol applied, so simply return False
132 assert (self.settings['PYBPOD_PROTOCOL'].startswith('_iblrig_tasks_advancedChoiceWorld') or
133 self.settings['PYBPOD_PROTOCOL'].startswith('ccu_neuromodulatorChoiceWorld'))
134 return False
136 trial_repeated = np.fromiter(map(get_trial_repeat, self.bpod_trials), int) 1Ccdefb
137 repNum = np.fromiter(accumulate(trial_repeated, lambda x, y: x + y if y else 0), int) 1Ccdefb
138 return repNum 1Ccdefb
141class RewardVolume(BaseBpodTrialsExtractor):
142 """
143 Load reward volume delivered for each trial.
144 **Optional:** saves _ibl_trials.rewardVolume.npy
146 Uses reward_current to accumulate the amount of
147 """
148 save_names = '_ibl_trials.rewardVolume.npy'
149 var_names = 'rewardVolume'
151 def _extract(self):
152 trial_volume = [x['reward_amount'] 1aGcjhdeqlmknoifbgp
153 if x['trial_correct'] else 0 for x in self.bpod_trials]
154 reward_volume = np.array(trial_volume).astype(np.float64) 1aGcjhdeqlmknoifbgp
155 assert len(reward_volume) == len(self.bpod_trials) 1aGcjhdeqlmknoifbgp
156 return reward_volume 1aGcjhdeqlmknoifbgp
159class FeedbackTimes(BaseBpodTrialsExtractor):
160 """
161 Get the times the water or error tone was delivered to the animal.
162 **Optional:** saves _ibl_trials.feedback_times.npy
164 Gets reward and error state init times vectors,
165 checks if the intersection of nans is empty, then
166 merges the 2 vectors.
167 """
168 save_names = '_ibl_trials.feedback_times.npy'
169 var_names = 'feedback_times'
171 @staticmethod
172 def get_feedback_times_lt5(session_path, task_collection='raw_behavior_data', data=False):
173 if not data: 1azcqb
174 data = raw.load_data(session_path, task_collection=task_collection)
175 rw_times = [tr['behavior_data']['States timestamps']['reward'][0][0] 1azcqb
176 for tr in data]
177 err_times = [tr['behavior_data']['States timestamps']['error'][0][0] 1azcqb
178 for tr in data]
179 nogo_times = [tr['behavior_data']['States timestamps']['no_go'][0][0] 1azcqb
180 for tr in data]
181 assert sum(np.isnan(rw_times) & 1azcqb
182 np.isnan(err_times) & np.isnan(nogo_times)) == 0
183 merge = np.array([np.array(times)[~np.isnan(times)] for times in 1azcqb
184 zip(rw_times, err_times, nogo_times)]).squeeze()
186 return np.array(merge) 1azcqb
188 @staticmethod
189 def get_feedback_times_ge5(session_path, task_collection='raw_behavior_data', data=False):
190 # ger err and no go trig times -- look for BNC2High of trial -- verify
191 # only 2 onset times go tone and noise, select 2nd/-1 OR select the one
192 # that is grater than the nogo or err trial onset time
193 if not data: 1ucjshdeltmknoifbgp
194 data = raw.load_data(session_path, task_collection=task_collection)
195 missed_bnc2 = 0 1ucjshdeltmknoifbgp
196 rw_times, err_sound_times, merge = [np.zeros([len(data), ]) for _ in range(3)] 1ucjshdeltmknoifbgp
198 for ind, tr in enumerate(data): 1ucjshdeltmknoifbgp
199 st = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1ucjshdeltmknoifbgp
200 if not st: 1ucjshdeltmknoifbgp
201 st = np.array([np.nan, np.nan]) 1bg
202 missed_bnc2 += 1 1bg
203 # xonar soundcard duplicates events, remove consecutive events too close together
204 st = np.delete(st, np.where(np.diff(st) < 0.020)[0] + 1) 1ucjshdeltmknoifbgp
205 rw_times[ind] = tr['behavior_data']['States timestamps']['reward'][0][0] 1ucjshdeltmknoifbgp
206 # get the error sound only if the reward is nan
207 err_sound_times[ind] = st[-1] if st.size >= 2 and np.isnan(rw_times[ind]) else np.nan 1ucjshdeltmknoifbgp
208 if missed_bnc2 == len(data): 1ucjshdeltmknoifbgp
209 _logger.warning('No BNC2 for feedback times, filling error trials NaNs')
210 merge *= np.nan 1ucjshdeltmknoifbgp
211 merge[~np.isnan(rw_times)] = rw_times[~np.isnan(rw_times)] 1ucjshdeltmknoifbgp
212 merge[~np.isnan(err_sound_times)] = err_sound_times[~np.isnan(err_sound_times)] 1ucjshdeltmknoifbgp
214 return merge 1ucjshdeltmknoifbgp
216 def _extract(self):
217 # Version check
218 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1auzcjshdeqltmknoifbgp
219 merge = self.get_feedback_times_ge5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1ucjshdeltmknoifbgp
220 else:
221 merge = self.get_feedback_times_lt5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1azcqb
222 return np.array(merge) 1auzcjshdeqltmknoifbgp
225class Intervals(BaseBpodTrialsExtractor):
226 """
227 Trial start to trial end. Trial end includes 1 or 2 seconds after feedback,
228 (depending on the feedback) and 0.5 seconds of iti.
229 **Optional:** saves _ibl_trials.intervals.npy
231 Uses the corrected Trial start and Trial end timestamp values form PyBpod.
232 """
233 save_names = '_ibl_trials.intervals.npy'
234 var_names = 'intervals'
236 def _extract(self):
237 starts = [t['behavior_data']['Trial start timestamp'] for t in self.bpod_trials] 1aHcjhdeqlmknoifbgp
238 ends = [t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials] 1aHcjhdeqlmknoifbgp
239 return np.array([starts, ends]).T 1aHcjhdeqlmknoifbgp
242class ResponseTimes(BaseBpodTrialsExtractor):
243 """
244 Time (in absolute seconds from session start) when a response was recorded.
245 **Optional:** saves _ibl_trials.response_times.npy
247 Uses the timestamp of the end of the closed_loop state.
248 """
249 save_names = '_ibl_trials.response_times.npy'
250 var_names = 'response_times'
252 def _extract(self):
253 rt = np.array([tr['behavior_data']['States timestamps']['closed_loop'][0][1] 1aKcjhdeqlmknoifbgp
254 for tr in self.bpod_trials])
255 return rt 1aKcjhdeqlmknoifbgp
258class ItiDuration(BaseBpodTrialsExtractor):
259 """
260 Calculate duration of iti from state timestamps.
261 **Optional:** saves _ibl_trials.iti_duration.npy
263 Uses Trial end timestamp and get_response_times to calculate iti.
264 """
265 save_names = '_ibl_trials.itiDuration.npy'
266 var_names = 'iti_dur'
268 def _extract(self):
269 rt, _ = ResponseTimes(self.session_path).extract(
270 save=False, task_collection=self.task_collection, bpod_trials=self.bpod_trials, settings=self.settings)
271 ends = np.array([t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials])
272 iti_dur = ends - rt
273 return iti_dur
276class GoCueTriggerTimes(BaseBpodTrialsExtractor):
277 """
278 Get trigger times of goCue from state machine.
280 Current software solution for triggering sounds uses PyBpod soft codes.
281 Delays can be in the order of 10's of ms. This is the time when the command
282 to play the sound was executed. To measure accurate time, either getting the
283 sound onset from xonar soundcard sync pulse (latencies may vary).
284 """
285 save_names = '_ibl_trials.goCueTrigger_times.npy'
286 var_names = 'goCueTrigger_times'
288 def _extract(self):
289 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1aIcjhdeqlmknoifbgp
290 goCue = np.array([tr['behavior_data']['States timestamps'] 1Icjhdelmknoifbgp
291 ['play_tone'][0][0] for tr in self.bpod_trials])
292 else:
293 goCue = np.array([tr['behavior_data']['States timestamps'] 1acqb
294 ['closed_loop'][0][0] for tr in self.bpod_trials])
295 return goCue 1aIcjhdeqlmknoifbgp
298class TrialType(BaseBpodTrialsExtractor):
299 save_names = '_ibl_trials.type.npy'
300 var_name = 'trial_type'
302 def _extract(self):
303 trial_type = []
304 for tr in self.bpod_trials:
305 if ~np.isnan(tr["behavior_data"]["States timestamps"]["reward"][0][0]):
306 trial_type.append(1)
307 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["error"][0][0]):
308 trial_type.append(-1)
309 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0][0]):
310 trial_type.append(0)
311 else:
312 _logger.warning("Trial is not in set {-1, 0, 1}, appending NaN to trialType")
313 trial_type.append(np.nan)
314 return np.array(trial_type)
317class GoCueTimes(BaseBpodTrialsExtractor):
318 """
319 Get trigger times of goCue from state machine.
321 Current software solution for triggering sounds uses PyBpod soft codes.
322 Delays can be in the order of 10-100s of ms. This is the time when the command
323 to play the sound was executed. To measure accurate time, either getting the
324 sound onset from the future microphone OR the new xonar soundcard and
325 setup developed by Sanworks guarantees a set latency (in testing).
326 """
327 save_names = '_ibl_trials.goCue_times.npy'
328 var_names = 'goCue_times'
330 def _extract(self):
331 go_cue_times = np.zeros([len(self.bpod_trials), ]) 1avcxjshdeqltmknoifybgp
332 for ind, tr in enumerate(self.bpod_trials): 1avcxjshdeqltmknoifybgp
333 if raw.get_port_events(tr, 'BNC2'): 1avcxjshdeqltmknoifybgp
334 bnchigh = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1vcxjshdeqltmknoifybgp
335 if bnchigh: 1vcxjshdeqltmknoifybgp
336 go_cue_times[ind] = bnchigh[0] 1vcxjshdeqltmknoifybgp
337 continue 1vcxjshdeqltmknoifybgp
338 bnclow = tr['behavior_data']['Events timestamps'].get('BNC2Low', None) 1bg
339 if bnclow: 1bg
340 go_cue_times[ind] = bnclow[0] - 0.1 1bg
341 continue 1bg
342 go_cue_times[ind] = np.nan
343 else:
344 go_cue_times[ind] = np.nan 1avcb
346 nmissing = np.sum(np.isnan(go_cue_times)) 1avcxjshdeqltmknoifybgp
347 # Check if all stim_syncs have failed to be detected
348 if np.all(np.isnan(go_cue_times)): 1avcxjshdeqltmknoifybgp
349 _logger.warning( 1avcb
350 f'{self.session_path}: Missing ALL !! BNC2 TTLs ({nmissing} trials)')
351 # Check if any stim_sync has failed be detected for every trial
352 elif np.any(np.isnan(go_cue_times)): 1vcxjshdeqltmknoifybgp
353 _logger.warning(f'{self.session_path}: Missing BNC2 TTLs on {nmissing} trials')
355 return go_cue_times 1avcxjshdeqltmknoifybgp
358class IncludedTrials(BaseBpodTrialsExtractor):
359 save_names = '_ibl_trials.included.npy'
360 var_names = 'included'
362 def _extract(self):
363 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1aADEcjhqlmknoibgp
364 trials_included = self.get_included_trials_ge5( 1ADcjhlmknoibgp
365 data=self.bpod_trials, settings=self.settings)
366 else:
367 trials_included = self.get_included_trials_lt5(data=self.bpod_trials) 1aAEcqb
368 return trials_included 1aADEcjhqlmknoibgp
370 @staticmethod
371 def get_included_trials_lt5(data=False):
372 trials_included = np.ones(len(data), dtype=bool) 1aAEcqb
373 return trials_included 1aAEcqb
375 @staticmethod
376 def get_included_trials_ge5(data=False, settings=False):
377 trials_included = np.array([True for t in data]) 1ADcjhlmknoibgp
378 if ('SUBJECT_DISENGAGED_TRIGGERED' in settings.keys() and settings[ 1ADcjhlmknoibgp
379 'SUBJECT_DISENGAGED_TRIGGERED'] is not False):
380 idx = settings['SUBJECT_DISENGAGED_TRIALNUM'] - 1 1jhmknobg
381 trials_included[idx:] = False 1jhmknobg
382 return trials_included 1ADcjhlmknoibgp
385class ItiInTimes(BaseBpodTrialsExtractor):
386 var_names = 'itiIn_times'
388 def _extract(self):
389 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("5.0.0"): 1acjhdeqlmknoifbgp
390 iti_in = np.ones(len(self.bpod_trials)) * np.nan 1acqb
391 else:
392 iti_in = np.array( 1cjhdelmknoifbgp
393 [tr["behavior_data"]["States timestamps"]
394 ["exit_state"][0][0] for tr in self.bpod_trials]
395 )
396 return iti_in 1acjhdeqlmknoifbgp
399class ErrorCueTriggerTimes(BaseBpodTrialsExtractor):
400 var_names = 'errorCueTrigger_times'
402 def _extract(self):
403 errorCueTrigger_times = np.zeros(len(self.bpod_trials)) * np.nan 1acjhdeqlmknoifbgp
404 for i, tr in enumerate(self.bpod_trials): 1acjhdeqlmknoifbgp
405 nogo = tr["behavior_data"]["States timestamps"]["no_go"][0][0] 1acjhdeqlmknoifbgp
406 error = tr["behavior_data"]["States timestamps"]["error"][0][0] 1acjhdeqlmknoifbgp
407 if np.all(~np.isnan(nogo)): 1acjhdeqlmknoifbgp
408 errorCueTrigger_times[i] = nogo 1chdemnoibg
409 elif np.all(~np.isnan(error)): 1acjhdeqlmknoifbgp
410 errorCueTrigger_times[i] = error 1acjhdeqlmknoifbgp
411 return errorCueTrigger_times 1acjhdeqlmknoifbgp
414class StimFreezeTriggerTimes(BaseBpodTrialsExtractor):
415 var_names = 'stimFreezeTrigger_times'
417 def _extract(self):
418 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("6.2.5"): 1acrjhdeqlmknoifbgp
419 return np.ones(len(self.bpod_trials)) * np.nan 1acrqmnobgp
420 freeze_reward = np.array( 1jhdelkifg
421 [
422 True
423 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_reward"][0]))
424 else False
425 for tr in self.bpod_trials
426 ]
427 )
428 freeze_error = np.array( 1jhdelkifg
429 [
430 True
431 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_error"][0]))
432 else False
433 for tr in self.bpod_trials
434 ]
435 )
436 no_go = np.array( 1jhdelkifg
437 [
438 True
439 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0]))
440 else False
441 for tr in self.bpod_trials
442 ]
443 )
444 assert (np.sum(freeze_error) + np.sum(freeze_reward) + 1jhdelkifg
445 np.sum(no_go) == len(self.bpod_trials))
446 stimFreezeTrigger = np.array([]) 1jhdelkifg
447 for r, e, n, tr in zip(freeze_reward, freeze_error, no_go, self.bpod_trials): 1jhdelkifg
448 if n: 1jhdelkifg
449 stimFreezeTrigger = np.append(stimFreezeTrigger, np.nan) 1hdeig
450 continue 1hdeig
451 state = "freeze_reward" if r else "freeze_error" 1jhdelkifg
452 stimFreezeTrigger = np.append( 1jhdelkifg
453 stimFreezeTrigger, tr["behavior_data"]["States timestamps"][state][0][0]
454 )
455 return stimFreezeTrigger 1jhdelkifg
458class StimOffTriggerTimes(BaseBpodTrialsExtractor):
459 var_names = 'stimOffTrigger_times'
460 save_names = '_ibl_trials.stimOnTrigger_times.npy'
462 def _extract(self):
463 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') >= version.parse("6.2.5"): 1acrjhdeqlmknoifbgp
464 stim_off_trigger_state = "hide_stim" 1jhdelkifg
465 elif version.parse(self.settings["IBLRIG_VERSION"]) >= version.parse("5.0.0"): 1acrqmnobgp
466 stim_off_trigger_state = "exit_state" 1crmnobgp
467 else:
468 stim_off_trigger_state = "trial_start" 1acrqb
470 stimOffTrigger_times = np.array( 1acrjhdeqlmknoifbgp
471 [tr["behavior_data"]["States timestamps"][stim_off_trigger_state][0][0]
472 for tr in self.bpod_trials]
473 )
474 # If pre version 5.0.0 no specific nogo Off trigger was given, just return trial_starts
475 if stim_off_trigger_state == "trial_start": 1acrjhdeqlmknoifbgp
476 return stimOffTrigger_times 1acrqb
478 no_goTrigger_times = np.array( 1crjhdelmknoifbgp
479 [tr["behavior_data"]["States timestamps"]["no_go"][0][0] for tr in self.bpod_trials]
480 )
481 # Stim off trigs are either in their own state or in the no_go state if the
482 # mouse did not move, if the stim_off_trigger_state always exist
483 # (exit_state or trial_start)
484 # no NaNs will happen, NaNs might happen in at last trial if
485 # session was stopped after response
486 # if stim_off_trigger_state == "hide_stim":
487 # assert all(~np.isnan(no_goTrigger_times) == np.isnan(stimOffTrigger_times))
488 # Patch with the no_go states trig times
489 stimOffTrigger_times[~np.isnan(no_goTrigger_times)] = no_goTrigger_times[ 1crjhdelmknoifbgp
490 ~np.isnan(no_goTrigger_times)
491 ]
492 return stimOffTrigger_times 1crjhdelmknoifbgp
495class StimOnTriggerTimes(BaseBpodTrialsExtractor):
496 save_names = '_ibl_trials.stimOnTrigger_times.npy'
497 var_names = 'stimOnTrigger_times'
499 def _extract(self):
500 # Get the stim_on_state that triggers the onset of the stim
501 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1aLcrjshdeqltmknoifbgp
502 ['stim_on'][0] for tr in self.bpod_trials])
503 return stim_on_state[:, 0].T 1aLcrjshdeqltmknoifbgp
506class StimOnTimes_deprecated(BaseBpodTrialsExtractor):
507 save_names = '_ibl_trials.stimOn_times.npy'
508 var_names = 'stimOn_times'
510 def _extract(self):
511 """
512 Find the time of the state machine command to turn on the stim
513 (state stim_on start or rotary_encoder_event2)
514 Find the next frame change from the photodiode after that TS.
515 Screen is not displaying anything until then.
516 (Frame changes are in BNC1 High and BNC1 Low)
517 """
518 # Version check
519 _logger.warning("Deprecation Warning: this is an old version of stimOn extraction."
520 "From version 5., use StimOnOffFreezeTimes")
521 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'):
522 stimOn_times = self.get_stimOn_times_ge5(self.session_path, data=self.bpod_trials,
523 task_collection=self.task_collection)
524 else:
525 stimOn_times = self.get_stimOn_times_lt5(self.session_path, data=self.bpod_trials,
526 task_collection=self.task_collection)
527 return np.array(stimOn_times)
529 @staticmethod
530 def get_stimOn_times_ge5(session_path, data=False, task_collection='raw_behavior_data'):
531 """
532 Find first and last stim_sync pulse of the trial.
533 stimOn_times should be the first after the stim_on state.
534 (Stim updates are in BNC1High and BNC1Low - frame2TTL device)
535 Check that all trials have frame changes.
536 Find length of stim_on_state [start, stop].
537 If either check fails the HW device failed to detect the stim_sync square change
538 Substitute that trial's missing or incorrect value with a NaN.
539 return stimOn_times
540 """
541 if not data:
542 data = raw.load_data(session_path, task_collection=task_collection)
543 # Get all stim_sync events detected
544 stim_sync_all = [raw.get_port_events(tr, 'BNC1') for tr in data]
545 stim_sync_all = [np.array(x) for x in stim_sync_all]
546 # Get the stim_on_state that triggers the onset of the stim
547 stim_on_state = np.array([tr['behavior_data']['States timestamps']
548 ['stim_on'][0] for tr in data])
550 stimOn_times = np.array([])
551 for sync, on, off in zip(
552 stim_sync_all, stim_on_state[:, 0], stim_on_state[:, 1]):
553 pulse = sync[np.where(np.bitwise_and((sync > on), (sync <= off)))]
554 if pulse.size == 0:
555 stimOn_times = np.append(stimOn_times, np.nan)
556 else:
557 stimOn_times = np.append(stimOn_times, pulse)
559 nmissing = np.sum(np.isnan(stimOn_times))
560 # Check if all stim_syncs have failed to be detected
561 if np.all(np.isnan(stimOn_times)):
562 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({nmissing} trials)')
564 # Check if any stim_sync has failed be detected for every trial
565 if np.any(np.isnan(stimOn_times)):
566 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {nmissing} trials')
568 return stimOn_times
570 @staticmethod
571 def get_stimOn_times_lt5(session_path, data=False, task_collection='raw_behavior_data'):
572 """
573 Find the time of the statemachine command to turn on the stim
574 (state stim_on start or rotary_encoder_event2)
575 Find the next frame change from the photodiode after that TS.
576 Screen is not displaying anything until then.
577 (Frame changes are in BNC1High and BNC1Low)
578 """
579 if not data:
580 data = raw.load_data(session_path, task_collection=task_collection)
581 stim_on = []
582 bnc_h = []
583 bnc_l = []
584 for tr in data:
585 stim_on.append(tr['behavior_data']['States timestamps']['stim_on'][0][0])
586 if 'BNC1High' in tr['behavior_data']['Events timestamps'].keys():
587 bnc_h.append(np.array(tr['behavior_data']
588 ['Events timestamps']['BNC1High']))
589 else:
590 bnc_h.append(np.array([np.NINF]))
591 if 'BNC1Low' in tr['behavior_data']['Events timestamps'].keys():
592 bnc_l.append(np.array(tr['behavior_data']
593 ['Events timestamps']['BNC1Low']))
594 else:
595 bnc_l.append(np.array([np.NINF]))
597 stim_on = np.array(stim_on)
598 bnc_h = np.array(bnc_h, dtype=object)
599 bnc_l = np.array(bnc_l, dtype=object)
601 count_missing = 0
602 stimOn_times = np.zeros_like(stim_on)
603 for i in range(len(stim_on)):
604 hl = np.sort(np.concatenate([bnc_h[i], bnc_l[i]]))
605 stot = hl[hl > stim_on[i]]
606 if np.size(stot) == 0:
607 stot = np.array([np.nan])
608 count_missing += 1
609 stimOn_times[i] = stot[0]
611 if np.all(np.isnan(stimOn_times)):
612 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({count_missing} trials)')
614 if count_missing > 0:
615 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {count_missing} trials')
617 return np.array(stimOn_times)
620class StimOnOffFreezeTimes(BaseBpodTrialsExtractor):
621 """
622 Extracts stim on / off and freeze times from Bpod BNC1 detected fronts.
624 Each stimulus event is the first detected front of the BNC1 signal after the trigger state, but before the next
625 trigger state.
626 """
627 save_names = ('_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOff_times.npy', None)
628 var_names = ('stimOn_times', 'stimOff_times', 'stimFreeze_times')
630 def _extract(self):
631 choice = Choice(self.session_path).extract( 1acrjhdeqlmknoifbgp
632 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False
633 )[0]
634 stimOnTrigger = StimOnTriggerTimes(self.session_path).extract( 1acrjhdeqlmknoifbgp
635 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False
636 )[0]
637 stimFreezeTrigger = StimFreezeTriggerTimes(self.session_path).extract( 1acrjhdeqlmknoifbgp
638 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False
639 )[0]
640 stimOffTrigger = StimOffTriggerTimes(self.session_path).extract( 1acrjhdeqlmknoifbgp
641 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False
642 )[0]
643 f2TTL = [raw.get_port_events(tr, name='BNC1') for tr in self.bpod_trials] 1acrjhdeqlmknoifbgp
644 assert stimOnTrigger.size == stimFreezeTrigger.size == stimOffTrigger.size == choice.size == len(f2TTL) 1acrjhdeqlmknoifbgp
645 assert all(stimOnTrigger < np.nan_to_num(stimFreezeTrigger, nan=np.inf)) and \ 1acrjhdeqlmknoifbgp
646 all(np.nan_to_num(stimFreezeTrigger, nan=-np.inf) < stimOffTrigger)
648 stimOn_times = np.array([]) 1acrjhdeqlmknoifbgp
649 stimOff_times = np.array([]) 1acrjhdeqlmknoifbgp
650 stimFreeze_times = np.array([]) 1acrjhdeqlmknoifbgp
651 has_freeze = version.parse(self.settings.get('IBLRIG_VERSION', '0')) >= version.parse('6.2.5') 1acrjhdeqlmknoifbgp
652 for tr, on, freeze, off, c in zip(f2TTL, stimOnTrigger, stimFreezeTrigger, stimOffTrigger, choice): 1acrjhdeqlmknoifbgp
653 tr = np.array(tr) 1acrjhdeqlmknoifbgp
654 # stim on
655 lim = freeze if has_freeze else off 1acrjhdeqlmknoifbgp
656 idx, = np.where(np.logical_and(on < tr, tr < lim)) 1acrjhdeqlmknoifbgp
657 stimOn_times = np.append(stimOn_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acrjhdeqlmknoifbgp
658 # stim off
659 idx, = np.where(off < tr) 1acrjhdeqlmknoifbgp
660 stimOff_times = np.append(stimOff_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acrjhdeqlmknoifbgp
661 # stim freeze - take last event before off trigger
662 if has_freeze: 1acrjhdeqlmknoifbgp
663 idx, = np.where(np.logical_and(freeze < tr, tr < off)) 1jhdelkifg
664 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1jhdelkifg
665 else:
666 idx, = np.where(tr <= off) 1acrqmnobgp
667 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1acrqmnobgp
668 # In no_go trials no stimFreeze happens just stim Off
669 stimFreeze_times[choice == 0] = np.nan 1acrjhdeqlmknoifbgp
671 return stimOn_times, stimOff_times, stimFreeze_times 1acrjhdeqlmknoifbgp
674class PhasePosQuiescence(BaseBpodTrialsExtractor):
675 """Extract stimulus phase, position and quiescence from Bpod data.
677 For extraction of pre-generated events, use the ProbaContrasts extractor instead.
678 """
679 save_names = (None, None, '_ibl_trials.quiescencePeriod.npy')
680 var_names = ('phase', 'position', 'quiescence')
682 def _extract(self, **kwargs):
683 phase = np.array([t['stim_phase'] for t in self.bpod_trials]) 1acjhdeqlmknoifbgp
684 position = np.array([t['position'] for t in self.bpod_trials]) 1acjhdeqlmknoifbgp
685 quiescence = np.array([t['quiescent_period'] for t in self.bpod_trials]) 1acjhdeqlmknoifbgp
686 return phase, position, quiescence 1acjhdeqlmknoifbgp
689class PauseDuration(BaseBpodTrialsExtractor):
690 """Extract pause duration from raw trial data."""
691 save_names = None
692 var_names = 'pause_duration'
694 def _extract(self, **kwargs):
695 # pausing logic added in version 8.9.0
696 ver = version.parse(self.settings.get('IBLRIG_VERSION') or '0') 1Jcdefb
697 default = 0. if ver < version.parse('8.9.0') else np.nan 1Jcdefb
698 return np.fromiter((t.get('pause_duration', default) for t in self.bpod_trials), dtype=float) 1Jcdefb
701class TrialsTable(BaseBpodTrialsExtractor):
702 """
703 Extracts the following into a table from Bpod raw data:
704 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
705 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
706 Additionally extracts the following wheel data:
707 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
708 """
709 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
710 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
711 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
712 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement')
714 def _extract(self, extractor_classes=None, **kwargs):
715 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1cdefb
716 RewardVolume, ProbabilityLeft, Wheel]
717 out, _ = run_extractor_classes( 1cdefb
718 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False,
719 task_collection=self.task_collection)
720 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1cdefb
721 assert len(table.keys()) == 12 1cdefb
723 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1cdefb
726class TrainingTrials(BaseBpodTrialsExtractor):
727 save_names = ('_ibl_trials.repNum.npy', '_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None,
728 '_ibl_trials.stimOffTrigger_times.npy', None, None, '_ibl_trials.table.pqt', '_ibl_trials.stimOff_times.npy',
729 None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
730 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None,
731 '_ibl_trials.quiescencePeriod.npy', None)
732 var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
733 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times',
734 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude',
735 'wheelMoves_peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence', 'pause_duration')
737 def _extract(self) -> dict:
738 base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1cdefb
739 ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence, PauseDuration]
740 out, _ = run_extractor_classes( 1cdefb
741 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False,
742 task_collection=self.task_collection)
743 return {k: out[k] for k in self.var_names} 1cdefb