Coverage for ibllib/io/extractors/biased_trials.py: 97%
114 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1from pathlib import Path, PureWindowsPath
3from packaging import version
4import numpy as np
5from one.alf.io import AlfBunch
7from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
8import ibllib.io.raw_data_loaders as raw
9from ibllib.io.extractors.training_trials import (
10 Choice, FeedbackTimes, FeedbackType, GoCueTimes, GoCueTriggerTimes,
11 IncludedTrials, Intervals, ProbabilityLeft, ResponseTimes, RewardVolume,
12 StimOnTimes_deprecated, StimOnTriggerTimes, StimOnOffFreezeTimes, ItiInTimes,
13 StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, PhasePosQuiescence)
14from ibllib.io.extractors.training_wheel import Wheel
16__all__ = ['extract_all', 'BiasedTrials', 'EphysTrials']
19class ContrastLR(BaseBpodTrialsExtractor):
20 """Get left and right contrasts from raw datafile."""
21 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy')
22 var_names = ('contrastLeft', 'contrastRight')
24 def _extract(self, **kwargs):
25 contrastLeft = np.array([t['contrast'] if np.sign( 1amwnrsxtuefopqcdvb
26 t['position']) < 0 else np.nan for t in self.bpod_trials])
27 contrastRight = np.array([t['contrast'] if np.sign( 1amwnrsxtuefopqcdvb
28 t['position']) > 0 else np.nan for t in self.bpod_trials])
29 return contrastLeft, contrastRight 1amwnrsxtuefopqcdvb
32class ProbaContrasts(BaseBpodTrialsExtractor):
33 """Bpod pre-generated values for probabilityLeft, contrastLR, phase, quiescence."""
34 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', None, None,
35 '_ibl_trials.probabilityLeft.npy', '_ibl_trials.quiescencePeriod.npy')
36 var_names = ('contrastLeft', 'contrastRight', 'phase',
37 'position', 'probabilityLeft', 'quiescence')
39 def _extract(self, **kwargs):
40 """Extracts positions, contrasts, quiescent delay, stimulus phase and probability left
41 from pregenerated session files. Used in ephysChoiceWorld extractions.
42 Optional: saves alf contrastLR and probabilityLeft npy files"""
43 pe = self.get_pregenerated_events(self.bpod_trials, self.settings) 1aljekhfcdbgi
44 return [pe[k] for k in sorted(pe.keys())] 1aljekhfcdbgi
46 @staticmethod
47 def get_pregenerated_events(bpod_trials, settings):
48 for k in ['PRELOADED_SESSION_NUM', 'PREGENERATED_SESSION_NUM', 'SESSION_TEMPLATE_ID']: 1aljekhfcdbgi
49 num = settings.get(k, None) 1aljekhfcdbgi
50 if num is not None: 1aljekhfcdbgi
51 break 1ljkfg
52 if num is None: 1aljekhfcdbgi
53 fn = settings.get('SESSION_LOADED_FILE_PATH', '') 1aehcdbgi
54 fn = PureWindowsPath(fn).name 1aehcdbgi
55 num = ''.join([d for d in fn if d.isdigit()]) 1aehcdbgi
56 if num == '': 1aehcdbgi
57 raise ValueError("Can't extract left probability behaviour.")
58 # Load the pregenerated file
59 ntrials = len(bpod_trials) 1aljekhfcdbgi
60 sessions_folder = Path(raw.__file__).parent.joinpath( 1aljekhfcdbgi
61 "extractors", "ephys_sessions")
62 fname = f"session_{num}_ephys_pcqs.npy" 1aljekhfcdbgi
63 pcqsp = np.load(sessions_folder.joinpath(fname)) 1aljekhfcdbgi
64 pos = pcqsp[:, 0] 1aljekhfcdbgi
65 con = pcqsp[:, 1] 1aljekhfcdbgi
66 pos = pos[: ntrials] 1aljekhfcdbgi
67 con = con[: ntrials] 1aljekhfcdbgi
68 contrastRight = con.copy() 1aljekhfcdbgi
69 contrastLeft = con.copy() 1aljekhfcdbgi
70 contrastRight[pos < 0] = np.nan 1aljekhfcdbgi
71 contrastLeft[pos > 0] = np.nan 1aljekhfcdbgi
72 qui = pcqsp[:, 2] 1aljekhfcdbgi
73 qui = qui[: ntrials] 1aljekhfcdbgi
74 phase = pcqsp[:, 3] 1aljekhfcdbgi
75 phase = phase[: ntrials] 1aljekhfcdbgi
76 pLeft = pcqsp[:, 4] 1aljekhfcdbgi
77 pLeft = pLeft[: ntrials] 1aljekhfcdbgi
79 phase_path = sessions_folder.joinpath(f"session_{num}_stim_phase.npy") 1aljekhfcdbgi
80 is_patched_version = version.parse( 1aljekhfcdbgi
81 settings.get('IBLRIG_VERSION') or '0') > version.parse('6.4.0')
82 if phase_path.exists() and is_patched_version: 1aljekhfcdbgi
83 phase = np.load(phase_path)[:ntrials] 1ljkf
85 return {'position': pos, 'quiescence': qui, 'phase': phase, 'probabilityLeft': pLeft, 1aljekhfcdbgi
86 'contrastRight': contrastRight, 'contrastLeft': contrastLeft}
89class TrialsTableBiased(BaseBpodTrialsExtractor):
90 """
91 Extracts the following into a table from Bpod raw data:
92 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
93 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
94 Additionally extracts the following wheel data:
95 wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude
96 """
97 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
98 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
99 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
100 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement')
102 def _extract(self, extractor_classes=None, **kwargs):
103 extractor_classes = extractor_classes or [] 1mnrstufopqvb
104 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1mnrstufopqvb
105 RewardVolume, ProbabilityLeft, Wheel]
106 out, _ = run_extractor_classes( 1mnrstufopqvb
107 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials,
108 settings=self.settings, save=False, task_collection=self.task_collection)
110 table = AlfBunch({k: out.pop(k) for k in list(out.keys()) if k not in self.var_names}) 1mnrstufopqvb
111 assert len(table.keys()) == 12 1mnrstufopqvb
113 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1mnrstufopqvb
116class TrialsTableEphys(BaseBpodTrialsExtractor):
117 """
118 Extracts the following into a table from Bpod raw data:
119 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
120 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
121 Additionally extracts the following wheel data:
122 wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude
123 """
124 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
125 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None,
126 None, None, None, '_ibl_trials.quiescencePeriod.npy')
127 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
128 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement',
129 'phase', 'position', 'quiescence')
131 def _extract(self, extractor_classes=None, **kwargs):
132 extractor_classes = extractor_classes or [] 1ajekhcdbgi
133 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ProbaContrasts, 1ajekhcdbgi
134 FeedbackTimes, FeedbackType, RewardVolume, Wheel]
135 # Exclude from trials table
136 out, _ = run_extractor_classes( 1ajekhcdbgi
137 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials,
138 settings=self.settings, save=False, task_collection=self.task_collection)
139 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1ajekhcdbgi
140 assert len(table.keys()) == 12 1ajekhcdbgi
142 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1ajekhcdbgi
145class BiasedTrials(BaseBpodTrialsExtractor):
146 """
147 Same as training_trials.TrainingTrials except...
148 - there is no RepNum
149 - ContrastLR is extracted differently
150 - IncludedTrials is only extracted for 5.0.0 or greater
151 """
152 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, None, None, None,
153 '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
154 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, '_ibl_trials.included.npy',
155 None, None, '_ibl_trials.quiescencePeriod.npy')
156 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
157 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
158 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
159 'phase', 'position', 'quiescence')
161 def _extract(self, extractor_classes=None, **kwargs) -> dict:
162 extractor_classes = extractor_classes or [] 1mnrstuopqvb
163 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1mnrstuopqvb
164 ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence]
165 # Exclude from trials table
166 out, _ = run_extractor_classes( 1mnrstuopqvb
167 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials,
168 settings=self.settings, save=False, task_collection=self.task_collection)
169 return {k: out[k] for k in self.var_names} 1mnrstuopqvb
172class EphysTrials(BaseBpodTrialsExtractor):
173 """
174 Same as BiasedTrials except...
175 - Contrast, phase, position, probabilityLeft and quiescence is extracted differently
176 """
177 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, None, None, None,
178 '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
179 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None,
180 '_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy')
181 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
182 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
183 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
184 'phase', 'position', 'quiescence')
186 def _extract(self, extractor_classes=None, **kwargs) -> dict:
187 extractor_classes = extractor_classes or [] 1jekhfgi
189 # For iblrig v8 we use the biased trials table instead. ContrastLeft, ContrastRight and ProbabilityLeft are
190 # filled from the values in the bpod data itself rather than using the pregenerated session number
191 iblrig_version = self.settings.get('IBLRIG_VERSION', self.settings.get('IBLRIG_VERSION_TAG', '0')) 1jekhfgi
192 if version.parse(iblrig_version) >= version.parse('8.0.0'): 1jekhfgi
193 TrialsTable = TrialsTableBiased 1f
194 else:
195 TrialsTable = TrialsTableEphys 1jekhgi
197 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1jekhfgi
198 ErrorCueTriggerTimes, TrialsTable, IncludedTrials, PhasePosQuiescence]
199 # Get all detected TTLs. These are stored for QC purposes
200 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1jekhfgi
201 # Exclude from trials table
202 out, _ = run_extractor_classes( 1jekhfgi
203 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials,
204 settings=self.settings, save=False, task_collection=self.task_collection)
205 return {k: out[k] for k in self.var_names} 1jekhfgi
208def extract_all(session_path, save=False, bpod_trials=False, settings=False, extra_classes=None,
209 task_collection='raw_behavior_data', save_path=None):
210 """
211 Same as training_trials.extract_all except...
212 - there is no RepNum
213 - ContrastLR is extracted differently
214 - IncludedTrials is only extracted for 5.0.0 or greater
216 :param session_path:
217 :param save:
218 :param bpod_trials:
219 :param settings:
220 :param extra_classes: additional BaseBpodTrialsExtractor subclasses for custom extractions
221 :return:
222 """
223 if not bpod_trials: 1amnropqcdb
224 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 1m
225 if not settings: 1amnropqcdb
226 settings = raw.load_settings(session_path, task_collection=task_collection) 1m
227 if settings is None: 1amnropqcdb
228 settings = {'IBLRIG_VERSION': '100.0.0'}
230 if settings['IBLRIG_VERSION'] == '': 1amnropqcdb
231 settings['IBLRIG_VERSION'] = '100.0.0'
233 # Version check
234 if version.parse(settings['IBLRIG_VERSION']) >= version.parse('5.0.0'): 1amnropqcdb
235 # We now extract a single trials table
236 base = [BiasedTrials] 1mnropqb
237 else:
238 base = [ 1amncdb
239 GoCueTriggerTimes, Intervals, Wheel, FeedbackType, ContrastLR, ProbabilityLeft, Choice,
240 StimOnTimes_deprecated, RewardVolume, FeedbackTimes, ResponseTimes, GoCueTimes, PhasePosQuiescence
241 ]
243 if extra_classes: 1amnropqcdb
244 base.extend(extra_classes) 1aopqcdb
246 out, fil = run_extractor_classes(base, save=save, session_path=session_path, bpod_trials=bpod_trials, settings=settings, 1amnropqcdb
247 task_collection=task_collection, path_out=save_path)
248 return out, fil 1amnropqcdb