Coverage for ibllib/io/extractors/habituation_trials.py: 98%
64 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 14:26 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 14:26 +0100
1"""Habituation ChoiceWorld Bpod trials extraction."""
2import logging
3import numpy as np
5from packaging import version
7import ibllib.io.raw_data_loaders as raw
8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
9from ibllib.io.extractors.biased_trials import ContrastLR
10from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes
12_logger = logging.getLogger(__name__)
15class HabituationTrials(BaseBpodTrialsExtractor):
16 var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight',
17 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
18 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
19 'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase')
21 def __init__(self, *args, **kwargs):
22 super().__init__(*args, **kwargs) 1cab
23 exclude = ['itiIn_times', 'stimCenter_times', 'stimCenterTrigger_times', 'position', 'phase'] 1cab
24 self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) 1cab
26 def _extract(self) -> dict:
27 """
28 Extract the Bpod trial events.
30 For iblrig versions < 8.13 the Bpod state machine for this task had extremely misleading names!
31 The 'iti' state was actually the delay between valve close and trial end (the stimulus is
32 still present during this period), and the 'trial_start' state is actually the ITI during
33 which there is a 1s Bpod TTL and gray screen period.
35 In version 8.13 and later, the 'iti' state was renamed to 'post_reward' and 'trial_start'
36 was renamed to 'iti'.
38 Returns
39 -------
40 dict
41 A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute.
42 """
43 # Extract all trials...
45 # Get all detected TTLs. These are stored for QC purposes
46 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1cab
47 # These are the frame2TTL pulses as a list of lists, one per trial
48 ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] 1cab
50 # Report missing events
51 n_missing = sum(len(pulses) != 3 for pulses in ttls) 1cab
52 # Check if all stim syncs have failed to be detected
53 if n_missing == len(ttls): 1cab
54 _logger.error(f'{self.session_path}: Missing ALL BNC1 TTLs ({n_missing} trials)')
55 elif n_missing > 0: # Check if any stim_sync has failed be detected for every trial 1cab
56 _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') 1ab
58 # Extract datasets common to trainingChoiceWorld
59 training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes] 1cab
60 out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, 1cab
61 bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)
63 """ 1cab
64 The 'trial_start'/'iti' state is in fact the 1s grey screen period, therefore the first
65 timestamp is really the end of the previous trial and also the stimOff trigger time. The
66 second timestamp is the true trial start time. This state was renamed in version 8.13.
67 """
68 state_names = self.bpod_trials[0]['behavior_data']['States timestamps'].keys() 1cab
69 rig_version = version.parse(self.settings['IBLRIG_VERSION']) 1cab
70 legacy_state_machine = 'post_reward' not in state_names and 'trial_start' in state_names 1cab
72 key = 'iti' if (rig_version >= version.parse('8.13') and not legacy_state_machine) else 'trial_start' 1cab
73 (_, *ends), starts = zip(*[ 1cab
74 t['behavior_data']['States timestamps'][key][-1] for t in self.bpod_trials]
75 )
77 # StimOffTrigger times
78 out['stimOffTrigger_times'] = np.array(ends) 1cab
80 # StimOff times
81 """ 1cab
82 There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse.
83 If 1 or more pulses are missing, we can not be confident of assigning the correct one.
84 """
85 out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]]) 1cab
87 # Trial intervals
88 """ 1cab
89 In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim
90 off time often happens after the trial end TTL front, i.e. after the 'trial_start' start
91 begins. For these trials, we set the trial end time as the stim off time.
92 """
93 # NB: We lose the last trial because the stim off event occurs at trial_num + 1
94 n_trials = out['stimOff_times'].size 1cab
95 out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] 1cab
97 to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) 1cab
98 if np.any(to_correct): 1cab
99 _logger.debug( 1cab
100 '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end',
101 sum(to_correct), len(to_correct))
102 out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] 1cab
104 # itiIn times
105 out['itiIn_times'] = np.r_[ends, np.nan] 1cab
107 # GoCueTriggerTimes is the same event as StimOnTriggerTimes
108 out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() 1cab
110 # StimCenterTrigger times
111 # Get the stim_on_state that triggers the onset of the stim
112 stim_center_state = np.array([tr['behavior_data']['States timestamps'] 1cab
113 ['stim_center'][0] for tr in self.bpod_trials])
114 out['stimCenterTrigger_times'] = stim_center_state[:, 0].T 1cab
116 # StimCenter times
117 stim_center_times = np.full(out['stimCenterTrigger_times'].shape, np.nan) 1cab
118 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1cab
119 """We expect there to be 3 pulses per trial; if this is the case, stim center will 1cab
120 be the third pulse. If any pulses are missing, we can only be confident of the correct
121 one if exactly one pulse occurs after the stim center trigger"""
122 if len(sync) == 3 or (len(sync) > 0 and sum(pulse > last for pulse in sync) == 1): 1cab
123 stim_center_times[i] = sync[-1] 1cab
124 out['stimCenter_times'] = stim_center_times 1cab
126 # StimOn times
127 stimOn_times = np.full(out['stimOnTrigger_times'].shape, np.nan) 1cab
128 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1cab
129 """We expect there to be 3 pulses per trial; if this is the case, stim on will be the 1cab
130 second pulse. If 1 pulse is missing, we can only be confident of the correct one if
131 both pulses occur before the stim center trigger"""
132 if len(sync) == 3 or (len(sync) == 2 and sum(pulse < last for pulse in sync) == 2): 1cab
133 stimOn_times[i] = sync[1] 1cab
134 out['stimOn_times'] = stimOn_times 1cab
136 # RewardVolume
137 trial_volume = [x['reward_amount'] for x in self.bpod_trials] 1cab
138 out['rewardVolume'] = np.array(trial_volume).astype(np.float64) 1cab
140 # FeedbackType is always positive
141 out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) 1cab
143 # Phase and position
144 out['position'] = np.array([t['position'] for t in self.bpod_trials]) 1cab
145 out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) 1cab
147 # Double-check that the early and late trial events occur within the trial intervals
148 idx = ~np.isnan(out['stimOn_times'][:n_trials]) 1cab
149 assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \ 1cab
150 'Stim on events occurring outside trial intervals'
152 # Truncate arrays and return in correct order
153 return {k: out[k][:n_trials] for k in self.var_names} 1cab