Coverage for ibllib/qc/dlc.py: 77%

172 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""DLC QC 

2This module runs a list of quality control metrics on the dlc traces. 

3 

4Example - Run DLC QC 

5 qc = DlcQC(eid, 'left', download_data=True) 

6 qc.run() 

7Question: 

8 We're not extracting the audio based on TTL length. Is this a problem? 

9""" 

10import logging 

11import warnings 

12from inspect import getmembers, isfunction 

13 

14import numpy as np 

15 

16from ibllib.qc import base 

17import one.alf.io as alfio 

18from one.alf.exceptions import ALFObjectNotFound 

19from one.alf.spec import is_session_path 

20from iblutil.util import Bunch 

21from brainbox.behavior.dlc import insert_idx, SAMPLING 

22 

23_log = logging.getLogger(__name__) 

24 

25 

26class DlcQC(base.QC): 

27 """A class for computing camera QC metrics""" 

28 

29 bbox = { 

30 'body': { 

31 'xrange': range(201, 500), 

32 'yrange': range(81, 330) 

33 }, 

34 'left': { 

35 'xrange': range(301, 700), 

36 'yrange': range(181, 470) 

37 }, 

38 'right': { 

39 'xrange': range(301, 600), 

40 'yrange': range(110, 275) 

41 }, 

42 } 

43 

44 dstypes = { 

45 'left': [ 

46 '_ibl_leftCamera.dlc.*', '_ibl_leftCamera.times.*', '_ibl_leftCamera.features.*', '_ibl_trials.table.*' 

47 ], 

48 'right': [ 

49 '_ibl_rightCamera.dlc.*', '_ibl_rightCamera.times.*', '_ibl_rightCamera.features.*', '_ibl_trials.table.*' 

50 ], 

51 'body': [ 

52 '_ibl_bodyCamera.dlc.*', '_ibl_bodyCamera.times.*' 

53 ], 

54 } 

55 

56 def __init__(self, session_path_or_eid, side, ignore_checks=['check_pupil_diameter_snr'], **kwargs): 

57 """ 

58 :param session_path_or_eid: A session eid or path 

59 :param side: The camera to run QC on 

60 :param ignore_checks: Checks that won't count towards aggregate QC, but will be run and added to extended QC 

61 :param log: A logging.Logger instance, if None the 'ibllib' logger is used 

62 :param one: An ONE instance for fetching and setting the QC on Alyx 

63 """ 

64 # Make sure the type of camera is chosen 

65 self.side = side 

66 # When an eid is provided, we will download the required data by default (if necessary) 

67 download_data = not is_session_path(session_path_or_eid) 

68 self.download_data = kwargs.pop('download_data', download_data) 

69 super().__init__(session_path_or_eid, **kwargs) 

70 self.data = Bunch() 

71 

72 # checks to be added to extended QC but not taken into account for aggregate QC 

73 self.ignore_checks = ignore_checks 

74 # QC outcomes map 

75 self.metrics = None 

76 

77 def load_data(self, download_data: bool = None) -> None: 

78 """Extract the data from data files 

79 Extracts all the required task data from the data files. 

80 

81 Data keys: 

82 - camera_times (float array): camera frame timestamps extracted from frame headers 

83 - dlc_coords (dict): keys are the points traced by dlc, items are x-y coordinates of 

84 these points over time, those with likelihood <0.9 set to NaN 

85 

86 :param download_data: if True, any missing raw data is downloaded via ONE. 

