Coverage for ibllib/oneibl/data_handlers.py: 33%

235 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Downloading of task dependent datasets and registration of task output datasets. 

2 

3The DataHandler class is used by the pipes.tasks.Task class to ensure dependent datasets are 

4present and to register and upload the output datasets. For examples on how to run a task using 

5specific data handlers, see :func:`ibllib.pipes.tasks`. 

6""" 

7import logging 

8import pandas as pd 

9from pathlib import Path 

10import shutil 

11import os 

12import abc 

13from time import time 

14 

15from one.api import ONE 

16from one.webclient import AlyxClient 

17from one.util import filter_datasets, ensure_list 

18from one.alf.files import add_uuid_string, session_path_parts 

19from ibllib.oneibl.registration import register_dataset, get_lab, get_local_data_repository 

20from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH 

21 

22 

23_logger = logging.getLogger(__name__) 

24 

25 

26class DataHandler(abc.ABC): 

27 def __init__(self, session_path, signature, one=None): 

28 """ 

29 Base data handler class 

30 :param session_path: path to session 

31 :param signature: input and output file signatures 

32 :param one: ONE instance 

33 """ 

34 self.session_path = session_path 1aheifbjklmnopqrstuvwxyzABcCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-g./:;=d?@

35 self.signature = signature 1aheifbjklmnopqrstuvwxyzABcCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-g./:;=d?@

36 self.one = one 1aheifbjklmnopqrstuvwxyzABcCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-g./:;=d?@

37 self.processed = {} # Map of filepaths and their processed records (e.g. upload receipts or Alyx records) 1aheifbjklmnopqrstuvwxyzABcCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-g./:;=d?@

38 

39 def setUp(self): 

40 """Function to optionally overload to download required data to run task.""" 

41 pass 1fg

42 

43 def getData(self, one=None): 

44 """Finds the datasets required for task based on input signatures.""" 

45 if self.one is None and one is None: 

46 return 

47 

48 one = one or self.one 

49 session_datasets = one.list_datasets(one.path2eid(self.session_path), details=True) 

50 dfs = [] 

51 for file in self.signature['input_files']: 

52 dfs.append(filter_datasets(session_datasets, filename=file[0], collection=file[1], 

53 wildcards=True, assert_unique=False)) 

54 if len(dfs) == 0: 

55 return pd.DataFrame() 

56 df = pd.concat(dfs) 

57 

58 # Some cases the eid is stored in the index. If so we drop this level 

59 if 'eid' in df.index.names: 

60 df = df.droplevel(level='eid') 

61 return df 

62 

63 def getOutputFiles(self): 

64 assert self.session_path 

65 from one.alf.io import iter_datasets 

66 # Next convert datasets to frame 

67 from one.alf.cache import DATASETS_COLUMNS, _get_dataset_info 

68 # Create dataframe of all ALF datasets 

69 dsets = iter_datasets(self.session_path) 

70 records = [_get_dataset_info(self.session_path, dset, compute_hash=False) for dset in dsets] 

71 df = pd.DataFrame(records, columns=DATASETS_COLUMNS) 

72 from functools import partial 

73 filt = partial(filter_datasets, df, wildcards=True, assert_unique=False) 

74 # Filter outputs 

75 dids = pd.concat(filt(filename=file[0], collection=file[1]).index for file in self.signature['output_files']) 

76 present = df.loc[dids, :].copy() 

77 return present 

78 

79 def uploadData(self, outputs, version): 

80 """ 

81 Function to optionally overload to upload and register data 

82 :param outputs: output files from task to register 

83 :param version: ibllib version 

84 :return: 

85 """ 

86 if isinstance(outputs, list): 1ebcd

87 versions = [version for _ in outputs] 1ebcd

88 else: 

89 versions = [version] 

90 

91 return versions 1ebcd

92 

93 def cleanUp(self): 

94 """Function to optionally overload to clean up files after running task.""" 

95 pass 1bcd

96 

97 

98class LocalDataHandler(DataHandler): 

99 def __init__(self, session_path, signatures, one=None): 

100 """ 

