Coverage for ibllib/qc/task_extractors.py: 97%

116 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1import logging 

2 

3import numpy as np 

4from scipy.interpolate import interp1d 

5 

6from ibllib.io.extractors import bpod_trials 

7from ibllib.io.extractors.base import get_session_extractor_type 

8from ibllib.io.extractors.training_wheel import get_wheel_position 

9from ibllib.io.extractors import ephys_fpga 

10import ibllib.io.raw_data_loaders as raw 

11from one.alf.spec import is_session_path 

12import one.alf.io as alfio 

13from one.api import ONE 

14 

15 

16_logger = logging.getLogger('ibllib') 

17 

18REQUIRED_FIELDS = ['choice', 'contrastLeft', 'contrastRight', 'correct', 

19 'errorCueTrigger_times', 'errorCue_times', 'feedbackType', 'feedback_times', 

20 'firstMovement_times', 'goCueTrigger_times', 'goCue_times', 'intervals', 

21 'itiIn_times', 'phase', 'position', 'probabilityLeft', 'quiescence', 

22 'response_times', 'rewardVolume', 'stimFreezeTrigger_times', 

23 'stimFreeze_times', 'stimOffTrigger_times', 'stimOff_times', 

24 'stimOnTrigger_times', 'stimOn_times', 'valveOpen_times', 

25 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 

26 'wheel_position', 'wheel_timestamps'] 

27 

28 

29class TaskQCExtractor(object): 

30 def __init__(self, session_path, lazy=False, one=None, download_data=False, bpod_only=False, 

31 sync_collection=None, sync_type=None, task_collection=None): 

32 """ 

33 A class for extracting the task data required to perform task quality control 

34 :param session_path: a valid session path 

35 :param lazy: if True, the data are not extracted immediately 

36 :param one: an instance of ONE, used to download the raw data if download_data is True 

37 :param download_data: if True, any missing raw data is downloaded via ONE 

38 :param bpod_only: extract from from raw Bpod data only, even for FPGA sessions 

39 """ 

40 if not is_session_path(session_path): 1adhebcfkjlig

41 raise ValueError('Invalid session path') 

42 self.session_path = session_path 1adhebcfkjlig

43 self.one = one 1adhebcfkjlig

44 self.log = _logger 1adhebcfkjlig

45 

46 self.data = None 1adhebcfkjlig

47 self.settings = None 1adhebcfkjlig

48 self.raw_data = None 1adhebcfkjlig

49 self.frame_ttls = self.audio_ttls = self.bpod_ttls = None 1adhebcfkjlig

50 self.type = None 1adhebcfkjlig

51 self.wheel_encoding = None 1adhebcfkjlig

52 self.bpod_only = bpod_only 1adhebcfkjlig

53 self.sync_collection = sync_collection or 'raw_ephys_data' 1adhebcfkjlig

54 self.sync_type = sync_type 1adhebcfkjlig

55 self.task_collection = task_collection or 'raw_behavior_data' 1adhebcfkjlig

56 

57 if download_data: 1adhebcfkjlig

58 self.one = one or ONE() 1k

59 self._ensure_required_data() 1k

60 

61 if not lazy: 1adhebcfkjlig

62 self.load_raw_data() 1adhbcg

63 self.extract_data() 1adhbcg

64 

65 def _ensure_required_data(self): 

66 """ 

67 Attempt to download any required raw data if missing, and raise exception if any data are 

68 missing. 

69 :return: 

70 """ 

71 dstypes = [ 1k

72 '_iblrig_taskData.raw', 

73 '_iblrig_taskSettings.raw', 

74 '_iblrig_encoderPositions.raw', 

75 '_iblrig_encoderEvents.raw', 

76 '_iblrig_stimPositionScreen.raw', 

77 '_iblrig_syncSquareUpdate.raw', 

78 '_iblrig_encoderTrialInfo.raw', 

79 '_iblrig_ambientSensorData.raw', 

80 ] 

81 eid = self.one.path2eid(self.session_path) 1k

82 self.log.info(f'Downloading data for session {eid}') 1k

83 # Ensure we have the settings 

