Coverage for ibllib/qc/task_extractors.py: 83%
116 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1import logging
2import warnings
4import numpy as np
5from scipy.interpolate import interp1d
7from ibllib.io.extractors import bpod_trials
8from ibllib.io.extractors.base import get_session_extractor_type
9from ibllib.io.extractors.training_wheel import get_wheel_position
10from ibllib.io.extractors import ephys_fpga
11import ibllib.io.raw_data_loaders as raw
12from one.alf.spec import is_session_path
13import one.alf.io as alfio
14from one.api import ONE
17_logger = logging.getLogger('ibllib')
19REQUIRED_FIELDS = ['choice', 'contrastLeft', 'contrastRight', 'correct',
20 'errorCueTrigger_times', 'errorCue_times', 'feedbackType', 'feedback_times',
21 'firstMovement_times', 'goCueTrigger_times', 'goCue_times', 'intervals',
22 'itiIn_times', 'phase', 'position', 'probabilityLeft', 'quiescence',
23 'response_times', 'rewardVolume', 'stimFreezeTrigger_times',
24 'stimFreeze_times', 'stimOffTrigger_times', 'stimOff_times',
25 'stimOnTrigger_times', 'stimOn_times', 'valveOpen_times',
26 'wheel_moves_intervals', 'wheel_moves_peak_amplitude',
27 'wheel_position', 'wheel_timestamps']
30class TaskQCExtractor:
31 def __init__(self, session_path, lazy=False, one=None, download_data=False, bpod_only=False,
32 sync_collection=None, sync_type=None, task_collection=None):
33 """
34 A class for extracting the task data required to perform task quality control.
35 :param session_path: a valid session path
36 :param lazy: if True, the data are not extracted immediately
37 :param one: an instance of ONE, used to download the raw data if download_data is True
38 :param download_data: if True, any missing raw data is downloaded via ONE
39 :param bpod_only: extract from raw Bpod data only, even for FPGA sessions
40 """
41 if not is_session_path(session_path): 1aedhgcibf
42 raise ValueError('Invalid session path')
43 self.session_path = session_path 1aedhgcibf
44 self.one = one 1aedhgcibf
45 self.log = _logger 1aedhgcibf
47 self.data = None 1aedhgcibf
48 self.settings = None 1aedhgcibf
49 self.raw_data = None 1aedhgcibf
50 self.frame_ttls = self.audio_ttls = self.bpod_ttls = None 1aedhgcibf
51 self.type = None 1aedhgcibf
52 self.wheel_encoding = None 1aedhgcibf
53 self.bpod_only = bpod_only 1aedhgcibf
54 self.sync_collection = sync_collection or 'raw_ephys_data' 1aedhgcibf
55 self.sync_type = sync_type 1aedhgcibf
56 self.task_collection = task_collection or 'raw_behavior_data' 1aedhgcibf
58 if download_data: 1aedhgcibf
59 self.one = one or ONE() 1g
60 self._ensure_required_data() 1g
62 if not lazy: 1aedhgcibf
63 self.load_raw_data()
64 self.extract_data()
66 def _ensure_required_data(self):
67 """
68 Attempt to download any required raw data if missing, and raise exception if any data are
69 missing.
70 :return:
71 """
72 dstypes = [ 1g
73 '_iblrig_taskData.raw',
74 '_iblrig_taskSettings.raw',
75 '_iblrig_encoderPositions.raw',
76 '_iblrig_encoderEvents.raw',
77 '_iblrig_stimPositionScreen.raw',
78 '_iblrig_syncSquareUpdate.raw',
79 '_iblrig_encoderTrialInfo.raw',
80 '_iblrig_ambientSensorData.raw',
81 ]
82 eid = self.one.path2eid(self.session_path) 1g
83 self.log.info(f'Downloading data for session {eid}') 1g
84 # Ensure we have the settings
85 settings, _ = self.one.load_datasets(eid, ['_iblrig_taskSettings.raw.json'], 1g
86 collections=[self.task_collection],
87 download_only=True, assert_present=False)
89 is_ephys = get_session_extractor_type(self.session_path, task_collection=self.task_collection) == 'ephys' 1g
90 self.sync_type = self.sync_type or 'nidq' if is_ephys else 'bpod' 1g
91 is_fpga = 'bpod' not in self.sync_type 1g
93 if settings and is_ephys: 1g
95 dstypes.extend(['_spikeglx_sync.channels', 1g
96 '_spikeglx_sync.polarities',
97 '_spikeglx_sync.times',
98 'ephysData.raw.meta',
99 'ephysData.raw.wiring'])
100 elif settings and is_fpga: 1g
102 dstypes.extend(['_spikeglx_sync.channels',
103 '_spikeglx_sync.polarities',
104 '_spikeglx_sync.times',
105 'DAQData.raw.meta',
106 'DAQData.wiring'])
108 dataset = self.one.type2datasets(eid, dstypes, details=True) 1g
109 files = self.one._check_filesystem(dataset) 1g
111 missing = [True] * len(dstypes) if not files else [x is None for x in files] 1g
112 if self.session_path is None or all(missing): 1g
113 self.lazy = True 1g
114 self.log.error('Data not found on server, can\'t calculate QC.') 1g
115 elif any(missing):
116 self.log.warning(
117 f'Missing some datasets for session {eid} in path {self.session_path}'
118 )
120 def load_raw_data(self):
121 """Loads the TTLs, raw task data and task settings."""
122 self.log.info(f'Loading raw data from {self.session_path}') 1acb
123 self.type = self.type or get_session_extractor_type(self.session_path, task_collection=self.task_collection) 1acb
124 # Finds the sync type when it isn't explicitly set, if ephys we assume nidq otherwise bpod
125 self.sync_type = self.sync_type or 'nidq' if self.type == 'ephys' else 'bpod' 1acb
126 self.wheel_encoding = 'X4' if (self.sync_type != 'bpod' and not self.bpod_only) else 'X1' 1acb
128 self.settings, self.raw_data = raw.load_bpod(self.session_path, task_collection=self.task_collection) 1acb
129 # Fetch the TTLs for the photodiode and audio
130 if self.sync_type == 'bpod' or self.bpod_only is True: # Extract from Bpod 1acb
131 self.frame_ttls, self.audio_ttls = raw.load_bpod_fronts( 1acb
132 self.session_path, data=self.raw_data, task_collection=self.task_collection)
133 else: # Extract from FPGA
134 sync, chmap = ephys_fpga.get_sync_and_chn_map(self.session_path, self.sync_collection)
136 def channel_events(name):
137 """Fetches the polarities and times for a given channel"""
138 keys = ('polarities', 'times')
139 mask = sync['channels'] == chmap[name]
140 return dict(zip(keys, (sync[k][mask] for k in keys)))
142 ttls = [ephys_fpga._clean_frame2ttl(channel_events('frame2ttl')),
143 ephys_fpga._clean_audio(channel_events('audio')),
144 channel_events('bpod')]
145 self.frame_ttls, self.audio_ttls, self.bpod_ttls = ttls
147 def extract_data(self):
148 """Extracts and loads behaviour data for QC.
150 NB: partial extraction when bpod_only attribute is False requires intervals and
151 intervals_bpod to be assigned to the data attribute before calling this function.
152 """
153 warnings.warn('The TaskQCExtractor.extract_data will be removed in the future, ' 1acb
154 'use dynamic pipeline behaviour tasks instead.', FutureWarning)
155 self.log.info(f'Extracting session: {self.session_path}') 1acb
157 if not self.raw_data: 1acb
158 self.load_raw_data() 1b
160 # Run extractors
161 if self.sync_type != 'bpod' and not self.bpod_only: 1acb
162 data, _ = ephys_fpga.extract_all(self.session_path, save=False, task_collection=self.task_collection)
163 bpod2fpga = interp1d(data['intervals_bpod'][:, 0], data['table']['intervals_0'],
164 fill_value='extrapolate')
165 # Add Bpod wheel data
166 re_ts, pos = get_wheel_position(self.session_path, self.raw_data, task_collection=self.task_collection)
167 data['wheel_timestamps_bpod'] = bpod2fpga(re_ts)
168 data['wheel_position_bpod'] = pos
169 else:
170 kwargs = dict(save=False, bpod_trials=self.raw_data, settings=self.settings, task_collection=self.task_collection) 1acb
171 trials, wheel, _ = bpod_trials.extract_all(self.session_path, **kwargs) 1acb
172 n_trials = np.unique(list(map(lambda k: trials[k].shape[0], trials)))[0] 1acb
173 if self.type == 'habituation': 1acb
174 data = trials
175 data['position'] = np.array([t['position'] for t in self.raw_data])
176 data['phase'] = np.array([t['stim_phase'] for t in self.raw_data])
177 # Nasty hack to trim last trial due to stim off events happening at trial num + 1
178 data = {k: v[:n_trials] for k, v in data.items()}
179 else:
180 data = {**trials, **wheel} 1acb
181 # Update the data attribute with extracted data
182 self.data = self.rename_data(data) 1acb
184 @staticmethod
185 def rename_data(data):
186 """Rename the extracted data dict for use with TaskQC
187 Splits 'feedback_times' to 'errorCue_times' and 'valveOpen_times'.
188 NB: The data is not copied before making changes
189 :param data: A dict of task data returned by the task extractors
190 :return: the same dict after modifying the keys
191 """
192 # Expand trials dataframe into key value pairs
193 trials_table = data.pop('table', None) 1aedhcbf
194 if trials_table is not None: 1aedhcbf
195 data = {**data, **alfio.AlfBunch.from_df(trials_table)} 1aedhcbf
196 correct = data['feedbackType'] > 0 1aedhcbf
197 # get valve_time and errorCue_times from feedback_times
198 if 'errorCue_times' not in data: 1aedhcbf
199 data['errorCue_times'] = data['feedback_times'].copy() 1aedcbf
200 data['errorCue_times'][correct] = np.nan 1aedcbf
201 if 'valveOpen_times' not in data: 1aedhcbf
202 data['valveOpen_times'] = data['feedback_times'].copy() 1aedcbf
203 data['valveOpen_times'][~correct] = np.nan 1aedcbf
204 if 'wheel_moves_intervals' not in data and 'wheelMoves_intervals' in data: 1aedhcbf
205 data['wheel_moves_intervals'] = data.pop('wheelMoves_intervals') 1aedhcbf
206 if 'wheel_moves_peak_amplitude' not in data and 'wheelMoves_peakAmplitude' in data: 1aedhcbf
207 data['wheel_moves_peak_amplitude'] = data.pop('wheelMoves_peakAmplitude') 1aedhcbf
208 data['correct'] = correct 1aedhcbf
209 diff_fields = list(set(REQUIRED_FIELDS).difference(set(data.keys()))) 1aedhcbf
210 for miss_field in diff_fields: 1aedhcbf
211 data[miss_field] = data['feedback_times'] * np.nan 1adcb
212 if len(diff_fields): 1aedhcbf
213 _logger.warning(f'QC extractor, missing fields filled with NaNs: {diff_fields}') 1adcb
214 return data 1aedhcbf