Coverage for ibllib/qc/qcplots.py: 0%

47 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Plots for trial QC 

2 

3Example: 

4 one = ONE() 

5 # Load data 

6 eid = 'c8ef527b-6f7f-4f08-8b99-5aeb9d2b3740 

7 # Run QC 

8 qc = TaskQC(eid, one=one) 

9 plot_results(qc) 

10 plt.show() 

11 

12""" 

13from collections import Counter, Sized 

14from pathlib import Path 

15from datetime import datetime 

16 

17import matplotlib.pyplot as plt 

18import pandas as pd 

19import seaborn as sns 

20 

21from ibllib.qc.task_metrics import TaskQC 

22 

23 

24def plot_results(qc_obj, save_path=None): 

25 if not isinstance(qc_obj, TaskQC): 

26 raise ValueError('Input must be TaskQC object') 

27 

28 if not qc_obj.passed: 

29 qc_obj.compute() 

30 outcome, results, outcomes = qc_obj.compute_session_status() 

31 

32 # Sort checks by outcome and print 

33 map = {k: [] for k in set(outcomes.values())} 

34 for k, v in outcomes.items(): 

35 map[v].append(k[6:]) 

36 for k, v in map.items(): 

37 print(f'The following checks were labelled {k}:') 

38 print('\n'.join(v), '\n') 

39 

40 # Collect some session details 

41 n_trials = qc_obj.extractor.data['intervals'].shape[0] 

42 det = qc_obj.one.get_details(qc_obj.eid) 

43 ref = f"{datetime.fromisoformat(det['start_time']).date()}_{det['number']:d}_{det['subject']}" 

44 title = ref + (' (Bpod data only)' if qc_obj.extractor.bpod_only else '') 

45 

46 # Sort into each category 

47 counts = Counter(outcomes.values()) 

48 plt.bar(range(len(counts)), counts.values(), align='center', tick_label=list(counts.keys())) 

49 plt.gcf().suptitle(title) 

50 plt.ylabel('# QC checks') 

51 plt.xlabel('outcome') 

52 

53 a4_dims = (11.7, 8.27) 

54 fig, (ax0, ax1) = plt.subplots(2, 1, figsize=a4_dims, constrained_layout=True) 

55 fig.suptitle(title) 

56 

57 # Plot failed trial level metrics 

58 def get_trial_level_failed(d): 

59 new_dict = {k[6:]: v for k, v in d.items() 

60 if outcomes[k] == 'FAIL' and isinstance(v, Sized) and len(v) == n_trials} 

61 return pd.DataFrame.from_dict(new_dict) 

62 sns.boxplot(data=get_trial_level_failed(qc_obj.metrics), orient='h', ax=ax0) 

63 ax0.set_yticklabels(ax0.get_yticklabels(), rotation=30, fontsize=8) 

64 ax0.set(xscale='symlog', title='Metrics (failed)', xlabel='metric values (units vary)') 

65 

66 # Plot failed trial level metrics 

67 sns.barplot(data=get_trial_level_failed(qc_obj.passed), orient='h', ax=ax1) 

68 ax1.set_yticklabels(ax1.get_yticklabels(), rotation=30, fontsize=8) 

69 ax1.set(title='Counts', xlabel='proportion of trials that passed') 

70 

71 if save_path is not None: 

72 save_path = Path(save_path) 

73 

74 if save_path.is_dir() and not save_path.exists(): 

75 print(f"Folder {save_path} does not exist, not saving...") 

76 elif save_path.is_dir(): 

77 fig.savefig(save_path.joinpath(f"{ref}_QC.png")) 

78 else: 

79 fig.savefig(save_path)