Coverage for ibllib/io/extractors/base.py: 91%

193 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Base Extractor classes. 

2 

3A module for the base Extractor classes. The Extractor, given a session path, will extract the 

4processed data from raw hardware files and optionally save them. 

5""" 

6 

7import abc 

8from collections import OrderedDict 

9import json 

10from pathlib import Path 

11 

12import numpy as np 

13import pandas as pd 

14from ibllib.io import raw_data_loaders as raw 

15from ibllib.io.raw_data_loaders import load_settings, _logger 

16 

17 

18class BaseExtractor(abc.ABC): 

19 """ 

20 Base extractor class. 

21 

22 Writing an extractor checklist: 

23 

24 - on the child class, overload the _extract method 

25 - this method should output one or several numpy.arrays or dataframe with a consistent shape 

26 - save_names is a list or a string of filenames, there should be one per dataset 

27 - set save_names to None for a dataset that doesn't need saving (could be set dynamically in 

28 the _extract method) 

29 

30 :param session_path: Absolute path of session folder 

31 :type session_path: str/Path 

32 """ 

33 

34 session_path = None 

35 """pathlib.Path: Absolute path of session folder.""" 

36 

37 save_names = None 

38 """tuple of str: The filenames of each extracted dataset, or None if array should not be saved.""" 

39 

40 var_names = None 

41 """tuple of str: A list of names for the extracted variables. These become the returned output keys.""" 

42 

43 default_path = Path('alf') # relative to session 

44 """pathlib.Path: The default output folder relative to `session_path`.""" 

45 

46 def __init__(self, session_path=None): 

47 # If session_path is None Path(session_path) will fail 

48 self.session_path = Path(session_path) 2b ebfbgbhbs % ' ( ) * + , - . / : ; ! = ? @ [ ] ^ c _ h A ibjb| j q N O o k l i # kb2 3 lbmbm K D I J abnbobpbt r E u bbF G cb` v { P dbM n p w x qbrbsbtbube f g y z } a d B C

49 

50 def extract(self, save=False, path_out=None, **kwargs): 

51 """ 

52 :return: dict of numpy.array, list of filenames 

53 """ 

54 out = self._extract(**kwargs) 2b Q s % ' ( ) * + , - . / : ; ! = ? @ [ ] ^ c _ h A | j q N O o k l i 2 3 m K D I J abt r E u bbF G cb` v { P dbM n p w x e f g y z } a d B C

55 files = self._save(out, path_out=path_out) if save else None 2b Q s % ' ( ) * + , - . / : ; ! = ? @ [ ] ^ c _ h A | j q N O o k l i 2 3 m K D I J abt r E u bbF G cb` v { P dbM n p w x e f g y z } a d B C

56 return out, files 2b Q s % ' ( ) * + , - . / : ; ! = ? @ [ ] ^ c _ h A | j q N O o k l i 2 3 m K D I J abt r E u bbF G cb` v { P dbM n p w x e f g y z } a d B C

57 

58 def _save(self, data, path_out=None): 

59 # Check if self.save_names is of the same length of out 

60 if not path_out: 1QschjqNOokli23mKDIJtrPMnpefgad

61 path_out = self.session_path.joinpath(self.default_path) 1Qschjqkli23DIJtrMefgad

62 

63 def _write_to_disk(file_path, data): 1QschjqNOokli23mKDIJtrPMnpefgad

64 """Implements different save calls depending on file extension. 

65 

66 Parameters 

67 ---------- 

68 file_path : pathlib.Path 

69 The location to save the data. 

70 data : pandas.DataFrame, numpy.ndarray 

71 The data to save 

72 

