Coverage for ibllib/io/extractors/habituation_trials.py: 92%
53 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1import logging
2import numpy as np
4import ibllib.io.raw_data_loaders as raw
5from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
6from ibllib.io.extractors.biased_trials import ContrastLR
7from ibllib.io.extractors.training_trials import (
8 FeedbackTimes, StimOnTriggerTimes, Intervals, GoCueTimes
9)
11_logger = logging.getLogger(__name__)
14class HabituationTrials(BaseBpodTrialsExtractor):
15 var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight',
16 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
17 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
18 'stimCenterTrigger_times', 'stimCenter_times')
20 def __init__(self, *args, **kwargs):
21 super().__init__(*args, **kwargs) 1a
22 exclude = ['itiIn_times', 'stimOffTrigger_times', 1a
23 'stimCenter_times', 'stimCenterTrigger_times']
24 self.save_names = tuple([f'_ibl_trials.{x}.npy' if x not in exclude else None 1a
25 for x in self.var_names])
27 def _extract(self):
28 # Extract all trials...
30 # Get all stim_sync events detected
31 ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] 1a
33 # Report missing events
34 n_missing = sum(len(pulses) != 3 for pulses in ttls) 1a
35 # Check if all stim syncs have failed to be detected
36 if n_missing == len(ttls): 1a
37 _logger.error(f'{self.session_path}: Missing ALL BNC1 TTLs ({n_missing} trials)')
38 elif n_missing > 0: # Check if any stim_sync has failed be detected for every trial 1a
39 _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)')
41 # Extract datasets common to trainingChoiceWorld
42 training = [ContrastLR, FeedbackTimes, Intervals, GoCueTimes, StimOnTriggerTimes] 1a
43 out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, 1a
44 bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)
46 # GoCueTriggerTimes is the same event as StimOnTriggerTimes
47 out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() 1a
49 # StimCenterTrigger times
50 # Get the stim_on_state that triggers the onset of the stim
51 stim_center_state = np.array([tr['behavior_data']['States timestamps'] 1a
52 ['stim_center'][0] for tr in self.bpod_trials])
53 out['stimCenterTrigger_times'] = stim_center_state[:, 0].T 1a
55 # StimCenter times
56 stim_center_times = np.full(out['stimCenterTrigger_times'].shape, np.nan) 1a
57 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1a
58 """We expect there to be 3 pulses per trial; if this is the case, stim center will
59 be the third pulse. If any pulses are missing, we can only be confident of the correct
60 one if exactly one pulse occurs after the stim center trigger"""
61 if len(sync) == 3 or (len(sync) > 0 and sum(pulse > last for pulse in sync) == 1): 1a
62 stim_center_times[i] = sync[-1] 1a
63 out['stimCenter_times'] = stim_center_times 1a
65 # StimOn times
66 stimOn_times = np.full(out['stimOnTrigger_times'].shape, np.nan) 1a
67 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1a
68 """We expect there to be 3 pulses per trial; if this is the case, stim on will be the
69 second pulse. If 1 pulse is missing, we can only be confident of the correct one if
70 both pulses occur before the stim center trigger"""
71 if len(sync) == 3 or (len(sync) == 2 and sum(pulse < last for pulse in sync) == 2): 1a
72 stimOn_times[i] = sync[1] 1a
73 out['stimOn_times'] = stimOn_times 1a
75 # RewardVolume
76 trial_volume = [x['reward_amount'] for x in self.bpod_trials] 1a
77 out['rewardVolume'] = np.array(trial_volume).astype(np.float64) 1a
79 # StimOffTrigger times
80 # StimOff occurs at trial start (ignore the first trial's state update)
81 out['stimOffTrigger_times'] = np.array( 1a
82 [tr["behavior_data"]["States timestamps"]
83 ["trial_start"][0][0] for tr in self.bpod_trials[1:]]
84 )
86 # StimOff times
87 """ 1a
88 There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse.
89 If 1 or more pulses are missing, we can not be confident of assigning the correct one.
90 """
91 trigg = out['stimOffTrigger_times'] 1a
92 out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan 1a
93 for sync, off in zip(ttls[1:], trigg)])
95 # FeedbackType is always positive
96 out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) 1a
98 # ItiIn times
99 out['itiIn_times'] = np.array( 1a
100 [tr["behavior_data"]["States timestamps"]
101 ["iti"][0][0] for tr in self.bpod_trials]
102 )
104 # NB: We lose the last trial because the stim off event occurs at trial_num + 1
105 n_trials = out['stimOff_times'].size 1a
106 return [out[k][:n_trials] for k in self.var_names] 1a
109def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None):
110 """Extract all datasets from habituationChoiceWorld
111 Note: only the datasets from the HabituationTrials extractor will be saved to disc.
113 :param session_path: The session path where the raw data are saved
114 :param save: If True, the datasets that are considered standard are saved to the session path
115 :param bpod_trials: The raw Bpod trial data
116 :param settings: The raw Bpod sessions
117 :returns: a dict of datasets and a corresponding list of file names
118 """
119 if not bpod_trials: 1a
120 bpod_trials = raw.load_data(session_path, task_collection=task_collection)
121 if not settings: 1a
122 settings = raw.load_settings(session_path, task_collection=task_collection)
124 # Standard datasets that may be saved as ALFs
125 params = dict(session_path=session_path, bpod_trials=bpod_trials, settings=settings, task_collection=task_collection, 1a
126 path_out=save_path)
127 out, fil = run_extractor_classes(HabituationTrials, save=save, **params) 1a
128 return out, fil 1a