Coverage for ibllib/oneibl/data_handlers.py: 32%

196 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1import logging 

2import pandas as pd 

3from pathlib import Path 

4import shutil 

5import os 

6import abc 

7from time import time 

8 

9from one.api import ONE 

10from one.webclient import AlyxClient 

11from one.util import filter_datasets 

12from one.alf.files import add_uuid_string, session_path_parts 

13from ibllib.oneibl.registration import register_dataset, get_lab, get_local_data_repository 

14from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH 

15 

16 

17_logger = logging.getLogger(__name__) 

18 

19 

20class DataHandler(abc.ABC): 

21 def __init__(self, session_path, signature, one=None): 

22 """ 

23 Base data handler class 

24 :param session_path: path to session 

25 :param signature: input and output file signatures 

26 :param one: ONE instance 

27 """ 

28 self.session_path = session_path 1ahifbjklmnopqrstuvwxyzAcBCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-gd./:;=e?@

29 self.signature = signature 1ahifbjklmnopqrstuvwxyzAcBCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-gd./:;=e?@

30 self.one = one 1ahifbjklmnopqrstuvwxyzAcBCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-gd./:;=e?@

31 

32 def setUp(self): 

33 """Function to optionally overload to download required data to run task.""" 

34 pass 1fg

35 

36 def getData(self, one=None): 

37 """ 

38 Finds the datasets required for task based on input signatures 

39 :return: 

40 """ 

41 if self.one is None and one is None: 

42 return 

43 

44 one = one or self.one 

45 session_datasets = one.list_datasets(one.path2eid(self.session_path), details=True) 

46 dfs = [] 

47 for file in self.signature['input_files']: 

48 dfs.append(filter_datasets(session_datasets, filename=file[0], collection=file[1], 

49 wildcards=True, assert_unique=False)) 

50 df = pd.concat(dfs) 

51 

52 # Some cases the eid is stored in the index. If so we drop this level 

53 if 'eid' in df.index.names: 

54 df = df.droplevel(level='eid') 

55 return df 

56 

57 def uploadData(self, outputs, version): 

58 """ 

59 Function to optionally overload to upload and register data 

60 :param outputs: output files from task to register 

61 :param version: ibllib version 

62 :return: 

63 """ 

64 if isinstance(outputs, list): 1bcde

65 versions = [version for _ in outputs] 1bcde

66 else: 

67 versions = [version] 1b

68 

69 return versions 1bcde

70 

71 def cleanUp(self): 

72 """ 

73 Function to optionally overload to cleanup files after running task 

74 :return: 

75 """ 

76 pass 1bcde

77 

78 

79class LocalDataHandler(DataHandler): 

80 def __init__(self, session_path, signatures, one=None): 

81 """ 

82 Data handler for running tasks locally, with no architecture or db connection 

83 :param session_path: path to session 

84 :param signature: input and output file signatures 

85 :param one: ONE instance 

86 """ 

87 super().__init__(session_path, signatures, one=one) 1fg

88 

89 

90class ServerDataHandler(DataHandler): 

91 def __init__(self, session_path, signatures, one=None): 

92 """ 

93 Data handler for running tasks on lab local servers when all data is available locally 

94 

95 :param session_path: path to session 

96 :param signature: input and output file signatures 

97 :param one: ONE instance 

98 """ 

99 super().__init__(session_path, signatures, one=one) 1ahibjklmnopqrstuvwxyzAcBCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-d./:;=e?@

100 

101 def uploadData(self, outputs, version, **kwargs): 

102 """ 

103 Function to upload and register data of completed task 

104 :param outputs: output files from task to register 

105 :param version: ibllib version 

106 :return: output info of registered datasets 

107 """ 

108 versions = super().uploadData(outputs, version) 1bcde

109 data_repo = get_local_data_repository(self.one.alyx) 1bcde

110 return register_dataset(outputs, one=self.one, versions=versions, repository=data_repo, **kwargs) 1bcde