73 """ 

74 csv_separators = { 1QschjqNOokli23mKDIJtrPMnpefgad

75 ".csv": ",", 

76 ".ssv": " ", 

77 ".tsv": "\t" 

78 } 

79 # Ensure empty files are not created; we expect all datasets to have a non-zero size 

80 if getattr(data, 'size', len(data)) == 0: 1QschjqNOokli23mKDIJtrPMnpefgad

81 filename = file_path.relative_to(self.session_path).as_posix() 1s

82 raise ValueError(f'Data for {filename} appears to be empty') 1s

83 file_path = Path(file_path) 1QschjqNOokli23mKDIJtrPMnpefgad

84 file_path.parent.mkdir(exist_ok=True, parents=True) 1QschjqNOokli23mKDIJtrPMnpefgad

85 if file_path.suffix == ".npy": 1QschjqNOokli23mKDIJtrPMnpefgad

86 np.save(file_path, data) 1QschjqNOokli23mDIJtrPMnpefgad

87 elif file_path.suffix in [".parquet", ".pqt"]: 1QschjNOoklimKPMnpefgad

88 if not isinstance(data, pd.DataFrame): 1schjoklimKnpefgad

89 _logger.error("Data is not a panda's DataFrame object") 

90 raise TypeError("Data is not a panda's DataFrame object") 

91 data.to_parquet(file_path) 1schjoklimKnpefgad

92 elif file_path.suffix in csv_separators: 1QNOPM

93 sep = csv_separators[file_path.suffix] 1QNOPM

94 data.to_csv(file_path, sep=sep) 1QNOPM

95 # np.savetxt(file_path, data, delimiter=sep) 

96 else: 

97 _logger.error(f"Don't know how to save {file_path.suffix} files yet") 

98 

99 if self.save_names is None: 1QschjqNOokli23mKDIJtrPMnpefgad

100 file_paths = [] 

101 elif isinstance(self.save_names, str): 1QschjqNOokli23mKDIJtrPMnpefgad

102 file_paths = path_out.joinpath(self.save_names) 1schiKDIJtrad

103 _write_to_disk(file_paths, data) 1schiKDIJtrad

104 elif isinstance(data, dict): 1QschjqNOokli23mPMnpefgad

105 file_paths = [] 1schjqoklimnpefgad

106 for var, value in data.items(): 1schjqoklimnpefgad

107 if fn := self.save_names[self.var_names.index(var)]: 1schjqoklimnpefgad

108 fpath = path_out.joinpath(fn) 1schjqoklimnpefgad

109 _write_to_disk(fpath, value) 1schjqoklimnpefgad

110 file_paths.append(fpath) 1schjqoklimnpefgad

111 else: # Should be list or tuple... 

112 assert len(data) == len(self.save_names) 1QscNO23PMefgd

113 file_paths = [] 1QscNO23PMefgd

114 for data, fn in zip(data, self.save_names): 1QscNO23PMefgd

115 if fn: 1QscNO23PMefgd

116 fpath = path_out.joinpath(fn) 1QscNO23PMefgd

117 _write_to_disk(fpath, data) 1QscNO23PMefgd

118 file_paths.append(fpath) 1QscNO23PMefgd

119 return file_paths 1QschjqNOokli23mKDIJtrPMnpefgad

120 

121 @abc.abstractmethod 

122 def _extract(self): 

123 pass 

124 

125 

126class BaseBpodTrialsExtractor(BaseExtractor): 

127 """ 

128 Base (abstract) extractor class for bpod jsonable data set. 

129 

130 Wraps the _extract private method. 

131 

132 :param session_path: Absolute path of session folder. 

133 :type session_path: str 

134 :param bpod_trials 

135 :param settings 

136 """ 

137 

138 bpod_trials = None 

139 settings = None 

140 task_collection = None 

141 frame2ttl = None 

142 audio = None 

143 

144 def extract(self, bpod_trials=None, settings=None, **kwargs): 

145 """ 

146 :param: bpod_trials (optional) bpod trials from jsonable in a dictionary 

147 :param: settings (optional) bpod iblrig settings json file in a dictionary 

148 :param: save (bool) write output ALF files, defaults to False 

149 :param: path_out (pathlib.Path) output path (defaults to `{session_path}/alf`) 

