Coverage for ibllib/io/extractors/habituation_trials.py: 89%
72 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""Habituation ChoiceWorld Bpod trials extraction."""
2import logging
3import numpy as np
5from packaging import version
7import ibllib.io.raw_data_loaders as raw
8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
9from ibllib.io.extractors.biased_trials import ContrastLR
10from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes
12_logger = logging.getLogger(__name__)
15class HabituationTrials(BaseBpodTrialsExtractor):
16 var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight',
17 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
18 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
19 'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase')
21 def __init__(self, *args, **kwargs):
22 super().__init__(*args, **kwargs) 1ba
23 exclude = ['itiIn_times', 'stimOffTrigger_times', 'stimCenter_times', 1ba
24 'stimCenterTrigger_times', 'position', 'phase']
25 self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) 1ba
27 def _extract(self) -> dict:
28 """
29 Extract the Bpod trial events.
31 For iblrig versions < 8.13 the Bpod state machine for this task had extremely misleading names!
32 The 'iti' state was actually the delay between valve close and trial end (the stimulus is
33 still present during this period), and the 'trial_start' state is actually the ITI during
34 which there is a 1s Bpod TTL and gray screen period.
36 In version 8.13 and later, the 'iti' state was renamed to 'post_reward' and 'trial_start'
37 was renamed to 'iti'.
39 Returns
40 -------
41 dict
42 A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute.
43 """
44 # Extract all trials...
46 # Get all detected TTLs. These are stored for QC purposes
47 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1ba
48 # These are the frame2TTL pulses as a list of lists, one per trial
49 ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] 1ba
51 # Report missing events
52 n_missing = sum(len(pulses) != 3 for pulses in ttls) 1ba
53 # Check if all stim syncs have failed to be detected
54 if n_missing == len(ttls): 1ba
55 _logger.error(f'{self.session_path}: Missing ALL BNC1 TTLs ({n_missing} trials)')
56 elif n_missing > 0: # Check if any stim_sync has failed be detected for every trial 1ba
57 _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') 1a
59 # Extract datasets common to trainingChoiceWorld
60 training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes] 1ba
61 out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, 1ba
62 bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)
64 """ 1ba
65 The 'trial_start'/'iti' state is in fact the 1s grey screen period, therefore the first
66 timestamp is really the end of the previous trial and also the stimOff trigger time. The
67 second timestamp is the true trial start time. This state was renamed in version 8.13.
68 """
69 state_names = self.bpod_trials[0]['behavior_data']['States timestamps'].keys() 1ba
70 rig_version = version.parse(self.settings['IBLRIG_VERSION']) 1ba
71 legacy_state_machine = 'post_reward' not in state_names and 'trial_start' in state_names 1ba
73 key = 'iti' if (rig_version >= version.parse('8.13') and not legacy_state_machine) else 'trial_start' 1ba
74 (_, *ends), starts = zip(*[ 1ba
75 t['behavior_data']['States timestamps'][key][-1] for t in self.bpod_trials]
76 )
78 # StimOffTrigger times
79 out['stimOffTrigger_times'] = np.array(ends) 1ba
81 # StimOff times
82 """ 1ba
83 There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse.
84 If 1 or more pulses are missing, we can not be confident of assigning the correct one.
85 """
86 out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]]) 1ba
88 # Trial intervals
89 """ 1ba
90 In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim
91 off time often happens after the trial end TTL front, i.e. after the 'trial_start' start
92 begins. For these trials, we set the trial end time as the stim off time.
93 """
94 # NB: We lose the last trial because the stim off event occurs at trial_num + 1
95 n_trials = out['stimOff_times'].size 1ba
96 out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] 1ba
98 to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) 1ba
99 if np.any(to_correct): 1ba
100 _logger.debug( 1ba
101 '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end',
102 sum(to_correct), len(to_correct))
103 out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] 1ba
105 # itiIn times
106 out['itiIn_times'] = np.r_[ends, np.nan] 1ba
108 # GoCueTriggerTimes is the same event as StimOnTriggerTimes
109 out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() 1ba
111 # StimCenterTrigger times
112 # Get the stim_on_state that triggers the onset of the stim
113 stim_center_state = np.array([tr['behavior_data']['States timestamps'] 1ba
114 ['stim_center'][0] for tr in self.bpod_trials])
115 out['stimCenterTrigger_times'] = stim_center_state[:, 0].T 1ba
117 # StimCenter times
118 stim_center_times = np.full(out['stimCenterTrigger_times'].shape, np.nan) 1ba
119 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1ba
120 """We expect there to be 3 pulses per trial; if this is the case, stim center will 1ba
121 be the third pulse. If any pulses are missing, we can only be confident of the correct
122 one if exactly one pulse occurs after the stim center trigger"""
123 if len(sync) == 3 or (len(sync) > 0 and sum(pulse > last for pulse in sync) == 1): 1ba
124 stim_center_times[i] = sync[-1] 1ba
125 out['stimCenter_times'] = stim_center_times 1ba
127 # StimOn times
128 stimOn_times = np.full(out['stimOnTrigger_times'].shape, np.nan) 1ba
129 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1ba
130 """We expect there to be 3 pulses per trial; if this is the case, stim on will be the 1ba
131 second pulse. If 1 pulse is missing, we can only be confident of the correct one if
132 both pulses occur before the stim center trigger"""
133 if len(sync) == 3 or (len(sync) == 2 and sum(pulse < last for pulse in sync) == 2): 1ba
134 stimOn_times[i] = sync[1] 1ba
135 out['stimOn_times'] = stimOn_times 1ba
137 # RewardVolume
138 trial_volume = [x['reward_amount'] for x in self.bpod_trials] 1ba
139 out['rewardVolume'] = np.array(trial_volume).astype(np.float64) 1ba
141 # FeedbackType is always positive
142 out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) 1ba
144 # Phase and position
145 out['position'] = np.array([t['position'] for t in self.bpod_trials]) 1ba
146 out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) 1ba
148 # Double-check that the early and late trial events occur within the trial intervals
149 idx = ~np.isnan(out['stimOn_times'][:n_trials]) 1ba
150 assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \ 1ba
151 'Stim on events occurring outside trial intervals'
153 # Truncate arrays and return in correct order
154 return {k: out[k][:n_trials] for k in self.var_names} 1ba
157def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None):
158 """Extract all datasets from habituationChoiceWorld
159 Note: only the datasets from the HabituationTrials extractor will be saved to disc.
161 :param session_path: The session path where the raw data are saved
162 :param save: If True, the datasets that are considered standard are saved to the session path
163 :param bpod_trials: The raw Bpod trial data
164 :param settings: The raw Bpod sessions
165 :returns: a dict of datasets and a corresponding list of file names
166 """
167 if not bpod_trials:
168 bpod_trials = raw.load_data(session_path, task_collection=task_collection)
169 if not settings:
170 settings = raw.load_settings(session_path, task_collection=task_collection)
172 # Standard datasets that may be saved as ALFs
173 params = dict(session_path=session_path, bpod_trials=bpod_trials, settings=settings, task_collection=task_collection,
174 path_out=save_path)
175 out, fil = run_extractor_classes(HabituationTrials, save=save, **params)
176 return out, fil