Coverage for ibllib/qc/task_qc_viewer/task_qc.py: 88%

154 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1import logging 

2import argparse 

3from itertools import cycle 

4import random 

5from collections.abc import Sized 

6from pathlib import Path 

7 

8import pandas as pd 

9import numpy as np 

10from matplotlib.colors import TABLEAU_COLORS 

11from one.api import ONE 

12from one.alf.spec import is_session_path 

13 

14import ibllib.plots as plots 

15from ibllib.misc import qt 

16from ibllib.qc.task_metrics import TaskQC 

17from ibllib.qc.task_qc_viewer import ViewEphysQC 

18from ibllib.pipes.dynamic_pipeline import get_trials_tasks 

19from ibllib.pipes.base_tasks import BehaviourTask 

20from ibllib.pipes.behavior_tasks import HabituationTrialsBpod, ChoiceWorldTrialsBpod 

21 

22EVENT_MAP = {'goCue_times': ['#2ca02c', 'solid'], # green 

23 'goCueTrigger_times': ['#2ca02c', 'dotted'], # green 

24 'errorCue_times': ['#d62728', 'solid'], # red 

25 'errorCueTrigger_times': ['#d62728', 'dotted'], # red 

26 'valveOpen_times': ['#17becf', 'solid'], # cyan 

27 'stimFreeze_times': ['#0000ff', 'solid'], # blue 

28 'stimFreezeTrigger_times': ['#0000ff', 'dotted'], # blue 

29 'stimOff_times': ['#9400d3', 'solid'], # dark violet 

30 'stimOffTrigger_times': ['#9400d3', 'dotted'], # dark violet 

31 'stimOn_times': ['#e377c2', 'solid'], # pink 

32 'stimOnTrigger_times': ['#e377c2', 'dotted'], # pink 

33 'response_times': ['#8c564b', 'solid'], # brown 

34 } 

35cm = [EVENT_MAP[k][0] for k in EVENT_MAP] 

36ls = [EVENT_MAP[k][1] for k in EVENT_MAP] 

37CRITICAL_CHECKS = ( 

38 'check_audio_pre_trial', 

39 'check_correct_trial_event_sequence', 

40 'check_error_trial_event_sequence', 

41 'check_n_trial_events', 

42 'check_response_feedback_delays', 

43 'check_response_stimFreeze_delays', 

44 'check_reward_volume_set', 

45 'check_reward_volumes', 

46 'check_stimOn_goCue_delays', 

47 'check_stimulus_move_before_goCue', 

48 'check_wheel_move_before_feedback', 

49 'check_wheel_freeze_during_quiescence' 

50) 

51 

52 

53_logger = logging.getLogger(__name__) 

54 

55 

56class QcFrame: 

57 

58 qc = None 

59 """ibllib.qc.task_metrics.TaskQC: A TaskQC object containing extracted data""" 

60 

61 frame = None 

62 """pandas.DataFrame: A table of failing trial-level QC metrics.""" 

63 

64 def __init__(self, qc): 

65 """ 

66 An interactive display of task QC data. 

67 

68 Parameters 

69 ---------- 

70 qc : ibllib.qc.task_metrics.TaskQC 

71 A TaskQC object containing extracted data for plotting. 

72 """ 

73 assert qc.extractor and qc.metrics, 'Please run QC before passing to QcFrame' 1a

74 self.qc = qc 1a

75 

76 # Print failed 

77 outcome, results, outcomes = self.qc.compute_session_status() 1a

78 map = {k: [] for k in set(outcomes.values())} 1a

79 for k, v in outcomes.items(): 1a

80 map[v].append(k[6:]) 1a

81 for k, v in map.items(): 1a

82 if k == 'PASS': 1a

83 continue 

84 print(f'The following checks were labelled {k}:') 1a

85 print('\n'.join(v), '\n') 1a

86 

87 print('The following *critical* checks did not pass:') 1a

