Coverage for ibllib/qc/task_qc_viewer/task_qc.py: 88%
154 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1import logging
2import argparse
3from itertools import cycle
4import random
5from collections.abc import Sized
6from pathlib import Path
8import pandas as pd
9import numpy as np
10from matplotlib.colors import TABLEAU_COLORS
11from one.api import ONE
12from one.alf.spec import is_session_path
14import ibllib.plots as plots
15from ibllib.misc import qt
16from ibllib.qc.task_metrics import TaskQC
17from ibllib.qc.task_qc_viewer import ViewEphysQC
18from ibllib.pipes.dynamic_pipeline import get_trials_tasks
19from ibllib.pipes.base_tasks import BehaviourTask
20from ibllib.pipes.behavior_tasks import HabituationTrialsBpod, ChoiceWorldTrialsBpod
22EVENT_MAP = {'goCue_times': ['#2ca02c', 'solid'], # green
23 'goCueTrigger_times': ['#2ca02c', 'dotted'], # green
24 'errorCue_times': ['#d62728', 'solid'], # red
25 'errorCueTrigger_times': ['#d62728', 'dotted'], # red
26 'valveOpen_times': ['#17becf', 'solid'], # cyan
27 'stimFreeze_times': ['#0000ff', 'solid'], # blue
28 'stimFreezeTrigger_times': ['#0000ff', 'dotted'], # blue
29 'stimOff_times': ['#9400d3', 'solid'], # dark violet
30 'stimOffTrigger_times': ['#9400d3', 'dotted'], # dark violet
31 'stimOn_times': ['#e377c2', 'solid'], # pink
32 'stimOnTrigger_times': ['#e377c2', 'dotted'], # pink
33 'response_times': ['#8c564b', 'solid'], # brown
34 }
35cm = [EVENT_MAP[k][0] for k in EVENT_MAP]
36ls = [EVENT_MAP[k][1] for k in EVENT_MAP]
37CRITICAL_CHECKS = (
38 'check_audio_pre_trial',
39 'check_correct_trial_event_sequence',
40 'check_error_trial_event_sequence',
41 'check_n_trial_events',
42 'check_response_feedback_delays',
43 'check_response_stimFreeze_delays',
44 'check_reward_volume_set',
45 'check_reward_volumes',
46 'check_stimOn_goCue_delays',
47 'check_stimulus_move_before_goCue',
48 'check_wheel_move_before_feedback',
49 'check_wheel_freeze_during_quiescence'
50)
53_logger = logging.getLogger(__name__)
56class QcFrame:
58 qc = None
59 """ibllib.qc.task_metrics.TaskQC: A TaskQC object containing extracted data"""
61 frame = None
62 """pandas.DataFrame: A table of failing trial-level QC metrics."""
64 def __init__(self, qc):
65 """
66 An interactive display of task QC data.
68 Parameters
69 ----------
70 qc : ibllib.qc.task_metrics.TaskQC
71 A TaskQC object containing extracted data for plotting.
72 """
73 assert qc.extractor and qc.metrics, 'Please run QC before passing to QcFrame' 1a
74 self.qc = qc 1a
76 # Print failed
77 outcome, results, outcomes = self.qc.compute_session_status() 1a
78 map = {k: [] for k in set(outcomes.values())} 1a
79 for k, v in outcomes.items(): 1a
80 map[v].append(k[6:]) 1a
81 for k, v in map.items(): 1a
82 if k == 'PASS': 1a
83 continue
84 print(f'The following checks were labelled {k}:') 1a
85 print('\n'.join(v), '\n') 1a
87 print('The following *critical* checks did not pass:') 1a
88 critical_checks = [f'_{x.replace("check", "task")}' for x in CRITICAL_CHECKS] 1a
89 for k, v in outcomes.items(): 1a
90 if v != 'PASS' and k in critical_checks: 1a
91 print(k[6:])
93 # Make DataFrame from the trail level metrics
94 def get_trial_level_failed(d): 1a
95 new_dict = {k[6:]: v for k, v in d.items() if 1a
96 isinstance(v, Sized) and len(v) == self.n_trials}
97 return pd.DataFrame.from_dict(new_dict) 1a
99 self.frame = get_trial_level_failed(self.qc.metrics) 1a
100 self.frame['intervals_0'] = self.qc.extractor.data['intervals'][:, 0] 1a
101 self.frame['intervals_1'] = self.qc.extractor.data['intervals'][:, 1] 1a
102 self.frame.insert(loc=0, column='trial_no', value=self.frame.index) 1a
104 @property
105 def n_trials(self):
106 return self.qc.extractor.data['intervals'].shape[0]
108 def get_wheel_data(self):
109 return {'re_pos': self.qc.extractor.data.get('wheel_position', np.array([])), 1a
110 're_ts': self.qc.extractor.data.get('wheel_timestamps', np.array([]))}
112 def create_plots(self, axes,
113 wheel_axes=None, trial_events=None, color_map=None, linestyle=None):
114 """
115 Plots the data for bnc1 (sound) and bnc2 (frame2ttl).
117 :param axes: An axes handle on which to plot the TTL events
118 :param wheel_axes: An axes handle on which to plot the wheel trace
119 :param trial_events: A list of Bpod trial events to plot, e.g. ['stimFreeze_times'],
120 if None, valve, sound and stimulus events are plotted
121 :param color_map: A color map to use for the events, default is the tableau color map
122 linestyle: A line style map to use for the events, default is random.
123 :return: None
124 """
125 color_map = color_map or TABLEAU_COLORS.keys() 1a
126 if trial_events is None: 1a
127 # Default trial events to plot as vertical lines
128 trial_events = [
129 'goCue_times',
130 'goCueTrigger_times',
131 'feedback_times',
132 ('stimCenter_times'
133 if 'stimCenter_times' in self.qc.extractor.data
134 else 'stimFreeze_times'), # handle habituationChoiceWorld exception
135 'stimOff_times',
136 'stimOn_times'
137 ]
139 plot_args = { 1a
140 'ymin': 0,
141 'ymax': 4,
142 'linewidth': 2,
143 'ax': axes,
144 'alpha': 0.5,
145 }
147 bnc1 = self.qc.extractor.frame_ttls 1a
148 bnc2 = self.qc.extractor.audio_ttls 1a
149 trial_data = self.qc.extractor.data 1a
151 if bnc1['times'].size: 1a
152 plots.squares(bnc1['times'], bnc1['polarities'] * 0.4 + 1, ax=axes, color='k') 1a
153 if bnc2['times'].size: 1a
154 plots.squares(bnc2['times'], bnc2['polarities'] * 0.4 + 2, ax=axes, color='k') 1a
155 linestyle = linestyle or random.choices(('-', '--', '-.', ':'), k=len(trial_events)) 1a
157 if self.qc.extractor.bpod_ttls is not None: 1a
158 bpttls = self.qc.extractor.bpod_ttls 1a
159 plots.squares(bpttls['times'], bpttls['polarities'] * 0.4 + 3, ax=axes, color='k') 1a
160 plot_args['ymax'] = 4 1a
161 ylabels = ['', 'frame2ttl', 'sound', 'bpod', ''] 1a
162 else:
163 plot_args['ymax'] = 3
164 ylabels = ['', 'frame2ttl', 'sound', '']
166 for event, c, l in zip(trial_events, cycle(color_map), linestyle): 1a
167 if event in trial_data: 1a
168 plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=l, **plot_args)
170 axes.legend(loc='upper left', fontsize='xx-small', bbox_to_anchor=(1, 0.5)) 1a
171 axes.set_yticks(list(range(plot_args['ymax'] + 1))) 1a
172 axes.set_yticklabels(ylabels) 1a
173 axes.set_ylim([0, plot_args['ymax']]) 1a
175 if wheel_axes: 1a
176 wheel_data = self.get_wheel_data() 1a
177 wheel_plot_args = { 1a
178 'ax': wheel_axes,
179 'ymin': wheel_data['re_pos'].min() if wheel_data['re_pos'].size else 0,
180 'ymax': wheel_data['re_pos'].max() if wheel_data['re_pos'].size else 1}
181 plot_args = {**plot_args, **wheel_plot_args} 1a
183 wheel_axes.plot(wheel_data['re_ts'], wheel_data['re_pos'], 'k-x') 1a
184 for event, c, ln in zip(trial_events, cycle(color_map), linestyle): 1a
185 if event in trial_data: 1a
186 plots.vertical_lines(trial_data[event],
187 label=event, color=c, linestyle=ln, **plot_args)
190def get_bpod_trials_task(task):
191 """
192 Return the correct trials task for extracting only the Bpod trials.
194 Parameters
195 ----------
196 task : ibllib.pipes.tasks.Task
197 A pipeline task from which to derive the Bpod trials task.
199 Returns
200 -------
201 ibllib.pipes.tasks.Task
202 A Bpod choice world trials task instance.
203 """
204 if task.__class__ in (ChoiceWorldTrialsBpod, HabituationTrialsBpod): 1c
205 pass # do nothing; already Bpod only 1c
206 else:
207 assert isinstance(task, BehaviourTask) 1c
208 # A dynamic pipeline task
209 trials_class = HabituationTrialsBpod if 'habituation' in task.protocol else ChoiceWorldTrialsBpod 1c
210 task = trials_class(task.session_path, 1c
211 collection=task.collection, protocol_number=task.protocol_number,
212 protocol=task.protocol, one=task.one)
213 return task 1c
216def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=None, protocol_number=None):
217 """
218 Displays the task QC for a given session.
220 NB: For this to work, all behaviour trials task classes must implement a `run_qc` method.
222 Parameters
223 ----------
224 qc_or_session : str, pathlib.Path, ibllib.qc.task_metrics.TaskQC, QcFrame
225 An experiment ID, session path, or TaskQC object.
226 bpod_only : bool
227 If true, display Bpod extracted events instead of data from the DAQ.
228 local : bool
229 If true, asserts all data local (i.e. do not attempt to download missing datasets).
230 one : one.api.One
231 An instance of ONE.
232 protocol_number : int
233 If not None, displays the QC for the protocol number provided. Argument is ignored if
234 `qc_or_session` is a TaskQC object or QcFrame instance.
236 Returns
237 -------
238 QcFrame
239 The QcFrame object.
240 """
241 if isinstance(qc_or_session, QcFrame): 1a
242 qc = qc_or_session 1a
243 elif isinstance(qc_or_session, TaskQC): 1a
244 task_qc = qc_or_session 1a
245 qc = QcFrame(task_qc) 1a
246 else: # assumed to be eid or session path
247 one = one or ONE(mode='local' if local else 'remote') 1a
248 if not is_session_path(Path(qc_or_session)): 1a
249 eid = one.to_eid(qc_or_session)
250 session_path = one.eid2path(eid)
251 else:
252 session_path = Path(qc_or_session) 1a
254 tasks = get_trials_tasks(session_path, one=None if local else one, bpod_only=bpod_only) 1a
255 # Get the correct task and ensure not passive
256 if protocol_number is None: 1a
257 if not (task := next((t for t in tasks if 'passive' not in t.name.lower()), None)): 1a
258 raise ValueError('No non-passive behaviour tasks found for session ' + '/'.join(session_path.parts[-3:])) 1a
259 elif not isinstance(protocol_number, int) or protocol_number < 0: 1a
260 raise TypeError('Protocol number must be a positive integer') 1a
261 elif protocol_number > len(tasks) - 1: 1a
262 raise ValueError('Invalid protocol number') 1a
263 else:
264 task = tasks[protocol_number] 1a
265 if 'passive' in task.name.lower(): 1a
266 raise ValueError('QC display not supported for passive protocols') 1a
267 _logger.debug('Using %s task', task.name) 1a
268 # Ensure required data are present
269 task.location = 'server' if local else 'remote' # affects whether missing data are downloaded 1a
270 task.setUp() 1a
271 if local: # currently setUp does not raise on missing data 1a
272 task.assert_expected_inputs(raise_error=True) 1a
273 # Compute the QC and build the frame
274 task_qc = task.run_qc(update=False) 1a
275 qc = QcFrame(task_qc) 1a
277 # Handle trial event names in habituationChoiceWorld
278 events = EVENT_MAP.keys() 1a
279 if 'stimCenter_times' in qc.qc.extractor.data: 1a
280 events = map(lambda x: x.replace('stimFreeze', 'stimCenter'), events)
282 # Run QC and plot
283 w = ViewEphysQC.viewqc(wheel=qc.get_wheel_data()) 1a
284 qc.create_plots(w.wplot.canvas.ax, 1a
285 wheel_axes=w.wplot.canvas.ax2,
286 trial_events=list(events),
287 color_map=cm,
288 linestyle=ls)
290 # Update table and callbacks
291 n_trials = qc.frame.shape[0] 1a
292 if 'task_qc' in locals(): 1a
293 df_trials = pd.DataFrame({ 1a
294 k: v for k, v in task_qc.extractor.data.items()
295 if v.size == n_trials and not k.startswith('wheel')
296 })
297 df = df_trials.merge(qc.frame, left_index=True, right_index=True) 1a
298 else:
299 df = qc.frame 1a
300 df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) 1a
301 df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) 1a
302 df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) 1a
303 df = df.merge(df_pass.astype('boolean'), left_index=True, right_index=True) 1a
304 w.updateDataframe(df) 1a
305 qt.run_app() 1a
306 return qc 1a
309def qc_gui_cli():
310 """Run TaskQC viewer with wheel data.
312 For information on the QC checks see the QC Flags & failures document:
313 https://docs.google.com/document/d/1X-ypFEIxqwX6lU9pig4V_zrcR5lITpd8UJQWzW9I9zI/edit#
315 Examples
316 --------
317 >>> ipython task_qc.py c9fec76e-7a20-4da4-93ad-04510a89473b
318 >>> ipython task_qc.py ./KS022/2019-12-10/001 --local
319 """
320 # Parse parameters
321 parser = argparse.ArgumentParser(description='Quick viewer to see the behaviour data from'
322 'choice world sessions.')
323 parser.add_argument('session', help='session uuid or path')
324 parser.add_argument('--bpod', action='store_true', help='run QC on Bpod data only (no FPGA)')
325 parser.add_argument('--local', action='store_true', help='run from disk location (lab server')
326 args = parser.parse_args() # returns data from the options specified (echo)
328 show_session_task_qc(qc_or_session=args.session, bpod_only=args.bpod, local=args.local)
331if __name__ == '__main__':
332 qc_gui_cli()