101 Data handler for running tasks locally, with no architecture or db connection 

102 :param session_path: path to session 

103 :param signature: input and output file signatures 

104 :param one: ONE instance 

105 """ 

106 super().__init__(session_path, signatures, one=one) 1fg

107 

108 

109class ServerDataHandler(DataHandler): 

110 def __init__(self, session_path, signatures, one=None): 

111 """ 

112 Data handler for running tasks on lab local servers when all data is available locally 

113 

114 :param session_path: path to session 

115 :param signature: input and output file signatures 

116 :param one: ONE instance 

117 """ 

118 super().__init__(session_path, signatures, one=one) 1aheibjklmnopqrstuvwxyzABcCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-./:;=d?@

119 

120 def uploadData(self, outputs, version, clobber=False, **kwargs): 

121 """ 

122 Upload and/or register output data. 

123 

124 This is typically called by :meth:`ibllib.pipes.tasks.Task.register_datasets`. 

125 

126 Parameters 

127 ---------- 

128 outputs : list of pathlib.Path 

129 A set of ALF paths to register to Alyx. 

130 version : str, list of str 

131 The version of ibllib used to generate these output files. 

132 clobber : bool 

133 If True, re-upload outputs that have already been passed to this method. 

134 kwargs 

135 Optional keyword arguments for one.registration.RegistrationClient.register_files. 

136 

137 Returns 

138 ------- 

139 list of dicts, dict 

140 A list of newly created Alyx dataset records or the registration data if dry. 

141 """ 

142 versions = super().uploadData(outputs, version) 1ebcd

143 data_repo = get_local_data_repository(self.one.alyx) 1ebcd

144 # If clobber = False, do not re-upload the outputs that have already been processed 

145 outputs = ensure_list(outputs) 1ebcd

146 to_upload = list(filter(None if clobber else lambda x: x not in self.processed, outputs)) 1ebcd

147 records = register_dataset(to_upload, one=self.one, versions=versions, repository=data_repo, **kwargs) or [] 1ebcd

148 if kwargs.get('dry', False): 1ebcd

149 return records 1e

150 # Store processed outputs 

151 self.processed.update({k: v for k, v in zip(to_upload, records) if v}) 1ebcd

152 return [self.processed[x] for x in outputs if x in self.processed] 1ebcd

153 

154 def cleanUp(self): 

155 """Empties and returns the processed dataset mep.""" 

156 super().cleanUp() 1bcd

157 processed = self.processed 1bcd

158 self.processed = {} 1bcd

159 return processed 1bcd

160 

161 

162class ServerGlobusDataHandler(DataHandler): 

163 def __init__(self, session_path, signatures, one=None): 

164 """ 

165 Data handler for running tasks on lab local servers. Will download missing data from SDSC using Globus 

166 

167 :param session_path: path to session 

168 :param signatures: input and output file signatures 

169 :param one: ONE instance 