84 settings, _ = self.one.load_datasets(eid, ['_iblrig_taskSettings.raw.json'], 1k

85 collections=[self.task_collection], 

86 download_only=True, assert_present=False) 

87 

88 is_ephys = get_session_extractor_type(self.session_path, task_collection=self.task_collection) == 'ephys' 1k

89 self.sync_type = self.sync_type or 'nidq' if is_ephys else 'bpod' 1k

90 is_fpga = 'bpod' not in self.sync_type 1k

91 

92 if settings and is_ephys: 1k

93 

94 dstypes.extend(['_spikeglx_sync.channels', 1k

95 '_spikeglx_sync.polarities', 

96 '_spikeglx_sync.times', 

97 'ephysData.raw.meta', 

98 'ephysData.raw.wiring']) 

99 elif settings and is_fpga: 1k

100 

101 dstypes.extend(['_spikeglx_sync.channels', 

102 '_spikeglx_sync.polarities', 

103 '_spikeglx_sync.times', 

104 'DAQData.raw.meta', 

105 'DAQData.wiring']) 

106 

107 dataset = self.one.type2datasets(eid, dstypes, details=True) 1k

108 files = self.one._check_filesystem(dataset) 1k

109 

110 missing = [True] * len(dstypes) if not files else [x is None for x in files] 1k

111 if self.session_path is None or all(missing): 1k

112 self.lazy = True 1k

113 self.log.error('Data not found on server, can\'t calculate QC.') 1k

114 elif any(missing): 

115 self.log.warning( 

116 f'Missing some datasets for session {eid} in path {self.session_path}' 

117 ) 

118 

119 def load_raw_data(self): 

120 """ 

121 Loads the TTLs, raw task data and task settings 

122 :return: 

123 """ 

124 self.log.info(f'Loading raw data from {self.session_path}') 1adhebcfjig

125 self.type = self.type or get_session_extractor_type(self.session_path, task_collection=self.task_collection) 1adhebcfjig

126 # Finds the sync type when it isn't explicitly set, if ephys we assume nidq otherwise bpod 

127 self.sync_type = self.sync_type or 'nidq' if self.type == 'ephys' else 'bpod' 1adhebcfjig

128 

129 self.settings, self.raw_data = raw.load_bpod(self.session_path, task_collection=self.task_collection) 1adhebcfjig

130 # Fetch the TTLs for the photodiode and audio 

131 if self.sync_type == 'bpod' or self.bpod_only is True: # Extract from Bpod 1adhebcfjig

132 self.frame_ttls, self.audio_ttls = raw.load_bpod_fronts( 1adhbcjig

133 self.session_path, data=self.raw_data, task_collection=self.task_collection) 

134 else: # Extract from FPGA 

135 sync, chmap = ephys_fpga.get_sync_and_chn_map(self.session_path, self.sync_collection) 1ebcf

136 

137 def channel_events(name): 1ebcf

138 """Fetches the polarities and times for a given channel""" 

139 keys = ('polarities', 'times') 1ebcf

140 mask = sync['channels'] == chmap[name] 1ebcf

141 return dict(zip(keys, (sync[k][mask] for k in keys))) 1ebcf

142 

143 ttls = [ephys_fpga._clean_frame2ttl(channel_events('frame2ttl')), 1ebcf

144 ephys_fpga._clean_audio(channel_events('audio')), 

145 channel_events('bpod')] 

146 self.frame_ttls, self.audio_ttls, self.bpod_ttls = ttls 1ebcf

147 

148 def extract_data(self): 

149 """Extracts and loads behaviour data for QC 

150 NB: partial extraction when bpod_only attribute is False requires intervals and 

151 intervals_bpod to be assigned to the data attribute before calling this function. 

152 :return: 

153 """ 

154 self.log.info(f'Extracting session: {self.session_path}') 1adhebcfjig

155 self.type = self.type or get_session_extractor_type(self.session_path, task_collection=self.task_collection) 1adhebcfjig

156 # Finds the sync type when it isn't explicitly set, if ephys we assume nidq otherwise bpod 

157 self.sync_type = self.sync_type or 'nidq' if self.type == 'ephys' else 'bpod' 1adhebcfjig

158 

