Coverage for ibllib/io/extractors/biased_trials.py: 94%
104 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1from pathlib import Path, PureWindowsPath
3from pkg_resources import parse_version
4import numpy as np
5from one.alf.io import AlfBunch
7from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
8import ibllib.io.raw_data_loaders as raw
9from ibllib.io.extractors.training_trials import (
10 Choice, FeedbackTimes, FeedbackType, GoCueTimes, GoCueTriggerTimes,
11 IncludedTrials, Intervals, ProbabilityLeft, ResponseTimes, RewardVolume,
12 StimOnTimes_deprecated, StimOnTriggerTimes, StimOnOffFreezeTimes, ItiInTimes,
13 StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, PhasePosQuiescence)
14from ibllib.io.extractors.training_wheel import Wheel
16__all__ = ['extract_all', 'BiasedTrials', 'EphysTrials']
19class ContrastLR(BaseBpodTrialsExtractor):
20 """
21 Get left and right contrasts from raw datafile.
22 """
23 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy')
24 var_names = ('contrastLeft', 'contrastRight')
26 def _extract(self, **kwargs):
27 contrastLeft = np.array([t['contrast'] if np.sign( 1amwnsuxhtvcdeiopqjkrbfg
28 t['position']) < 0 else np.nan for t in self.bpod_trials])
29 contrastRight = np.array([t['contrast'] if np.sign( 1amwnsuxhtvcdeiopqjkrbfg
30 t['position']) > 0 else np.nan for t in self.bpod_trials])
31 return contrastLeft, contrastRight 1amwnsuxhtvcdeiopqjkrbfg
34class ProbaContrasts(BaseBpodTrialsExtractor):
35 """
36 Bpod pre-generated values for probabilityLeft, contrastLR, phase, quiescence
37 """
38 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', None, None,
39 '_ibl_trials.probabilityLeft.npy', '_ibl_trials.quiescencePeriod.npy')
40 var_names = ('contrastLeft', 'contrastRight', 'phase',
41 'position', 'probabilityLeft', 'quiescence')
43 def _extract(self, **kwargs):
44 """Extracts positions, contrasts, quiescent delay, stimulus phase and probability left
45 from pregenerated session files. Used in ephysChoiceWorld extractions.
46 Optional: saves alf contrastLR and probabilityLeft npy files"""
47 pe = self.get_pregenerated_events(self.bpod_trials, self.settings) 1alhcdeijkbfg
48 return [pe[k] for k in sorted(pe.keys())] 1alhcdeijkbfg
50 @staticmethod
51 def get_pregenerated_events(bpod_trials, settings):
52 num = settings.get("PRELOADED_SESSION_NUM", None) 1alhcdeijkbfg
53 if num is None: 1alhcdeijkbfg
54 num = settings.get("PREGENERATED_SESSION_NUM", None) 1ahcdeijkbfg
55 if num is None: 1alhcdeijkbfg
56 fn = settings.get('SESSION_LOADED_FILE_PATH', '') 1acdejkbfg
57 fn = PureWindowsPath(fn).name 1acdejkbfg
58 num = ''.join([d for d in fn if d.isdigit()]) 1acdejkbfg
59 if num == '': 1acdejkbfg
60 raise ValueError("Can't extract left probability behaviour.")
61 # Load the pregenerated file
62 ntrials = len(bpod_trials) 1alhcdeijkbfg
63 sessions_folder = Path(raw.__file__).parent.joinpath( 1alhcdeijkbfg
64 "extractors", "ephys_sessions")
65 fname = f"session_{num}_ephys_pcqs.npy" 1alhcdeijkbfg
66 pcqsp = np.load(sessions_folder.joinpath(fname)) 1alhcdeijkbfg
67 pos = pcqsp[:, 0] 1alhcdeijkbfg
68 con = pcqsp[:, 1] 1alhcdeijkbfg
69 pos = pos[: ntrials] 1alhcdeijkbfg
70 con = con[: ntrials] 1alhcdeijkbfg
71 contrastRight = con.copy() 1alhcdeijkbfg
72 contrastLeft = con.copy() 1alhcdeijkbfg
73 contrastRight[pos < 0] = np.nan 1alhcdeijkbfg
74 contrastLeft[pos > 0] = np.nan 1alhcdeijkbfg
75 qui = pcqsp[:, 2] 1alhcdeijkbfg
76 qui = qui[: ntrials] 1alhcdeijkbfg
77 phase = pcqsp[:, 3] 1alhcdeijkbfg
78 phase = phase[: ntrials] 1alhcdeijkbfg
79 pLeft = pcqsp[:, 4] 1alhcdeijkbfg
80 pLeft = pLeft[: ntrials] 1alhcdeijkbfg
82 phase_path = sessions_folder.joinpath(f"session_{num}_stim_phase.npy") 1alhcdeijkbfg
83 is_patched_version = parse_version( 1alhcdeijkbfg
84 settings.get('IBLRIG_VERSION_TAG', 0)) > parse_version('6.4.0')
85 if phase_path.exists() and is_patched_version: 1alhcdeijkbfg
86 phase = np.load(phase_path)[:ntrials] 1lhi
88 return {'position': pos, 'quiescence': qui, 'phase': phase, 'probabilityLeft': pLeft, 1alhcdeijkbfg
89 'contrastRight': contrastRight, 'contrastLeft': contrastLeft}
92class TrialsTableBiased(BaseBpodTrialsExtractor):
93 """
94 Extracts the following into a table from Bpod raw data:
95 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
96 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
97 Additionally extracts the following wheel data:
98 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
99 """
100 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
101 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
102 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals',
103 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement')
105 def _extract(self, extractor_classes=None, **kwargs):
106 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1mnsuhtvcdeiopqrbfg
107 RewardVolume, ProbabilityLeft, Wheel]
108 out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, 1mnsuhtvcdeiopqrbfg
109 save=False, task_collection=self.task_collection)
111 table = AlfBunch({k: out.pop(k) for k in list(out.keys()) if k not in self.var_names}) 1mnsuhtvcdeiopqrbfg
112 assert len(table.keys()) == 12 1mnsuhtvcdeiopqrbfg
114 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1mnsuhtvcdeiopqrbfg
117class TrialsTableEphys(BaseBpodTrialsExtractor):
118 """
119 Extracts the following into a table from Bpod raw data:
120 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
121 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
122 Additionally extracts the following wheel data:
123 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
124 """
125 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
126 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None,
127 None, None, None, '_ibl_trials.quiescencePeriod.npy')
128 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals',
129 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement',
130 'phase', 'position', 'quiescence')
132 def _extract(self, extractor_classes=None, **kwargs):
133 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ProbaContrasts, 1ahcdeijkbfg
134 FeedbackTimes, FeedbackType, RewardVolume, Wheel]
135 # Exclude from trials table
136 out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, 1ahcdeijkbfg
137 save=False, task_collection=self.task_collection)
138 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1ahcdeijkbfg
139 assert len(table.keys()) == 12 1ahcdeijkbfg
141 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1ahcdeijkbfg
144class BiasedTrials(BaseBpodTrialsExtractor):
145 """
146 Same as training_trials.TrainingTrials except...
147 - there is no RepNum
148 - ContrastLR is extracted differently
149 - IncludedTrials is only extracted for 5.0.0 or greater
150 """
151 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, None, None, None,
152 '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
153 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, '_ibl_trials.included.npy',
154 None, None, '_ibl_trials.quiescencePeriod.npy')
155 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
156 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
157 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included',
158 'phase', 'position', 'quiescence')
160 def _extract(self, extractor_classes=None, **kwargs):
161 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1mnsuhtvcdeiopqrbfg
162 ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence]
163 # Exclude from trials table
164 out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, 1mnsuhtvcdeiopqrbfg
165 save=False, task_collection=self.task_collection)
166 return tuple(out.pop(x) for x in self.var_names) 1mnsuhtvcdeiopqrbfg
169class EphysTrials(BaseBpodTrialsExtractor):
170 """
171 Same as BiasedTrials except...
172 - Contrast, phase, position, probabilityLeft and quiescence is extracted differently
173 """
174 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, None, None, None,
175 '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
176 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None,
177 '_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy')
178 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
179 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
180 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included',
181 'phase', 'position', 'quiescence')
183 def _extract(self, extractor_classes=None, **kwargs):
184 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
185 ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence]
186 # Exclude from trials table
187 out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
188 save=False, task_collection=self.task_collection)
189 return tuple(out.pop(x) for x in self.var_names)
192def extract_all(session_path, save=False, bpod_trials=False, settings=False, extra_classes=None,
193 task_collection='raw_behavior_data', save_path=None):
194 """
195 Same as training_trials.extract_all except...
196 - there is no RepNum
197 - ContrastLR is extracted differently
198 - IncludedTrials is only extracted for 5.0.0 or greater
200 :param session_path:
201 :param save:
202 :param bpod_trials:
203 :param settings:
204 :param extra_classes: additional BaseBpodTrialsExtractor subclasses for custom extractions
205 :return:
206 """
207 if not bpod_trials: 1amnshtcdeiopqjkrbfg
208 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 1m
209 if not settings: 1amnshtcdeiopqjkrbfg
210 settings = raw.load_settings(session_path, task_collection=task_collection) 1m
211 if settings is None: 1amnshtcdeiopqjkrbfg
212 settings = {'IBLRIG_VERSION_TAG': '100.0.0'}
214 if settings['IBLRIG_VERSION_TAG'] == '': 1amnshtcdeiopqjkrbfg
215 settings['IBLRIG_VERSION_TAG'] = '100.0.0'
217 # Version check
218 if parse_version(settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1amnshtcdeiopqjkrbfg
219 # We now extract a single trials table
220 base = [BiasedTrials] 1mnshcdeiopqrbfg
221 else:
222 base = [ 1amntjkrb
223 GoCueTriggerTimes, Intervals, Wheel, FeedbackType, ContrastLR, ProbabilityLeft, Choice,
224 StimOnTimes_deprecated, RewardVolume, FeedbackTimes, ResponseTimes, GoCueTimes, PhasePosQuiescence
225 ]
227 if extra_classes: 1amnshtcdeiopqjkrbfg
228 base.extend(extra_classes) 1ahcdeiopqjkbfg
230 out, fil = run_extractor_classes(base, save=save, session_path=session_path, bpod_trials=bpod_trials, settings=settings, 1amnshtcdeiopqjkrbfg
231 task_collection=task_collection, path_out=save_path)
232 return out, fil 1amnshtcdeiopqjkrbfg