170 """ 

171 from one.remote.globus import Globus, get_lab_from_endpoint_id # noqa 

172 super().__init__(session_path, signatures, one=one) 

173 self.globus = Globus(client_name='server', headless=True) 

174 

175 # on local servers set up the local root path manually as some have different globus config paths 

176 self.globus.endpoints['local']['root_path'] = '/mnt/s0/Data/Subjects' 

177 

178 # Find the lab 

179 self.lab = get_lab(self.session_path, self.one.alyx) 

180 

181 # For cortex lab we need to get the endpoint from the ibl alyx 

182 if self.lab == 'cortexlab': 

183 alyx = AlyxClient(base_url='https://alyx.internationalbrainlab.org', cache_rest=None) 

184 self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=alyx) 

185 else: 

186 self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=self.one.alyx) 

187 

188 self.local_paths = [] 

189 

190 def setUp(self): 

191 """Function to download necessary data to run tasks using globus-sdk.""" 

192 if self.lab == 'cortexlab': 

193 df = super().getData(one=ONE(base_url='https://alyx.internationalbrainlab.org')) 

194 else: 

195 df = super().getData(one=self.one) 

196 

197 if len(df) == 0: 

198 # If no datasets found in the cache only work off local file system do not attempt to 

199 # download any missing data using Globus 

200 return 

201 

202 # Check for space on local server. If less that 500 GB don't download new data 

203 space_free = shutil.disk_usage(self.globus.endpoints['local']['root_path'])[2] 

204 if space_free < 500e9: 

205 _logger.warning('Space left on server is < 500GB, won\'t re-download new data') 

206 return 

207 

208 rel_sess_path = '/'.join(df.iloc[0]['session_path'].split('/')[-3:]) 

209 assert rel_sess_path.split('/')[0] == self.one.path2ref(self.session_path)['subject'] 

210 

211 target_paths = [] 

212 source_paths = [] 

213 for i, d in df.iterrows(): 

214 sess_path = Path(rel_sess_path).joinpath(d['rel_path']) 

215 full_local_path = Path(self.globus.endpoints['local']['root_path']).joinpath(sess_path) 

216 if not full_local_path.exists(): 

217 uuid = i 

218 self.local_paths.append(full_local_path) 

219 target_paths.append(sess_path) 

220 source_paths.append(add_uuid_string(sess_path, uuid)) 

221 

222 if len(target_paths) != 0: 

223 ts = time() 

224 for sp, tp in zip(source_paths, target_paths): 

225 _logger.info(f'Downloading {sp} to {tp}') 

226 self.globus.mv(f'flatiron_{self.lab}', 'local', source_paths, target_paths) 

227 _logger.debug(f'Complete. Time elapsed {time() - ts}') 

228 

229 def uploadData(self, outputs, version, **kwargs): 

230 """ 

231 Function to upload and register data of completed task 

232 :param outputs: output files from task to register 

233 :param version: ibllib version 

234 :return: output info of registered datasets 

235 """ 

236 versions = super().uploadData(outputs, version) 

237 data_repo = get_local_data_repository(self.one.alyx) 

238 return register_dataset(outputs, one=self.one, versions=versions, repository=data_repo, **kwargs) 

239 

240 def cleanUp(self): 

241 """Clean up, remove the files that were downloaded from Globus once task has completed.""" 

242 for file in self.local_paths: 

243 os.unlink(file) 

244 

245 

246class RemoteHttpDataHandler(DataHandler): 

247 def __init__(self, session_path, signature, one=None): 

248 """ 

249 Data handler for running tasks on remote compute node. Will download missing data via http using ONE 

250 

251 :param session_path: path to session 

252 :param signature: input and output file signatures 

253 :param one: ONE instance 

254 """ 

255 super().__init__(session_path, signature, one=one) 

256 

257 def setUp(self): 

258 """ 

259 Function to download necessary data to run tasks using ONE 

260 :return: 

261 """ 

262 df = super().getData() 

263 self.one._check_filesystem(df) 

264 

265 def uploadData(self, outputs, version, **kwargs): 

266 """ 

267 Function to upload and register data of completed task via FTP patcher 

268 :param outputs: output files from task to register 

269 :param version: ibllib version 

270 :return: output info of registered datasets 

271 """ 

272 versions = super().uploadData(outputs, version) 

273 ftp_patcher = FTPPatcher(one=self.one) 

274 return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user, 

275 versions=versions, **kwargs) 

276 

277 

278class RemoteAwsDataHandler(DataHandler): 

279 def __init__(self, task, session_path, signature, one=None): 

280 """ 

281 Data handler for running tasks on remote compute node. 

282 

283 This will download missing data from the private IBL S3 AWS data bucket. New datasets are 

284 uploaded via Globus. 

285 

286 :param session_path: path to session 

287 :param signature: input and output file signatures 

