Coverage for ibllib/qc/task_qc_viewer/task_qc.py: 88%

148 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1import logging 

2import argparse 

3from itertools import cycle 

4import random 

5from collections.abc import Sized 

6from pathlib import Path 

7 

8import pandas as pd 

9import numpy as np 

10from matplotlib.colors import TABLEAU_COLORS 

11from one.api import ONE 

12from one.alf.spec import is_session_path 

13 

14import ibllib.plots as plots 

15from ibllib.misc import qt 

16from ibllib.qc.task_metrics import TaskQC 

17from ibllib.qc.task_qc_viewer import ViewEphysQC 

18from ibllib.pipes.dynamic_pipeline import get_trials_tasks 

19from ibllib.pipes.base_tasks import BehaviourTask 

20from ibllib.pipes.behavior_tasks import HabituationTrialsBpod, ChoiceWorldTrialsBpod 

21from ibllib.pipes.training_preprocessing import TrainingTrials 

22 

23EVENT_MAP = {'goCue_times': ['#2ca02c', 'solid'], # green 

24 'goCueTrigger_times': ['#2ca02c', 'dotted'], # green 

25 'errorCue_times': ['#d62728', 'solid'], # red 

26 'errorCueTrigger_times': ['#d62728', 'dotted'], # red 

27 'valveOpen_times': ['#17becf', 'solid'], # cyan 

28 'stimFreeze_times': ['#0000ff', 'solid'], # blue 

29 'stimFreezeTrigger_times': ['#0000ff', 'dotted'], # blue 

30 'stimOff_times': ['#9400d3', 'solid'], # dark violet 

31 'stimOffTrigger_times': ['#9400d3', 'dotted'], # dark violet 

32 'stimOn_times': ['#e377c2', 'solid'], # pink 

33 'stimOnTrigger_times': ['#e377c2', 'dotted'], # pink 

34 'response_times': ['#8c564b', 'solid'], # brown 

35 } 

36cm = [EVENT_MAP[k][0] for k in EVENT_MAP] 

37ls = [EVENT_MAP[k][1] for k in EVENT_MAP] 

38CRITICAL_CHECKS = ( 

39 'check_audio_pre_trial', 

40 'check_correct_trial_event_sequence', 

41 'check_error_trial_event_sequence', 

42 'check_n_trial_events', 

43 'check_response_feedback_delays', 

44 'check_response_stimFreeze_delays', 

45 'check_reward_volume_set', 

46 'check_reward_volumes', 

47 'check_stimOn_goCue_delays', 

48 'check_stimulus_move_before_goCue', 

49 'check_wheel_move_before_feedback', 

50 'check_wheel_freeze_during_quiescence' 

51) 

52 

53 

54_logger = logging.getLogger(__name__) 

55 

56 

57class QcFrame: 

58 

59 qc = None 

60 """ibllib.qc.task_metrics.TaskQC: A TaskQC object containing extracted data""" 

61 

62 frame = None 

63 """pandas.DataFrame: A table of failing trial-level QC metrics.""" 

64 

65 def __init__(self, qc): 

66 """ 

67 An interactive display of task QC data. 

68 

69 Parameters 

70 ---------- 

71 qc : ibllib.qc.task_metrics.TaskQC 

72 A TaskQC object containing extracted data for plotting. 

73 """ 

74 assert qc.extractor and qc.metrics, 'Please run QC before passing to QcFrame' 1a

75 self.qc = qc 1a

76 

77 # Print failed 

78 outcome, results, outcomes = self.qc.compute_session_status() 1a

79 map = {k: [] for k in set(outcomes.values())} 1a

80 for k, v in outcomes.items(): 1a

81 map[v].append(k[6:]) 1a

82 for k, v in map.items(): 1a

83 if k == 'PASS': 1a

84 continue 

85 print(f'The following checks were labelled {k}:') 1a

86 print('\n'.join(v), '\n') 1a