87 """ 

88 if download_data is not None: 1c

89 self.download_data = download_data 

90 if self.one and not self.one.offline: 1c

91 self._ensure_required_data() 1c

92 

93 alf_path = self.session_path / 'alf' 

94 

95 # Load times 

96 cam_path = next(alf_path.rglob(f'*{self.side}Camera.times*')).parent 

97 self.data['camera_times'] = alfio.load_object(cam_path, f'{self.side}Camera')['times'] 

98 # Load dlc traces 

99 dlc_path = next(alf_path.rglob(f'*{self.side}Camera.dlc*')).parent 

100 dlc_df = alfio.load_object(dlc_path, f'{self.side}Camera', namespace='ibl')['dlc'] 

101 targets = np.unique(['_'.join(col.split('_')[:-1]) for col in dlc_df.columns]) 

102 # Set values to nan if likelihood is too low 

103 dlc_coords = {} 

104 for t in targets: 

105 idx = dlc_df.loc[dlc_df[f'{t}_likelihood'] < 0.9].index 

106 dlc_df.loc[idx, [f'{t}_x', f'{t}_y']] = np.nan 

107 dlc_coords[t] = np.array((dlc_df[f'{t}_x'], dlc_df[f'{t}_y'])) 

108 self.data['dlc_coords'] = dlc_coords 

109 

110 # load stim on times 

111 trial_path = next(alf_path.rglob('*trials.table*')).parent 

112 self.data['stimOn_times'] = alfio.load_object(trial_path, 'trials', namespace='ibl')['stimOn_times'] 

113 

114 # load pupil diameters 

115 if self.side in ['left', 'right']: 

116 feat_path = next(alf_path.rglob(f'*{self.side}Camera.features*')).parent 

117 features = alfio.load_object(feat_path, f'{self.side}Camera', namespace='ibl')['features'] 

118 self.data['pupilDiameter_raw'] = features['pupilDiameter_raw'] 

119 self.data['pupilDiameter_smooth'] = features['pupilDiameter_smooth'] 

120 

121 def _ensure_required_data(self): 

122 """ 

123 Ensures the datasets required for QC are local. If the download_data attribute is True, 

124 any missing data are downloaded. If all the data are not present locally at the end of 

125 it an exception is raised. 

126 :return: 

127 """ 

128 for ds in self.dstypes[self.side]: 1c

129 # Check if data available locally 

130 if not next(self.session_path.rglob(ds), None): 1c

131 # If download is allowed, try to download 

132 if self.download_data is True: 1c

133 assert self.one is not None, 'ONE required to download data' 1c

134 try: 1c

135 self.one.load_dataset(self.eid, ds, download_only=True) 1c

136 except ALFObjectNotFound: 1c

137 raise AssertionError(f'Dataset {ds} not found locally and failed to download') 1c

138 else: 

139 raise AssertionError(f'Dataset {ds} not found locally and download_data is False') 1c

140 

141 def _compute_trial_window_idxs(self): 

142 """Find start and end times of a window around stimulus onsets in video indices.""" 

143 window_lag = -0.5 1b

144 window_len = 2.0 1b

145 start_window = self.data['stimOn_times'] + window_lag 1b

146 start_idx = insert_idx(self.data['camera_times'], start_window) 1b

147 end_idx = np.array(start_idx + int(window_len * SAMPLING[self.side]), dtype='int64') 1b

148 return start_idx, end_idx 1b

149 

150 def _compute_proportion_nan_in_trial_window(self, body_part): 

151 """Find proportion of NaN frames for a given body part in trial-based windows.""" 

152 # find timepoints in windows around stimulus onset 

153 start_idx, end_idx = self._compute_trial_window_idxs() 1b

154 # compute fraction of points in windows that are NaN 

155 dlc_coords = np.concatenate([self.data['dlc_coords'][body_part][0, start_idx[i]:end_idx[i]] 1b

156 for i in range(len(start_idx))]) 

157 prop_nan = np.sum(np.isnan(dlc_coords)) / dlc_coords.shape[0] 1b

158 return prop_nan 1b

159 

160 def run(self, update: bool = False, **kwargs) -> (str, dict): 

161 """ 

162 Run DLC QC checks and return outcome 

163 :param update: if True, updates the session QC fields on Alyx 

164 :param download_data: if True, downloads any missing data if required 

165 :returns: overall outcome as a str, a dict of checks and their outcomes 

