Coverage for ibllib/io/extractors/base.py: 90%

148 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-07 14:26 +0100

1"""Base Extractor classes. 

2 

3A module for the base Extractor classes. The Extractor, given a session path, will extract the 

4processed data from raw hardware files and optionally save them. 

5""" 

6 

7import abc 

8from collections import OrderedDict 

9import json 

10from pathlib import Path 

11 

12import numpy as np 

13import pandas as pd 

14from ibllib.io import raw_data_loaders as raw 

15from ibllib.io.raw_data_loaders import load_settings, _logger 

16 

17 

18class BaseExtractor(abc.ABC): 

19 """ 

20 Base extractor class. 

21 

22 Writing an extractor checklist: 

23 

24 - on the child class, overload the _extract method 

25 - this method should output one or several numpy.arrays or dataframe with a consistent shape 

26 - save_names is a list or a string of filenames, there should be one per dataset 

27 - set save_names to None for a dataset that doesn't need saving (could be set dynamically in 

28 the _extract method) 

29 

30 :param session_path: Absolute path of session folder 

31 :type session_path: str/Path 

32 """ 

33 

34 session_path = None 

35 """pathlib.Path: Absolute path of session folder.""" 

36 

37 save_names = None 

38 """tuple of str: The filenames of each extracted dataset, or None if array should not be saved.""" 

39 

40 var_names = None 

41 """tuple of str: A list of names for the extracted variables. These become the returned output keys.""" 

42 

43 default_path = Path('alf') # relative to session 

44 """pathlib.Path: The default output folder relative to `session_path`.""" 

45 

46 def __init__(self, session_path=None): 

47 # If session_path is None Path(session_path) will fail 

48 self.session_path = Path(session_path) 1a$%'(IJKLMNOPQRSTGUVWXcYE)*3dlwxjefbH+6BC,-gmust7./:8Z09!12y#vhkpno;=?@[4iqr

49 

50 def extract(self, save=False, path_out=None, **kwargs): 

51 """ 

52 :return: dict of numpy.array, list of filenames 

53 """ 

54 out = self._extract(**kwargs) 1azIJKLMNOPQRSTGUVWXcY3dlwxjefb6BCgmust78Z09!12y#vhkpno4iqr

55 files = self._save(out, path_out=path_out) if save else None 1azIJKLMNOPQRSTGUVWXcY3dlwxjefb6BCgmust78Z09!12y#vhkpno4iqr

56 return out, files 1azIJKLMNOPQRSTGUVWXcY3dlwxjefb6BCgmust78Z09!12y#vhkpno4iqr

57 

58 def _save(self, data, path_out=None): 

59 # Check if self.save_names is of the same length of out 

60 if not path_out: 1zcdlwxjefbBCgmustyvhki

61 path_out = self.session_path.joinpath(self.default_path) 1zcdlefbBCustvi

62 

63 def _write_to_disk(file_path, data): 1zcdlwxjefbBCgmustyvhki

64 """Implements different save calls depending on file extension. 

65 

66 Parameters 

67 ---------- 

68 file_path : pathlib.Path 

69 The location to save the data. 

70 data : pandas.DataFrame, numpy.ndarray 

71 The data to save 

72 

73 """ 

74 csv_separators = { 1zcdlwxjefbBCgmustyvhki

75 ".csv": ",", 

76 ".ssv": " ", 

77 ".tsv": "\t" 

78 } 

79 # Ensure empty files are not created; we expect all datasets to have a non-zero size 

80 if getattr(data, 'size', len(data)) == 0: 1zcdlwxjefbBCgmustyvhki

81 filename = file_path.relative_to(self.session_path).as_posix() 

82 raise ValueError(f'Data for {filename} appears to be empty') 

83 file_path = Path(file_path) 1zcdlwxjefbBCgmustyvhki

84 file_path.parent.mkdir(exist_ok=True, parents=True) 1zcdlwxjefbBCgmustyvhki

