Coverage for ibllib/qc/task_extractors.py: 97%
36 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1import logging
3import numpy as np
5from one.alf.spec import is_session_path
6import one.alf.io as alfio
9_logger = logging.getLogger('ibllib')
11REQUIRED_FIELDS = ['choice', 'contrastLeft', 'contrastRight', 'correct',
12 'errorCueTrigger_times', 'errorCue_times', 'feedbackType', 'feedback_times',
13 'firstMovement_times', 'goCueTrigger_times', 'goCue_times', 'intervals',
14 'itiIn_times', 'phase', 'position', 'probabilityLeft', 'quiescence',
15 'response_times', 'rewardVolume', 'stimFreezeTrigger_times',
16 'stimFreeze_times', 'stimOffTrigger_times', 'stimOff_times',
17 'stimOnTrigger_times', 'stimOn_times', 'valveOpen_times',
18 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times',
19 'wheel_position', 'wheel_timestamps']
22class TaskQCExtractor:
23 def __init__(self, session_path):
24 """
25 A class for holding the task data required to perform task quality control.
26 :param session_path: a valid session path
27 """
28 if not is_session_path(session_path): 1acbd
29 raise ValueError('Invalid session path')
30 self.session_path = session_path 1acbd
31 self.log = _logger 1acbd
33 self.data = None 1acbd
34 self.settings = None 1acbd
35 self.raw_data = None 1acbd
36 self.frame_ttls = self.audio_ttls = self.bpod_ttls = None 1acbd
37 self.wheel_encoding = None 1acbd
39 @staticmethod
40 def rename_data(data):
41 """Rename the extracted data dict for use with TaskQC
42 Splits 'feedback_times' to 'errorCue_times' and 'valveOpen_times'.
43 NB: The data is not copied before making changes
44 :param data: A dict of task data returned by the task extractors
45 :return: the same dict after modifying the keys
46 """
47 # Expand trials dataframe into key value pairs
48 trials_table = data.pop('table', None) 1acbd
49 if trials_table is not None: 1acbd
50 data = {**data, **alfio.AlfBunch.from_df(trials_table)} 1acbd
51 correct = data['feedbackType'] > 0 1acbd
52 # get valve_time and errorCue_times from feedback_times
53 if 'errorCue_times' not in data: 1acbd
54 data['errorCue_times'] = data['feedback_times'].copy() 1acb
55 data['errorCue_times'][correct] = np.nan 1acb
56 if 'valveOpen_times' not in data: 1acbd
57 data['valveOpen_times'] = data['feedback_times'].copy() 1acb
58 data['valveOpen_times'][~correct] = np.nan 1acb
59 data['correct'] = correct 1acbd
60 diff_fields = list(set(REQUIRED_FIELDS).difference(set(data.keys()))) 1acbd
61 for miss_field in diff_fields: 1acbd
62 data[miss_field] = None if miss_field.startswith('wheel') else data['feedback_times'] * np.nan 1b
63 if len(diff_fields): 1acbd
64 _logger.warning(f'QC extractor, missing fields filled with NaNs: {diff_fields}') 1b
65 return data 1acbd