Coverage for ibllib/io/extractors/bpod_trials.py: 80%
61 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""Trials data extraction from raw Bpod output.
3This module will extract the Bpod trials and wheel data based on the task protocol,
4i.e. habituation, training or biased.
5"""
6import logging
7import importlib
8from collections import OrderedDict
9import warnings
11from packaging import version
12from ibllib.io.extractors import habituation_trials, training_trials, biased_trials, opto_trials
13from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor
14from ibllib.io.extractors.habituation_trials import HabituationTrials
15from ibllib.io.extractors.training_trials import TrainingTrials
16from ibllib.io.extractors.biased_trials import BiasedTrials, EphysTrials
17from ibllib.io.extractors.base import get_session_extractor_type, BaseBpodTrialsExtractor
18import ibllib.io.raw_data_loaders as rawio
20_logger = logging.getLogger(__name__)
23def extract_all(session_path, save=True, bpod_trials=None, settings=None,
24 task_collection='raw_behavior_data', extractor_type=None, save_path=None):
25 """
26 Extracts a training session from its path. NB: Wheel must be extracted first in order to
27 extract trials.firstMovement_times.
29 Parameters
30 ----------
31 session_path : str, pathlib.Path
32 The path to the session to be extracted.
33 task_collection : str
34 The subfolder containing the raw Bpod data files.
35 save : bool
36 If true, save the output files to save_path.
37 bpod_trials : list of dict
38 The loaded Bpod trial data. If None, attempts to load _iblrig_taskData.raw from
39 raw_task_collection.
40 settings : dict
41 The loaded Bpod settings. If None, attempts to load _iblrig_taskSettings.raw from
42 raw_task_collection.
43 extractor_type : str
44 The type of extraction. Supported types are {'ephys', 'biased', 'biased_opto',
45 'ephys_biased_opto', 'training', 'ephys_training', 'habituation'}. If None, extractor type
46 determined from settings.
47 save_path : str, pathlib.Path
48 The location of the output files if save is true. Defaults to <session_path>/alf.
50 Returns
51 -------
52 dict
53 The extracted trials data.
54 dict
55 The extracted wheel data.
56 list of pathlib.Path
57 The output files if save is true.
58 """
59 warnings.warn('`extract_all` functions soon to be removed, use `bpod_trials.get_bpod_extractor` instead', FutureWarning) 1abicdefgjh
60 if not extractor_type: 1abicdefgjh
61 extractor_type = get_session_extractor_type(session_path, task_collection=task_collection) 1abicdefgjh
62 _logger.info(f'Extracting {session_path} as {extractor_type}') 1abicdefgjh
63 bpod_trials = bpod_trials or rawio.load_data(session_path, task_collection=task_collection) 1abicdefgjh
64 settings = settings or rawio.load_settings(session_path, task_collection=task_collection) 1abicdefgjh
65 _logger.info(f'{extractor_type} session on {settings["PYBPOD_BOARD"]}') 1abicdefgjh
67 # Determine which additional extractors are required
68 extra = [] 1abicdefgjh
69 if extractor_type == 'ephys': # Should exclude 'ephys_biased' 1abicdefgjh
70 _logger.debug('Engaging biased TrialsTableEphys') 1afgh
71 extra.append(biased_trials.TrialsTableEphys) 1afgh
72 if extractor_type in ['biased_opto', 'ephys_biased_opto']: 1abicdefgjh
73 _logger.debug('Engaging opto_trials LaserBool') 1cde
74 extra.append(opto_trials.LaserBool) 1cde
76 # Determine base extraction
77 if extractor_type in ['training', 'ephys_training']: 1abicdefgjh
78 trials, files_trials = training_trials.extract_all(session_path, bpod_trials=bpod_trials, settings=settings, save=save, 1b
79 task_collection=task_collection, save_path=save_path)
80 # This is hacky but avoids extracting the wheel twice.
81 # files_trials should contain wheel files at the end.
82 files_wheel = [] 1b
83 wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k}) 1b
84 elif 'biased' in extractor_type or 'ephys' in extractor_type: 1abicdefgjh
85 trials, files_trials = biased_trials.extract_all( 1abicdefgh
86 session_path, bpod_trials=bpod_trials, settings=settings, save=save, extra_classes=extra,
87 task_collection=task_collection, save_path=save_path)
89 files_wheel = [] 1abicdefgh
90 wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k}) 1abicdefgh
91 elif extractor_type == 'habituation': 1j
92 if settings['IBLRIG_VERSION'] and \ 1j
93 version.parse(settings['IBLRIG_VERSION']) <= version.parse('5.0.0'):
94 _logger.warning('No extraction of legacy habituation sessions') 1j
95 return None, None, None 1j
96 trials, files_trials = habituation_trials.extract_all(session_path, bpod_trials=bpod_trials, settings=settings, save=save,
97 task_collection=task_collection, save_path=save_path)
98 wheel = None
99 files_wheel = []
100 else:
101 raise ValueError(f'No extractor for task {extractor_type}')
102 _logger.info('session extracted \n') # timing info in log 1abicdefgh
103 return trials, wheel, (files_trials + files_wheel) if save else None 1abicdefgh
106def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavior_data') -> BaseBpodTrialsExtractor:
107 """
108 Returns an extractor for a given session.
110 Parameters
111 ----------
112 session_path : str, pathlib.Path
113 The path to the session to be extracted.
114 protocol : str, optional
115 The protocol name, otherwise uses the PYBPOD_PROTOCOL key in iblrig task settings files.
116 task_collection : str
117 The folder within the session that contains the raw task data.
119 Returns
120 -------
121 BaseBpodTrialsExtractor
122 An instance of the task extractor class, instantiated with the session path.
123 """
124 builtins = { 1lmnopqrkstuvwxy
125 'HabituationTrials': HabituationTrials,
126 'TrainingTrials': TrainingTrials,
127 'BiasedTrials': BiasedTrials,
128 'EphysTrials': EphysTrials
129 }
130 if protocol: 1lmnopqrkstuvwxy
131 class_name = protocol2extractor(protocol) 1rk
132 else:
133 class_name = get_bpod_extractor_class(session_path, task_collection=task_collection) 1lmnopqkstuvwxy
134 if class_name in builtins: 1lmnopqrkstuvwxy
135 return builtins[class_name](session_path) 1lmnopqrkstuvwxy
137 # look if there are custom extractor types in the personal projects repo
138 if not class_name.startswith('projects.'):
139 class_name = 'projects.' + class_name
140 module, class_name = class_name.rsplit('.', 1)
141 mdl = importlib.import_module(module)
142 extractor_class = getattr(mdl, class_name, None)
143 if extractor_class:
144 return extractor_class(session_path)
145 else:
146 raise ValueError(f'extractor {class_name} not found')