85 if file_path.suffix == ".npy": 1zcdlwxjefbBCgmustyvhki

86 np.save(file_path, data) 1zcdlwxjefbBCgmustyvhki

87 elif file_path.suffix in [".parquet", ".pqt"]: 1zcdwxjefbgyvhki

88 if not isinstance(data, pd.DataFrame): 1cdjefbghki

89 _logger.error("Data is not a panda's DataFrame object") 

90 raise TypeError("Data is not a panda's DataFrame object") 

91 data.to_parquet(file_path) 1cdjefbghki

92 elif file_path.suffix in csv_separators: 1zwxyv

93 sep = csv_separators[file_path.suffix] 1zwxyv

94 data.to_csv(file_path, sep=sep) 1zwxyv

95 # np.savetxt(file_path, data, delimiter=sep) 

96 else: 

97 _logger.error(f"Don't know how to save {file_path.suffix} files yet") 

98 

99 if self.save_names is None: 1zcdlwxjefbBCgmustyvhki

100 file_paths = [] 

101 elif isinstance(self.save_names, str): 1zcdlwxjefbBCgmustyvhki

102 file_paths = path_out.joinpath(self.save_names) 1bust

103 _write_to_disk(file_paths, data) 1bust

104 elif isinstance(data, dict): 1zcdlwxjefbBCgmyvhki

105 file_paths = [] 1cdljefbgmhki

106 for var, value in data.items(): 1cdljefbgmhki

107 if fn := self.save_names[self.var_names.index(var)]: 1cdljefbgmhki

108 fpath = path_out.joinpath(fn) 1cdljefbgmhki

109 _write_to_disk(fpath, value) 1cdljefbgmhki

110 file_paths.append(fpath) 1cdljefbgmhki

111 else: # Should be list or tuple... 

112 assert len(data) == len(self.save_names) 1zwxBCyv

113 file_paths = [] 1zwxBCyv

114 for data, fn in zip(data, self.save_names): 1zwxBCyv

115 if fn: 1zwxBCyv

116 fpath = path_out.joinpath(fn) 1zwxBCyv

117 _write_to_disk(fpath, data) 1zwxBCyv

118 file_paths.append(fpath) 1zwxBCyv

119 return file_paths 1zcdlwxjefbBCgmustyvhki

120 

121 @abc.abstractmethod 

122 def _extract(self): 

123 pass 

124 

125 

126class BaseBpodTrialsExtractor(BaseExtractor): 

127 """ 

128 Base (abstract) extractor class for bpod jsonable data set. 

129 

130 Wraps the _extract private method. 

131 

132 :param session_path: Absolute path of session folder. 

133 :type session_path: str 

134 :param bpod_trials 

135 :param settings 

136 """ 

137 

138 bpod_trials = None 

139 settings = None 

140 task_collection = None 

141 frame2ttl = None 

142 audio = None 

143 

144 def extract(self, bpod_trials=None, settings=None, **kwargs): 

145 """ 

146 :param: bpod_trials (optional) bpod trials from jsonable in a dictionary 

147 :param: settings (optional) bpod iblrig settings json file in a dictionary 

148 :param: save (bool) write output ALF files, defaults to False 

149 :param: path_out (pathlib.Path) output path (defaults to `{session_path}/alf`) 

150 :return: numpy.ndarray or list of ndarrays, list of filenames 

151 :rtype: dtype('float64') 

152 """ 

153 self.bpod_trials = bpod_trials 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

154 self.settings = settings 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

155 self.task_collection = kwargs.pop('task_collection', 'raw_behavior_data') 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

156 if self.bpod_trials is None: 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

157 self.bpod_trials = raw.load_data(self.session_path, task_collection=self.task_collection) 1aIJKLMNOPQRSTGUVWXcYdljefbgmuZ012hkpnoiqr

158 if not self.settings: 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

