Coverage for ibllib/qc/dlc.py: 95%

168 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1"""DLC QC 

2This module runs a list of quality control metrics on the dlc traces. 

3 

4Example - Run DLC QC 

5 qc = DlcQC(eid, 'left', download_data=True) 

6 qc.run() 

7Question: 

8 We're not extracting the audio based on TTL length. Is this a problem? 

9""" 

10import logging 

11import warnings 

12from inspect import getmembers, isfunction 

13 

14import numpy as np 

15 

16from ibllib.qc import base 

17import one.alf.io as alfio 

18from one.alf.exceptions import ALFObjectNotFound 

19from one.alf.spec import is_session_path 

20from iblutil.util import Bunch 

21from brainbox.behavior.dlc import insert_idx, SAMPLING 

22 

23_log = logging.getLogger(__name__) 

24 

25 

26class DlcQC(base.QC): 

27 """A class for computing camera QC metrics""" 

28 

29 bbox = { 

30 'body': { 

31 'xrange': range(201, 500), 

32 'yrange': range(81, 330) 

33 }, 

34 'left': { 

35 'xrange': range(301, 700), 

36 'yrange': range(181, 470) 

37 }, 

38 'right': { 

39 'xrange': range(301, 600), 

40 'yrange': range(110, 275) 

41 }, 

42 } 

43 

44 dstypes = { 

45 'left': [ 

46 '_ibl_leftCamera.dlc.*', '_ibl_leftCamera.times.*', '_ibl_leftCamera.features.*', '_ibl_trials.table.*' 

47 ], 

48 'right': [ 

49 '_ibl_rightCamera.dlc.*', '_ibl_rightCamera.times.*', '_ibl_rightCamera.features.*', '_ibl_trials.table.*' 

50 ], 

51 'body': [ 

52 '_ibl_bodyCamera.dlc.*', '_ibl_bodyCamera.times.*' 

53 ], 

54 } 

55 

56 def __init__(self, session_path_or_eid, side, ignore_checks=['check_pupil_diameter_snr'], **kwargs): 

57 """ 

58 :param session_path_or_eid: A session eid or path 

59 :param side: The camera to run QC on 

60 :param ignore_checks: Checks that won't count towards aggregate QC, but will be run and added to extended QC 

61 :param log: A logging.Logger instance, if None the 'ibllib' logger is used 

62 :param one: An ONE instance for fetching and setting the QC on Alyx 

63 """ 

64 # Make sure the type of camera is chosen 

65 self.side = side 1ba

66 # When an eid is provided, we will download the required data by default (if necessary) 

67 download_data = not is_session_path(session_path_or_eid) 1ba

68 self.download_data = kwargs.pop('download_data', download_data) 1ba

69 super().__init__(session_path_or_eid, **kwargs) 1ba

70 self.data = Bunch() 1ba

71 

72 # checks to be added to extended QC but not taken into account for aggregate QC 

73 self.ignore_checks = ignore_checks 1ba

74 # QC outcomes map 

75 self.metrics = None 1ba

76 

77 def load_data(self, download_data: bool = None) -> None: 

78 """Extract the data from data files 

79 Extracts all the required task data from the data files. 

80 

81 Data keys: 

82 - camera_times (float array): camera frame timestamps extracted from frame headers 

83 - dlc_coords (dict): keys are the points traced by dlc, items are x-y coordinates of 

84 these points over time, those with likelihood <0.9 set to NaN 

85 

86 :param download_data: if True, any missing raw data is downloaded via ONE. 

87 """ 

88 if download_data is not None: 1da

89 self.download_data = download_data 

90 if self.one and not self.one.offline: 1da

91 self._ensure_required_data() 1da

92 

93 alf_path = self.session_path / 'alf' 1a

94 

95 # Load times 

96 self.data['camera_times'] = alfio.load_object(alf_path, f'{self.side}Camera')['times'] 1a

97 # Load dlc traces 

98 dlc_df = alfio.load_object(alf_path, f'{self.side}Camera', namespace='ibl')['dlc'] 1a

