Coverage for ibllib/io/extractors/bpod_trials.py: 85%
61 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1"""Trials data extraction from raw Bpod output
2This module will extract the Bpod trials and wheel data based on the task protocol,
3i.e. habituation, training or biased.
4"""
5import logging
6import importlib
7from collections import OrderedDict
8import warnings
10from pkg_resources import parse_version
11from ibllib.io.extractors import habituation_trials, training_trials, biased_trials, opto_trials
12from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor
13from ibllib.io.extractors.habituation_trials import HabituationTrials
14from ibllib.io.extractors.training_trials import TrainingTrials
15from ibllib.io.extractors.biased_trials import BiasedTrials, EphysTrials
16from ibllib.io.extractors.base import get_session_extractor_type, BaseBpodTrialsExtractor
17import ibllib.io.raw_data_loaders as rawio
19_logger = logging.getLogger(__name__)
22def extract_all(session_path, save=True, bpod_trials=None, settings=None,
23 task_collection='raw_behavior_data', extractor_type=None, save_path=None):
24 """
25 Extracts a training session from its path. NB: Wheel must be extracted first in order to
26 extract trials.firstMovement_times.
28 Parameters
29 ----------
30 session_path : str, pathlib.Path
31 The path to the session to be extracted.
32 task_collection : str
33 The subfolder containing the raw Bpod data files.
34 save : bool
35 If true, save the output files to save_path.
36 bpod_trials : list of dict
37 The loaded Bpod trial data. If None, attempts to load _iblrig_taskData.raw from
38 raw_task_collection.
39 settings : dict
40 The loaded Bpod settings. If None, attempts to load _iblrig_taskSettings.raw from
41 raw_task_collection.
42 extractor_type : str
43 The type of extraction. Supported types are {'ephys', 'biased', 'biased_opto',
44 'ephys_biased_opto', 'training', 'ephys_training', 'habituation'}. If None, extractor type
45 determined from settings.
46 save_path : str, pathlib.Path
47 The location of the output files if save is true. Defaults to <session_path>/alf.
49 Returns
50 -------
51 dict
52 The extracted trials data.
53 dict
54 The extracted wheel data.
55 list of pathlib.Path
56 The output files if save is true.
57 """
58 warnings.warn('`extract_all` functions soon to be deprecated, use `bpod_trials.get_bpod_extractor` instead', FutureWarning) 1aescfdghijklmnobtpqr
59 if not extractor_type: 1aescfdghijklmnobtpqr
60 extractor_type = get_session_extractor_type(session_path, task_collection=task_collection) 1aescfdghijklmnobtpqr
61 _logger.info(f'Extracting {session_path} as {extractor_type}') 1aescfdghijklmnobtpqr
62 bpod_trials = bpod_trials or rawio.load_data(session_path, task_collection=task_collection) 1aescfdghijklmnobtpqr
63 settings = settings or rawio.load_settings(session_path, task_collection=task_collection) 1aescfdghijklmnobtpqr
64 _logger.info(f'{extractor_type} session on {settings["PYBPOD_BOARD"]}') 1aescfdghijklmnobtpqr
66 # Determine which additional extractors are required
67 extra = [] 1aescfdghijklmnobtpqr
68 if extractor_type == 'ephys': # Should exclude 'ephys_biased' 1aescfdghijklmnobtpqr
69 _logger.debug('Engaging biased TrialsTableEphys') 1afghijnopqr
70 extra.append(biased_trials.TrialsTableEphys) 1afghijnopqr
71 if extractor_type in ['biased_opto', 'ephys_biased_opto']: 1aescfdghijklmnobtpqr
72 _logger.debug('Engaging opto_trials LaserBool') 1klm
73 extra.append(opto_trials.LaserBool) 1klm
75 # Determine base extraction
76 if extractor_type in ['training', 'ephys_training']: 1aescfdghijklmnobtpqr
77 trials, files_trials = training_trials.extract_all(session_path, bpod_trials=bpod_trials, settings=settings, save=save, 1eb
78 task_collection=task_collection, save_path=save_path)
79 # This is hacky but avoids extracting the wheel twice.
80 # files_trials should contain wheel files at the end.
81 files_wheel = [] 1eb
82 wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k}) 1eb
83 elif 'biased' in extractor_type or 'ephys' in extractor_type: 1aescfdghijklmnobtpqr
84 trials, files_trials = biased_trials.extract_all( 1aesfdghijklmnobpqr
85 session_path, bpod_trials=bpod_trials, settings=settings, save=save, extra_classes=extra,
86 task_collection=task_collection, save_path=save_path)
88 files_wheel = [] 1aesfdghijklmnobpqr
89 wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k}) 1aesfdghijklmnobpqr
90 elif extractor_type == 'habituation': 1ct
91 if settings['IBLRIG_VERSION_TAG'] and \ 1ct
92 parse_version(settings['IBLRIG_VERSION_TAG']) <= parse_version('5.0.0'):
93 _logger.warning('No extraction of legacy habituation sessions') 1t
94 return None, None, None 1t
95 trials, files_trials = habituation_trials.extract_all(session_path, bpod_trials=bpod_trials, settings=settings, save=save, 1c
96 task_collection=task_collection, save_path=save_path)
97 wheel = None 1c
98 files_wheel = [] 1c
99 else:
100 raise ValueError(f'No extractor for task {extractor_type}')
101 _logger.info('session extracted \n') # timing info in log 1aescfdghijklmnobpqr
102 return trials, wheel, (files_trials + files_wheel) if save else None 1aescfdghijklmnobpqr
105def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavior_data') -> BaseBpodTrialsExtractor:
106 """
107 Returns an extractor for a given session.
109 Parameters
110 ----------
111 session_path : str, pathlib.Path
112 The path to the session to be extracted.
113 protocol : str, optional
114 The protocol name, otherwise uses the PYBPOD_PROTOCOL key in iblrig task settings files.
115 task_collection : str
116 The folder within the session that contains the raw task data.
118 Returns
119 -------
120 BaseBpodTrialsExtractor
121 An instance of the task extractor class, instantiated with the session path.
122 """
123 builtins = { 1ucvwdxb
124 'HabituationTrials': HabituationTrials,
125 'TrainingTrials': TrainingTrials,
126 'BiasedTrials': BiasedTrials,
127 'EphysTrials': EphysTrials
128 }
129 if protocol: 1ucvwdxb
130 class_name = protocol2extractor(protocol) 1x
131 else:
132 class_name = get_bpod_extractor_class(session_path, task_collection=task_collection) 1ucvwdb
133 if class_name in builtins: 1ucvwdxb
134 return builtins[class_name](session_path) 1ucvwdxb
136 # look if there are custom extractor types in the personal projects repo
137 if not class_name.startswith('projects.'):
138 class_name = 'projects.' + class_name
139 module, class_name = class_name.rsplit('.', 1)
140 mdl = importlib.import_module(module)
141 extractor_class = getattr(mdl, class_name, None)
142 if extractor_class:
143 return extractor_class(session_path)
144 else:
145 raise ValueError(f'extractor {class_name} not found')