288 :param one: ONE instance 

289 """ 

290 super().__init__(session_path, signature, one=one) 

291 self.task = task 

292 

293 self.local_paths = [] 

294 

295 def setUp(self): 

296 """Function to download necessary data to run tasks using AWS boto3.""" 

297 df = super().getData() 

298 self.local_paths = self.one._download_aws(map(lambda x: x[1], df.iterrows())) 

299 

300 def uploadData(self, outputs, version, **kwargs): 

301 """ 

302 Function to upload and register data of completed task via FTP patcher 

303 :param outputs: output files from task to register 

304 :param version: ibllib version 

305 :return: output info of registered datasets 

306 """ 

307 # Set up Globus 

308 from one.remote.globus import Globus # noqa 

309 self.globus = Globus(client_name='server', headless=True) 

310 self.lab = session_path_parts(self.session_path, as_dict=True)['lab'] 

311 if self.lab == 'cortexlab' and 'cortexlab' in self.one.alyx.base_url: 

312 base_url = 'https://alyx.internationalbrainlab.org' 

313 _logger.warning('Changing Alyx client to %s', base_url) 

314 ac = AlyxClient(base_url=base_url) 

315 else: 

316 ac = self.one.alyx 

317 self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=ac) 

318 

319 # register datasets 

320 versions = super().uploadData(outputs, version) 

321 response = register_dataset(outputs, one=self.one, server_only=True, versions=versions, **kwargs) 

322 

323 # upload directly via globus 

324 source_paths = [] 

325 target_paths = [] 

326 collections = {} 

327 

328 for dset, out in zip(response, outputs): 

329 assert Path(out).name == dset['name'] 

330 # set flag to false 

331 fr = next(fr for fr in dset['file_records'] if 'flatiron' in fr['data_repository']) 

332 collection = '/'.join(fr['relative_path'].split('/')[:-1]) 

333 if collection in collections.keys(): 

334 collections[collection].update({f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}}) 

335 else: 

336 collections[collection] = {f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}} 

337 

338 # Set all exists status to false for server file records 

339 self.one.alyx.rest('files', 'partial_update', id=fr['id'], data={'exists': False}) 

340 

341 source_paths.append(out) 

342 target_paths.append(add_uuid_string(fr['relative_path'], dset['id'])) 

343 

344 if len(target_paths) != 0: 

345 ts = time() 

346 for sp, tp in zip(source_paths, target_paths): 

347 _logger.info(f'Uploading {sp} to {tp}') 

348 self.globus.mv('local', f'flatiron_{self.lab}', source_paths, target_paths) 

349 _logger.debug(f'Complete. Time elapsed {time() - ts}') 

350 

351 for collection, files in collections.items(): 

352 globus_files = self.globus.ls(f'flatiron_{self.lab}', collection, remove_uuid=True, return_size=True) 

353 file_names = [str(gl[0]) for gl in globus_files] 

354 file_sizes = [gl[1] for gl in globus_files] 

355 

356 for name, details in files.items(): 

357 try: 

358 idx = file_names.index(name) 

359 size = file_sizes[idx] 

360 if size == details['size']: 

361 # update the file record if sizes match 

362 self.one.alyx.rest('files', 'partial_update', id=details['fr_id'], data={'exists': True}) 

363 else: 

364 _logger.warning(f'File {name} found on SDSC but sizes do not match') 

365 except ValueError: 

366 _logger.warning(f'File {name} not found on SDSC') 

367 

368 return response 

369 

370 # ftp_patcher = FTPPatcher(one=self.one) 

371 # return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user, 

372 # versions=versions, **kwargs) 

373 

374 def cleanUp(self): 

375 """Clean up, remove the files that were downloaded from globus once task has completed.""" 

376 if self.task.status == 0: 

377 for file in self.local_paths: 

378 os.unlink(file) 

379 

380 

381class RemoteGlobusDataHandler(DataHandler): 

382 """ 