111 

112 

113class ServerGlobusDataHandler(DataHandler): 

114 def __init__(self, session_path, signatures, one=None): 

115 """ 

116 Data handler for running tasks on lab local servers. Will download missing data from SDSC using Globus 

117 

118 :param session_path: path to session 

119 :param signatures: input and output file signatures 

120 :param one: ONE instance 

121 """ 

122 from one.remote.globus import Globus, get_lab_from_endpoint_id # noqa 

123 super().__init__(session_path, signatures, one=one) 

124 self.globus = Globus(client_name='server') 

125 

126 # on local servers set up the local root path manually as some have different globus config paths 

127 self.globus.endpoints['local']['root_path'] = '/mnt/s0/Data/Subjects' 

128 

129 # Find the lab 

130 self.lab = get_lab(self.session_path, self.one.alyx) 

131 

132 # For cortex lab we need to get the endpoint from the ibl alyx 

133 if self.lab == 'cortexlab': 

134 self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=ONE(base_url='https://alyx.internationalbrainlab.org').alyx) 

135 else: 

136 self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=self.one.alyx) 

137 

138 self.local_paths = [] 

139 

140 def setUp(self): 

141 """Function to download necessary data to run tasks using globus-sdk.""" 

142 if self.lab == 'cortexlab': 

143 one = ONE(base_url='https://alyx.internationalbrainlab.org') 

144 df = super().getData(one=one) 

145 else: 

146 one = self.one 

147 df = super().getData() 

148 

149 if len(df) == 0: 

150 # If no datasets found in the cache only work off local file system do not attempt to download any missing data 

151 # using globus 

152 return 

153 

154 # Check for space on local server. If less that 500 GB don't download new data 

155 space_free = shutil.disk_usage(self.globus.endpoints['local']['root_path'])[2] 

156 if space_free < 500e9: 

157 _logger.warning('Space left on server is < 500GB, wont redownload new data') 

158 return 

159 

160 rel_sess_path = '/'.join(df.iloc[0]['session_path'].split('/')[-3:]) 

161 assert rel_sess_path.split('/')[0] == self.one.path2ref(self.session_path)['subject'] 

162 

163 target_paths = [] 

164 source_paths = [] 

165 for i, d in df.iterrows(): 

166 sess_path = Path(rel_sess_path).joinpath(d['rel_path']) 

167 full_local_path = Path(self.globus.endpoints['local']['root_path']).joinpath(sess_path) 

168 if not full_local_path.exists(): 

169 uuid = i 

170 self.local_paths.append(full_local_path) 

171 target_paths.append(sess_path) 

172 source_paths.append(add_uuid_string(sess_path, uuid)) 

173 

174 if len(target_paths) != 0: 

175 ts = time() 

176 for sp, tp in zip(source_paths, target_paths): 

177 _logger.info(f'Downloading {sp} to {tp}') 

178 self.globus.mv(f'flatiron_{self.lab}', 'local', source_paths, target_paths) 

179 _logger.debug(f'Complete. Time elapsed {time() - ts}') 

180 

181 def uploadData(self, outputs, version, **kwargs): 

182 """ 

183 Function to upload and register data of completed task 

184 :param outputs: output files from task to register 

185 :param version: ibllib version 

186 :return: output info of registered datasets 

187 """ 

188 versions = super().uploadData(outputs, version) 

189 data_repo = get_local_data_repository(self.one.alyx) 

190 return register_dataset(outputs, one=self.one, versions=versions, repository=data_repo, **kwargs) 

191 

192 def cleanUp(self): 

193 """Clean up, remove the files that were downloaded from globus once task has completed.""" 

194 for file in self.local_paths: 

195 os.unlink(file) 

196 

197 

198class RemoteHttpDataHandler(DataHandler): 

199 def __init__(self, session_path, signature, one=None): 

200 """ 

201 Data handler for running tasks on remote compute node. Will download missing data via http using ONE 

202 

203 :param session_path: path to session 

204 :param signature: input and output file signatures 

205 :param one: ONE instance 

206 """ 