166 """ 

167 _log.info(f'Running DLC QC for {self.side} camera, session {self.eid}') 1c

168 namespace = f'dlc{self.side.capitalize()}' 1c

169 if all(x is None for x in self.data.values()): 1c

170 self.load_data(**kwargs) 1c

171 

172 def is_metric(x): 

173 return isfunction(x) and x.__name__.startswith('check_') 

174 

175 checks = getmembers(DlcQC, is_metric) 

176 self.metrics = {f'_{namespace}_' + k[6:]: fn(self) for k, fn in checks} 

177 

178 ignore_metrics = [f'_{namespace}_' + i[6:] for i in self.ignore_checks] 

179 metrics_to_aggregate = {k: v for k, v in self.metrics.items() if k not in ignore_metrics} 

180 outcome = self.overall_outcome(metrics_to_aggregate.values()) 

181 

182 if update: 

183 extended = { 

184 k: 'NOT_SET' if v is None else v 

185 for k, v in self.metrics.items() 

186 } 

187 self.update_extended_qc(extended) 

188 self.update(outcome, namespace) 

189 return outcome, self.metrics 

190 

191 def check_time_trace_length_match(self): 

192 ''' 

193 Check that the length of the DLC traces is the same length as the video. 

194 ''' 

195 dlc_coords = self.data['dlc_coords'] 1h

196 times = self.data['camera_times'] 1h

197 for target in dlc_coords.keys(): 1h

198 if times.shape[0] != dlc_coords[target].shape[1]: 1h

199 _log.warning(f'{self.side}Camera length of camera.times does not match ' 1h

200 f'length of camera.dlc {target}') 

201 return 'FAIL' 1h

202 return 'PASS' 1h

203 

204 def check_trace_all_nan(self): 

205 ''' 

206 Check that none of the dlc traces, except for the 'tube' traces, are all NaN. 

207 ''' 

208 dlc_coords = self.data['dlc_coords'] 1i

209 for target in dlc_coords.keys(): 1i

210 if 'tube' not in target: 1i

211 if all(np.isnan(dlc_coords[target][0])) or all(np.isnan(dlc_coords[target][1])): 1i

212 _log.warning(f'{self.side}Camera dlc trace {target} all NaN') 1i

213 return 'FAIL' 1i

214 return 'PASS' 1i

215 

216 def check_mean_in_bbox(self): 

217 ''' 

218 Empirical bounding boxes around average dlc points, averaged across time and points; 

219 sessions with points out of this box were often faulty in terms of raw videos 

220 ''' 

221 

222 dlc_coords = self.data['dlc_coords'] 1d

223 with warnings.catch_warnings(): 1d

224 warnings.simplefilter("ignore", category=RuntimeWarning) 1d

225 x_mean = np.nanmean([np.nanmean(dlc_coords[k][0]) for k in dlc_coords.keys()]) 1d

226 y_mean = np.nanmean([np.nanmean(dlc_coords[k][1]) for k in dlc_coords.keys()]) 1d

227 

228 xrange = self.bbox[self.side]['xrange'] 1d

229 yrange = self.bbox[self.side]['yrange'] 1d

230 if int(x_mean) not in xrange or int(y_mean) not in yrange: 1d

231 return 'FAIL' 1d

232 else: 

233 return 'PASS' 1d

234 

235 def check_pupil_blocked(self): 

236 ''' 

237 Check if pupil diameter is nan for more than 60 % of the frames 

238 (might be blocked by a whisker) 

239 Check if standard deviation is above a threshold, found for bad sessions 

240 ''' 

241 

242 if self.side == 'body': 1e

243 return 'NOT_SET' 1e

244 

245 if np.mean(np.isnan(self.data['pupilDiameter_raw'])) > 0.9: 1e

246 _log.warning(f'{self.eid}, {self.side}Camera, pupil diameter too often NaN') 1e

247 return 'FAIL' 1e

248 

249 thr = 5 if self.side == 'right' else 10 1e

250 if np.nanstd(self.data['pupilDiameter_raw']) > thr: 1e

251 _log.warning(f'{self.eid}, {self.side}Camera, pupil diameter too unstable') 1e

252 return 'FAIL' 1e

253 

254 return 'PASS' 1e

255 

256 def check_lick_detection(self): 

257 ''' 

258 Check if both of the two tongue edge points are less than 10 % NaN, indicating that 

259 wrong points are detected (spout edge, mouth edge) 

260 ''' 

261 

262 if self.side == 'body': 1g

263 return 'NOT_SET' 1g

264 dlc_coords = self.data['dlc_coords'] 1g

265 nan_l = np.mean(np.isnan(dlc_coords['tongue_end_l'][0])) 1g

266 nan_r = np.mean(np.isnan(dlc_coords['tongue_end_r'][0])) 1g

267 if (nan_l < 0.1) and (nan_r < 0.1): 1g

268 return 'FAIL' 1g

269 return 'PASS' 1g

270 

271 def check_pupil_diameter_snr(self): 

272 if self.side == 'body': 1f

273 return 'NOT_SET' 1f

274 thresh = 5 if self.side == 'right' else 10 1f

275 if 'pupilDiameter_raw' not in self.data.keys() or 'pupilDiameter_smooth' not in self.data.keys(): 1f

276 return 'NOT_SET' 

277 # compute signal to noise ratio between raw and smooth dia 

278 good_idxs = np.where(~np.isnan(self.data['pupilDiameter_smooth']) & ~np.isnan(self.data['pupilDiameter_raw']))[0] 1f

279 snr = (np.var(self.data['pupilDiameter_smooth'][good_idxs]) / 1f

280 (np.var(self.data['pupilDiameter_smooth'][good_idxs] - self.data['pupilDiameter_raw'][good_idxs]))) 

281 if snr < thresh: 1f

282 return 'FAIL', float(round(snr, 3)) 1f

283 return 'PASS', float(round(snr, 3)) 1f

284 

285 def check_paw_close_nan(self): 

286 if self.side == 'body': 1b

287 return 'NOT_SET' 1b

288 thresh_fail = 0.20 # prop of NaNs above this threshold means the check fails 1b

289 thresh_warning = 0.10 # prop of NaNs above this threshold means the check is a warning 1b

290 # compute fraction of points in windows that are NaN 

291 prop_nan = self._compute_proportion_nan_in_trial_window(body_part='paw_r') 1b

292 if prop_nan > thresh_fail: 1b

293 return 'FAIL' 1b

294 elif prop_nan > thresh_warning: 1b

295 return 'WARNING' 1b

296 else: 

297 return 'PASS' 1b

298 

299 def check_paw_far_nan(self): 

300 if self.side == 'body': 1b

301 return 'NOT_SET' 1b

302 thresh_fail = 0.20 # prop of NaNs above this threshold means the check fails 1b

303 thresh_warning = 0.10 # prop of NaNs above this threshold means the check is a warning 1b

304 # compute fraction of points in windows that are NaN 

305 prop_nan = self._compute_proportion_nan_in_trial_window(body_part='paw_l') 1b

306 if prop_nan > thresh_fail: 1b

307 return 'FAIL' 1b

308 elif prop_nan > thresh_warning: 1b

309 return 'WARNING' 1b

310 else: 

311 return 'PASS' 1b

312 

313 

314def run_all_qc(session, cameras=('left', 'right', 'body'), one=None, **kwargs): 

315 """Run DLC QC for all cameras 

316 Run the DLC QC for left, right and body cameras. 

317 :param session: A session path or eid. 

318 :param update: If True, QC fields are updated on Alyx. 

319 :param cameras: A list of camera names to perform QC on. 

320 :return: dict of DlcQC objects 

321 """ 

322 qc = {} 

323 run_args = {k: kwargs.pop(k) for k in ('download_data', 'update') if k in kwargs.keys()} 

324 for camera in cameras: 

325 qc[camera] = DlcQC(session, side=camera, one=one, **kwargs) 

326 qc[camera].run(**run_args) 

327 return qc