383 Data handler for running tasks on remote compute node. Will download missing data using Globus. 

384 

385 :param session_path: path to session 

386 :param signature: input and output file signatures 

387 :param one: ONE instance 

388 """ 

389 def __init__(self, session_path, signature, one=None): 

390 super().__init__(session_path, signature, one=one) 

391 

392 def setUp(self): 

393 """Function to download necessary data to run tasks using globus.""" 

394 # TODO 

395 pass 

396 

397 def uploadData(self, outputs, version, **kwargs): 

398 """ 

399 Function to upload and register data of completed task via FTP patcher 

400 :param outputs: output files from task to register 

401 :param version: ibllib version 

402 :return: output info of registered datasets 

403 """ 

404 versions = super().uploadData(outputs, version) 

405 ftp_patcher = FTPPatcher(one=self.one) 

406 return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user, 

407 versions=versions, **kwargs) 

408 

409 

410class SDSCDataHandler(DataHandler): 

411 """ 

412 Data handler for running tasks on SDSC compute node 

413 

414 :param session_path: path to session 

415 :param signature: input and output file signatures 

416 :param one: ONE instance 

417 """ 

418 

419 def __init__(self, task, session_path, signatures, one=None): 

420 super().__init__(session_path, signatures, one=one) 

421 self.task = task 

422 self.SDSC_PATCH_PATH = SDSC_PATCH_PATH 

423 self.SDSC_ROOT_PATH = SDSC_ROOT_PATH 

424 

425 def setUp(self): 

426 """Function to create symlinks to necessary data to run tasks.""" 

427 df = super().getData() 

428 

429 SDSC_TMP = Path(self.SDSC_PATCH_PATH.joinpath(self.task.__class__.__name__)) 

430 for i, d in df.iterrows(): 

431 file_path = Path(d['session_path']).joinpath(d['rel_path']) 

432 uuid = i 

433 file_uuid = add_uuid_string(file_path, uuid) 

434 file_link = SDSC_TMP.joinpath(file_path) 

435 file_link.parent.mkdir(exist_ok=True, parents=True) 

436 try: 

437 file_link.symlink_to( 

438 Path(self.SDSC_ROOT_PATH.joinpath(file_uuid))) 

439 except FileExistsError: 

440 pass 

441 

442 self.task.session_path = SDSC_TMP.joinpath(d['session_path']) 

443 

444 def uploadData(self, outputs, version, **kwargs): 

445 """ 

446 Function to upload and register data of completed task via SDSC patcher 

447 :param outputs: output files from task to register 

448 :param version: ibllib version 

449 :return: output info of registered datasets 

450 """ 

451 versions = super().uploadData(outputs, version) 

452 sdsc_patcher = SDSCPatcher(one=self.one) 

453 return sdsc_patcher.patch_datasets(outputs, dry=False, versions=versions, **kwargs) 

454 

455 def cleanUp(self): 

456 """Function to clean up symlinks created to run task.""" 

457 assert SDSC_PATCH_PATH.parts[0:4] == self.task.session_path.parts[0:4] 

458 shutil.rmtree(self.task.session_path) 

459 

460 

461class PopeyeDataHandler(SDSCDataHandler): 

462 

463 def __init__(self, task, session_path, signatures, one=None): 

464 super().__init__(task, session_path, signatures, one=one) 

465 self.SDSC_PATCH_PATH = Path(os.getenv('SDSC_PATCH_PATH', "/mnt/sdceph/users/ibl/data/quarantine/tasks/")) 

466 self.SDSC_ROOT_PATH = Path("/mnt/sdceph/users/ibl/data") 

467 

468 def uploadData(self, outputs, version, **kwargs): 

469 raise NotImplementedError( 

470 "Cannot register data from Popeye. Login as Datauser and use the RegisterSpikeSortingSDSC task." 

471 ) 

472 

473 def cleanUp(self): 

474 """Symlinks are preserved until registration.""" 

475 pass