159 self.wheel_encoding = 'X4' if (self.sync_type != 'bpod' and not self.bpod_only) else 'X1' 1adhebcfjig

160 

161 if not self.raw_data: 1adhebcfjig

162 self.load_raw_data() 1ebcfi

163 # Run extractors 

164 if self.sync_type != 'bpod' and not self.bpod_only: 1adhebcfjig

165 data, _ = ephys_fpga.extract_all(self.session_path, save=False, task_collection=self.task_collection) 1ebcf

166 bpod2fpga = interp1d(data['intervals_bpod'][:, 0], data['table']['intervals_0'], 1ebcf

167 fill_value='extrapolate') 

168 # Add Bpod wheel data 

169 re_ts, pos = get_wheel_position(self.session_path, self.raw_data, task_collection=self.task_collection) 1ebcf

170 data['wheel_timestamps_bpod'] = bpod2fpga(re_ts) 1ebcf

171 data['wheel_position_bpod'] = pos 1ebcf

172 else: 

173 kwargs = dict(save=False, bpod_trials=self.raw_data, settings=self.settings, task_collection=self.task_collection) 1adhbcjig

174 trials, wheel, _ = bpod_trials.extract_all(self.session_path, **kwargs) 1adhbcjig

175 n_trials = np.unique(list(map(lambda k: trials[k].shape[0], trials)))[0] 1adhbcjig

176 if self.type == 'habituation': 1adhbcjig

177 data = trials 1d

178 data['position'] = np.array([t['position'] for t in self.raw_data]) 1d

179 data['phase'] = np.array([t['stim_phase'] for t in self.raw_data]) 1d

180 # Nasty hack to trim last trial due to stim off events happening at trial num + 1 

181 data = {k: v[:n_trials] for k, v in data.items()} 1d

182 else: 

183 data = {**trials, **wheel} 1ahbcjig

184 # Update the data attribute with extracted data 

185 self.data = self.rename_data(data) 1adhebcfjig

186 

187 @staticmethod 

188 def rename_data(data): 

189 """Rename the extracted data dict for use with TaskQC 

190 Splits 'feedback_times' to 'errorCue_times' and 'valveOpen_times'. 

191 NB: The data is not copied before making changes 

192 :param data: A dict of task data returned by the task extractors 

193 :return: the same dict after modifying the keys 

194 """ 

195 # Expand trials dataframe into key value pairs 

196 trials_table = data.pop('table', None) 1adhebcfjig

197 if trials_table is not None: 1adhebcfjig

198 data = {**data, **alfio.AlfBunch.from_df(trials_table)} 1aebcfjig

199 correct = data['feedbackType'] > 0 1adhebcfjig

200 # get valve_time and errorCue_times from feedback_times 

201 if 'errorCue_times' not in data: 1adhebcfjig

202 data['errorCue_times'] = data['feedback_times'].copy() 1adhbcjig

203 data['errorCue_times'][correct] = np.nan 1adhbcjig

204 if 'valveOpen_times' not in data: 1adhebcfjig

205 data['valveOpen_times'] = data['feedback_times'].copy() 1adhbcjig

206 data['valveOpen_times'][~correct] = np.nan 1adhbcjig

207 if 'wheel_moves_intervals' not in data and 'wheelMoves_intervals' in data: 1adhebcfjig

208 data['wheel_moves_intervals'] = data.pop('wheelMoves_intervals') 1ebcf

209 if 'wheel_moves_peak_amplitude' not in data and 'wheelMoves_peakAmplitude' in data: 1adhebcfjig

210 data['wheel_moves_peak_amplitude'] = data.pop('wheelMoves_peakAmplitude') 1ebcf

211 data['correct'] = correct 1adhebcfjig

212 diff_fields = list(set(REQUIRED_FIELDS).difference(set(data.keys()))) 1adhebcfjig

213 for miss_field in diff_fields: 1adhebcfjig

214 data[miss_field] = data['feedback_times'] * np.nan 1adhjig

215 if len(diff_fields): 1adhebcfjig

216 _logger.warning(f'QC extractor, missing fields filled with NaNs: {diff_fields}') 1adhjig

217 return data 1adhebcfjig