207 super().__init__(session_path, signature, one=one) 

208 

209 def setUp(self): 

210 """ 

211 Function to download necessary data to run tasks using ONE 

212 :return: 

213 """ 

214 df = super().getData() 

215 self.one._check_filesystem(df) 

216 

217 def uploadData(self, outputs, version, **kwargs): 

218 """ 

219 Function to upload and register data of completed task via FTP patcher 

220 :param outputs: output files from task to register 

221 :param version: ibllib version 

222 :return: output info of registered datasets 

223 """ 

224 versions = super().uploadData(outputs, version) 

225 ftp_patcher = FTPPatcher(one=self.one) 

226 return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user, 

227 versions=versions, **kwargs) 

228 

229 

230class RemoteAwsDataHandler(DataHandler): 

231 def __init__(self, task, session_path, signature, one=None): 

232 """ 

233 Data handler for running tasks on remote compute node. Will download missing data from private ibl s3 AWS data bucket 

234 

235 :param session_path: path to session 

236 :param signature: input and output file signatures 

237 :param one: ONE instance 

238 """ 

239 super().__init__(session_path, signature, one=one) 

240 self.task = task 

241 

242 self.local_paths = [] 

243 

244 def setUp(self): 

245 """Function to download necessary data to run tasks using AWS boto3.""" 

246 df = super().getData() 

247 self.local_paths = self.one._download_aws(map(lambda x: x[1], df.iterrows())) 

248 

249 def uploadData(self, outputs, version, **kwargs): 

250 """ 

251 Function to upload and register data of completed task via FTP patcher 

252 :param outputs: output files from task to register 

253 :param version: ibllib version 

254 :return: output info of registered datasets 

255 """ 

256 # Set up Globus 

257 from one.remote.globus import Globus # noqa 

258 self.globus = Globus(client_name='server') 

259 self.lab = session_path_parts(self.session_path, as_dict=True)['lab'] 

260 if self.lab == 'cortexlab' and 'cortexlab' in self.one.alyx.base_url: 

261 base_url = 'https://alyx.internationalbrainlab.org' 

262 _logger.warning('Changing Alyx client to %s', base_url) 

263 ac = AlyxClient(base_url=base_url) 

264 else: 

265 ac = self.one.alyx 

266 self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=ac) 

267 

268 # register datasets 

269 versions = super().uploadData(outputs, version) 

270 response = register_dataset(outputs, one=self.one, server_only=True, versions=versions, **kwargs) 

271 

272 # upload directly via globus 

273 source_paths = [] 

274 target_paths = [] 

275 collections = {} 

276 

277 for dset, out in zip(response, outputs): 

278 assert Path(out).name == dset['name'] 

279 # set flag to false 

280 fr = next(fr for fr in dset['file_records'] if 'flatiron' in fr['data_repository']) 

281 collection = '/'.join(fr['relative_path'].split('/')[:-1]) 

282 if collection in collections.keys(): 

283 collections[collection].update({f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}}) 

284 else: 

285 collections[collection] = {f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}} 

286 

287 # Set all exists status to false for server file records 

288 self.one.alyx.rest('files', 'partial_update', id=fr['id'], data={'exists': False}) 

289 

290 source_paths.append(out) 

291 target_paths.append(add_uuid_string(fr['relative_path'], dset['id'])) 

292 

293 if len(target_paths) != 0: 

294 ts = time() 

295 for sp, tp in zip(source_paths, target_paths): 

296 _logger.info(f'Uploading {sp} to {tp}') 

297 self.globus.mv('local', f'flatiron_{self.lab}', source_paths, target_paths) 

298 _logger.debug(f'Complete. Time elapsed {time() - ts}') 

299 

300 for collection, files in collections.items(): 

301 globus_files = self.globus.ls(f'flatiron_{self.lab}', collection, remove_uuid=True, return_size=True) 