87 

88 print('The following *critical* checks did not pass:') 1a

89 critical_checks = [f'_{x.replace("check", "task")}' for x in CRITICAL_CHECKS] 1a

90 for k, v in outcomes.items(): 1a

91 if v != 'PASS' and k in critical_checks: 1a

92 print(k[6:]) 

93 

94 # Make DataFrame from the trail level metrics 

95 def get_trial_level_failed(d): 1a

96 new_dict = {k[6:]: v for k, v in d.items() if 1a

97 isinstance(v, Sized) and len(v) == self.n_trials} 

98 return pd.DataFrame.from_dict(new_dict) 1a

99 

100 self.frame = get_trial_level_failed(self.qc.metrics) 1a

101 self.frame['intervals_0'] = self.qc.extractor.data['intervals'][:, 0] 1a

102 self.frame['intervals_1'] = self.qc.extractor.data['intervals'][:, 1] 1a

103 self.frame.insert(loc=0, column='trial_no', value=self.frame.index) 1a

104 

105 @property 

106 def n_trials(self): 

107 return self.qc.extractor.data['intervals'].shape[0] 

108 

109 def get_wheel_data(self): 

110 return {'re_pos': self.qc.extractor.data.get('wheel_position', np.array([])), 1a

111 're_ts': self.qc.extractor.data.get('wheel_timestamps', np.array([]))} 

112 

113 def create_plots(self, axes, 

114 wheel_axes=None, trial_events=None, color_map=None, linestyle=None): 

115 """ 

116 Plots the data for bnc1 (sound) and bnc2 (frame2ttl). 

117 

118 :param axes: An axes handle on which to plot the TTL events 

119 :param wheel_axes: An axes handle on which to plot the wheel trace 

120 :param trial_events: A list of Bpod trial events to plot, e.g. ['stimFreeze_times'], 

121 if None, valve, sound and stimulus events are plotted 

122 :param color_map: A color map to use for the events, default is the tableau color map 

123 linestyle: A line style map to use for the events, default is random. 

124 :return: None 

125 """ 

126 color_map = color_map or TABLEAU_COLORS.keys() 1a

127 if trial_events is None: 1a

128 # Default trial events to plot as vertical lines 

129 trial_events = [ 

130 'goCue_times', 

131 'goCueTrigger_times', 

132 'feedback_times', 

133 ('stimCenter_times' 

134 if 'stimCenter_times' in self.qc.extractor.data 

135 else 'stimFreeze_times'), # handle habituationChoiceWorld exception 

136 'stimOff_times', 

137 'stimOn_times' 

138 ] 

139 

140 plot_args = { 1a

141 'ymin': 0, 

142 'ymax': 4, 

143 'linewidth': 2, 

144 'ax': axes 

145 } 

146 

147 bnc1 = self.qc.extractor.frame_ttls 1a

148 bnc2 = self.qc.extractor.audio_ttls 1a

149 trial_data = self.qc.extractor.data 1a

150 

151 if bnc1['times'].size: 1a

152 plots.squares(bnc1['times'], bnc1['polarities'] * 0.4 + 1, ax=axes, color='k') 1a

153 if bnc2['times'].size: 1a

154 plots.squares(bnc2['times'], bnc2['polarities'] * 0.4 + 2, ax=axes, color='k') 1a

155 linestyle = linestyle or random.choices(('-', '--', '-.', ':'), k=len(trial_events)) 1a

156 

157 if self.qc.extractor.bpod_ttls is not None: 1a

158 bpttls = self.qc.extractor.bpod_ttls 1a

159 plots.squares(bpttls['times'], bpttls['polarities'] * 0.4 + 3, ax=axes, color='k') 1a

160 plot_args['ymax'] = 4 1a

161 ylabels = ['', 'frame2ttl', 'sound', 'bpod', ''] 1a

162 else: 

163 plot_args['ymax'] = 3 

