Coverage for ibllib/qc/qcplots.py: 0%
47 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""Plots for trial QC
3Example:
4 one = ONE()
5 # Load data
6 eid = 'c8ef527b-6f7f-4f08-8b99-5aeb9d2b3740
7 # Run QC
8 qc = TaskQC(eid, one=one)
9 plot_results(qc)
10 plt.show()
12"""
13from collections import Counter, Sized
14from pathlib import Path
15from datetime import datetime
17import matplotlib.pyplot as plt
18import pandas as pd
19import seaborn as sns
21from ibllib.qc.task_metrics import TaskQC
24def plot_results(qc_obj, save_path=None):
25 if not isinstance(qc_obj, TaskQC):
26 raise ValueError('Input must be TaskQC object')
28 if not qc_obj.passed:
29 qc_obj.compute()
30 outcome, results, outcomes = qc_obj.compute_session_status()
32 # Sort checks by outcome and print
33 map = {k: [] for k in set(outcomes.values())}
34 for k, v in outcomes.items():
35 map[v].append(k[6:])
36 for k, v in map.items():
37 print(f'The following checks were labelled {k}:')
38 print('\n'.join(v), '\n')
40 # Collect some session details
41 n_trials = qc_obj.extractor.data['intervals'].shape[0]
42 det = qc_obj.one.get_details(qc_obj.eid)
43 ref = f"{datetime.fromisoformat(det['start_time']).date()}_{det['number']:d}_{det['subject']}"
44 title = ref + (' (Bpod data only)' if qc_obj.extractor.bpod_only else '')
46 # Sort into each category
47 counts = Counter(outcomes.values())
48 plt.bar(range(len(counts)), counts.values(), align='center', tick_label=list(counts.keys()))
49 plt.gcf().suptitle(title)
50 plt.ylabel('# QC checks')
51 plt.xlabel('outcome')
53 a4_dims = (11.7, 8.27)
54 fig, (ax0, ax1) = plt.subplots(2, 1, figsize=a4_dims, constrained_layout=True)
55 fig.suptitle(title)
57 # Plot failed trial level metrics
58 def get_trial_level_failed(d):
59 new_dict = {k[6:]: v for k, v in d.items()
60 if outcomes[k] == 'FAIL' and isinstance(v, Sized) and len(v) == n_trials}
61 return pd.DataFrame.from_dict(new_dict)
62 sns.boxplot(data=get_trial_level_failed(qc_obj.metrics), orient='h', ax=ax0)
63 ax0.set_yticklabels(ax0.get_yticklabels(), rotation=30, fontsize=8)
64 ax0.set(xscale='symlog', title='Metrics (failed)', xlabel='metric values (units vary)')
66 # Plot failed trial level metrics
67 sns.barplot(data=get_trial_level_failed(qc_obj.passed), orient='h', ax=ax1)
68 ax1.set_yticklabels(ax1.get_yticklabels(), rotation=30, fontsize=8)
69 ax1.set(title='Counts', xlabel='proportion of trials that passed')
71 if save_path is not None:
72 save_path = Path(save_path)
74 if save_path.is_dir() and not save_path.exists():
75 print(f"Folder {save_path} does not exist, not saving...")
76 elif save_path.is_dir():
77 fig.savefig(save_path.joinpath(f"{ref}_QC.png"))
78 else:
79 fig.savefig(save_path)