302 file_names = [str(gl[0]) for gl in globus_files] 

303 file_sizes = [gl[1] for gl in globus_files] 

304 

305 for name, details in files.items(): 

306 try: 

307 idx = file_names.index(name) 

308 size = file_sizes[idx] 

309 if size == details['size']: 

310 # update the file record if sizes match 

311 self.one.alyx.rest('files', 'partial_update', id=details['fr_id'], data={'exists': True}) 

312 else: 

313 _logger.warning(f'File {name} found on SDSC but sizes do not match') 

314 except ValueError: 

315 _logger.warning(f'File {name} not found on SDSC') 

316 

317 return response 

318 

319 # ftp_patcher = FTPPatcher(one=self.one) 

320 # return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user, 

321 # versions=versions, **kwargs) 

322 

323 def cleanUp(self): 

324 """Clean up, remove the files that were downloaded from globus once task has completed.""" 

325 if self.task.status == 0: 

326 for file in self.local_paths: 

327 os.unlink(file) 

328 

329 

330class RemoteGlobusDataHandler(DataHandler): 

331 """ 

332 Data handler for running tasks on remote compute node. Will download missing data using globus 

333 

334 :param session_path: path to session 

335 :param signature: input and output file signatures 

336 :param one: ONE instance 

337 """ 

338 def __init__(self, session_path, signature, one=None): 

339 super().__init__(session_path, signature, one=one) 

340 

341 def setUp(self): 

342 """Function to download necessary data to run tasks using globus.""" 

343 # TODO 

344 pass 

345 

346 def uploadData(self, outputs, version, **kwargs): 

347 """ 

348 Function to upload and register data of completed task via FTP patcher 

349 :param outputs: output files from task to register 

350 :param version: ibllib version 

351 :return: output info of registered datasets 

352 """ 

353 versions = super().uploadData(outputs, version) 

354 ftp_patcher = FTPPatcher(one=self.one) 

355 return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user, 

356 versions=versions, **kwargs) 

357 

358 

359class SDSCDataHandler(DataHandler): 

360 """ 

361 Data handler for running tasks on SDSC compute node 

362 

363 :param session_path: path to session 

364 :param signature: input and output file signatures 

365 :param one: ONE instance 

366 """ 

367 def __init__(self, task, session_path, signatures, one=None): 

368 super().__init__(session_path, signatures, one=one) 

369 self.task = task 

370 

371 def setUp(self): 

372 """Function to create symlinks to necessary data to run tasks.""" 

373 df = super().getData() 

374 

375 SDSC_TMP = Path(SDSC_PATCH_PATH.joinpath(self.task.__class__.__name__)) 

376 for i, d in df.iterrows(): 

377 file_path = Path(d['session_path']).joinpath(d['rel_path']) 

378 uuid = i 

379 file_uuid = add_uuid_string(file_path, uuid) 

380 file_link = SDSC_TMP.joinpath(file_path) 

381 file_link.parent.mkdir(exist_ok=True, parents=True) 

382 file_link.symlink_to( 

383 Path(SDSC_ROOT_PATH.joinpath(file_uuid))) 

384 

385 self.task.session_path = SDSC_TMP.joinpath(d['session_path']) 

386 

387 def uploadData(self, outputs, version, **kwargs): 

388 """ 

389 Function to upload and register data of completed task via SDSC patcher 

390 :param outputs: output files from task to register 

391 :param version: ibllib version 

392 :return: output info of registered datasets 

393 """ 

394 versions = super().uploadData(outputs, version) 

395 sdsc_patcher = SDSCPatcher(one=self.one) 

396 return sdsc_patcher.patch_datasets(outputs, dry=False, versions=versions, **kwargs) 

397 

398 def cleanUp(self): 

399 """Function to clean up symlinks created to run task.""" 

400 assert SDSC_PATCH_PATH.parts[0:4] == self.task.session_path.parts[0:4] 

401 shutil.rmtree(self.task.session_path)