164 ylabels = ['', 'frame2ttl', 'sound', ''] 

165 

166 for event, c, l in zip(trial_events, cycle(color_map), linestyle): 1a

167 if event in trial_data: 1a

168 plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=l, **plot_args) 

169 

170 axes.legend(loc='upper left', fontsize='xx-small', bbox_to_anchor=(1, 0.5)) 1a

171 axes.set_yticks(list(range(plot_args['ymax'] + 1))) 1a

172 axes.set_yticklabels(ylabels) 1a

173 axes.set_ylim([0, plot_args['ymax']]) 1a

174 

175 if wheel_axes: 1a

176 wheel_data = self.get_wheel_data() 1a

177 wheel_plot_args = { 1a

178 'ax': wheel_axes, 

179 'ymin': wheel_data['re_pos'].min() if wheel_data['re_pos'].size else 0, 

180 'ymax': wheel_data['re_pos'].max() if wheel_data['re_pos'].size else 1} 

181 plot_args = {**plot_args, **wheel_plot_args} 1a

182 

183 wheel_axes.plot(wheel_data['re_ts'], wheel_data['re_pos'], 'k-x') 1a

184 for event, c, ln in zip(trial_events, cycle(color_map), linestyle): 1a

185 if event in trial_data: 1a

186 plots.vertical_lines(trial_data[event], 

187 label=event, color=c, linestyle=ln, **plot_args) 

188 

189 

190def get_bpod_trials_task(task): 

191 """ 

192 Return the correct trials task for extracting only the Bpod trials. 

193 

194 Parameters 

195 ---------- 

196 task : ibllib.pipes.tasks.Task 

197 A pipeline task from which to derive the Bpod trials task. 

198 

199 Returns 

200 ------- 

201 ibllib.pipes.tasks.Task 

202 A Bpod choice world trials task instance. 

203 """ 

204 if isinstance(task, TrainingTrials) or task.__class__ in (ChoiceWorldTrialsBpod, HabituationTrialsBpod): 1c

205 pass # do nothing; already Bpod only 1c

206 elif isinstance(task, BehaviourTask): 1c

207 # A dynamic pipeline task 

208 trials_class = HabituationTrialsBpod if 'habituation' in task.protocol else ChoiceWorldTrialsBpod 1c

209 task = trials_class(task.session_path, 1c

210 collection=task.collection, protocol_number=task.protocol_number, 

211 protocol=task.protocol, one=task.one) 

212 else: # A legacy pipeline task (should be EphysTrials as there are no other options) 

213 task = TrainingTrials(task.session_path, one=task.one) 1c

214 return task 1c

215 

216 

217def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=None, protocol_number=None): 

218 """ 

219 Displays the task QC for a given session. 

220 

221 NB: For this to work, all behaviour trials task classes must implement a `run_qc` method. 

222 

223 Parameters 

224 ---------- 

225 qc_or_session : str, pathlib.Path, ibllib.qc.task_metrics.TaskQC, QcFrame 

226 An experiment ID, session path, or TaskQC object. 

227 bpod_only : bool 

228 If true, display Bpod extracted events instead of data from the DAQ. 

229 local : bool 

230 If true, asserts all data local (i.e. do not attempt to download missing datasets). 

231 one : one.api.One 

232 An instance of ONE. 

233 protocol_number : int 

234 If not None, displays the QC for the protocol number provided. Argument is ignored if 

235 `qc_or_session` is a TaskQC object or QcFrame instance. 

236 

237 Returns 

238 ------- 

239 QcFrame 

240 The QcFrame object. 

241 """ 

242 if isinstance(qc_or_session, QcFrame): 1a

243 qc = qc_or_session 1a

244 elif isinstance(qc_or_session, TaskQC): 1a

245 qc = QcFrame(qc_or_session) 1a

246 else: # assumed to be eid or session path 

247 one = one or ONE(mode='local' if local else 'auto') 1a