150 :return: numpy.ndarray or list of ndarrays, list of filenames 

151 :rtype: dtype('float64') 

152 """ 

153 self.bpod_trials = bpod_trials 1bs%'()*+,-./:;!=?@[]^c_hA|jqoklimKDru`v{npwxefgyz}adBC

154 self.settings = settings 1bs%'()*+,-./:;!=?@[]^c_hA|jqoklimKDru`v{npwxefgyz}adBC

155 self.task_collection = kwargs.pop('task_collection', 'raw_behavior_data') 1bs%'()*+,-./:;!=?@[]^c_hA|jqoklimKDru`v{npwxefgyz}adBC

156 if self.bpod_trials is None: 1bs%'()*+,-./:;!=?@[]^c_hA|jqoklimKDru`v{npwxefgyz}adBC

157 self.bpod_trials = raw.load_data(self.session_path, task_collection=self.task_collection) 1%'()*+,-./:;!=?@[]^_jqoklimKDru`v{npwxeaBC

158 if not self.settings: 1bs%'()*+,-./:;!=?@[]^c_hA|jqoklimKDru`v{npwxefgyz}adBC

159 self.settings = raw.load_settings(self.session_path, task_collection=self.task_collection) 1%'()*+,-./:;!=?@[]^_|jqoklimKDru`v{npwxe}aBC

160 if self.settings is None: 1bs%'()*+,-./:;!=?@[]^c_hA|jqoklimKDru`v{npwxefgyz}adBC

161 self.settings = {"IBLRIG_VERSION": "100.0.0"} 

162 elif self.settings.get("IBLRIG_VERSION", "") == "": 1bs%'()*+,-./:;!=?@[]^c_hA|jqoklimKDru`v{npwxefgyz}adBC

163 self.settings["IBLRIG_VERSION"] = "100.0.0" 1!mKn

164 # Get all detected TTLs. These are stored for QC purposes 

165 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1bs%'()*+,-./:;!=?@[]^c_hA|jqoklimKDru`v{npwxefgyz}adBC

166 

167 return super(BaseBpodTrialsExtractor, self).extract(**kwargs) 1bs%'()*+,-./:;!=?@[]^c_hA|jqoklimKDru`v{npwxefgyz}adBC

168 

169 @property 

170 def alf_path(self): 

171 """pathlib.Path: The full task collection filepath.""" 

172 if self.session_path: 

173 return self.session_path.joinpath(self.task_collection or '').absolute() 

174 

175 

176def run_extractor_classes(classes, session_path=None, **kwargs): 

177 """ 

178 Run a set of extractors with the same inputs. 

179 

180 :param classes: list of Extractor class 

181 :param save: True/False 

182 :param path_out: (defaults to alf path) 

183 :param kwargs: extractor arguments (session_path...) 

184 :return: dictionary of arrays, list of files 

185 """ 

186 files = [] 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

187 outputs = OrderedDict({}) 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

188 assert session_path 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

189 # if a single class is passed, convert as a list 

190 try: 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

191 iter(classes) 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

192 except TypeError: 1iDruvwxa

193 classes = [classes] 1iDruvwxa

194 for classe in classes: 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

195 cls = classe(session_path=session_path) 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

196 out, fil = cls.extract(**kwargs) 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

197 if isinstance(fil, list): 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

198 files.extend(fil) 1schfgd

199 elif fil is not None: 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

200 files.append(fil) 1sciDIJtrad

201 if isinstance(out, dict): 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

202 outputs.update(out) 1schefgd

203 elif isinstance(cls.var_names, str): 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

204 outputs[cls.var_names] = out 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

205 else: 

206 for i, k in enumerate(cls.var_names): 1bschAjqoklimnpwxefgyzadBC

207 outputs[k] = out[i] 1bschAjqoklimnpwxefgyzadBC

208 return outputs, files 1bschAjqoklimDIJtrEuFGvnpwxefgyzadBC

209 

210 

211def _get_task_types_json_config(): 

212 """ 