88 critical_checks = [f'_{x.replace("check", "task")}' for x in CRITICAL_CHECKS] 1a

89 for k, v in outcomes.items(): 1a

90 if v != 'PASS' and k in critical_checks: 1a

91 print(k[6:]) 

92 

93 # Make DataFrame from the trail level metrics 

94 def get_trial_level_failed(d): 1a

95 new_dict = {k[6:]: v for k, v in d.items() if 1a

96 isinstance(v, Sized) and len(v) == self.n_trials} 

97 return pd.DataFrame.from_dict(new_dict) 1a

98 

99 self.frame = get_trial_level_failed(self.qc.metrics) 1a

100 self.frame['intervals_0'] = self.qc.extractor.data['intervals'][:, 0] 1a

101 self.frame['intervals_1'] = self.qc.extractor.data['intervals'][:, 1] 1a

102 self.frame.insert(loc=0, column='trial_no', value=self.frame.index) 1a

103 

104 @property 

105 def n_trials(self): 

106 return self.qc.extractor.data['intervals'].shape[0] 

107 

108 def get_wheel_data(self): 

109 return {'re_pos': self.qc.extractor.data.get('wheel_position', np.array([])), 1a

110 're_ts': self.qc.extractor.data.get('wheel_timestamps', np.array([]))} 

111 

112 def create_plots(self, axes, 

113 wheel_axes=None, trial_events=None, color_map=None, linestyle=None): 

114 """ 

115 Plots the data for bnc1 (sound) and bnc2 (frame2ttl). 

116 

117 :param axes: An axes handle on which to plot the TTL events 

118 :param wheel_axes: An axes handle on which to plot the wheel trace 

119 :param trial_events: A list of Bpod trial events to plot, e.g. ['stimFreeze_times'], 

120 if None, valve, sound and stimulus events are plotted 

121 :param color_map: A color map to use for the events, default is the tableau color map 

122 linestyle: A line style map to use for the events, default is random. 

123 :return: None 

124 """ 

125 color_map = color_map or TABLEAU_COLORS.keys() 1a

126 if trial_events is None: 1a

127 # Default trial events to plot as vertical lines 

128 trial_events = [ 

129 'goCue_times', 

130 'goCueTrigger_times', 

131 'feedback_times', 

132 ('stimCenter_times' 

133 if 'stimCenter_times' in self.qc.extractor.data 

134 else 'stimFreeze_times'), # handle habituationChoiceWorld exception 

135 'stimOff_times', 

136 'stimOn_times' 

137 ] 

138 

139 plot_args = { 1a

140 'ymin': 0, 

141 'ymax': 4, 

142 'linewidth': 2, 

143 'ax': axes, 

144 'alpha': 0.5, 

145 } 

146 

147 bnc1 = self.qc.extractor.frame_ttls 1a

148 bnc2 = self.qc.extractor.audio_ttls 1a

149 trial_data = self.qc.extractor.data 1a

150 

151 if bnc1['times'].size: 1a

152 plots.squares(bnc1['times'], bnc1['polarities'] * 0.4 + 1, ax=axes, color='k') 1a

153 if bnc2['times'].size: 1a

154 plots.squares(bnc2['times'], bnc2['polarities'] * 0.4 + 2, ax=axes, color='k') 1a

155 linestyle = linestyle or random.choices(('-', '--', '-.', ':'), k=len(trial_events)) 1a

156 

157 if self.qc.extractor.bpod_ttls is not None: 1a

158 bpttls = self.qc.extractor.bpod_ttls 1a

159 plots.squares(bpttls['times'], bpttls['polarities'] * 0.4 + 3, ax=axes, color='k') 1a

160 plot_args['ymax'] = 4 1a

161 ylabels = ['', 'frame2ttl', 'sound', 'bpod', ''] 1a

162 else: 

163 plot_args['ymax'] = 3 

164 ylabels = ['', 'frame2ttl', 'sound', ''] 

165 

