Coverage for ibllib/io/extractors/training_trials.py: 92%
386 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1import logging
2import numpy as np
3from itertools import accumulate
4from packaging import version
5from one.alf.io import AlfBunch
7import ibllib.io.raw_data_loaders as raw
8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
9from ibllib.io.extractors.training_wheel import Wheel
12_logger = logging.getLogger(__name__)
13__all__ = ['TrainingTrials', 'extract_all']
16class FeedbackType(BaseBpodTrialsExtractor):
17 """
18 Get the feedback that was delivered to subject.
19 **Optional:** saves _ibl_trials.feedbackType.npy
21 Checks in raw datafile for error and reward state.
22 Will raise an error if more than one of the mutually exclusive states have
23 been triggered.
25 Sets feedbackType to -1 if error state was triggered (applies to no-go)
26 Sets feedbackType to +1 if reward state was triggered
27 """
28 save_names = '_ibl_trials.feedbackType.npy'
29 var_names = 'feedbackType'
31 def _extract(self):
32 feedbackType = np.zeros(len(self.bpod_trials), np.int64) 1acEbpvlifgxqrosjkmntuedhw
33 for i, t in enumerate(self.bpod_trials): 1acEbpvlifgxqrosjkmntuedhw
34 state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go'] 1acEbpvlifgxqrosjkmntuedhw
35 outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.NaN]])[0][0]) for sn in state_names} 1acEbpvlifgxqrosjkmntuedhw
36 assert np.sum(list(outcome.values())) == 1 1acEbpvlifgxqrosjkmntuedhw
37 outcome = next(k for k in outcome if outcome[k]) 1acEbpvlifgxqrosjkmntuedhw
38 if outcome == 'correct': 1acEbpvlifgxqrosjkmntuedhw
39 feedbackType[i] = 1 1acEbpvlifgxqrosjkmntuedh
40 elif outcome in ['error', 'no_go']: 1acEbpvlifgxqrosjkmntuedhw
41 feedbackType[i] = -1 1acEbpvlifgxqrosjkmntuedhw
42 return feedbackType 1acEbpvlifgxqrosjkmntuedhw
45class ContrastLR(BaseBpodTrialsExtractor):
46 """
47 Get left and right contrasts from raw datafile. Optionally, saves
48 _ibl_trials.contrastLeft.npy and _ibl_trials.contrastRight.npy to alf folder.
50 Uses signed_contrast to create left and right contrast vectors.
51 """
52 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy')
53 var_names = ('contrastLeft', 'contrastRight')
55 def _extract(self):
56 # iblrigv8 has only flat values in the trial table so we can switch to parquet table when times come
57 # and all the clutter here would fit in ~30 lines
58 if isinstance(self.bpod_trials[0]['contrast'], float): 1cNbfge
59 contrastLeft = np.array([t['contrast'] if np.sign(
60 t['position']) < 0 else np.nan for t in self.bpod_trials])
61 contrastRight = np.array([t['contrast'] if np.sign(
62 t['position']) > 0 else np.nan for t in self.bpod_trials])
63 else:
64 contrastLeft = np.array([t['contrast']['value'] if np.sign( 1cNbfge
65 t['position']) < 0 else np.nan for t in self.bpod_trials])
66 contrastRight = np.array([t['contrast']['value'] if np.sign( 1cNbfge
67 t['position']) > 0 else np.nan for t in self.bpod_trials])
69 return contrastLeft, contrastRight 1cNbfge
72class ProbabilityLeft(BaseBpodTrialsExtractor):
73 save_names = '_ibl_trials.probabilityLeft.npy'
74 var_names = 'probabilityLeft'
76 def _extract(self, **kwargs):
77 return np.array([t['stim_probability_left'] for t in self.bpod_trials]) 1acVbplfgxqjkmntued
80class Choice(BaseBpodTrialsExtractor):
81 """
82 Get the subject's choice in every trial.
83 **Optional:** saves _ibl_trials.choice.npy to alf folder.
85 Uses signed_contrast and trial_correct.
86 -1 is a CCW turn (towards the left)
87 +1 is a CW turn (towards the right)
88 0 is a no_go trial
89 If a trial is correct the choice of the animal was the inverse of the sign
90 of the position.
92 >>> choice[t] = -np.sign(position[t]) if trial_correct[t]
93 """
94 save_names = '_ibl_trials.choice.npy'
95 var_names = 'choice'
97 def _extract(self):
98 sitm_side = np.array([np.sign(t['position']) for t in self.bpod_trials]) 1acJbypvlifgxqrosjkmntuedhw
99 trial_correct = np.array([t['trial_correct'] for t in self.bpod_trials]) 1acJbypvlifgxqrosjkmntuedhw
100 trial_nogo = np.array( 1acJbypvlifgxqrosjkmntuedhw
101 [~np.isnan(t['behavior_data']['States timestamps']['no_go'][0][0])
102 for t in self.bpod_trials])
103 choice = sitm_side.copy() 1acJbypvlifgxqrosjkmntuedhw
104 choice[trial_correct] = -choice[trial_correct] 1acJbypvlifgxqrosjkmntuedhw
105 choice[trial_nogo] = 0 1acJbypvlifgxqrosjkmntuedhw
106 choice = choice.astype(int) 1acJbypvlifgxqrosjkmntuedhw
107 return choice 1acJbypvlifgxqrosjkmntuedhw
110class RepNum(BaseBpodTrialsExtractor):
111 """
112 Count the consecutive repeated trials.
113 **Optional:** saves _ibl_trials.repNum.npy to alf folder.
115 Creates trial_repeated from trial['contrast']['type'] == 'RepeatContrast'
117 >>> trial_repeated = [0, 1, 1, 0, 1, 0, 1, 1, 1, 0]
118 >>> repNum = [0, 1, 2, 0, 1, 0, 1, 2, 3, 0]
119 """
120 save_names = '_ibl_trials.repNum.npy'
121 var_names = 'repNum'
123 def _extract(self):
124 def get_trial_repeat(trial): 1cKbfge
125 if 'debias_trial' in trial: 1cKbfge
126 return trial['debias_trial']
127 elif 'contrast' in trial and isinstance(trial['contrast'], dict): 1cKbfge
128 return trial['contrast']['type'] == 'RepeatContrast' 1cKbfge
129 else:
130 # For advanced choice world before version 8.19.0 there was no 'debias_trial' field
131 # and no debiasing protocol applied, so simply return False
132 assert self.settings['PYBPOD_PROTOCOL'].startswith('_iblrig_tasks_advancedChoiceWorld')
133 return False
135 trial_repeated = np.fromiter(map(get_trial_repeat, self.bpod_trials), int) 1cKbfge
136 repNum = np.fromiter(accumulate(trial_repeated, lambda x, y: x + y if y else 0), int) 1cKbfge
137 return repNum 1cKbfge
140class RewardVolume(BaseBpodTrialsExtractor):
141 """
142 Load reward volume delivered for each trial.
143 **Optional:** saves _ibl_trials.rewardVolume.npy
145 Uses reward_current to accumulate the amount of
146 """
147 save_names = '_ibl_trials.rewardVolume.npy'
148 var_names = 'rewardVolume'
150 def _extract(self):
151 trial_volume = [x['reward_amount'] 1acObpvlifgxqrosjkmntuedhw
152 if x['trial_correct'] else 0 for x in self.bpod_trials]
153 reward_volume = np.array(trial_volume).astype(np.float64) 1acObpvlifgxqrosjkmntuedhw
154 assert len(reward_volume) == len(self.bpod_trials) 1acObpvlifgxqrosjkmntuedhw
155 return reward_volume 1acObpvlifgxqrosjkmntuedhw
158class FeedbackTimes(BaseBpodTrialsExtractor):
159 """
160 Get the times the water or error tone was delivered to the animal.
161 **Optional:** saves _ibl_trials.feedback_times.npy
163 Gets reward and error state init times vectors,
164 checks if the intersection of nans is empty, then
165 merges the 2 vectors.
166 """
167 save_names = '_ibl_trials.feedback_times.npy'
168 var_names = 'feedback_times'
170 @staticmethod
171 def get_feedback_times_lt5(session_path, task_collection='raw_behavior_data', data=False):
172 if not data: 1acHbxtued
173 data = raw.load_data(session_path, task_collection=task_collection)
174 rw_times = [tr['behavior_data']['States timestamps']['reward'][0][0] 1acHbxtued
175 for tr in data]
176 err_times = [tr['behavior_data']['States timestamps']['error'][0][0] 1acHbxtued
177 for tr in data]
178 nogo_times = [tr['behavior_data']['States timestamps']['no_go'][0][0] 1acHbxtued
179 for tr in data]
180 assert sum(np.isnan(rw_times) & 1acHbxtued
181 np.isnan(err_times) & np.isnan(nogo_times)) == 0
182 merge = np.array([np.array(times)[~np.isnan(times)] for times in 1acHbxtued
183 zip(rw_times, err_times, nogo_times)]).squeeze()
185 return np.array(merge) 1acHbxtued
187 @staticmethod
188 def get_feedback_times_ge5(session_path, task_collection='raw_behavior_data', data=False):
189 # ger err and no go trig times -- look for BNC2High of trial -- verify
190 # only 2 onset times go tone and noise, select 2nd/-1 OR select the one
191 # that is grater than the nogo or err trial onset time
192 if not data: 1cCbpvlAifgqrosjkmnedhw
193 data = raw.load_data(session_path, task_collection=task_collection)
194 missed_bnc2 = 0 1cCbpvlAifgqrosjkmnedhw
195 rw_times, err_sound_times, merge = [np.zeros([len(data), ]) for _ in range(3)] 1cCbpvlAifgqrosjkmnedhw
197 for ind, tr in enumerate(data): 1cCbpvlAifgqrosjkmnedhw
198 st = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1cCbpvlAifgqrosjkmnedhw
199 if not st: 1cCbpvlAifgqrosjkmnedhw
200 st = np.array([np.nan, np.nan]) 1dh
201 missed_bnc2 += 1 1dh
202 # xonar soundcard duplicates events, remove consecutive events too close together
203 st = np.delete(st, np.where(np.diff(st) < 0.020)[0] + 1) 1cCbpvlAifgqrosjkmnedhw
204 rw_times[ind] = tr['behavior_data']['States timestamps']['reward'][0][0] 1cCbpvlAifgqrosjkmnedhw
205 # get the error sound only if the reward is nan
206 err_sound_times[ind] = st[-1] if st.size >= 2 and np.isnan(rw_times[ind]) else np.nan 1cCbpvlAifgqrosjkmnedhw
207 if missed_bnc2 == len(data): 1cCbpvlAifgqrosjkmnedhw
208 _logger.warning('No BNC2 for feedback times, filling error trials NaNs')
209 merge *= np.nan 1cCbpvlAifgqrosjkmnedhw
210 merge[~np.isnan(rw_times)] = rw_times[~np.isnan(rw_times)] 1cCbpvlAifgqrosjkmnedhw
211 merge[~np.isnan(err_sound_times)] = err_sound_times[~np.isnan(err_sound_times)] 1cCbpvlAifgqrosjkmnedhw
213 return merge 1cCbpvlAifgqrosjkmnedhw
215 def _extract(self):
216 # Version check
217 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1acCHbpvlAifgxqrosjkmntuedhw
218 merge = self.get_feedback_times_ge5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1cCbpvlAifgqrosjkmnedhw
219 else:
220 merge = self.get_feedback_times_lt5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1acHbxtued
221 return np.array(merge) 1acCHbpvlAifgxqrosjkmntuedhw
224class Intervals(BaseBpodTrialsExtractor):
225 """
226 Trial start to trial end. Trial end includes 1 or 2 seconds after feedback,
227 (depending on the feedback) and 0.5 seconds of iti.
228 **Optional:** saves _ibl_trials.intervals.npy
230 Uses the corrected Trial start and Trial end timestamp values form PyBpod.
231 """
232 save_names = '_ibl_trials.intervals.npy'
233 var_names = 'intervals'
235 def _extract(self):
236 starts = [t['behavior_data']['Trial start timestamp'] for t in self.bpod_trials] 1acPbpvlifgxqrosjkmntuedhw
237 ends = [t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials] 1acPbpvlifgxqrosjkmntuedhw
238 return np.array([starts, ends]).T 1acPbpvlifgxqrosjkmntuedhw
241class ResponseTimes(BaseBpodTrialsExtractor):
242 """
243 Time (in absolute seconds from session start) when a response was recorded.
244 **Optional:** saves _ibl_trials.response_times.npy
246 Uses the timestamp of the end of the closed_loop state.
247 """
248 save_names = '_ibl_trials.response_times.npy'
249 var_names = 'response_times'
251 def _extract(self):
252 rt = np.array([tr['behavior_data']['States timestamps']['closed_loop'][0][1] 1acTbpvlifgxqrosjkmntuedhw
253 for tr in self.bpod_trials])
254 return rt 1acTbpvlifgxqrosjkmntuedhw
257class ItiDuration(BaseBpodTrialsExtractor):
258 """
259 Calculate duration of iti from state timestamps.
260 **Optional:** saves _ibl_trials.iti_duration.npy
262 Uses Trial end timestamp and get_response_times to calculate iti.
263 """
264 save_names = '_ibl_trials.itiDuration.npy'
265 var_names = 'iti_dur'
267 def _extract(self):
268 rt, _ = ResponseTimes(self.session_path).extract(
269 save=False, task_collection=self.task_collection, bpod_trials=self.bpod_trials, settings=self.settings)
270 ends = np.array([t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials])
271 iti_dur = ends - rt
272 return iti_dur
275class GoCueTriggerTimes(BaseBpodTrialsExtractor):
276 """
277 Get trigger times of goCue from state machine.
279 Current software solution for triggering sounds uses PyBpod soft codes.
280 Delays can be in the order of 10's of ms. This is the time when the command
281 to play the sound was executed. To measure accurate time, either getting the
282 sound onset from xonar soundcard sync pulse (latencies may vary).
283 """
284 save_names = '_ibl_trials.goCueTrigger_times.npy'
285 var_names = 'goCueTrigger_times'
287 def _extract(self):
288 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1acQbpvlifgxqRrosjkmntuedhw
289 goCue = np.array([tr['behavior_data']['States timestamps'] 1cQbpvlifgqRrosjkmnedhw
290 ['play_tone'][0][0] for tr in self.bpod_trials])
291 else:
292 goCue = np.array([tr['behavior_data']['States timestamps'] 1acbxtued
293 ['closed_loop'][0][0] for tr in self.bpod_trials])
294 return goCue 1acQbpvlifgxqRrosjkmntuedhw
297class TrialType(BaseBpodTrialsExtractor):
298 save_names = '_ibl_trials.type.npy'
299 var_name = 'trial_type'
301 def _extract(self):
302 trial_type = []
303 for tr in self.bpod_trials:
304 if ~np.isnan(tr["behavior_data"]["States timestamps"]["reward"][0][0]):
305 trial_type.append(1)
306 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["error"][0][0]):
307 trial_type.append(-1)
308 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0][0]):
309 trial_type.append(0)
310 else:
311 _logger.warning("Trial is not in set {-1, 0, 1}, appending NaN to trialType")
312 trial_type.append(np.nan)
313 return np.array(trial_type)
316class GoCueTimes(BaseBpodTrialsExtractor):
317 """
318 Get trigger times of goCue from state machine.
320 Current software solution for triggering sounds uses PyBpod soft codes.
321 Delays can be in the order of 10-100s of ms. This is the time when the command
322 to play the sound was executed. To measure accurate time, either getting the
323 sound onset from the future microphone OR the new xonar soundcard and
324 setup developed by Sanworks guarantees a set latency (in testing).
325 """
326 save_names = '_ibl_trials.goCue_times.npy'
327 var_names = 'goCue_times'
329 def _extract(self):
330 go_cue_times = np.zeros([len(self.bpod_trials), ]) 1acDbpvFlAifgxqrosjkmntuGedhw
331 for ind, tr in enumerate(self.bpod_trials): 1acDbpvFlAifgxqrosjkmntuGedhw
332 if raw.get_port_events(tr, 'BNC2'): 1acDbpvFlAifgxqrosjkmntuGedhw
333 bnchigh = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1cDbpvFlAifgxqrosjkmnGedhw
334 if bnchigh: 1cDbpvFlAifgxqrosjkmnGedhw
335 go_cue_times[ind] = bnchigh[0] 1cDbpvFlAifgxqrosjkmnGedhw
336 continue 1cDbpvFlAifgxqrosjkmnGedhw
337 bnclow = tr['behavior_data']['Events timestamps'].get('BNC2Low', None) 1dh
338 if bnclow: 1dh
339 go_cue_times[ind] = bnclow[0] - 0.1 1dh
340 continue 1dh
341 go_cue_times[ind] = np.nan
342 else:
343 go_cue_times[ind] = np.nan 1acDbtued
345 nmissing = np.sum(np.isnan(go_cue_times)) 1acDbpvFlAifgxqrosjkmntuGedhw
346 # Check if all stim_syncs have failed to be detected
347 if np.all(np.isnan(go_cue_times)): 1acDbpvFlAifgxqrosjkmntuGedhw
348 _logger.warning( 1acDbtued
349 f'{self.session_path}: Missing ALL !! BNC2 TTLs ({nmissing} trials)')
350 # Check if any stim_sync has failed be detected for every trial
351 elif np.any(np.isnan(go_cue_times)): 1cDbpvFlAifgxqrosjkmnGedhw
352 _logger.warning(f'{self.session_path}: Missing BNC2 TTLs on {nmissing} trials')
354 return go_cue_times 1acDbpvFlAifgxqrosjkmntuGedhw
357class IncludedTrials(BaseBpodTrialsExtractor):
358 save_names = '_ibl_trials.included.npy'
359 var_names = 'included'
361 def _extract(self):
362 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1cILMbplixqrosjkmnedhw
363 trials_included = self.get_included_trials_ge5( 1cILbpliqrosjkmnedhw
364 data=self.bpod_trials, settings=self.settings)
365 else:
366 trials_included = self.get_included_trials_lt5(data=self.bpod_trials) 1IMbxe
367 return trials_included 1cILMbplixqrosjkmnedhw
369 @staticmethod
370 def get_included_trials_lt5(data=False):
371 trials_included = np.ones(len(data), dtype=bool) 1IMbxe
372 return trials_included 1IMbxe
374 @staticmethod
375 def get_included_trials_ge5(data=False, settings=False):
376 trials_included = np.array([True for t in data]) 1cILbpliqrosjkmnedhw
377 if ('SUBJECT_DISENGAGED_TRIGGERED' in settings.keys() and settings[ 1cILbpliqrosjkmnedhw
378 'SUBJECT_DISENGAGED_TRIGGERED'] is not False):
379 idx = settings['SUBJECT_DISENGAGED_TRIALNUM'] - 1 1lirosmndh
380 trials_included[idx:] = False 1lirosmndh
381 return trials_included 1cILbpliqrosjkmnedhw
384class ItiInTimes(BaseBpodTrialsExtractor):
385 var_names = 'itiIn_times'
387 def _extract(self):
388 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("5.0.0"): 1cbpvlifgxqrosjkmnedhw
389 iti_in = np.ones(len(self.bpod_trials)) * np.nan 1xe
390 else:
391 iti_in = np.array( 1cbpvlifgqrosjkmnedhw
392 [tr["behavior_data"]["States timestamps"]
393 ["exit_state"][0][0] for tr in self.bpod_trials]
394 )
395 return iti_in 1cbpvlifgxqrosjkmnedhw
398class ErrorCueTriggerTimes(BaseBpodTrialsExtractor):
399 var_names = 'errorCueTrigger_times'
401 def _extract(self):
402 errorCueTrigger_times = np.zeros(len(self.bpod_trials)) * np.nan 1cbpvlifgxqrosjkmnedhw
403 for i, tr in enumerate(self.bpod_trials): 1cbpvlifgxqrosjkmnedhw
404 nogo = tr["behavior_data"]["States timestamps"]["no_go"][0][0] 1cbpvlifgxqrosjkmnedhw
405 error = tr["behavior_data"]["States timestamps"]["error"][0][0] 1cbpvlifgxqrosjkmnedhw
406 if np.all(~np.isnan(nogo)): 1cbpvlifgxqrosjkmnedhw
407 errorCueTrigger_times[i] = nogo 1cbvifgrsjkedh
408 elif np.all(~np.isnan(error)): 1cbpvlifgxqrosjkmnedhw
409 errorCueTrigger_times[i] = error 1cbpvlifgxqrosjkmnedhw
410 return errorCueTrigger_times 1cbpvlifgxqrosjkmnedhw
413class StimFreezeTriggerTimes(BaseBpodTrialsExtractor):
414 var_names = 'stimFreezeTrigger_times'
416 def _extract(self):
417 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("6.2.5"): 1acbypvlifgxqrosjkmntuedhw
418 return np.ones(len(self.bpod_trials)) * np.nan 1acbyxrstuedhw
419 freeze_reward = np.array( 1pvlifgqojkmnh
420 [
421 True
422 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_reward"][0]))
423 else False
424 for tr in self.bpod_trials
425 ]
426 )
427 freeze_error = np.array( 1pvlifgqojkmnh
428 [
429 True
430 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_error"][0]))
431 else False
432 for tr in self.bpod_trials
433 ]
434 )
435 no_go = np.array( 1pvlifgqojkmnh
436 [
437 True
438 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0]))
439 else False
440 for tr in self.bpod_trials
441 ]
442 )
443 assert (np.sum(freeze_error) + np.sum(freeze_reward) + 1pvlifgqojkmnh
444 np.sum(no_go) == len(self.bpod_trials))
445 stimFreezeTrigger = np.array([]) 1pvlifgqojkmnh
446 for r, e, n, tr in zip(freeze_reward, freeze_error, no_go, self.bpod_trials): 1pvlifgqojkmnh
447 if n: 1pvlifgqojkmnh
448 stimFreezeTrigger = np.append(stimFreezeTrigger, np.nan) 1vifgjkh
449 continue 1vifgjkh
450 state = "freeze_reward" if r else "freeze_error" 1pvlifgqojkmnh
451 stimFreezeTrigger = np.append( 1pvlifgqojkmnh
452 stimFreezeTrigger, tr["behavior_data"]["States timestamps"][state][0][0]
453 )
454 return stimFreezeTrigger 1pvlifgqojkmnh
457class StimOffTriggerTimes(BaseBpodTrialsExtractor):
458 var_names = 'stimOffTrigger_times'
460 def _extract(self):
461 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') >= version.parse("6.2.5"): 1acbypvlifgxqrosjkmntuedhw
462 stim_off_trigger_state = "hide_stim" 1pvlifgqojkmnh
463 elif version.parse(self.settings["IBLRIG_VERSION"]) >= version.parse("5.0.0"): 1acbyxrstuedhw
464 stim_off_trigger_state = "exit_state" 1cbyrsedhw
465 else:
466 stim_off_trigger_state = "trial_start" 1ayxtued
468 stimOffTrigger_times = np.array( 1acbypvlifgxqrosjkmntuedhw
469 [tr["behavior_data"]["States timestamps"][stim_off_trigger_state][0][0]
470 for tr in self.bpod_trials]
471 )
472 # If pre version 5.0.0 no specific nogo Off trigger was given, just return trial_starts
473 if stim_off_trigger_state == "trial_start": 1acbypvlifgxqrosjkmntuedhw
474 return stimOffTrigger_times 1ayxtued
476 no_goTrigger_times = np.array( 1cbypvlifgqrosjkmnedhw
477 [tr["behavior_data"]["States timestamps"]["no_go"][0][0] for tr in self.bpod_trials]
478 )
479 # Stim off trigs are either in their own state or in the no_go state if the
480 # mouse did not move, if the stim_off_trigger_state always exist
481 # (exit_state or trial_start)
482 # no NaNs will happen, NaNs might happen in at last trial if
483 # session was stopped after response
484 # if stim_off_trigger_state == "hide_stim":
485 # assert all(~np.isnan(no_goTrigger_times) == np.isnan(stimOffTrigger_times))
486 # Patch with the no_go states trig times
487 stimOffTrigger_times[~np.isnan(no_goTrigger_times)] = no_goTrigger_times[ 1cbypvlifgqrosjkmnedhw
488 ~np.isnan(no_goTrigger_times)
489 ]
490 return stimOffTrigger_times 1cbypvlifgqrosjkmnedhw
493class StimOnTriggerTimes(BaseBpodTrialsExtractor):
494 save_names = '_ibl_trials.stimOnTrigger_times.npy'
495 var_names = 'stimOnTrigger_times'
497 def _extract(self):
498 # Get the stim_on_state that triggers the onset of the stim
499 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1acUbypvlAifgxqrosjkmntuedhw
500 ['stim_on'][0] for tr in self.bpod_trials])
501 return stim_on_state[:, 0].T 1acUbypvlAifgxqrosjkmntuedhw
504class StimOnTimes_deprecated(BaseBpodTrialsExtractor):
505 save_names = '_ibl_trials.stimOn_times.npy'
506 var_names = 'stimOn_times'
508 def _extract(self):
509 """
510 Find the time of the state machine command to turn on the stim
511 (state stim_on start or rotary_encoder_event2)
512 Find the next frame change from the photodiode after that TS.
513 Screen is not displaying anything until then.
514 (Frame changes are in BNC1 High and BNC1 Low)
515 """
516 # Version check
517 _logger.warning("Deprecation Warning: this is an old version of stimOn extraction." 1acBzbtud
518 "From version 5., use StimOnOffFreezeTimes")
519 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1acBzbtud
520 stimOn_times = self.get_stimOn_times_ge5(self.session_path, data=self.bpod_trials, 1B
521 task_collection=self.task_collection)
522 else:
523 stimOn_times = self.get_stimOn_times_lt5(self.session_path, data=self.bpod_trials, 1aczbtud
524 task_collection=self.task_collection)
525 return np.array(stimOn_times) 1acBzbtud
527 @staticmethod
528 def get_stimOn_times_ge5(session_path, data=False, task_collection='raw_behavior_data'):
529 """
530 Find first and last stim_sync pulse of the trial.
531 stimOn_times should be the first after the stim_on state.
532 (Stim updates are in BNC1High and BNC1Low - frame2TTL device)
533 Check that all trials have frame changes.
534 Find length of stim_on_state [start, stop].
535 If either check fails the HW device failed to detect the stim_sync square change
536 Substitute that trial's missing or incorrect value with a NaN.
537 return stimOn_times
538 """
539 if not data: 1B
540 data = raw.load_data(session_path, task_collection=task_collection)
541 # Get all stim_sync events detected
542 stim_sync_all = [raw.get_port_events(tr, 'BNC1') for tr in data] 1B
543 stim_sync_all = [np.array(x) for x in stim_sync_all] 1B
544 # Get the stim_on_state that triggers the onset of the stim
545 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1B
546 ['stim_on'][0] for tr in data])
548 stimOn_times = np.array([]) 1B
549 for sync, on, off in zip( 1B
550 stim_sync_all, stim_on_state[:, 0], stim_on_state[:, 1]):
551 pulse = sync[np.where(np.bitwise_and((sync > on), (sync <= off)))] 1B
552 if pulse.size == 0: 1B
553 stimOn_times = np.append(stimOn_times, np.nan) 1B
554 else:
555 stimOn_times = np.append(stimOn_times, pulse) 1B
557 nmissing = np.sum(np.isnan(stimOn_times)) 1B
558 # Check if all stim_syncs have failed to be detected
559 if np.all(np.isnan(stimOn_times)): 1B
560 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({nmissing} trials)')
562 # Check if any stim_sync has failed be detected for every trial
563 if np.any(np.isnan(stimOn_times)): 1B
564 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {nmissing} trials') 1B
566 return stimOn_times 1B
568 @staticmethod
569 def get_stimOn_times_lt5(session_path, data=False, task_collection='raw_behavior_data'):
570 """
571 Find the time of the statemachine command to turn on the stim
572 (state stim_on start or rotary_encoder_event2)
573 Find the next frame change from the photodiode after that TS.
574 Screen is not displaying anything until then.
575 (Frame changes are in BNC1High and BNC1Low)
576 """
577 if not data: 1aczbtud
578 data = raw.load_data(session_path, task_collection=task_collection)
579 stim_on = [] 1aczbtud
580 bnc_h = [] 1aczbtud
581 bnc_l = [] 1aczbtud
582 for tr in data: 1aczbtud
583 stim_on.append(tr['behavior_data']['States timestamps']['stim_on'][0][0]) 1aczbtud
584 if 'BNC1High' in tr['behavior_data']['Events timestamps'].keys(): 1aczbtud
585 bnc_h.append(np.array(tr['behavior_data'] 1czbd
586 ['Events timestamps']['BNC1High']))
587 else:
588 bnc_h.append(np.array([np.NINF])) 1azbtud
589 if 'BNC1Low' in tr['behavior_data']['Events timestamps'].keys(): 1aczbtud
590 bnc_l.append(np.array(tr['behavior_data'] 1czbd
591 ['Events timestamps']['BNC1Low']))
592 else:
593 bnc_l.append(np.array([np.NINF])) 1azbtud
595 stim_on = np.array(stim_on) 1aczbtud
596 bnc_h = np.array(bnc_h, dtype=object) 1aczbtud
597 bnc_l = np.array(bnc_l, dtype=object) 1aczbtud
599 count_missing = 0 1aczbtud
600 stimOn_times = np.zeros_like(stim_on) 1aczbtud
601 for i in range(len(stim_on)): 1aczbtud
602 hl = np.sort(np.concatenate([bnc_h[i], bnc_l[i]])) 1aczbtud
603 stot = hl[hl > stim_on[i]] 1aczbtud
604 if np.size(stot) == 0: 1aczbtud
605 stot = np.array([np.nan]) 1azbtud
606 count_missing += 1 1azbtud
607 stimOn_times[i] = stot[0] 1aczbtud
609 if np.all(np.isnan(stimOn_times)): 1aczbtud
610 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({count_missing} trials)') 1atud
612 if count_missing > 0: 1aczbtud
613 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {count_missing} trials') 1azbtud
615 return np.array(stimOn_times) 1aczbtud
618class StimOnOffFreezeTimes(BaseBpodTrialsExtractor):
619 """
620 Extracts stim on / off and freeze times from Bpod BNC1 detected fronts.
622 Each stimulus event is the first detected front of the BNC1 signal after the trigger state, but before the next
623 trigger state.
624 """
625 save_names = ('_ibl_trials.stimOn_times.npy', None, None)
626 var_names = ('stimOn_times', 'stimOff_times', 'stimFreeze_times')
628 def _extract(self):
629 choice = Choice(self.session_path).extract( 1acbypvlifgxqrosjkmntuedhw
630 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False
631 )[0]
632 stimOnTrigger = StimOnTriggerTimes(self.session_path).extract( 1acbypvlifgxqrosjkmntuedhw
633 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False
634 )[0]
635 stimFreezeTrigger = StimFreezeTriggerTimes(self.session_path).extract( 1acbypvlifgxqrosjkmntuedhw
636 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False
637 )[0]
638 stimOffTrigger = StimOffTriggerTimes(self.session_path).extract( 1acbypvlifgxqrosjkmntuedhw
639 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False
640 )[0]
641 f2TTL = [raw.get_port_events(tr, name='BNC1') for tr in self.bpod_trials] 1acbypvlifgxqrosjkmntuedhw
642 assert stimOnTrigger.size == stimFreezeTrigger.size == stimOffTrigger.size == choice.size == len(f2TTL) 1acbypvlifgxqrosjkmntuedhw
643 assert all(stimOnTrigger < np.nan_to_num(stimFreezeTrigger, nan=np.inf)) and \ 1acbypvlifgxqrosjkmntuedhw
644 all(np.nan_to_num(stimFreezeTrigger, nan=-np.inf) < stimOffTrigger)
646 stimOn_times = np.array([]) 1acbypvlifgxqrosjkmntuedhw
647 stimOff_times = np.array([]) 1acbypvlifgxqrosjkmntuedhw
648 stimFreeze_times = np.array([]) 1acbypvlifgxqrosjkmntuedhw
649 has_freeze = version.parse(self.settings.get('IBLRIG_VERSION', '0')) >= version.parse('6.2.5') 1acbypvlifgxqrosjkmntuedhw
650 for tr, on, freeze, off, c in zip(f2TTL, stimOnTrigger, stimFreezeTrigger, stimOffTrigger, choice): 1acbypvlifgxqrosjkmntuedhw
651 tr = np.array(tr) 1acbypvlifgxqrosjkmntuedhw
652 # stim on
653 lim = freeze if has_freeze else off 1acbypvlifgxqrosjkmntuedhw
654 idx, = np.where(np.logical_and(on < tr, tr < lim)) 1acbypvlifgxqrosjkmntuedhw
655 stimOn_times = np.append(stimOn_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acbypvlifgxqrosjkmntuedhw
656 # stim off
657 idx, = np.where(off < tr) 1acbypvlifgxqrosjkmntuedhw
658 stimOff_times = np.append(stimOff_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acbypvlifgxqrosjkmntuedhw
659 # stim freeze - take last event before off trigger
660 if has_freeze: 1acbypvlifgxqrosjkmntuedhw
661 idx, = np.where(np.logical_and(freeze < tr, tr < off)) 1pvlifgqojkmnh
662 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1pvlifgqojkmnh
663 else:
664 idx, = np.where(tr <= off) 1acbyxrstuedhw
665 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1acbyxrstuedhw
666 # In no_go trials no stimFreeze happens just stim Off
667 stimFreeze_times[choice == 0] = np.nan 1acbypvlifgxqrosjkmntuedhw
669 return stimOn_times, stimOff_times, stimFreeze_times 1acbypvlifgxqrosjkmntuedhw
672class PhasePosQuiescence(BaseBpodTrialsExtractor):
673 """Extract stimulus phase, position and quiescence from Bpod data.
675 For extraction of pre-generated events, use the ProbaContrasts extractor instead.
676 """
677 save_names = (None, None, '_ibl_trials.quiescencePeriod.npy')
678 var_names = ('phase', 'position', 'quiescence')
680 def _extract(self, **kwargs):
681 phase = np.array([t['stim_phase'] for t in self.bpod_trials]) 1acbplifgxqrosjkmntuedhw
682 position = np.array([t['position'] for t in self.bpod_trials]) 1acbplifgxqrosjkmntuedhw
683 quiescence = np.array([t['quiescent_period'] for t in self.bpod_trials]) 1acbplifgxqrosjkmntuedhw
684 return phase, position, quiescence 1acbplifgxqrosjkmntuedhw
687class PauseDuration(BaseBpodTrialsExtractor):
688 """Extract pause duration from raw trial data."""
689 save_names = None
690 var_names = 'pause_duration'
692 def _extract(self, **kwargs):
693 # pausing logic added in version 8.9.0
694 ver = version.parse(self.settings.get('IBLRIG_VERSION') or '0') 1cSbfge
695 default = 0. if ver < version.parse('8.9.0') else np.nan 1cSbfge
696 return np.fromiter((t.get('pause_duration', default) for t in self.bpod_trials), dtype=float) 1cSbfge
699class TrialsTable(BaseBpodTrialsExtractor):
700 """
701 Extracts the following into a table from Bpod raw data:
702 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
703 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
704 Additionally extracts the following wheel data:
705 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
706 """
707 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
708 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
709 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
710 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement')
712 def _extract(self, extractor_classes=None, **kwargs):
713 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1cbfge
714 RewardVolume, ProbabilityLeft, Wheel]
715 out, _ = run_extractor_classes( 1cbfge
716 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False,
717 task_collection=self.task_collection)
718 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1cbfge
719 assert len(table.keys()) == 12 1cbfge
721 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1cbfge
724class TrainingTrials(BaseBpodTrialsExtractor):
725 save_names = ('_ibl_trials.repNum.npy', '_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None,
726 None, None, None, '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
727 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, None, None)
728 var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
729 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times',
730 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude',
731 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence', 'pause_duration')
733 def _extract(self) -> dict:
734 base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1cbfge
735 ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence, PauseDuration]
736 out, _ = run_extractor_classes( 1cbfge
737 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False,
738 task_collection=self.task_collection)
739 return {k: out[k] for k in self.var_names} 1cbfge
742def extract_all(session_path, save=False, bpod_trials=None, settings=None, task_collection='raw_behavior_data', save_path=None):
743 """Extract trials and wheel data.
745 For task versions >= 5.0.0, outputs wheel data and trials.table dataset (+ some extra datasets)
747 Parameters
748 ----------
749 session_path : str, pathlib.Path
750 The path to the session
751 save : bool
752 If true save the data files to ALF
753 bpod_trials : list of dicts
754 The Bpod trial dicts loaded from the _iblrig_taskData.raw dataset
755 settings : dict
756 The Bpod settings loaded from the _iblrig_taskSettings.raw dataset
758 Returns
759 -------
760 A list of extracted data and a list of file paths if save is True (otherwise None)
761 """
762 if not bpod_trials: 1cb
763 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 1c
764 if not settings: 1cb
765 settings = raw.load_settings(session_path, task_collection=task_collection) 1c
766 if settings is None or settings['IBLRIG_VERSION'] == '': 1cb
767 settings = {'IBLRIG_VERSION': '100.0.0'}
769 # Version check
770 if version.parse(settings['IBLRIG_VERSION']) >= version.parse('5.0.0'): 1cb
771 # We now extract a single trials table
772 base = [TrainingTrials] 1cb
773 else:
774 base = [ 1cb
775 RepNum, GoCueTriggerTimes, Intervals, Wheel, FeedbackType, ContrastLR, ProbabilityLeft, Choice, IncludedTrials,
776 StimOnTimes_deprecated, RewardVolume, FeedbackTimes, ResponseTimes, GoCueTimes, PhasePosQuiescence
777 ]
779 out, fil = run_extractor_classes(base, save=save, session_path=session_path, bpod_trials=bpod_trials, settings=settings, 1cb
780 task_collection=task_collection, path_out=save_path)
781 return out, fil 1cb