99 targets = np.unique(['_'.join(col.split('_')[:-1]) for col in dlc_df.columns]) 1a

100 # Set values to nan if likelihood is too low 

101 dlc_coords = {} 1a

102 for t in targets: 1a

103 idx = dlc_df.loc[dlc_df[f'{t}_likelihood'] < 0.9].index 1a

104 dlc_df.loc[idx, [f'{t}_x', f'{t}_y']] = np.nan 1a

105 dlc_coords[t] = np.array((dlc_df[f'{t}_x'], dlc_df[f'{t}_y'])) 1a

106 self.data['dlc_coords'] = dlc_coords 1a

107 

108 # load stim on times 

109 self.data['stimOn_times'] = alfio.load_object(alf_path, 'trials', namespace='ibl')['stimOn_times'] 1a

110 

111 # load pupil diameters 

112 if self.side in ['left', 'right']: 1a

113 features = alfio.load_object(alf_path, f'{self.side}Camera', namespace='ibl')['features'] 1a

114 self.data['pupilDiameter_raw'] = features['pupilDiameter_raw'] 1a

115 self.data['pupilDiameter_smooth'] = features['pupilDiameter_smooth'] 1a

116 

117 def _ensure_required_data(self): 

118 """ 

119 Ensures the datasets required for QC are local. If the download_data attribute is True, 

120 any missing data are downloaded. If all the data are not present locally at the end of 

121 it an exception is raised. 

122 :return: 

123 """ 

124 for ds in self.dstypes[self.side]: 1da

125 # Check if data available locally 

126 if not next(self.session_path.rglob(ds), None): 1da

127 # If download is allowed, try to download 

128 if self.download_data is True: 1d

129 assert self.one is not None, 'ONE required to download data' 1d

130 try: 1d

131 self.one.load_dataset(self.eid, ds, download_only=True) 1d

132 except ALFObjectNotFound: 1d

133 raise AssertionError(f'Dataset {ds} not found locally and failed to download') 1d

134 else: 

135 raise AssertionError(f'Dataset {ds} not found locally and download_data is False') 1d

136 

137 def _compute_trial_window_idxs(self): 

138 """Find start and end times of a window around stimulus onsets in video indices.""" 

139 window_lag = -0.5 1ca

140 window_len = 2.0 1ca

141 start_window = self.data['stimOn_times'] + window_lag 1ca

142 start_idx = insert_idx(self.data['camera_times'], start_window) 1ca

143 end_idx = np.array(start_idx + int(window_len * SAMPLING[self.side]), dtype='int64') 1ca

144 return start_idx, end_idx 1ca

145 

146 def _compute_proportion_nan_in_trial_window(self, body_part): 

147 """Find proportion of NaN frames for a given body part in trial-based windows.""" 

148 # find timepoints in windows around stimulus onset 

149 start_idx, end_idx = self._compute_trial_window_idxs() 1ca

150 # compute fraction of points in windows that are NaN 

151 dlc_coords = np.concatenate([self.data['dlc_coords'][body_part][0, start_idx[i]:end_idx[i]] 1ca

152 for i in range(len(start_idx))]) 

153 prop_nan = np.sum(np.isnan(dlc_coords)) / dlc_coords.shape[0] 1ca

154 return prop_nan 1ca

155 

156 def run(self, update: bool = False, **kwargs) -> (str, dict): 

157 """ 

158 Run DLC QC checks and return outcome 

159 :param update: if True, updates the session QC fields on Alyx 

160 :param download_data: if True, downloads any missing data if required 

161 :returns: overall outcome as a str, a dict of checks and their outcomes 

162 """ 

163 _log.info(f'Running DLC QC for {self.side} camera, session {self.eid}') 1da

164 namespace = f'dlc{self.side.capitalize()}' 1da

165 if all(x is None for x in self.data.values()): 1da

166 self.load_data(**kwargs) 1da

167 

168 def is_metric(x): 1a