166 for event, c, l in zip(trial_events, cycle(color_map), linestyle): 1a

167 if event in trial_data: 1a

168 plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=l, **plot_args) 

169 

170 axes.legend(loc='upper left', fontsize='xx-small', bbox_to_anchor=(1, 0.5)) 1a

171 axes.set_yticks(list(range(plot_args['ymax'] + 1))) 1a

172 axes.set_yticklabels(ylabels) 1a

173 axes.set_ylim([0, plot_args['ymax']]) 1a

174 

175 if wheel_axes: 1a

176 wheel_data = self.get_wheel_data() 1a

177 wheel_plot_args = { 1a

178 'ax': wheel_axes, 

179 'ymin': wheel_data['re_pos'].min() if wheel_data['re_pos'].size else 0, 

180 'ymax': wheel_data['re_pos'].max() if wheel_data['re_pos'].size else 1} 

181 plot_args = {**plot_args, **wheel_plot_args} 1a

182 

183 wheel_axes.plot(wheel_data['re_ts'], wheel_data['re_pos'], 'k-x') 1a

184 for event, c, ln in zip(trial_events, cycle(color_map), linestyle): 1a

185 if event in trial_data: 1a

186 plots.vertical_lines(trial_data[event], 

187 label=event, color=c, linestyle=ln, **plot_args) 

188 

189 

190def get_bpod_trials_task(task): 

191 """ 

192 Return the correct trials task for extracting only the Bpod trials. 

193 

194 Parameters 

195 ---------- 

196 task : ibllib.pipes.tasks.Task 

197 A pipeline task from which to derive the Bpod trials task. 

198 

199 Returns 

200 ------- 

201 ibllib.pipes.tasks.Task 

202 A Bpod choice world trials task instance. 

203 """ 

204 if task.__class__ in (ChoiceWorldTrialsBpod, HabituationTrialsBpod): 1c

205 pass # do nothing; already Bpod only 1c

206 else: 

207 assert isinstance(task, BehaviourTask) 1c

208 # A dynamic pipeline task 

209 trials_class = HabituationTrialsBpod if 'habituation' in task.protocol else ChoiceWorldTrialsBpod 1c

210 task = trials_class(task.session_path, 1c

211 collection=task.collection, protocol_number=task.protocol_number, 

212 protocol=task.protocol, one=task.one) 

213 return task 1c

214 

215 

216def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=None, protocol_number=None): 

217 """ 

218 Displays the task QC for a given session. 

219 

220 NB: For this to work, all behaviour trials task classes must implement a `run_qc` method. 

221 

222 Parameters 

223 ---------- 

224 qc_or_session : str, pathlib.Path, ibllib.qc.task_metrics.TaskQC, QcFrame 

225 An experiment ID, session path, or TaskQC object. 

226 bpod_only : bool 

227 If true, display Bpod extracted events instead of data from the DAQ. 

228 local : bool 

229 If true, asserts all data local (i.e. do not attempt to download missing datasets). 

230 one : one.api.One 

231 An instance of ONE. 

232 protocol_number : int 

233 If not None, displays the QC for the protocol number provided. Argument is ignored if 

234 `qc_or_session` is a TaskQC object or QcFrame instance. 

235 

236 Returns 

237 ------- 

238 QcFrame 

239 The QcFrame object. 

240 """ 

241 if isinstance(qc_or_session, QcFrame): 1a

242 qc = qc_or_session 1a

243 elif isinstance(qc_or_session, TaskQC): 1a

244 task_qc = qc_or_session 1a

245 qc = QcFrame(task_qc) 1a

246 else: # assumed to be eid or session path 

247 one = one or ONE(mode='local' if local else 'remote') 1a

248 if not is_session_path(Path(qc_or_session)): 1a

249 eid = one.to_eid(qc_or_session) 

250 session_path = one.eid2path(eid) 

251 else: 

252 session_path = Path(qc_or_session) 1a

253 

