Coverage for ibllib/qc/task_qc_viewer/task_qc.py: 88%
148 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1import logging
2import argparse
3from itertools import cycle
4import random
5from collections.abc import Sized
6from pathlib import Path
8import pandas as pd
9import numpy as np
10from matplotlib.colors import TABLEAU_COLORS
11from one.api import ONE
12from one.alf.spec import is_session_path
14import ibllib.plots as plots
15from ibllib.misc import qt
16from ibllib.qc.task_metrics import TaskQC
17from ibllib.qc.task_qc_viewer import ViewEphysQC
18from ibllib.pipes.dynamic_pipeline import get_trials_tasks
19from ibllib.pipes.base_tasks import BehaviourTask
20from ibllib.pipes.behavior_tasks import HabituationTrialsBpod, ChoiceWorldTrialsBpod
21from ibllib.pipes.training_preprocessing import TrainingTrials
23EVENT_MAP = {'goCue_times': ['#2ca02c', 'solid'], # green
24 'goCueTrigger_times': ['#2ca02c', 'dotted'], # green
25 'errorCue_times': ['#d62728', 'solid'], # red
26 'errorCueTrigger_times': ['#d62728', 'dotted'], # red
27 'valveOpen_times': ['#17becf', 'solid'], # cyan
28 'stimFreeze_times': ['#0000ff', 'solid'], # blue
29 'stimFreezeTrigger_times': ['#0000ff', 'dotted'], # blue
30 'stimOff_times': ['#9400d3', 'solid'], # dark violet
31 'stimOffTrigger_times': ['#9400d3', 'dotted'], # dark violet
32 'stimOn_times': ['#e377c2', 'solid'], # pink
33 'stimOnTrigger_times': ['#e377c2', 'dotted'], # pink
34 'response_times': ['#8c564b', 'solid'], # brown
35 }
36cm = [EVENT_MAP[k][0] for k in EVENT_MAP]
37ls = [EVENT_MAP[k][1] for k in EVENT_MAP]
38CRITICAL_CHECKS = (
39 'check_audio_pre_trial',
40 'check_correct_trial_event_sequence',
41 'check_error_trial_event_sequence',
42 'check_n_trial_events',
43 'check_response_feedback_delays',
44 'check_response_stimFreeze_delays',
45 'check_reward_volume_set',
46 'check_reward_volumes',
47 'check_stimOn_goCue_delays',
48 'check_stimulus_move_before_goCue',
49 'check_wheel_move_before_feedback',
50 'check_wheel_freeze_during_quiescence'
51)
54_logger = logging.getLogger(__name__)
57class QcFrame:
59 qc = None
60 """ibllib.qc.task_metrics.TaskQC: A TaskQC object containing extracted data"""
62 frame = None
63 """pandas.DataFrame: A table of failing trial-level QC metrics."""
65 def __init__(self, qc):
66 """
67 An interactive display of task QC data.
69 Parameters
70 ----------
71 qc : ibllib.qc.task_metrics.TaskQC
72 A TaskQC object containing extracted data for plotting.
73 """
74 assert qc.extractor and qc.metrics, 'Please run QC before passing to QcFrame' 1a
75 self.qc = qc 1a
77 # Print failed
78 outcome, results, outcomes = self.qc.compute_session_status() 1a
79 map = {k: [] for k in set(outcomes.values())} 1a
80 for k, v in outcomes.items(): 1a
81 map[v].append(k[6:]) 1a
82 for k, v in map.items(): 1a
83 if k == 'PASS': 1a
84 continue
85 print(f'The following checks were labelled {k}:') 1a
86 print('\n'.join(v), '\n') 1a
88 print('The following *critical* checks did not pass:') 1a
89 critical_checks = [f'_{x.replace("check", "task")}' for x in CRITICAL_CHECKS] 1a
90 for k, v in outcomes.items(): 1a
91 if v != 'PASS' and k in critical_checks: 1a
92 print(k[6:])
94 # Make DataFrame from the trail level metrics
95 def get_trial_level_failed(d): 1a
96 new_dict = {k[6:]: v for k, v in d.items() if 1a
97 isinstance(v, Sized) and len(v) == self.n_trials}
98 return pd.DataFrame.from_dict(new_dict) 1a
100 self.frame = get_trial_level_failed(self.qc.metrics) 1a
101 self.frame['intervals_0'] = self.qc.extractor.data['intervals'][:, 0] 1a
102 self.frame['intervals_1'] = self.qc.extractor.data['intervals'][:, 1] 1a
103 self.frame.insert(loc=0, column='trial_no', value=self.frame.index) 1a
105 @property
106 def n_trials(self):
107 return self.qc.extractor.data['intervals'].shape[0]
109 def get_wheel_data(self):
110 return {'re_pos': self.qc.extractor.data.get('wheel_position', np.array([])), 1a
111 're_ts': self.qc.extractor.data.get('wheel_timestamps', np.array([]))}
113 def create_plots(self, axes,
114 wheel_axes=None, trial_events=None, color_map=None, linestyle=None):
115 """
116 Plots the data for bnc1 (sound) and bnc2 (frame2ttl).
118 :param axes: An axes handle on which to plot the TTL events
119 :param wheel_axes: An axes handle on which to plot the wheel trace
120 :param trial_events: A list of Bpod trial events to plot, e.g. ['stimFreeze_times'],
121 if None, valve, sound and stimulus events are plotted
122 :param color_map: A color map to use for the events, default is the tableau color map
123 linestyle: A line style map to use for the events, default is random.
124 :return: None
125 """
126 color_map = color_map or TABLEAU_COLORS.keys() 1a
127 if trial_events is None: 1a
128 # Default trial events to plot as vertical lines
129 trial_events = [
130 'goCue_times',
131 'goCueTrigger_times',
132 'feedback_times',
133 ('stimCenter_times'
134 if 'stimCenter_times' in self.qc.extractor.data
135 else 'stimFreeze_times'), # handle habituationChoiceWorld exception
136 'stimOff_times',
137 'stimOn_times'
138 ]
140 plot_args = { 1a
141 'ymin': 0,
142 'ymax': 4,
143 'linewidth': 2,
144 'ax': axes
145 }
147 bnc1 = self.qc.extractor.frame_ttls 1a
148 bnc2 = self.qc.extractor.audio_ttls 1a
149 trial_data = self.qc.extractor.data 1a
151 if bnc1['times'].size: 1a
152 plots.squares(bnc1['times'], bnc1['polarities'] * 0.4 + 1, ax=axes, color='k') 1a
153 if bnc2['times'].size: 1a
154 plots.squares(bnc2['times'], bnc2['polarities'] * 0.4 + 2, ax=axes, color='k') 1a
155 linestyle = linestyle or random.choices(('-', '--', '-.', ':'), k=len(trial_events)) 1a
157 if self.qc.extractor.bpod_ttls is not None: 1a
158 bpttls = self.qc.extractor.bpod_ttls 1a
159 plots.squares(bpttls['times'], bpttls['polarities'] * 0.4 + 3, ax=axes, color='k') 1a
160 plot_args['ymax'] = 4 1a
161 ylabels = ['', 'frame2ttl', 'sound', 'bpod', ''] 1a
162 else:
163 plot_args['ymax'] = 3
164 ylabels = ['', 'frame2ttl', 'sound', '']
166 for event, c, l in zip(trial_events, cycle(color_map), linestyle): 1a
167 if event in trial_data: 1a
168 plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=l, **plot_args)
170 axes.legend(loc='upper left', fontsize='xx-small', bbox_to_anchor=(1, 0.5)) 1a
171 axes.set_yticks(list(range(plot_args['ymax'] + 1))) 1a
172 axes.set_yticklabels(ylabels) 1a
173 axes.set_ylim([0, plot_args['ymax']]) 1a
175 if wheel_axes: 1a
176 wheel_data = self.get_wheel_data() 1a
177 wheel_plot_args = { 1a
178 'ax': wheel_axes,
179 'ymin': wheel_data['re_pos'].min() if wheel_data['re_pos'].size else 0,
180 'ymax': wheel_data['re_pos'].max() if wheel_data['re_pos'].size else 1}
181 plot_args = {**plot_args, **wheel_plot_args} 1a
183 wheel_axes.plot(wheel_data['re_ts'], wheel_data['re_pos'], 'k-x') 1a
184 for event, c, ln in zip(trial_events, cycle(color_map), linestyle): 1a
185 if event in trial_data: 1a
186 plots.vertical_lines(trial_data[event],
187 label=event, color=c, linestyle=ln, **plot_args)
190def get_bpod_trials_task(task):
191 """
192 Return the correct trials task for extracting only the Bpod trials.
194 Parameters
195 ----------
196 task : ibllib.pipes.tasks.Task
197 A pipeline task from which to derive the Bpod trials task.
199 Returns
200 -------
201 ibllib.pipes.tasks.Task
202 A Bpod choice world trials task instance.
203 """
204 if isinstance(task, TrainingTrials) or task.__class__ in (ChoiceWorldTrialsBpod, HabituationTrialsBpod): 1c
205 pass # do nothing; already Bpod only 1c
206 elif isinstance(task, BehaviourTask): 1c
207 # A dynamic pipeline task
208 trials_class = HabituationTrialsBpod if 'habituation' in task.protocol else ChoiceWorldTrialsBpod 1c
209 task = trials_class(task.session_path, 1c
210 collection=task.collection, protocol_number=task.protocol_number,
211 protocol=task.protocol, one=task.one)
212 else: # A legacy pipeline task (should be EphysTrials as there are no other options)
213 task = TrainingTrials(task.session_path, one=task.one) 1c
214 return task 1c
217def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=None, protocol_number=None):
218 """
219 Displays the task QC for a given session.
221 NB: For this to work, all behaviour trials task classes must implement a `run_qc` method.
223 Parameters
224 ----------
225 qc_or_session : str, pathlib.Path, ibllib.qc.task_metrics.TaskQC, QcFrame
226 An experiment ID, session path, or TaskQC object.
227 bpod_only : bool
228 If true, display Bpod extracted events instead of data from the DAQ.
229 local : bool
230 If true, asserts all data local (i.e. do not attempt to download missing datasets).
231 one : one.api.One
232 An instance of ONE.
233 protocol_number : int
234 If not None, displays the QC for the protocol number provided. Argument is ignored if
235 `qc_or_session` is a TaskQC object or QcFrame instance.
237 Returns
238 -------
239 QcFrame
240 The QcFrame object.
241 """
242 if isinstance(qc_or_session, QcFrame): 1a
243 qc = qc_or_session 1a
244 elif isinstance(qc_or_session, TaskQC): 1a
245 qc = QcFrame(qc_or_session) 1a
246 else: # assumed to be eid or session path
247 one = one or ONE(mode='local' if local else 'auto') 1a
248 if not is_session_path(Path(qc_or_session)): 1a
249 eid = one.to_eid(qc_or_session)
250 session_path = one.eid2path(eid)
251 else:
252 session_path = Path(qc_or_session) 1a
253 tasks = get_trials_tasks(session_path, one=None if local else one) 1a
254 # Get the correct task and ensure not passive
255 if protocol_number is None: 1a
256 if not (task := next((t for t in tasks if 'passive' not in t.name.lower()), None)): 1a
257 raise ValueError('No non-passive behaviour tasks found for session ' + '/'.join(session_path.parts[-3:])) 1a
258 elif not isinstance(protocol_number, int) or protocol_number < 0: 1a
259 raise TypeError('Protocol number must be a positive integer') 1a
260 elif protocol_number > len(tasks) - 1: 1a
261 raise ValueError('Invalid protocol number') 1a
262 else:
263 task = tasks[protocol_number] 1a
264 if 'passive' in task.name.lower(): 1a
265 raise ValueError('QC display not supported for passive protocols') 1a
266 # If Bpod only and not a dynamic pipeline Bpod behaviour task OR legacy TrainingTrials task
267 if bpod_only and 'bpod' not in task.name.lower(): 1a
268 # Use the dynamic pipeline Bpod behaviour task instead (should work with legacy pipeline too)
269 task = get_bpod_trials_task(task) 1a
270 _logger.debug('Using %s task', task.name) 1a
271 # Ensure required data are present
272 task.location = 'server' if local else 'remote' # affects whether missing data are downloaded 1a
273 task.setUp() 1a
274 if local: # currently setUp does not raise on missing data 1a
275 task.assert_expected_inputs(raise_error=True) 1a
276 # Compute the QC and build the frame
277 task_qc = task.run_qc(update=False) 1a
278 qc = QcFrame(task_qc) 1a
280 # Handle trial event names in habituationChoiceWorld
281 events = EVENT_MAP.keys() 1a
282 if 'stimCenter_times' in qc.qc.extractor.data: 1a
283 events = map(lambda x: x.replace('stimFreeze', 'stimCenter'), events)
285 # Run QC and plot
286 w = ViewEphysQC.viewqc(wheel=qc.get_wheel_data()) 1a
287 qc.create_plots(w.wplot.canvas.ax, 1a
288 wheel_axes=w.wplot.canvas.ax2,
289 trial_events=list(events),
290 color_map=cm,
291 linestyle=ls)
292 # Update table and callbacks
293 w.update_df(qc.frame) 1a
294 qt.run_app() 1a
295 return qc 1a
298def qc_gui_cli():
299 """Run TaskQC viewer with wheel data.
301 For information on the QC checks see the QC Flags & failures document:
302 https://docs.google.com/document/d/1X-ypFEIxqwX6lU9pig4V_zrcR5lITpd8UJQWzW9I9zI/edit#
304 Examples
305 --------
306 >>> ipython task_qc.py c9fec76e-7a20-4da4-93ad-04510a89473b
307 >>> ipython task_qc.py ./KS022/2019-12-10/001 --local
308 """
309 # Parse parameters
310 parser = argparse.ArgumentParser(description='Quick viewer to see the behaviour data from'
311 'choice world sessions.')
312 parser.add_argument('session', help='session uuid or path')
313 parser.add_argument('--bpod', action='store_true', help='run QC on Bpod data only (no FPGA)')
314 parser.add_argument('--local', action='store_true', help='run from disk location (lab server')
315 args = parser.parse_args() # returns data from the options specified (echo)
317 show_session_task_qc(qc_or_session=args.session, bpod_only=args.bpod, local=args.local)
320if __name__ == '__main__':
321 qc_gui_cli()