169 return isfunction(x) and x.__name__.startswith('check_') 1a

170 

171 checks = getmembers(DlcQC, is_metric) 1a

172 self.metrics = {f'_{namespace}_' + k[6:]: fn(self) for k, fn in checks} 1a

173 

174 ignore_metrics = [f'_{namespace}_' + i[6:] for i in self.ignore_checks] 1a

175 metrics_to_aggregate = {k: v for k, v in self.metrics.items() if k not in ignore_metrics} 1a

176 outcome = self.overall_outcome(metrics_to_aggregate.values()) 1a

177 

178 if update: 1a

179 extended = { 1a

180 k: 'NOT_SET' if v is None else v 

181 for k, v in self.metrics.items() 

182 } 

183 self.update_extended_qc(extended) 1a

184 self.update(outcome, namespace) 1a

185 return outcome, self.metrics 1a

186 

187 def check_time_trace_length_match(self): 

188 ''' 

189 Check that the length of the DLC traces is the same length as the video. 

190 ''' 

191 dlc_coords = self.data['dlc_coords'] 1ia

192 times = self.data['camera_times'] 1ia

193 for target in dlc_coords.keys(): 1ia

194 if times.shape[0] != dlc_coords[target].shape[1]: 1ia

195 _log.warning(f'{self.side}Camera length of camera.times does not match ' 1ia

196 f'length of camera.dlc {target}') 

197 return 'FAIL' 1ia

198 return 'PASS' 1i

199 

200 def check_trace_all_nan(self): 

201 ''' 

202 Check that none of the dlc traces, except for the 'tube' traces, are all NaN. 

203 ''' 

204 dlc_coords = self.data['dlc_coords'] 1ja

205 for target in dlc_coords.keys(): 1ja

206 if 'tube' not in target: 1ja

207 if all(np.isnan(dlc_coords[target][0])) or all(np.isnan(dlc_coords[target][1])): 1ja

208 _log.warning(f'{self.side}Camera dlc trace {target} all NaN') 1ja

209 return 'FAIL' 1ja

210 return 'PASS' 1ja

211 

212 def check_mean_in_bbox(self): 

213 ''' 

214 Empirical bounding boxes around average dlc points, averaged across time and points; 

215 sessions with points out of this box were often faulty in terms of raw videos 

216 ''' 

217 

218 dlc_coords = self.data['dlc_coords'] 1ea

219 with warnings.catch_warnings(): 1ea

220 warnings.simplefilter("ignore", category=RuntimeWarning) 1ea

221 x_mean = np.nanmean([np.nanmean(dlc_coords[k][0]) for k in dlc_coords.keys()]) 1ea

222 y_mean = np.nanmean([np.nanmean(dlc_coords[k][1]) for k in dlc_coords.keys()]) 1ea

223 

224 xrange = self.bbox[self.side]['xrange'] 1ea

225 yrange = self.bbox[self.side]['yrange'] 1ea

226 if int(x_mean) not in xrange or int(y_mean) not in yrange: 1ea

227 return 'FAIL' 1e

228 else: 

229 return 'PASS' 1ea

230 

231 def check_pupil_blocked(self): 

232 ''' 

233 Check if pupil diameter is nan for more than 60 % of the frames 

234 (might be blocked by a whisker) 

235 Check if standard deviation is above a threshold, found for bad sessions 

236 ''' 

237 

238 if self.side == 'body': 1fa

239 return 'NOT_SET' 1fa

240 

241 if np.mean(np.isnan(self.data['pupilDiameter_raw'])) > 0.9: 1fa

242 _log.warning(f'{self.eid}, {self.side}Camera, pupil diameter too often NaN') 1f

243 return 'FAIL' 1f

244 

245 thr = 5 if self.side == 'right' else 10 1fa

246 if np.nanstd(self.data['pupilDiameter_raw']) > thr: 1fa

247 _log.warning(f'{self.eid}, {self.side}Camera, pupil diameter too unstable') 1f

