Coverage for ibllib/io/extractors/base.py: 90%
148 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-02 18:55 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-02 18:55 +0100
1"""Base Extractor classes.
3A module for the base Extractor classes. The Extractor, given a session path, will extract the
4processed data from raw hardware files and optionally save them.
5"""
7import abc
8from collections import OrderedDict
9import json
10from pathlib import Path
12import numpy as np
13import pandas as pd
14from ibllib.io import raw_data_loaders as raw
15from ibllib.io.raw_data_loaders import load_settings, _logger
18class BaseExtractor(abc.ABC):
19 """
20 Base extractor class.
22 Writing an extractor checklist:
24 - on the child class, overload the _extract method
25 - this method should output one or several numpy.arrays or dataframe with a consistent shape
26 - save_names is a list or a string of filenames, there should be one per dataset
27 - set save_names to None for a dataset that doesn't need saving (could be set dynamically in
28 the _extract method)
30 :param session_path: Absolute path of session folder
31 :type session_path: str/Path
32 """
34 session_path = None
35 """pathlib.Path: Absolute path of session folder."""
37 save_names = None
38 """tuple of str: The filenames of each extracted dataset, or None if array should not be saved."""
40 var_names = None
41 """tuple of str: A list of names for the extracted variables. These become the returned output keys."""
43 default_path = Path('alf') # relative to session
44 """pathlib.Path: The default output folder relative to `session_path`."""
46 def __init__(self, session_path=None):
47 # If session_path is None Path(session_path) will fail
48 self.session_path = Path(session_path) 1a()*+JKLMNOPQRSTUHVWXYcZF,-4dlxyjefbI.9CD/:gmvtu!;=?#01$%23z'whkpnoq@[]^_5irs
50 def extract(self, save=False, path_out=None, **kwargs):
51 """
52 :return: dict of numpy.array, list of filenames
53 """
54 out = self._extract(**kwargs) 1aAJKLMNOPQRSTUHVWXYcZ4dlxyjefb9CDgmvtu!#01$%2367z'whkpnoq5irs
55 files = self._save(out, path_out=path_out) if save else None 1aAJKLMNOPQRSTUHVWXYcZ4dlxyjefb9CDgmvtu!#01$%2367z'whkpnoq5irs
56 return out, files 1aAJKLMNOPQRSTUHVWXYcZ4dlxyjefb9CDgmvtu!#01$%2367z'whkpnoq5irs
58 def _save(self, data, path_out=None):
59 # Check if self.save_names is of the same length of out
60 if not path_out: 1AcdlxyjefbCDgmvtuzwhki
61 path_out = self.session_path.joinpath(self.default_path) 1AcdlefbCDvtuwi
63 def _write_to_disk(file_path, data): 1AcdlxyjefbCDgmvtuzwhki
64 """Implements different save calls depending on file extension.
66 Parameters
67 ----------
68 file_path : pathlib.Path
69 The location to save the data.
70 data : pandas.DataFrame, numpy.ndarray
71 The data to save
73 """
74 csv_separators = { 1AcdlxyjefbCDgmvtuzwhki
75 ".csv": ",",
76 ".ssv": " ",
77 ".tsv": "\t"
78 }
79 # Ensure empty files are not created; we expect all datasets to have a non-zero size
80 if getattr(data, 'size', len(data)) == 0: 1AcdlxyjefbCDgmvtuzwhki
81 filename = file_path.relative_to(self.session_path).as_posix()
82 raise ValueError(f'Data for {filename} appears to be empty')
83 file_path = Path(file_path) 1AcdlxyjefbCDgmvtuzwhki
84 file_path.parent.mkdir(exist_ok=True, parents=True) 1AcdlxyjefbCDgmvtuzwhki
85 if file_path.suffix == ".npy": 1AcdlxyjefbCDgmvtuzwhki
86 np.save(file_path, data) 1AcdlxyjefbCDgmvtuzwhki
87 elif file_path.suffix in [".parquet", ".pqt"]: 1Acdxyjefbgzwhki
88 if not isinstance(data, pd.DataFrame): 1cdjefbghki
89 _logger.error("Data is not a panda's DataFrame object")
90 raise TypeError("Data is not a panda's DataFrame object")
91 data.to_parquet(file_path) 1cdjefbghki
92 elif file_path.suffix in csv_separators: 1Axyzw
93 sep = csv_separators[file_path.suffix] 1Axyzw
94 data.to_csv(file_path, sep=sep) 1Axyzw
95 # np.savetxt(file_path, data, delimiter=sep)
96 else:
97 _logger.error(f"Don't know how to save {file_path.suffix} files yet")
99 if self.save_names is None: 1AcdlxyjefbCDgmvtuzwhki
100 file_paths = []
101 elif isinstance(self.save_names, str): 1AcdlxyjefbCDgmvtuzwhki
102 file_paths = path_out.joinpath(self.save_names) 1bvtu
103 _write_to_disk(file_paths, data) 1bvtu
104 elif isinstance(data, dict): 1AcdlxyjefbCDgmzwhki
105 file_paths = [] 1cdljefbgmhki
106 for var, value in data.items(): 1cdljefbgmhki
107 if fn := self.save_names[self.var_names.index(var)]: 1cdljefbgmhki
108 fpath = path_out.joinpath(fn) 1cdljefbgmhki
109 _write_to_disk(fpath, value) 1cdljefbgmhki
110 file_paths.append(fpath) 1cdljefbgmhki
111 else: # Should be list or tuple...
112 assert len(data) == len(self.save_names) 1AxyCDzw
113 file_paths = [] 1AxyCDzw
114 for data, fn in zip(data, self.save_names): 1AxyCDzw
115 if fn: 1AxyCDzw
116 fpath = path_out.joinpath(fn) 1AxyCDzw
117 _write_to_disk(fpath, data) 1AxyCDzw
118 file_paths.append(fpath) 1AxyCDzw
119 return file_paths 1AcdlxyjefbCDgmvtuzwhki
121 @abc.abstractmethod
122 def _extract(self):
123 pass
126class BaseBpodTrialsExtractor(BaseExtractor):
127 """
128 Base (abstract) extractor class for bpod jsonable data set.
130 Wraps the _extract private method.
132 :param session_path: Absolute path of session folder.
133 :type session_path: str
134 :param bpod_trials
135 :param settings
136 """
138 bpod_trials = None
139 settings = None
140 task_collection = None
141 frame2ttl = None
142 audio = None
144 def extract(self, bpod_trials=None, settings=None, **kwargs):
145 """
146 :param: bpod_trials (optional) bpod trials from jsonable in a dictionary
147 :param: settings (optional) bpod iblrig settings json file in a dictionary
148 :param: save (bool) write output ALF files, defaults to False
149 :param: path_out (pathlib.Path) output path (defaults to `{session_path}/alf`)
150 :return: numpy.ndarray or list of ndarrays, list of filenames
151 :rtype: dtype('float64')
152 """
153 self.bpod_trials = bpod_trials 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
154 self.settings = settings 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
155 self.task_collection = kwargs.pop('task_collection', 'raw_behavior_data') 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
156 if self.bpod_trials is None: 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
157 self.bpod_trials = raw.load_data(self.session_path, task_collection=self.task_collection) 1aJKLMNOPQRSTUHVWXYcZdljefbgmv012367hkpnoqirs
158 if not self.settings: 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
159 self.settings = raw.load_settings(self.session_path, task_collection=self.task_collection) 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
160 if self.settings is None: 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
161 self.settings = {"IBLRIG_VERSION": "100.0.0"}
162 elif self.settings.get("IBLRIG_VERSION", "") == "": 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
163 self.settings["IBLRIG_VERSION"] = "100.0.0" 1Hgh
164 # Get all detected TTLs. These are stored for QC purposes
165 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
167 return super(BaseBpodTrialsExtractor, self).extract(**kwargs) 1aJKLMNOPQRSTUHVWXYcZ4dljefbgmv012367hkpnoq5irs
169 @property
170 def alf_path(self):
171 """pathlib.Path: The full task collection filepath."""
172 if self.session_path:
173 return self.session_path.joinpath(self.task_collection or '').absolute()
176def run_extractor_classes(classes, session_path=None, **kwargs):
177 """
178 Run a set of extractors with the same inputs.
180 :param classes: list of Extractor class
181 :param save: True/False
182 :param path_out: (defaults to alf path)
183 :param kwargs: extractor arguments (session_path...)
184 :return: dictionary of arrays, list of files
185 """
186 files = [] 1acdljefbgmtuhkpnoqirs
187 outputs = OrderedDict({}) 1acdljefbgmtuhkpnoqirs
188 assert session_path 1acdljefbgmtuhkpnoqirs
189 # if a single class is passed, convert as a list
190 try: 1acdljefbgmtuhkpnoqirs
191 iter(classes) 1acdljefbgmtuhkpnoqirs
192 except TypeError: 1no
193 classes = [classes] 1no
194 for classe in classes: 1acdljefbgmtuhkpnoqirs
195 cls = classe(session_path=session_path) 1acdljefbgmtuhkpnoqirs
196 out, fil = cls.extract(**kwargs) 1acdljefbgmtuhkpnoqirs
197 if isinstance(fil, list): 1acdljefbgmtuhkpnoqirs
198 files.extend(fil)
199 elif fil is not None: 1acdljefbgmtuhkpnoqirs
200 files.append(fil) 1tu
201 if isinstance(out, dict): 1acdljefbgmtuhkpnoqirs
202 outputs.update(out)
203 elif isinstance(cls.var_names, str): 1acdljefbgmtuhkpnoqirs
204 outputs[cls.var_names] = out 1acdljefbgmtuhkpnoqirs
205 else:
206 for i, k in enumerate(cls.var_names): 1acdljefbgmhkpnoqirs
207 outputs[k] = out[i] 1acdljefbgmhkpnoqirs
208 return outputs, files 1acdljefbgmtuhkpnoqirs
211def get_task_protocol(session_path, task_collection='raw_behavior_data'):
212 """
213 Return the task protocol name from task settings.
215 If the session path and/or task collection do not exist, the settings file is missing or
216 otherwise can not be parsed, or if the 'PYBPOD_PROTOCOL' key is absent, None is returned.
217 A warning is logged if the session path or settings file doesn't exist. An error is logged if
218 the settings file can not be parsed.
220 Parameters
221 ----------
222 session_path : str, pathlib.Path
223 The absolute session path.
224 task_collection : str
225 The session path directory containing the task settings file.
227 Returns
228 -------
229 str or None
230 The Pybpod task protocol name or None if not found.
231 """
232 try: 1acBdljefbgmhkpnoqEirs
233 settings = load_settings(session_path, task_collection=task_collection) 1acBdljefbgmhkpnoqEirs
234 except json.decoder.JSONDecodeError:
235 _logger.error(f'Can\'t read settings for {session_path}')
236 return
237 if settings: 1acBdljefbgmhkpnoqEirs
238 return settings.get('PYBPOD_PROTOCOL', None) 1acBdljefbgmhkpnoqEirs
239 else:
240 return 1B
243def _get_task_extractor_map():
244 """
245 Load the task protocol extractor map.
247 Returns
248 -------
249 Dict[str, str]
250 A map of task protocol to Bpod trials extractor class.
251 """
252 FILENAME = 'task_extractor_map.json' 1acFB8GdljefbIgmhkpnoqEirs
253 with open(Path(__file__).parent.joinpath(FILENAME)) as fp: 1acFB8GdljefbIgmhkpnoqEirs
254 task_extractors = json.load(fp) 1acFB8GdljefbIgmhkpnoqEirs
255 try: 1acFB8GdljefbIgmhkpnoqEirs
256 # look if there are custom extractor types in the personal projects repo
257 import projects 1acFB8GdljefbIgmhkpnoqEirs
258 custom_extractors = Path(projects.__file__).parent.joinpath(FILENAME) 1acFB8GdljefbIgmhkpnoqEirs
259 with open(custom_extractors, 'r') as fp: 1acFB8GdljefbIgmhkpnoqEirs
260 custom_task_types = json.load(fp) 1acFB8GdljefbIgmhkpnoqEirs
261 task_extractors.update(custom_task_types) 1acFB8GdljefbIgmhkpnoqEirs
262 except (ModuleNotFoundError, FileNotFoundError): 18
263 pass 18
264 return task_extractors 1acFB8GdljefbIgmhkpnoqEirs
267def get_bpod_extractor_class(session_path, task_collection='raw_behavior_data'):
268 """
269 Get the Bpod trials extractor class associated with a given Bpod session.
271 Note that unlike :func:`get_session_extractor_type`, this function maps directly to the Bpod
272 trials extractor class name. This is hardware invariant and is purly to determine the Bpod only
273 trials extractor.
275 Parameters
276 ----------
277 session_path : str, pathlib.Path
278 The session path containing Bpod behaviour data.
279 task_collection : str
280 The session_path sub-folder containing the Bpod settings file.
282 Returns
283 -------
284 str
285 The extractor class name.
286 """
287 # Attempt to get protocol name from settings file
288 protocol = get_task_protocol(session_path, task_collection=task_collection) 1acBdljefbgmhkpnoqEirs
289 if not protocol: 1acBdljefbgmhkpnoqEirs
290 raise ValueError(f'No task protocol found in {Path(session_path) / task_collection}') 1B
291 return protocol2extractor(protocol) 1acBdljefbgmhkpnoqEirs
294def protocol2extractor(protocol):
295 """
296 Get the Bpod trials extractor class associated with a given Bpod task protocol.
298 The Bpod task protocol can be found in the 'PYBPOD_PROTOCOL' field of the
299 _iblrig_taskSettings.raw.json file.
301 Parameters
302 ----------
303 protocol : str
304 A Bpod task protocol name.
306 Returns
307 -------
308 str
309 The extractor class name.
310 """
311 # Attempt to get extractor class from protocol
312 extractor_map = _get_task_extractor_map() 1acFBGdljefbIgmhkpnoqEirs
313 extractor = extractor_map.get(protocol, None) 1acFBGdljefbIgmhkpnoqEirs
314 if extractor is None: # Try lazy matching of name 1acFBGdljefbIgmhkpnoqEirs
315 extractor = next((extractor_map[tt] for tt in extractor_map if tt in protocol), None) 1acFBGdljefbgmhkpnoqEirs
316 if extractor is None: 1acFBGdljefbIgmhkpnoqEirs
317 raise ValueError(f'No extractor associated with "{protocol}"') 1FG
318 return extractor 1acFBGdljefbIgmhkpnoqEirs