248 if not is_session_path(Path(qc_or_session)): 1a

249 eid = one.to_eid(qc_or_session) 

250 session_path = one.eid2path(eid) 

251 else: 

252 session_path = Path(qc_or_session) 1a

253 tasks = get_trials_tasks(session_path, one=None if local else one) 1a

254 # Get the correct task and ensure not passive 

255 if protocol_number is None: 1a

256 if not (task := next((t for t in tasks if 'passive' not in t.name.lower()), None)): 1a

257 raise ValueError('No non-passive behaviour tasks found for session ' + '/'.join(session_path.parts[-3:])) 1a

258 elif not isinstance(protocol_number, int) or protocol_number < 0: 1a

259 raise TypeError('Protocol number must be a positive integer') 1a

260 elif protocol_number > len(tasks) - 1: 1a

261 raise ValueError('Invalid protocol number') 1a

262 else: 

263 task = tasks[protocol_number] 1a

264 if 'passive' in task.name.lower(): 1a

265 raise ValueError('QC display not supported for passive protocols') 1a

266 # If Bpod only and not a dynamic pipeline Bpod behaviour task OR legacy TrainingTrials task 

267 if bpod_only and 'bpod' not in task.name.lower(): 1a

268 # Use the dynamic pipeline Bpod behaviour task instead (should work with legacy pipeline too) 

269 task = get_bpod_trials_task(task) 1a

270 _logger.debug('Using %s task', task.name) 1a

271 # Ensure required data are present 

272 task.location = 'server' if local else 'remote' # affects whether missing data are downloaded 1a

273 task.setUp() 1a

274 if local: # currently setUp does not raise on missing data 1a

275 task.assert_expected_inputs(raise_error=True) 1a

276 # Compute the QC and build the frame 

277 task_qc = task.run_qc(update=False) 1a

278 qc = QcFrame(task_qc) 1a

279 

280 # Handle trial event names in habituationChoiceWorld 

281 events = EVENT_MAP.keys() 1a

282 if 'stimCenter_times' in qc.qc.extractor.data: 1a

283 events = map(lambda x: x.replace('stimFreeze', 'stimCenter'), events) 

284 

285 # Run QC and plot 

286 w = ViewEphysQC.viewqc(wheel=qc.get_wheel_data()) 1a

287 qc.create_plots(w.wplot.canvas.ax, 1a

288 wheel_axes=w.wplot.canvas.ax2, 

289 trial_events=list(events), 

290 color_map=cm, 

291 linestyle=ls) 

292 # Update table and callbacks 

293 w.update_df(qc.frame) 1a

294 qt.run_app() 1a

295 return qc 1a

296 

297 

298def qc_gui_cli(): 

299 """Run TaskQC viewer with wheel data. 

300 

301 For information on the QC checks see the QC Flags & failures document: 

302 https://docs.google.com/document/d/1X-ypFEIxqwX6lU9pig4V_zrcR5lITpd8UJQWzW9I9zI/edit# 

303 

304 Examples 

305 -------- 

306 >>> ipython task_qc.py c9fec76e-7a20-4da4-93ad-04510a89473b 

307 >>> ipython task_qc.py ./KS022/2019-12-10/001 --local 

308 """ 

309 # Parse parameters 

310 parser = argparse.ArgumentParser(description='Quick viewer to see the behaviour data from' 

311 'choice world sessions.') 

312 parser.add_argument('session', help='session uuid or path') 

313 parser.add_argument('--bpod', action='store_true', help='run QC on Bpod data only (no FPGA)') 

314 parser.add_argument('--local', action='store_true', help='run from disk location (lab server') 

315 args = parser.parse_args() # returns data from the options specified (echo) 

316 

317 show_session_task_qc(qc_or_session=args.session, bpod_only=args.bpod, local=args.local) 

318 

319 

320if __name__ == '__main__': 

321 qc_gui_cli()