213 Return the extractor types map. 

214 

215 This function is only used for legacy sessions, i.e. those without an experiment description 

216 file and will be removed in favor of :func:`_get_task_extractor_map`, which directly returns 

217 the Bpod extractor class name. The experiment description file cuts out the need for pipeline 

218 name identifiers. 

219 

220 Returns 

221 ------- 

222 Dict[str, str] 

223 A map of task protocol to task extractor identifier, e.g. 'ephys', 'habituation', etc. 

224 

225 See Also 

226 -------- 

227 _get_task_extractor_map - returns a map of task protocol to Bpod trials extractor class name. 

228 """ 

229 with open(Path(__file__).parent.joinpath('extractor_types.json')) as fp: 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

230 task_types = json.load(fp) 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

231 try: 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

232 # look if there are custom extractor types in the personal projects repo 

233 import projects.base 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

234 custom_extractors = Path(projects.base.__file__).parent.joinpath('extractor_types.json') 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

235 _logger.debug('Loading extractor types from %s', custom_extractors) 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

236 with open(custom_extractors) as fp: 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

237 custom_task_types = json.load(fp) 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

238 task_types.update(custom_task_types) 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

239 except (ModuleNotFoundError, FileNotFoundError): 

240 pass 

241 return task_types 1bcRSTH8674$UVWXhAtrEYZuFGvefgLyza0d

242 

243 

244def get_task_protocol(session_path, task_collection='raw_behavior_data'): 

245 """ 

246 Return the task protocol name from task settings. 

247 

248 If the session path and/or task collection do not exist, the settings file is missing or 

249 otherwise can not be parsed, or if the 'PYBPOD_PROTOCOL' key is absent, None is returned. 

250 A warning is logged if the session path or settings file doesn't exist. An error is logged if 

251 the settings file can not be parsed. 

252 

253 Parameters 

254 ---------- 

255 session_path : str, pathlib.Path 

256 The absolute session path. 

257 task_collection : str 

258 The session path directory containing the task settings file. 

259 

260 Returns 

261 ------- 

262 str or None 

263 The Pybpod task protocol name or None if not found. 

264 """ 

265 try: 1bc1RSTHUVWXhAjqoklimtrEYZuFGvnpwxefgLyza50dBC

266 settings = load_settings(session_path, task_collection=task_collection) 1bc1RSTHUVWXhAjqoklimtrEYZuFGvnpwxefgLyza50dBC

267 except json.decoder.JSONDecodeError: 

268 _logger.error(f'Can\'t read settings for {session_path}') 

269 return 

270 if settings: 1bc1RSTHUVWXhAjqoklimtrEYZuFGvnpwxefgLyza50dBC

271 return settings.get('PYBPOD_PROTOCOL', None) 1bc1RSTHUVWXhAjqoklimtrEYZuFGvnpwxefgLyza50dBC

272 else: 

273 return 11L

274 

275 

276def get_task_extractor_type(task_name): 

277 """ 

278 Returns the task type string from the full pybpod task name. 

279 

280 Parameters 

281 ---------- 

282 task_name : str 

283 The complete task protocol name from the PYBPOD_PROTOCOL field of the task settings. 

284 

285 Returns 

286 ------- 

287 str 

288 The extractor type identifier. Examples include 'biased', 'habituation', 'training', 

289 'ephys', 'mock_ephys' and 'sync_ephys'. 

290 

291 Examples 

292 -------- 

293 >>> get_task_extractor_type('_iblrig_tasks_biasedChoiceWorld3.7.0') 

294 'biased' 

295 

296 >>> get_task_extractor_type('_iblrig_tasks_trainingChoiceWorld3.6.0') 

297 'training' 