159 self.settings = raw.load_settings(self.session_path, task_collection=self.task_collection) 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

160 if self.settings is None: 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

161 self.settings = {"IBLRIG_VERSION": "100.0.0"} 

162 elif self.settings.get("IBLRIG_VERSION", "") == "": 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

163 self.settings["IBLRIG_VERSION"] = "100.0.0" 1Ggh

164 # Get all detected TTLs. These are stored for QC purposes 

165 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

166 

167 return super(BaseBpodTrialsExtractor, self).extract(**kwargs) 1aIJKLMNOPQRSTGUVWXcY3dljefbgmuZ012hkpno4iqr

168 

169 @property 

170 def alf_path(self): 

171 """pathlib.Path: The full task collection filepath.""" 

172 if self.session_path: 

173 return self.session_path.joinpath(self.task_collection or '').absolute() 

174 

175 

176def run_extractor_classes(classes, session_path=None, **kwargs): 

177 """ 

178 Run a set of extractors with the same inputs. 

179 

180 :param classes: list of Extractor class 

181 :param save: True/False 

182 :param path_out: (defaults to alf path) 

183 :param kwargs: extractor arguments (session_path...) 

184 :return: dictionary of arrays, list of files 

185 """ 

186 files = [] 1acdljefbgmsthkpnoiqr

187 outputs = OrderedDict({}) 1acdljefbgmsthkpnoiqr

188 assert session_path 1acdljefbgmsthkpnoiqr

189 # if a single class is passed, convert as a list 

190 try: 1acdljefbgmsthkpnoiqr

191 iter(classes) 1acdljefbgmsthkpnoiqr

192 except TypeError: 1no

193 classes = [classes] 1no

194 for classe in classes: 1acdljefbgmsthkpnoiqr

195 cls = classe(session_path=session_path) 1acdljefbgmsthkpnoiqr

196 out, fil = cls.extract(**kwargs) 1acdljefbgmsthkpnoiqr

197 if isinstance(fil, list): 1acdljefbgmsthkpnoiqr

198 files.extend(fil) 

199 elif fil is not None: 1acdljefbgmsthkpnoiqr

200 files.append(fil) 1st

201 if isinstance(out, dict): 1acdljefbgmsthkpnoiqr

202 outputs.update(out) 

203 elif isinstance(cls.var_names, str): 1acdljefbgmsthkpnoiqr

204 outputs[cls.var_names] = out 1acdljefbgmsthkpnoiqr

205 else: 

206 for i, k in enumerate(cls.var_names): 1acdljefbgmhkpnoiqr

207 outputs[k] = out[i] 1acdljefbgmhkpnoiqr

208 return outputs, files 1acdljefbgmsthkpnoiqr

209 

210 

211def get_task_protocol(session_path, task_collection='raw_behavior_data'): 

212 """ 

213 Return the task protocol name from task settings. 

214 

215 If the session path and/or task collection do not exist, the settings file is missing or 

216 otherwise can not be parsed, or if the 'PYBPOD_PROTOCOL' key is absent, None is returned. 

217 A warning is logged if the session path or settings file doesn't exist. An error is logged if 

218 the settings file can not be parsed. 

219 

220 Parameters 

221 ---------- 

222 session_path : str, pathlib.Path 

223 The absolute session path. 

224 task_collection : str 

225 The session path directory containing the task settings file. 

226 

227 Returns 

228 ------- 

229 str or None 

230 The Pybpod task protocol name or None if not found. 

231 """ 

232 try: 1acAdljefbgmhkpnoDiqr

233 settings = load_settings(session_path, task_collection=task_collection) 1acAdljefbgmhkpnoDiqr

234 except json.decoder.JSONDecodeError: 

235 _logger.error(f'Can\'t read settings for {session_path}') 

236 return 

237 if settings: 1acAdljefbgmhkpnoDiqr

238 return settings.get('PYBPOD_PROTOCOL', None) 1acAdljefbgmhkpnoDiqr