248 return 'FAIL' 1f

249 

250 return 'PASS' 1fa

251 

252 def check_lick_detection(self): 

253 ''' 

254 Check if both of the two tongue edge points are less than 10 % NaN, indicating that 

255 wrong points are detected (spout edge, mouth edge) 

256 ''' 

257 

258 if self.side == 'body': 1ha

259 return 'NOT_SET' 1ha

260 dlc_coords = self.data['dlc_coords'] 1ha

261 nan_l = np.mean(np.isnan(dlc_coords['tongue_end_l'][0])) 1ha

262 nan_r = np.mean(np.isnan(dlc_coords['tongue_end_r'][0])) 1ha

263 if (nan_l < 0.1) and (nan_r < 0.1): 1ha

264 return 'FAIL' 1h

265 return 'PASS' 1ha

266 

267 def check_pupil_diameter_snr(self): 

268 if self.side == 'body': 1ga

269 return 'NOT_SET' 1ga

270 thresh = 5 if self.side == 'right' else 10 1ga

271 if 'pupilDiameter_raw' not in self.data.keys() or 'pupilDiameter_smooth' not in self.data.keys(): 1ga

272 return 'NOT_SET' 

273 # compute signal to noise ratio between raw and smooth dia 

274 good_idxs = np.where(~np.isnan(self.data['pupilDiameter_smooth']) & ~np.isnan(self.data['pupilDiameter_raw']))[0] 1ga

275 snr = (np.var(self.data['pupilDiameter_smooth'][good_idxs]) / 1ga

276 (np.var(self.data['pupilDiameter_smooth'][good_idxs] - self.data['pupilDiameter_raw'][good_idxs]))) 

277 if snr < thresh: 1ga

278 return 'FAIL', float(round(snr, 3)) 1g

279 return 'PASS', float(round(snr, 3)) 1ga

280 

281 def check_paw_close_nan(self): 

282 if self.side == 'body': 1ca

283 return 'NOT_SET' 1ca

284 thresh_fail = 0.20 # prop of NaNs above this threshold means the check fails 1ca

285 thresh_warning = 0.10 # prop of NaNs above this threshold means the check is a warning 1ca

286 # compute fraction of points in windows that are NaN 

287 prop_nan = self._compute_proportion_nan_in_trial_window(body_part='paw_r') 1ca

288 if prop_nan > thresh_fail: 1ca

289 return 'FAIL' 1c

290 elif prop_nan > thresh_warning: 1ca

291 return 'WARNING' 1c

292 else: 

293 return 'PASS' 1ca

294 

295 def check_paw_far_nan(self): 

296 if self.side == 'body': 1ca

297 return 'NOT_SET' 1ca

298 thresh_fail = 0.20 # prop of NaNs above this threshold means the check fails 1ca

299 thresh_warning = 0.10 # prop of NaNs above this threshold means the check is a warning 1ca

300 # compute fraction of points in windows that are NaN 

301 prop_nan = self._compute_proportion_nan_in_trial_window(body_part='paw_l') 1ca

302 if prop_nan > thresh_fail: 1ca

303 return 'FAIL' 1c

304 elif prop_nan > thresh_warning: 1ca

305 return 'WARNING' 1c

306 else: 

307 return 'PASS' 1ca

308 

309 

310def run_all_qc(session, cameras=('left', 'right', 'body'), one=None, **kwargs): 

311 """Run DLC QC for all cameras 

312 Run the DLC QC for left, right and body cameras. 

313 :param session: A session path or eid. 

314 :param update: If True, QC fields are updated on Alyx. 

315 :param cameras: A list of camera names to perform QC on. 

316 :return: dict of DlcQC objects 

317 """ 

318 qc = {} 

319 run_args = {k: kwargs.pop(k) for k in ('download_data', 'update') if k in kwargs.keys()} 

320 for camera in cameras: 

321 qc[camera] = DlcQC(session, side=camera, one=one, **kwargs) 

322 qc[camera].run(**run_args) 

323 return qc