Coverage for ibllib/io/extractors/bpod_trials.py: 96%
24 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""Trials data extraction from raw Bpod output.
3This module will extract the Bpod trials and wheel data based on the task protocol,
4i.e. habituation, training or biased.
5"""
6import importlib
8from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor, BaseExtractor
9from ibllib.io.extractors.habituation_trials import HabituationTrials
10from ibllib.io.extractors.training_trials import TrainingTrials
11from ibllib.io.extractors.biased_trials import BiasedTrials, EphysTrials
12from ibllib.io.extractors.base import BaseBpodTrialsExtractor
15def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavior_data') -> BaseBpodTrialsExtractor:
16 """
17 Returns an extractor for a given session.
19 Parameters
20 ----------
21 session_path : str, pathlib.Path
22 The path to the session to be extracted.
23 protocol : str, optional
24 The protocol name, otherwise uses the PYBPOD_PROTOCOL key in iblrig task settings files.
25 task_collection : str
26 The folder within the session that contains the raw task data.
28 Returns
29 -------
30 BaseBpodTrialsExtractor
31 An instance of the task extractor class, instantiated with the session path.
32 """
33 builtins = { 1bdaefghijklcmnopqrst
34 'HabituationTrials': HabituationTrials,
35 'TrainingTrials': TrainingTrials,
36 'BiasedTrials': BiasedTrials,
37 'EphysTrials': EphysTrials
38 }
40 if protocol: 1bdaefghijklcmnopqrst
41 extractor_class_name = protocol2extractor(protocol) 1elc
42 else:
43 extractor_class_name = get_bpod_extractor_class(session_path, task_collection=task_collection) 1bdafghijkcmnopqrst
44 if extractor_class_name in builtins: 1bdaefghijklcmnopqrst
45 return builtins[extractor_class_name](session_path) 1bdefghijklcmnopqrst
47 # look if there are custom extractor types in the personal projects repo
48 if not extractor_class_name.startswith('projects.'): 1a
49 extractor_class_name = 'projects.' + extractor_class_name 1a
50 module, extractor_class_name = extractor_class_name.rsplit('.', 1) 1a
51 mdl = importlib.import_module(module) 1a
52 extractor_class = getattr(mdl, extractor_class_name, None) 1a
53 if extractor_class: 1a
54 my_extractor = extractor_class(session_path) 1a
55 if not isinstance(my_extractor, BaseExtractor): 1a
56 raise ValueError( 1a
57 f"{my_extractor} should be an Extractor class inheriting from ibllib.io.extractors.base.BaseExtractor")
58 return my_extractor 1a
59 else:
60 raise ValueError(f'extractor {extractor_class_name} not found')