Coverage for ibllib/io/extractors/training_trials.py: 93%
377 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1import logging
2import numpy as np
3from pkg_resources import parse_version
4from one.alf.io import AlfBunch
6import ibllib.io.raw_data_loaders as raw
7from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
8from ibllib.io.extractors.training_wheel import Wheel
11_logger = logging.getLogger(__name__)
12__all__ = ['TrainingTrials', 'extract_all']
15class FeedbackType(BaseBpodTrialsExtractor):
16 """
17 Get the feedback that was delivered to subject.
18 **Optional:** saves _ibl_trials.feedbackType.npy
20 Checks in raw datafile for error and reward state.
21 Will raise an error if more than one of the mutually exclusive states have
22 been triggered.
24 Sets feedbackType to -1 if error state was triggered (applies to no-go)
25 Sets feedbackType to +1 if reward state was triggered
26 """
27 save_names = '_ibl_trials.feedbackType.npy'
28 var_names = 'feedbackType'
30 def _extract(self):
31 feedbackType = np.zeros(len(self.bpod_trials), np.int64) 1adFcotjkfgiqrsuplmnwxbehv
32 for i, t in enumerate(self.bpod_trials): 1adFcotjkfgiqrsuplmnwxbehv
33 state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go'] 1adFcotjkfgiqrsuplmnwxbehv
34 outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.NaN]])[0][0]) for sn in state_names} 1adFcotjkfgiqrsuplmnwxbehv
35 assert np.sum(list(outcome.values())) == 1 1adFcotjkfgiqrsuplmnwxbehv
36 outcome = next(k for k in outcome if outcome[k]) 1adFcotjkfgiqrsuplmnwxbehv
37 if outcome == 'correct': 1adFcotjkfgiqrsuplmnwxbehv
38 feedbackType[i] = 1 1adFcotjkfgiqrsuplmnwxbeh
39 elif outcome in ['error', 'no_go']: 1adFcotjkfgiqrsuplmnwxbehv
40 feedbackType[i] = -1 1adFcotjkfgiqrsuplmnwxbehv
41 return feedbackType 1adFcotjkfgiqrsuplmnwxbehv
44class ContrastLR(BaseBpodTrialsExtractor):
45 """
46 Get left and right contrasts from raw datafile. Optionally, saves
47 _ibl_trials.contrastLeft.npy and _ibl_trials.contrastRight.npy to alf folder.
49 Uses signed_contrast to create left and right contrast vectors.
50 """
51 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy')
52 var_names = ('contrastLeft', 'contrastRight')
54 def _extract(self):
55 # iblrigv8 has only flat values in the trial table so we can switch to parquet table when times come
56 # and all the clutter here would fit in ~30 lines
57 if isinstance(self.bpod_trials[0]['contrast'], float): 1dNcfgb
58 contrastLeft = np.array([t['contrast'] if np.sign(
59 t['position']) < 0 else np.nan for t in self.bpod_trials])
60 contrastRight = np.array([t['contrast'] if np.sign(
61 t['position']) > 0 else np.nan for t in self.bpod_trials])
62 else:
63 contrastLeft = np.array([t['contrast']['value'] if np.sign( 1dNcfgb
64 t['position']) < 0 else np.nan for t in self.bpod_trials])
65 contrastRight = np.array([t['contrast']['value'] if np.sign( 1dNcfgb
66 t['position']) > 0 else np.nan for t in self.bpod_trials])
68 return contrastLeft, contrastRight 1dNcfgb
71class ProbabilityLeft(BaseBpodTrialsExtractor):
72 save_names = '_ibl_trials.probabilityLeft.npy'
73 var_names = 'probabilityLeft'
75 def _extract(self, **kwargs):
76 return np.array([t['stim_probability_left'] for t in self.bpod_trials]) 1adUcojkfgiqrsuplmnwxbehv
79class Choice(BaseBpodTrialsExtractor):
80 """
81 Get the subject's choice in every trial.
82 **Optional:** saves _ibl_trials.choice.npy to alf folder.
84 Uses signed_contrast and trial_correct.
85 -1 is a CCW turn (towards the left)
86 +1 is a CW turn (towards the right)
87 0 is a no_go trial
88 If a trial is correct the choice of the animal was the inverse of the sign
89 of the position.
91 >>> choice[t] = -np.sign(position[t]) if trial_correct[t]
92 """
93 save_names = '_ibl_trials.choice.npy'
94 var_names = 'choice'
96 def _extract(self):
97 sitm_side = np.array([np.sign(t['position']) for t in self.bpod_trials]) 1adKcAotjkfgiqrsuplmnwxbehv
98 trial_correct = np.array([t['trial_correct'] for t in self.bpod_trials]) 1adKcAotjkfgiqrsuplmnwxbehv
99 trial_nogo = np.array( 1adKcAotjkfgiqrsuplmnwxbehv
100 [~np.isnan(t['behavior_data']['States timestamps']['no_go'][0][0])
101 for t in self.bpod_trials])
102 choice = sitm_side.copy() 1adKcAotjkfgiqrsuplmnwxbehv
103 choice[trial_correct] = -choice[trial_correct] 1adKcAotjkfgiqrsuplmnwxbehv
104 choice[trial_nogo] = 0 1adKcAotjkfgiqrsuplmnwxbehv
105 choice = choice.astype(int) 1adKcAotjkfgiqrsuplmnwxbehv
106 return choice 1adKcAotjkfgiqrsuplmnwxbehv
109class RepNum(BaseBpodTrialsExtractor):
110 """
111 Count the consecutive repeated trials.
112 **Optional:** saves _ibl_trials.repNum.npy to alf folder.
114 Creates trial_repeated from trial['contrast']['type'] == 'RepeatContrast'
116 >>> trial_repeated = [0, 1, 1, 0, 1, 0, 1, 1, 1, 0]
117 >>> repNum = [0, 1, 2, 0, 1, 0, 1, 2, 3, 0]
118 """
119 save_names = '_ibl_trials.repNum.npy'
120 var_names = 'repNum'
122 def _extract(self):
123 def get_trial_repeat(trial): 1dDcfgb
124 if 'debias_trial' in trial: 1dDcfgb
125 return trial['debias_trial']
126 else:
127 return trial['contrast']['type'] == 'RepeatContrast' 1dDcfgb
129 trial_repeated = np.array(list(map(get_trial_repeat, self.bpod_trials))).astype(int) 1dDcfgb
130 repNum = trial_repeated.copy() 1dDcfgb
131 c = 0 1dDcfgb
132 for i in range(len(trial_repeated)): 1dDcfgb
133 if trial_repeated[i] == 0: 1dDcfgb
134 c = 0 1dDcfgb
135 repNum[i] = 0 1dDcfgb
136 continue 1dDcfgb
137 c += 1 1dDcfgb
138 repNum[i] = c 1dDcfgb
139 return repNum 1dDcfgb
142class RewardVolume(BaseBpodTrialsExtractor):
143 """
144 Load reward volume delivered for each trial.
145 **Optional:** saves _ibl_trials.rewardVolume.npy
147 Uses reward_current to accumulate the amount of
148 """
149 save_names = '_ibl_trials.rewardVolume.npy'
150 var_names = 'rewardVolume'
152 def _extract(self):
153 trial_volume = [x['reward_amount'] 1adOcotjkfgiqrsuplmnwxbehv
154 if x['trial_correct'] else 0 for x in self.bpod_trials]
155 reward_volume = np.array(trial_volume).astype(np.float64) 1adOcotjkfgiqrsuplmnwxbehv
156 assert len(reward_volume) == len(self.bpod_trials) 1adOcotjkfgiqrsuplmnwxbehv
157 return reward_volume 1adOcotjkfgiqrsuplmnwxbehv
160class FeedbackTimes(BaseBpodTrialsExtractor):
161 """
162 Get the times the water or error tone was delivered to the animal.
163 **Optional:** saves _ibl_trials.feedback_times.npy
165 Gets reward and error state init times vectors,
166 checks if theintersection of nans is empty, then
167 merges the 2 vectors.
168 """
169 save_names = '_ibl_trials.feedback_times.npy'
170 var_names = 'feedback_times'
172 @staticmethod
173 def get_feedback_times_lt5(session_path, task_collection='raw_behavior_data', data=False):
174 if not data: 1adIciwxbe
175 data = raw.load_data(session_path, task_collection=task_collection)
176 rw_times = [tr['behavior_data']['States timestamps']['reward'][0][0] 1adIciwxbe
177 for tr in data]
178 err_times = [tr['behavior_data']['States timestamps']['error'][0][0] 1adIciwxbe
179 for tr in data]
180 nogo_times = [tr['behavior_data']['States timestamps']['no_go'][0][0] 1adIciwxbe
181 for tr in data]
182 assert sum(np.isnan(rw_times) & 1adIciwxbe
183 np.isnan(err_times) & np.isnan(nogo_times)) == 0
184 merge = np.array([np.array(times)[~np.isnan(times)] for times in 1adIciwxbe
185 zip(rw_times, err_times, nogo_times)]).squeeze()
187 return np.array(merge) 1adIciwxbe
189 @staticmethod
190 def get_feedback_times_ge5(session_path, task_collection='raw_behavior_data', data=False):
191 # ger err and no go trig times -- look for BNC2High of trial -- verify
192 # only 2 onset times go tone and noise, select 2nd/-1 OR select the one
193 # that is grater than the nogo or err trial onset time
194 if not data: 1dCcotjykfgqrsuplmnbehv
195 data = raw.load_data(session_path, task_collection=task_collection)
196 missed_bnc2 = 0 1dCcotjykfgqrsuplmnbehv
197 rw_times, err_sound_times, merge = [np.zeros([len(data), ]) for _ in range(3)] 1dCcotjykfgqrsuplmnbehv
199 for ind, tr in enumerate(data): 1dCcotjykfgqrsuplmnbehv
200 st = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1dCcotjykfgqrsuplmnbehv
201 if not st: 1dCcotjykfgqrsuplmnbehv
202 st = np.array([np.nan, np.nan]) 1eh
203 missed_bnc2 += 1 1eh
204 # xonar soundcard duplicates events, remove consecutive events too close together
205 st = np.delete(st, np.where(np.diff(st) < 0.020)[0] + 1) 1dCcotjykfgqrsuplmnbehv
206 rw_times[ind] = tr['behavior_data']['States timestamps']['reward'][0][0] 1dCcotjykfgqrsuplmnbehv
207 # get the error sound only if the reward is nan
208 err_sound_times[ind] = st[-1] if st.size >= 2 and np.isnan(rw_times[ind]) else np.nan 1dCcotjykfgqrsuplmnbehv
209 if missed_bnc2 == len(data): 1dCcotjykfgqrsuplmnbehv
210 _logger.warning('No BNC2 for feedback times, filling error trials NaNs')
211 merge *= np.nan 1dCcotjykfgqrsuplmnbehv
212 merge[~np.isnan(rw_times)] = rw_times[~np.isnan(rw_times)] 1dCcotjykfgqrsuplmnbehv
213 merge[~np.isnan(err_sound_times)] = err_sound_times[~np.isnan(err_sound_times)] 1dCcotjykfgqrsuplmnbehv
215 return merge 1dCcotjykfgqrsuplmnbehv
217 def _extract(self):
218 # Version check
219 if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1adCIcotjykfgiqrsuplmnwxbehv
220 merge = self.get_feedback_times_ge5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1dCcotjykfgqrsuplmnbehv
221 else:
222 merge = self.get_feedback_times_lt5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1adIciwxbe
223 return np.array(merge) 1adCIcotjykfgiqrsuplmnwxbehv
226class Intervals(BaseBpodTrialsExtractor):
227 """
228 Trial start to trial end. Trial end includes 1 or 2 seconds after feedback,
229 (depending on the feedback) and 0.5 seconds of iti.
230 **Optional:** saves _ibl_trials.intervals.npy
232 Uses the corrected Trial start and Trial end timestamp values form PyBpod.
233 """
234 save_names = '_ibl_trials.intervals.npy'
235 var_names = 'intervals'
237 def _extract(self):
238 starts = [t['behavior_data']['Trial start timestamp'] for t in self.bpod_trials] 1adPcotjykfgiqrsuplmnwxbehv
239 ends = [t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials] 1adPcotjykfgiqrsuplmnwxbehv
240 return np.array([starts, ends]).T 1adPcotjykfgiqrsuplmnwxbehv
243class ResponseTimes(BaseBpodTrialsExtractor):
244 """
245 Time (in absolute seconds from session start) when a response was recorded.
246 **Optional:** saves _ibl_trials.response_times.npy
248 Uses the timestamp of the end of the closed_loop state.
249 """
250 save_names = '_ibl_trials.response_times.npy'
251 var_names = 'response_times'
253 def _extract(self):
254 rt = np.array([tr['behavior_data']['States timestamps']['closed_loop'][0][1] 1adScotjkfgiqrsuplmnwxbehv
255 for tr in self.bpod_trials])
256 return rt 1adScotjkfgiqrsuplmnwxbehv
259class ItiDuration(BaseBpodTrialsExtractor):
260 """
261 Calculate duration of iti from state timestamps.
262 **Optional:** saves _ibl_trials.iti_duration.npy
264 Uses Trial end timestamp and get_response_times to calculate iti.
265 """
266 save_names = '_ibl_trials.itiDuration.npy'
267 var_names = 'iti_dur'
269 def _extract(self):
270 rt, _ = ResponseTimes(self.session_path).extract(
271 save=False, task_collection=self.task_collection, bpod_trials=self.bpod_trials, settings=self.settings)
272 ends = np.array([t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials])
273 iti_dur = ends - rt
274 return iti_dur
277class GoCueTriggerTimes(BaseBpodTrialsExtractor):
278 """
279 Get trigger times of goCue from state machine.
281 Current software solution for triggering sounds uses PyBpod soft codes.
282 Delays can be in the order of 10's of ms. This is the time when the command
283 to play the sound was executed. To measure accurate time, either getting the
284 sound onset from xonar soundcard sync pulse (latencies may vary).
285 """
286 save_names = '_ibl_trials.goCueTrigger_times.npy'
287 var_names = 'goCueTrigger_times'
289 def _extract(self):
290 if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1adQcotjkfgiqRrsuplmnwxbehv
291 goCue = np.array([tr['behavior_data']['States timestamps'] 1dQcotjkfgqRrsuplmnbehv
292 ['play_tone'][0][0] for tr in self.bpod_trials])
293 else:
294 goCue = np.array([tr['behavior_data']['States timestamps'] 1adciwxbe
295 ['closed_loop'][0][0] for tr in self.bpod_trials])
296 return goCue 1adQcotjkfgiqRrsuplmnwxbehv
299class TrialType(BaseBpodTrialsExtractor):
300 save_names = '_ibl_trials.type.npy'
301 var_name = 'trial_type'
303 def _extract(self):
304 trial_type = []
305 for tr in self.bpod_trials:
306 if ~np.isnan(tr["behavior_data"]["States timestamps"]["reward"][0][0]):
307 trial_type.append(1)
308 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["error"][0][0]):
309 trial_type.append(-1)
310 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0][0]):
311 trial_type.append(0)
312 else:
313 _logger.warning("Trial is not in set {-1, 0, 1}, appending NaN to trialType")
314 trial_type.append(np.nan)
315 return np.array(trial_type)
318class GoCueTimes(BaseBpodTrialsExtractor):
319 """
320 Get trigger times of goCue from state machine.
322 Current software solution for triggering sounds uses PyBpod soft codes.
323 Delays can be in the order of 10-100s of ms. This is the time when the command
324 to play the sound was executed. To measure accurate time, either getting the
325 sound onset from the future microphone OR the new xonar soundcard and
326 setup developed by Sanworks guarantees a set latency (in testing).
327 """
328 save_names = '_ibl_trials.goCue_times.npy'
329 var_names = 'goCue_times'
331 def _extract(self):
332 go_cue_times = np.zeros([len(self.bpod_trials), ]) 1adEcotGjykfgiqrsuplmnwxHbehv
333 for ind, tr in enumerate(self.bpod_trials): 1adEcotGjykfgiqrsuplmnwxHbehv
334 if raw.get_port_events(tr, 'BNC2'): 1adEcotGjykfgiqrsuplmnwxHbehv
335 bnchigh = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1dEcotGjykfgiqrsuplmnHbehv
336 if bnchigh: 1dEcotGjykfgiqrsuplmnHbehv
337 go_cue_times[ind] = bnchigh[0] 1dEcotGjykfgiqrsuplmnHbehv
338 continue 1dEcotGjykfgiqrsuplmnHbehv
339 bnclow = tr['behavior_data']['Events timestamps'].get('BNC2Low', None) 1eh
340 if bnclow: 1eh
341 go_cue_times[ind] = bnclow[0] - 0.1 1eh
342 continue 1eh
343 go_cue_times[ind] = np.nan
344 else:
345 go_cue_times[ind] = np.nan 1adEcwxbe
347 nmissing = np.sum(np.isnan(go_cue_times)) 1adEcotGjykfgiqrsuplmnwxHbehv
348 # Check if all stim_syncs have failed to be detected
349 if np.all(np.isnan(go_cue_times)): 1adEcotGjykfgiqrsuplmnwxHbehv
350 _logger.warning( 1adEcwxbe
351 f'{self.session_path}: Missing ALL !! BNC2 TTLs ({nmissing} trials)')
352 # Check if any stim_sync has failed be detected for every trial
353 elif np.any(np.isnan(go_cue_times)): 1dEcotGjykfgiqrsuplmnHbehv
354 _logger.warning(f'{self.session_path}: Missing BNC2 TTLs on {nmissing} trials')
356 return go_cue_times 1adEcotGjykfgiqrsuplmnwxHbehv
359class IncludedTrials(BaseBpodTrialsExtractor):
360 save_names = '_ibl_trials.included.npy'
361 var_names = 'included'
363 def _extract(self):
364 if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1dJLMcojkiqrsuplmnbehv
365 trials_included = self.get_included_trials_ge5( 1dJLcojkqrsuplmnbehv
366 data=self.bpod_trials, settings=self.settings)
367 else:
368 trials_included = self.get_included_trials_lt5(data=self.bpod_trials) 1JMcib
369 return trials_included 1dJLMcojkiqrsuplmnbehv
371 @staticmethod
372 def get_included_trials_lt5(data=False):
373 trials_included = np.array([True for t in data]) 1JMcib
374 return trials_included 1JMcib
376 @staticmethod
377 def get_included_trials_ge5(data=False, settings=False):
378 trials_included = np.array([True for t in data]) 1dJLcojkqrsuplmnbehv
379 if ('SUBJECT_DISENGAGED_TRIGGERED' in settings.keys() and settings[ 1dJLcojkqrsuplmnbehv
380 'SUBJECT_DISENGAGED_TRIGGERED'] is not False):
381 idx = settings['SUBJECT_DISENGAGED_TRIALNUM'] - 1 1jkupmneh
382 trials_included[idx:] = False 1jkupmneh
383 return trials_included 1dJLcojkqrsuplmnbehv
386class ItiInTimes(BaseBpodTrialsExtractor):
387 var_names = 'itiIn_times'
389 def _extract(self):
390 if parse_version(self.settings["IBLRIG_VERSION_TAG"]) < parse_version("5.0.0"): 1dcotjkfgiqrsuplmnbehv
391 iti_in = np.ones(len(self.bpod_trials)) * np.nan 1ib
392 else:
393 iti_in = np.array( 1dcotjkfgqrsuplmnbehv
394 [tr["behavior_data"]["States timestamps"]
395 ["exit_state"][0][0] for tr in self.bpod_trials]
396 )
397 return iti_in 1dcotjkfgiqrsuplmnbehv
400class ErrorCueTriggerTimes(BaseBpodTrialsExtractor):
401 var_names = 'errorCueTrigger_times'
403 def _extract(self):
404 errorCueTrigger_times = np.zeros(len(self.bpod_trials)) * np.nan 1dcotjkfgiqrsuplmnbehv
405 for i, tr in enumerate(self.bpod_trials): 1dcotjkfgiqrsuplmnbehv
406 nogo = tr["behavior_data"]["States timestamps"]["no_go"][0][0] 1dcotjkfgiqrsuplmnbehv
407 error = tr["behavior_data"]["States timestamps"]["error"][0][0] 1dcotjkfgiqrsuplmnbehv
408 if np.all(~np.isnan(nogo)): 1dcotjkfgiqrsuplmnbehv
409 errorCueTrigger_times[i] = nogo 1dctkfgrsulbeh
410 elif np.all(~np.isnan(error)): 1dcotjkfgiqrsuplmnbehv
411 errorCueTrigger_times[i] = error 1dcotjkfgiqrsuplmnbehv
412 return errorCueTrigger_times 1dcotjkfgiqrsuplmnbehv
415class StimFreezeTriggerTimes(BaseBpodTrialsExtractor):
416 var_names = 'stimFreezeTrigger_times'
418 def _extract(self):
419 if parse_version(self.settings["IBLRIG_VERSION_TAG"]) < parse_version("6.2.5"): 1dcotjkfgiqrsuplmnbehv
420 return np.ones(len(self.bpod_trials)) * np.nan 1dcirsubehv
421 freeze_reward = np.array( 1otjkfgqplmnh
422 [
423 True
424 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_reward"][0]))
425 else False
426 for tr in self.bpod_trials
427 ]
428 )
429 freeze_error = np.array( 1otjkfgqplmnh
430 [
431 True
432 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_error"][0]))
433 else False
434 for tr in self.bpod_trials
435 ]
436 )
437 no_go = np.array( 1otjkfgqplmnh
438 [
439 True
440 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0]))
441 else False
442 for tr in self.bpod_trials
443 ]
444 )
445 assert (np.sum(freeze_error) + np.sum(freeze_reward) + 1otjkfgqplmnh
446 np.sum(no_go) == len(self.bpod_trials))
447 stimFreezeTrigger = np.array([]) 1otjkfgqplmnh
448 for r, e, n, tr in zip(freeze_reward, freeze_error, no_go, self.bpod_trials): 1otjkfgqplmnh
449 if n: 1otjkfgqplmnh
450 stimFreezeTrigger = np.append(stimFreezeTrigger, np.nan) 1tkfglh
451 continue 1tkfglh
452 state = "freeze_reward" if r else "freeze_error" 1otjkfgqplmnh
453 stimFreezeTrigger = np.append( 1otjkfgqplmnh
454 stimFreezeTrigger, tr["behavior_data"]["States timestamps"][state][0][0]
455 )
456 return stimFreezeTrigger 1otjkfgqplmnh
459class StimOffTriggerTimes(BaseBpodTrialsExtractor):
460 var_names = 'stimOffTrigger_times'
462 def _extract(self):
463 if parse_version(self.settings["IBLRIG_VERSION_TAG"]) >= parse_version("6.2.5"): 1dcotjkfgiqrsuplmnbehv
464 stim_off_trigger_state = "hide_stim" 1otjkfgqplmnh
465 elif parse_version(self.settings["IBLRIG_VERSION_TAG"]) >= parse_version("5.0.0"): 1dcirsubehv
466 stim_off_trigger_state = "exit_state" 1dcrsubehv
467 else:
468 stim_off_trigger_state = "trial_start" 1ib
470 stimOffTrigger_times = np.array( 1dcotjkfgiqrsuplmnbehv
471 [tr["behavior_data"]["States timestamps"][stim_off_trigger_state][0][0]
472 for tr in self.bpod_trials]
473 )
474 # If pre version 5.0.0 no specific nogo Off trigger was given, just return trial_starts
475 if stim_off_trigger_state == "trial_start": 1dcotjkfgiqrsuplmnbehv
476 return stimOffTrigger_times 1ib
478 no_goTrigger_times = np.array( 1dcotjkfgqrsuplmnbehv
479 [tr["behavior_data"]["States timestamps"]["no_go"][0][0] for tr in self.bpod_trials]
480 )
481 # Stim off trigs are either in their own state or in the no_go state if the
482 # mouse did not move, if the stim_off_trigger_state always exist
483 # (exit_state or trial_start)
484 # no NaNs will happen, NaNs might happen in at last trial if
485 # session was stopped after response
486 # if stim_off_trigger_state == "hide_stim":
487 # assert all(~np.isnan(no_goTrigger_times) == np.isnan(stimOffTrigger_times))
488 # Patch with the no_go states trig times
489 stimOffTrigger_times[~np.isnan(no_goTrigger_times)] = no_goTrigger_times[ 1dcotjkfgqrsuplmnbehv
490 ~np.isnan(no_goTrigger_times)
491 ]
492 return stimOffTrigger_times 1dcotjkfgqrsuplmnbehv
495class StimOnTriggerTimes(BaseBpodTrialsExtractor):
496 save_names = '_ibl_trials.stimOnTrigger_times.npy'
497 var_names = 'stimOnTrigger_times'
499 def _extract(self):
500 # Get the stim_on_state that triggers the onset of the stim
501 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1dTcotjykfgiqrsuplmnbehv
502 ['stim_on'][0] for tr in self.bpod_trials])
503 return stim_on_state[:, 0].T 1dTcotjykfgiqrsuplmnbehv
506class StimOnTimes_deprecated(BaseBpodTrialsExtractor):
507 save_names = '_ibl_trials.stimOn_times.npy'
508 var_names = 'stimOn_times'
510 def _extract(self):
511 """
512 Find the time of the state machine command to turn on the stim
513 (state stim_on start or rotary_encoder_event2)
514 Find the next frame change from the photodiode after that TS.
515 Screen is not displaying anything until then.
516 (Frame changes are in BNC1 High and BNC1 Low)
517 """
518 # Version check
519 _logger.warning("Deprecation Warning: this is an old version of stimOn extraction." 1adBzciwxbe
520 "From version 5., use StimOnOffFreezeTimes")
521 if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1adBzciwxbe
522 stimOn_times = self.get_stimOn_times_ge5(self.session_path, data=self.bpod_trials, 1B
523 task_collection=self.task_collection)
524 else:
525 stimOn_times = self.get_stimOn_times_lt5(self.session_path, data=self.bpod_trials, 1adzciwxbe
526 task_collection=self.task_collection)
527 return np.array(stimOn_times) 1adBzciwxbe
529 @staticmethod
530 def get_stimOn_times_ge5(session_path, data=False, task_collection='raw_behavior_data'):
531 """
532 Find first and last stim_sync pulse of the trial.
533 stimOn_times should be the first after the stim_on state.
534 (Stim updates are in BNC1High and BNC1Low - frame2TTL device)
535 Check that all trials have frame changes.
536 Find length of stim_on_state [start, stop].
537 If either check fails the HW device failed to detect the stim_sync square change
538 Substitute that trial's missing or incorrect value with a NaN.
539 return stimOn_times
540 """
541 if not data: 1B
542 data = raw.load_data(session_path, task_collection=task_collection)
543 # Get all stim_sync events detected
544 stim_sync_all = [raw.get_port_events(tr, 'BNC1') for tr in data] 1B
545 stim_sync_all = [np.array(x) for x in stim_sync_all] 1B
546 # Get the stim_on_state that triggers the onset of the stim
547 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1B
548 ['stim_on'][0] for tr in data])
550 stimOn_times = np.array([]) 1B
551 for sync, on, off in zip( 1B
552 stim_sync_all, stim_on_state[:, 0], stim_on_state[:, 1]):
553 pulse = sync[np.where(np.bitwise_and((sync > on), (sync <= off)))] 1B
554 if pulse.size == 0: 1B
555 stimOn_times = np.append(stimOn_times, np.nan) 1B
556 else:
557 stimOn_times = np.append(stimOn_times, pulse) 1B
559 nmissing = np.sum(np.isnan(stimOn_times)) 1B
560 # Check if all stim_syncs have failed to be detected
561 if np.all(np.isnan(stimOn_times)): 1B
562 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({nmissing} trials)')
564 # Check if any stim_sync has failed be detected for every trial
565 if np.any(np.isnan(stimOn_times)): 1B
566 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {nmissing} trials') 1B
568 return stimOn_times 1B
570 @staticmethod
571 def get_stimOn_times_lt5(session_path, data=False, task_collection='raw_behavior_data'):
572 """
573 Find the time of the statemachine command to turn on hte stim
574 (state stim_on start or rotary_encoder_event2)
575 Find the next frame change from the photodiodeafter that TS.
576 Screen is not displaying anything until then.
577 (Frame changes are in BNC1High and BNC1Low)
578 """
579 if not data: 1adzciwxbe
580 data = raw.load_data(session_path, task_collection=task_collection)
581 stim_on = [] 1adzciwxbe
582 bnc_h = [] 1adzciwxbe
583 bnc_l = [] 1adzciwxbe
584 for tr in data: 1adzciwxbe
585 stim_on.append(tr['behavior_data']['States timestamps']['stim_on'][0][0]) 1adzciwxbe
586 if 'BNC1High' in tr['behavior_data']['Events timestamps'].keys(): 1adzciwxbe
587 bnc_h.append(np.array(tr['behavior_data'] 1dzcibe
588 ['Events timestamps']['BNC1High']))
589 else:
590 bnc_h.append(np.array([np.NINF])) 1azcwxe
591 if 'BNC1Low' in tr['behavior_data']['Events timestamps'].keys(): 1adzciwxbe
592 bnc_l.append(np.array(tr['behavior_data'] 1dzcibe
593 ['Events timestamps']['BNC1Low']))
594 else:
595 bnc_l.append(np.array([np.NINF])) 1azcwxe
597 stim_on = np.array(stim_on) 1adzciwxbe
598 bnc_h = np.array(bnc_h, dtype=object) 1adzciwxbe
599 bnc_l = np.array(bnc_l, dtype=object) 1adzciwxbe
601 count_missing = 0 1adzciwxbe
602 stimOn_times = np.zeros_like(stim_on) 1adzciwxbe
603 for i in range(len(stim_on)): 1adzciwxbe
604 hl = np.sort(np.concatenate([bnc_h[i], bnc_l[i]])) 1adzciwxbe
605 stot = hl[hl > stim_on[i]] 1adzciwxbe
606 if np.size(stot) == 0: 1adzciwxbe
607 stot = np.array([np.nan]) 1azcwxe
608 count_missing += 1 1azcwxe
609 stimOn_times[i] = stot[0] 1adzciwxbe
611 if np.all(np.isnan(stimOn_times)): 1adzciwxbe
612 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({count_missing} trials)') 1awxe
614 if count_missing > 0: 1adzciwxbe
615 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {count_missing} trials') 1azcwxe
617 return np.array(stimOn_times) 1adzciwxbe
620class StimOnOffFreezeTimes(BaseBpodTrialsExtractor):
621 """
622 Extracts stim on / off and freeze times from Bpod BNC1 detected fronts
623 """
624 save_names = ('_ibl_trials.stimOn_times.npy', None, None)
625 var_names = ('stimOn_times', 'stimOff_times', 'stimFreeze_times')
627 def _extract(self):
628 choice = Choice(self.session_path).extract( 1adcAotjkfgiqrsuplmnwxbehv
629 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False
630 )[0]
631 f2TTL = [raw.get_port_events(tr, name='BNC1') for tr in self.bpod_trials] 1adcAotjkfgiqrsuplmnwxbehv
633 stimOn_times = np.array([]) 1adcAotjkfgiqrsuplmnwxbehv
634 stimOff_times = np.array([]) 1adcAotjkfgiqrsuplmnwxbehv
635 stimFreeze_times = np.array([]) 1adcAotjkfgiqrsuplmnwxbehv
636 for tr in f2TTL: 1adcAotjkfgiqrsuplmnwxbehv
637 if tr and len(tr) == 2: 1adcAotjkfgiqrsuplmnwxbehv
638 stimOn_times = np.append(stimOn_times, tr[0]) 1Aotjfgirslmnbh
639 stimOff_times = np.append(stimOff_times, tr[-1]) 1Aotjfgirslmnbh
640 stimFreeze_times = np.append(stimFreeze_times, np.nan) 1Aotjfgirslmnbh
641 elif tr and len(tr) >= 3: 1adcAotjkfgiqrsuplmnwxbehv
642 stimOn_times = np.append(stimOn_times, tr[0]) 1dcAotjkfgiqrsuplmnbehv
643 stimOff_times = np.append(stimOff_times, tr[-1]) 1dcAotjkfgiqrsuplmnbehv
644 stimFreeze_times = np.append(stimFreeze_times, tr[-2]) 1dcAotjkfgiqrsuplmnbehv
645 else:
646 stimOn_times = np.append(stimOn_times, np.nan) 1aAjkwxe
647 stimOff_times = np.append(stimOff_times, np.nan) 1aAjkwxe
648 stimFreeze_times = np.append(stimFreeze_times, np.nan) 1aAjkwxe
650 # In no_go trials no stimFreeze happens just stim Off
651 stimFreeze_times[choice == 0] = np.nan 1adcAotjkfgiqrsuplmnwxbehv
652 # Check for trigger times
653 # 2nd order criteria:
654 # stimOn -> Closest one to stimOnTrigger?
655 # stimOff -> Closest one to stimOffTrigger?
656 # stimFreeze -> Closest one to stimFreezeTrigger?
658 return stimOn_times, stimOff_times, stimFreeze_times 1adcAotjkfgiqrsuplmnwxbehv
661class PhasePosQuiescence(BaseBpodTrialsExtractor):
662 """Extracts stimulus phase, position and quiescence from Bpod data.
663 For extraction of pre-generated events, use the ProbaContrasts extractor instead.
664 """
665 save_names = (None, None, '_ibl_trials.quiescencePeriod.npy')
666 var_names = ('phase', 'position', 'quiescence')
668 def _extract(self, **kwargs):
669 phase = np.array([t['stim_phase'] for t in self.bpod_trials]) 1adcojkfgiqrsuplmnwxbehv
670 position = np.array([t['position'] for t in self.bpod_trials]) 1adcojkfgiqrsuplmnwxbehv
671 quiescence = np.array([t['quiescent_period'] for t in self.bpod_trials]) 1adcojkfgiqrsuplmnwxbehv
672 return phase, position, quiescence 1adcojkfgiqrsuplmnwxbehv
675class TrialsTable(BaseBpodTrialsExtractor):
676 """
677 Extracts the following into a table from Bpod raw data:
678 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
679 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
680 Additionally extracts the following wheel data:
681 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
682 """
683 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
684 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
685 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals',
686 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement')
688 def _extract(self, extractor_classes=None, **kwargs):
689 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1dcfgb
690 RewardVolume, ProbabilityLeft, Wheel]
691 out, _ = run_extractor_classes( 1dcfgb
692 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False,
693 task_collection=self.task_collection)
694 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1dcfgb
695 assert len(table.keys()) == 12 1dcfgb
697 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1dcfgb
700class TrainingTrials(BaseBpodTrialsExtractor):
701 save_names = ('_ibl_trials.repNum.npy', '_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None,
702 None, None, None, '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
703 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, None)
704 var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
705 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times',
706 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', 'wheel_moves_peak_amplitude',
707 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence')
709 def _extract(self):
710 base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1dcfgb
711 ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence]
712 out, _ = run_extractor_classes( 1dcfgb
713 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False,
714 task_collection=self.task_collection)
715 return tuple(out.pop(x) for x in self.var_names) 1dcfgb
718def extract_all(session_path, save=False, bpod_trials=None, settings=None, task_collection='raw_behavior_data', save_path=None):
719 """Extract trials and wheel data.
721 For task versions >= 5.0.0, outputs wheel data and trials.table dataset (+ some extra datasets)
723 Parameters
724 ----------
725 session_path : str, pathlib.Path
726 The path to the session
727 save : bool
728 If true save the data files to ALF
729 bpod_trials : list of dicts
730 The Bpod trial dicts loaded from the _iblrig_taskData.raw dataset
731 settings : dict
732 The Bpod settings loaded from the _iblrig_taskSettings.raw dataset
734 Returns
735 -------
736 A list of extracted data and a list of file paths if save is True (otherwise None)
737 """
738 if not bpod_trials: 1dcb
739 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 1d
740 if not settings: 1dcb
741 settings = raw.load_settings(session_path, task_collection=task_collection) 1d
742 if settings is None or settings['IBLRIG_VERSION_TAG'] == '': 1dcb
743 settings = {'IBLRIG_VERSION_TAG': '100.0.0'}
745 # Version check
746 if parse_version(settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1dcb
747 # We now extract a single trials table
748 base = [TrainingTrials] 1dcb
749 else:
750 base = [ 1dcb
751 RepNum, GoCueTriggerTimes, Intervals, Wheel, FeedbackType, ContrastLR, ProbabilityLeft, Choice, IncludedTrials,
752 StimOnTimes_deprecated, RewardVolume, FeedbackTimes, ResponseTimes, GoCueTimes, PhasePosQuiescence
753 ]
755 out, fil = run_extractor_classes(base, save=save, session_path=session_path, bpod_trials=bpod_trials, settings=settings, 1dcb
756 task_collection=task_collection, path_out=save_path)
757 return out, fil 1dcb