298 """ 

299 if isinstance(task_name, Path): 1bcRSTH8674UVWXhAtrEYZuFGvefgLyza0d

300 task_name = get_task_protocol(task_name) 

301 if task_name is None: 

302 return 

303 task_types = _get_task_types_json_config() 1bcRSTH8674UVWXhAtrEYZuFGvefgLyza0d

304 

305 task_type = task_types.get(task_name, None) 1bcRSTH8674UVWXhAtrEYZuFGvefgLyza0d

306 if task_type is None: # Try lazy matching of name 1bcRSTH8674UVWXhAtrEYZuFGvefgLyza0d

307 task_type = next((task_types[tt] for tt in task_types if tt in task_name), None) 1bcRSTH8674UVWXhtrEYZuFGvefgLyza0d

308 if task_type is None: 1bcRSTH8674UVWXhAtrEYZuFGvefgLyza0d

309 _logger.warning(f'No extractor type found for {task_name}') 1674

310 return task_type 1bcRSTH8674UVWXhAtrEYZuFGvefgLyza0d

311 

312 

313def get_session_extractor_type(session_path, task_collection='raw_behavior_data'): 

314 """ 

315 Infer trials extractor type from task settings. 

316 

317 From a session path, loads the settings file, finds the task and checks if extractors exist. 

318 Examples include 'biased', 'habituation', 'training', 'ephys', 'mock_ephys', and 'sync_ephys'. 

319 Note this should only be used for legacy sessions, i.e. those without an experiment description 

320 file. 

321 

322 Parameters 

323 ---------- 

324 session_path : str, pathlib.Path 

325 The session path for which to determine the pipeline. 

326 task_collection : str 

327 The session path directory containing the raw task data. 

328 

329 Returns 

330 ------- 

331 str or False 

332 The task extractor type, e.g. 'biased', 'habituation', 'ephys', or False if unknown. 

333 """ 

334 task_protocol = get_task_protocol(session_path, task_collection=task_collection) 1bcRSTHUVWXhAtrEYZuFGvefgLyza0d

335 if task_protocol is None: 1bcRSTHUVWXhAtrEYZuFGvefgLyza0d

336 _logger.error(f'ABORT: No task protocol found in "{task_collection}" folder {session_path}') 1L

337 return False 1L

338 extractor_type = get_task_extractor_type(task_protocol) 1bcRSTHUVWXhAtrEYZuFGvefgLyza0d

339 if extractor_type: 1bcRSTHUVWXhAtrEYZuFGvefgLyza0d

340 return extractor_type 1bcRSTHUVWXhAtrEYZuFGvefgLyza0d

341 else: 

342 return False 

343 

344 

345def get_pipeline(session_path, task_collection='raw_behavior_data'): 

346 """ 

347 Get the pre-processing pipeline name from a session path. 

348 

349 Note this is only suitable for legacy sessions, i.e. those without an experiment description 

350 file. This function will be removed in the future. 

351 

352 Parameters 

353 ---------- 

354 session_path : str, pathlib.Path 

355 The session path for which to determine the pipeline. 

356 task_collection : str 

357 The session path directory containing the raw task data. 

358 

359 Returns 

360 ------- 

361 str 

362 The pipeline name inferred from the extractor type, e.g. 'ephys', 'training', 'widefield'. 

363 """ 

364 stype = get_session_extractor_type(session_path, task_collection=task_collection) 1H

365 return _get_pipeline_from_task_type(stype) 1H

366 

367 

368def _get_pipeline_from_task_type(stype): 

369 """ 

370 Return the pipeline from the task type. 

371 

372 Some task types directly define the pipeline. Note this is only suitable for legacy sessions, 

373 i.e. those without an experiment description file. This function will be removed in the future. 

374 

375 Parameters 

376 ---------- 

377 stype : str 

378 The session type or task extractor type, e.g. 'habituation', 'ephys', etc. 

379 

380 Returns 

381 ------- 

382 str 

383 A task pipeline identifier. 