239 else: 

240 return 1A

241 

242 

243def _get_task_extractor_map(): 

244 """ 

245 Load the task protocol extractor map. 

246 

247 Returns 

248 ------- 

249 Dict[str, str] 

250 A map of task protocol to Bpod trials extractor class. 

251 """ 

252 FILENAME = 'task_extractor_map.json' 1acEA5FdljefbHgmhkpnoDiqr

253 with open(Path(__file__).parent.joinpath(FILENAME)) as fp: 1acEA5FdljefbHgmhkpnoDiqr

254 task_extractors = json.load(fp) 1acEA5FdljefbHgmhkpnoDiqr

255 try: 1acEA5FdljefbHgmhkpnoDiqr

256 # look if there are custom extractor types in the personal projects repo 

257 import projects 1acEA5FdljefbHgmhkpnoDiqr

258 custom_extractors = Path(projects.__file__).parent.joinpath(FILENAME) 1acEA5FdljefbHgmhkpnoDiqr

259 with open(custom_extractors, 'r') as fp: 1acEA5FdljefbHgmhkpnoDiqr

260 custom_task_types = json.load(fp) 1acEA5FdljefbHgmhkpnoDiqr

261 task_extractors.update(custom_task_types) 1acEA5FdljefbHgmhkpnoDiqr

262 except (ModuleNotFoundError, FileNotFoundError): 15

263 pass 15

264 return task_extractors 1acEA5FdljefbHgmhkpnoDiqr

265 

266 

267def get_bpod_extractor_class(session_path, task_collection='raw_behavior_data'): 

268 """ 

269 Get the Bpod trials extractor class associated with a given Bpod session. 

270 

271 Note that unlike :func:`get_session_extractor_type`, this function maps directly to the Bpod 

272 trials extractor class name. This is hardware invariant and is purly to determine the Bpod only 

273 trials extractor. 

274 

275 Parameters 

276 ---------- 

277 session_path : str, pathlib.Path 

278 The session path containing Bpod behaviour data. 

279 task_collection : str 

280 The session_path sub-folder containing the Bpod settings file. 

281 

282 Returns 

283 ------- 

284 str 

285 The extractor class name. 

286 """ 

287 # Attempt to get protocol name from settings file 

288 protocol = get_task_protocol(session_path, task_collection=task_collection) 1acAdljefbgmhkpnoDiqr

289 if not protocol: 1acAdljefbgmhkpnoDiqr

290 raise ValueError(f'No task protocol found in {Path(session_path) / task_collection}') 1A

291 return protocol2extractor(protocol) 1acAdljefbgmhkpnoDiqr

292 

293 

294def protocol2extractor(protocol): 

295 """ 

296 Get the Bpod trials extractor class associated with a given Bpod task protocol. 

297 

298 The Bpod task protocol can be found in the 'PYBPOD_PROTOCOL' field of the 

299 _iblrig_taskSettings.raw.json file. 

300 

301 Parameters 

302 ---------- 

303 protocol : str 

304 A Bpod task protocol name. 

305 

306 Returns 

307 ------- 

308 str 

309 The extractor class name. 

310 """ 

311 # Attempt to get extractor class from protocol 

312 extractor_map = _get_task_extractor_map() 1acEAFdljefbHgmhkpnoDiqr

313 extractor = extractor_map.get(protocol, None) 1acEAFdljefbHgmhkpnoDiqr

314 if extractor is None: # Try lazy matching of name 1acEAFdljefbHgmhkpnoDiqr

315 extractor = next((extractor_map[tt] for tt in extractor_map if tt in protocol), None) 1acEAFdljefbgmhkpnoDiqr

316 if extractor is None: 1acEAFdljefbHgmhkpnoDiqr

317 raise ValueError(f'No extractor associated with "{protocol}"') 1EF

318 return extractor 1acEAFdljefbHgmhkpnoDiqr