254 tasks = get_trials_tasks(session_path, one=None if local else one, bpod_only=bpod_only) 1a

255 # Get the correct task and ensure not passive 

256 if protocol_number is None: 1a

257 if not (task := next((t for t in tasks if 'passive' not in t.name.lower()), None)): 1a

258 raise ValueError('No non-passive behaviour tasks found for session ' + '/'.join(session_path.parts[-3:])) 1a

259 elif not isinstance(protocol_number, int) or protocol_number < 0: 1a

260 raise TypeError('Protocol number must be a positive integer') 1a

261 elif protocol_number > len(tasks) - 1: 1a

262 raise ValueError('Invalid protocol number') 1a

263 else: 

264 task = tasks[protocol_number] 1a

265 if 'passive' in task.name.lower(): 1a

266 raise ValueError('QC display not supported for passive protocols') 1a

267 _logger.debug('Using %s task', task.name) 1a

268 # Ensure required data are present 

269 task.location = 'server' if local else 'remote' # affects whether missing data are downloaded 1a

270 task.setUp() 1a

271 if local: # currently setUp does not raise on missing data 1a

272 task.assert_expected_inputs(raise_error=True) 1a

273 # Compute the QC and build the frame 

274 task_qc = task.run_qc(update=False) 1a

275 qc = QcFrame(task_qc) 1a

276 

277 # Handle trial event names in habituationChoiceWorld 

278 events = EVENT_MAP.keys() 1a

279 if 'stimCenter_times' in qc.qc.extractor.data: 1a

280 events = map(lambda x: x.replace('stimFreeze', 'stimCenter'), events) 

281 

282 # Run QC and plot 

283 w = ViewEphysQC.viewqc(wheel=qc.get_wheel_data()) 1a

284 qc.create_plots(w.wplot.canvas.ax, 1a

285 wheel_axes=w.wplot.canvas.ax2, 

286 trial_events=list(events), 

287 color_map=cm, 

288 linestyle=ls) 

289 

290 # Update table and callbacks 

291 n_trials = qc.frame.shape[0] 1a

292 if 'task_qc' in locals(): 1a

293 df_trials = pd.DataFrame({ 1a

294 k: v for k, v in task_qc.extractor.data.items() 

295 if v.size == n_trials and not k.startswith('wheel') 

296 }) 

297 df = df_trials.merge(qc.frame, left_index=True, right_index=True) 1a

298 else: 

299 df = qc.frame 1a

300 df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) 1a

301 df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) 1a

302 df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) 1a

303 df = df.merge(df_pass.astype('boolean'), left_index=True, right_index=True) 1a

304 w.updateDataframe(df) 1a

305 qt.run_app() 1a

306 return qc 1a

307 

308 

309def qc_gui_cli(): 

310 """Run TaskQC viewer with wheel data. 

311 

312 For information on the QC checks see the QC Flags & failures document: 

313 https://docs.google.com/document/d/1X-ypFEIxqwX6lU9pig4V_zrcR5lITpd8UJQWzW9I9zI/edit# 

314 

315 Examples 

316 -------- 

317 >>> ipython task_qc.py c9fec76e-7a20-4da4-93ad-04510a89473b 

318 >>> ipython task_qc.py ./KS022/2019-12-10/001 --local 

319 """ 

320 # Parse parameters 

321 parser = argparse.ArgumentParser(description='Quick viewer to see the behaviour data from' 

322 'choice world sessions.') 

323 parser.add_argument('session', help='session uuid or path') 

324 parser.add_argument('--bpod', action='store_true', help='run QC on Bpod data only (no FPGA)') 

325 parser.add_argument('--local', action='store_true', help='run from disk location (lab server') 

326 args = parser.parse_args() # returns data from the options specified (echo) 

327 

328 show_session_task_qc(qc_or_session=args.session, bpod_only=args.bpod, local=args.local) 

329 

330 

331if __name__ == '__main__': 

332 qc_gui_cli()