384 """ 

385 if stype in ['ephys_biased_opto', 'ephys', 'ephys_training', 'mock_ephys', 'sync_ephys']: 1H4$

386 return 'ephys' 1H4$

387 elif stype in ['habituation', 'training', 'biased', 'biased_opto']: 1H4$

388 return 'training' 1H4$

389 elif isinstance(stype, str) and 'widefield' in stype: 1H4$

390 return 'widefield' 

391 else: 

392 return stype or '' 1H4$

393 

394 

395def _get_task_extractor_map(): 

396 """ 

397 Load the task protocol extractor map. 

398 

399 Returns 

400 ------- 

401 Dict[str, str] 

402 A map of task protocol to Bpod trials extractor class. 

403 """ 

404 FILENAME = 'task_extractor_map.json' 11~9jqokli#mnpwxa5BC

405 with open(Path(__file__).parent.joinpath(FILENAME)) as fp: 11~9jqokli#mnpwxa5BC

406 task_extractors = json.load(fp) 11~9jqokli#mnpwxa5BC

407 try: 11~9jqokli#mnpwxa5BC

408 # look if there are custom extractor types in the personal projects repo 

409 import projects.base 11~9jqokli#mnpwxa5BC

410 custom_extractors = Path(projects.base.__file__).parent.joinpath(FILENAME) 11~9jqokli#mnpwxa5BC

411 with open(custom_extractors, 'r') as fp: 11~9jqokli#mnpwxa5BC

412 custom_task_types = json.load(fp) 11~9jqokli#mnpwxa5BC

413 task_extractors.update(custom_task_types) 11~9jqokli#mnpwxa5BC

414 except (ModuleNotFoundError, FileNotFoundError): 1~

415 pass 1~

416 return task_extractors 11~9jqokli#mnpwxa5BC

417 

418 

419def get_bpod_extractor_class(session_path, task_collection='raw_behavior_data'): 

420 """ 

421 Get the Bpod trials extractor class associated with a given Bpod session. 

422 

423 Note that unlike :func:`get_session_extractor_type`, this function maps directly to the Bpod 

424 trials extractor class name. This is hardware invariant and is purly to determine the Bpod only 

425 trials extractor. 

426 

427 Parameters 

428 ---------- 

429 session_path : str, pathlib.Path 

430 The session path containing Bpod behaviour data. 

431 task_collection : str 

432 The session_path sub-folder containing the Bpod settings file. 

433 

434 Returns 

435 ------- 

436 str 

437 The extractor class name. 

438 """ 

439 # Attempt to get protocol name from settings file 

440 protocol = get_task_protocol(session_path, task_collection=task_collection) 11jqoklimnpwxa5BC

441 if not protocol: 11jqoklimnpwxa5BC

442 raise ValueError(f'No task protocol found in {Path(session_path) / task_collection}') 11

443 return protocol2extractor(protocol) 11jqoklimnpwxa5BC

444 

445 

446def protocol2extractor(protocol): 

447 """ 

448 Get the Bpod trials extractor class associated with a given Bpod task protocol. 

449 

450 The Bpod task protocol can be found in the 'PYBPOD_PROTOCOL' field of the 

451 _iblrig_taskSettings.raw.json file. 

452 

453 Parameters 

454 ---------- 

455 protocol : str 

456 A Bpod task protocol name. 

457 

458 Returns 

459 ------- 

460 str 

461 The extractor class name. 

462 """ 

463 # Attempt to get extractor class from protocol 

464 extractor_map = _get_task_extractor_map() 119jqokli#mnpwxa5BC

465 extractor = extractor_map.get(protocol, None) 119jqokli#mnpwxa5BC

466 if extractor is None: # Try lazy matching of name 119jqokli#mnpwxa5BC

467 extractor = next((extractor_map[tt] for tt in extractor_map if tt in protocol), None) 119jqoklimnpwxa5BC

468 if extractor is None: 119jqokli#mnpwxa5BC

469 raise ValueError(f'No extractor associated with "{protocol}"') 19

470 return extractor 119jqokli#mnpwxa5BC