Coverage for ibllib/io/extractors/habituation_trials.py: 98%
64 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""Habituation ChoiceWorld Bpod trials extraction."""
2import logging
3import numpy as np
5from packaging import version
7import ibllib.io.raw_data_loaders as raw
8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
9from ibllib.io.extractors.biased_trials import ContrastLR
10from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes
12_logger = logging.getLogger(__name__)
15class HabituationTrials(BaseBpodTrialsExtractor):
16 var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight',
17 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
18 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
19 'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase')
21 def __init__(self, *args, **kwargs):
22 super().__init__(*args, **kwargs) 1ba
23 exclude = ['itiIn_times', 'stimCenter_times', 'stimCenterTrigger_times', 'position', 'phase'] 1ba
24 self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) 1ba
26 def _extract(self) -> dict:
27 """
28 Extract the Bpod trial events.
30 For iblrig versions < 8.13 the Bpod state machine for this task had extremely misleading names!
31 The 'iti' state was actually the delay between valve close and trial end (the stimulus is
32 still present during this period), and the 'trial_start' state is actually the ITI during
33 which there is a 1s Bpod TTL and gray screen period.
35 In version 8.13 and later, the 'iti' state was renamed to 'post_reward' and 'trial_start'
36 was renamed to 'iti'.
38 Returns
39 -------
40 dict
41 A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute.
42 """
43 # Extract all trials...
45 # Get all detected TTLs. These are stored for QC purposes
46 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1ba
47 # These are the frame2TTL pulses as a list of lists, one per trial
48 ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] 1ba
50 # Report missing events
51 n_missing = sum(len(pulses) != 3 for pulses in ttls) 1ba
52 # Check if all stim syncs have failed to be detected
53 if n_missing == len(ttls): 1ba
54 _logger.error(f'{self.session_path}: Missing ALL BNC1 TTLs ({n_missing} trials)')
55 elif n_missing > 0: # Check if any stim_sync has failed be detected for every trial 1ba
56 _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') 1a
58 # Extract datasets common to trainingChoiceWorld
59 training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes] 1ba
60 out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, 1ba
61 bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)
63 """ 1ba
64 The 'trial_start'/'iti' state is in fact the 1s grey screen period, therefore the first
65 timestamp is really the end of the previous trial and also the stimOff trigger time. The
66 second timestamp is the true trial start time. This state was renamed in version 8.13.
67 """
68 state_names = self.bpod_trials[0]['behavior_data']['States timestamps'].keys() 1ba
69 rig_version = version.parse(self.settings['IBLRIG_VERSION']) 1ba
70 legacy_state_machine = 'post_reward' not in state_names and 'trial_start' in state_names 1ba
72 key = 'iti' if (rig_version >= version.parse('8.13') and not legacy_state_machine) else 'trial_start' 1ba
73 (_, *ends), starts = zip(*[ 1ba
74 t['behavior_data']['States timestamps'][key][-1] for t in self.bpod_trials]
75 )
77 # StimOffTrigger times
78 out['stimOffTrigger_times'] = np.array(ends) 1ba
80 # StimOff times
81 """ 1ba
82 There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse.
83 If 1 or more pulses are missing, we can not be confident of assigning the correct one.
84 """
85 out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]]) 1ba
87 # Trial intervals
88 """ 1ba
89 In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim
90 off time often happens after the trial end TTL front, i.e. after the 'trial_start' start
91 begins. For these trials, we set the trial end time as the stim off time.
92 """
93 # NB: We lose the last trial because the stim off event occurs at trial_num + 1
94 n_trials = out['stimOff_times'].size 1ba
95 out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] 1ba
97 to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) 1ba
98 if np.any(to_correct): 1ba
99 _logger.debug( 1ba
100 '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end',
101 sum(to_correct), len(to_correct))
102 out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] 1ba
104 # itiIn times
105 out['itiIn_times'] = np.r_[ends, np.nan] 1ba
107 # GoCueTriggerTimes is the same event as StimOnTriggerTimes
108 out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() 1ba
110 # StimCenterTrigger times
111 # Get the stim_on_state that triggers the onset of the stim
112 stim_center_state = np.array([tr['behavior_data']['States timestamps'] 1ba
113 ['stim_center'][0] for tr in self.bpod_trials])
114 out['stimCenterTrigger_times'] = stim_center_state[:, 0].T 1ba
116 # StimCenter times
117 stim_center_times = np.full(out['stimCenterTrigger_times'].shape, np.nan) 1ba
118 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1ba
119 """We expect there to be 3 pulses per trial; if this is the case, stim center will 1ba
120 be the third pulse. If any pulses are missing, we can only be confident of the correct
121 one if exactly one pulse occurs after the stim center trigger"""
122 if len(sync) == 3 or (len(sync) > 0 and sum(pulse > last for pulse in sync) == 1): 1ba
123 stim_center_times[i] = sync[-1] 1ba
124 out['stimCenter_times'] = stim_center_times 1ba
126 # StimOn times
127 stimOn_times = np.full(out['stimOnTrigger_times'].shape, np.nan) 1ba
128 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1ba
129 """We expect there to be 3 pulses per trial; if this is the case, stim on will be the 1ba
130 second pulse. If 1 pulse is missing, we can only be confident of the correct one if
131 both pulses occur before the stim center trigger"""
132 if len(sync) == 3 or (len(sync) == 2 and sum(pulse < last for pulse in sync) == 2): 1ba
133 stimOn_times[i] = sync[1] 1ba
134 out['stimOn_times'] = stimOn_times 1ba
136 # RewardVolume
137 trial_volume = [x['reward_amount'] for x in self.bpod_trials] 1ba
138 out['rewardVolume'] = np.array(trial_volume).astype(np.float64) 1ba
140 # FeedbackType is always positive
141 out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) 1ba
143 # Phase and position
144 out['position'] = np.array([t['position'] for t in self.bpod_trials]) 1ba
145 out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) 1ba
147 # Double-check that the early and late trial events occur within the trial intervals
148 idx = ~np.isnan(out['stimOn_times'][:n_trials]) 1ba
149 assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \ 1ba
150 'Stim on events occurring outside trial intervals'
152 # Truncate arrays and return in correct order
153 return {k: out[k][:n_trials] for k in self.var_names} 1ba