Coverage for ibllib/oneibl/data_handlers.py: 32%
196 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1import logging
2import pandas as pd
3from pathlib import Path
4import shutil
5import os
6import abc
7from time import time
9from one.api import ONE
10from one.webclient import AlyxClient
11from one.util import filter_datasets
12from one.alf.files import add_uuid_string, session_path_parts
13from ibllib.oneibl.registration import register_dataset, get_lab, get_local_data_repository
14from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH
17_logger = logging.getLogger(__name__)
20class DataHandler(abc.ABC):
21 def __init__(self, session_path, signature, one=None):
22 """
23 Base data handler class
24 :param session_path: path to session
25 :param signature: input and output file signatures
26 :param one: ONE instance
27 """
28 self.session_path = session_path 1ahifbjklmnopqrstuvwxyzAcBCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-gd./:;=e?@
29 self.signature = signature 1ahifbjklmnopqrstuvwxyzAcBCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-gd./:;=e?@
30 self.one = one 1ahifbjklmnopqrstuvwxyzAcBCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-gd./:;=e?@
32 def setUp(self):
33 """Function to optionally overload to download required data to run task."""
34 pass 1fg
36 def getData(self, one=None):
37 """
38 Finds the datasets required for task based on input signatures
39 :return:
40 """
41 if self.one is None and one is None:
42 return
44 one = one or self.one
45 session_datasets = one.list_datasets(one.path2eid(self.session_path), details=True)
46 dfs = []
47 for file in self.signature['input_files']:
48 dfs.append(filter_datasets(session_datasets, filename=file[0], collection=file[1],
49 wildcards=True, assert_unique=False))
50 df = pd.concat(dfs)
52 # Some cases the eid is stored in the index. If so we drop this level
53 if 'eid' in df.index.names:
54 df = df.droplevel(level='eid')
55 return df
57 def uploadData(self, outputs, version):
58 """
59 Function to optionally overload to upload and register data
60 :param outputs: output files from task to register
61 :param version: ibllib version
62 :return:
63 """
64 if isinstance(outputs, list): 1bcde
65 versions = [version for _ in outputs] 1bcde
66 else:
67 versions = [version] 1b
69 return versions 1bcde
71 def cleanUp(self):
72 """
73 Function to optionally overload to cleanup files after running task
74 :return:
75 """
76 pass 1bcde
79class LocalDataHandler(DataHandler):
80 def __init__(self, session_path, signatures, one=None):
81 """
82 Data handler for running tasks locally, with no architecture or db connection
83 :param session_path: path to session
84 :param signature: input and output file signatures
85 :param one: ONE instance
86 """
87 super().__init__(session_path, signatures, one=one) 1fg
90class ServerDataHandler(DataHandler):
91 def __init__(self, session_path, signatures, one=None):
92 """
93 Data handler for running tasks on lab local servers when all data is available locally
95 :param session_path: path to session
96 :param signature: input and output file signatures
97 :param one: ONE instance
98 """
99 super().__init__(session_path, signatures, one=one) 1ahibjklmnopqrstuvwxyzAcBCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-d./:;=e?@
101 def uploadData(self, outputs, version, **kwargs):
102 """
103 Function to upload and register data of completed task
104 :param outputs: output files from task to register
105 :param version: ibllib version
106 :return: output info of registered datasets
107 """
108 versions = super().uploadData(outputs, version) 1bcde
109 data_repo = get_local_data_repository(self.one.alyx) 1bcde
110 return register_dataset(outputs, one=self.one, versions=versions, repository=data_repo, **kwargs) 1bcde
113class ServerGlobusDataHandler(DataHandler):
114 def __init__(self, session_path, signatures, one=None):
115 """
116 Data handler for running tasks on lab local servers. Will download missing data from SDSC using Globus
118 :param session_path: path to session
119 :param signatures: input and output file signatures
120 :param one: ONE instance
121 """
122 from one.remote.globus import Globus, get_lab_from_endpoint_id # noqa
123 super().__init__(session_path, signatures, one=one)
124 self.globus = Globus(client_name='server')
126 # on local servers set up the local root path manually as some have different globus config paths
127 self.globus.endpoints['local']['root_path'] = '/mnt/s0/Data/Subjects'
129 # Find the lab
130 self.lab = get_lab(self.session_path, self.one.alyx)
132 # For cortex lab we need to get the endpoint from the ibl alyx
133 if self.lab == 'cortexlab':
134 self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=ONE(base_url='https://alyx.internationalbrainlab.org').alyx)
135 else:
136 self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=self.one.alyx)
138 self.local_paths = []
140 def setUp(self):
141 """Function to download necessary data to run tasks using globus-sdk."""
142 if self.lab == 'cortexlab':
143 one = ONE(base_url='https://alyx.internationalbrainlab.org')
144 df = super().getData(one=one)
145 else:
146 one = self.one
147 df = super().getData()
149 if len(df) == 0:
150 # If no datasets found in the cache only work off local file system do not attempt to download any missing data
151 # using globus
152 return
154 # Check for space on local server. If less that 500 GB don't download new data
155 space_free = shutil.disk_usage(self.globus.endpoints['local']['root_path'])[2]
156 if space_free < 500e9:
157 _logger.warning('Space left on server is < 500GB, wont redownload new data')
158 return
160 rel_sess_path = '/'.join(df.iloc[0]['session_path'].split('/')[-3:])
161 assert rel_sess_path.split('/')[0] == self.one.path2ref(self.session_path)['subject']
163 target_paths = []
164 source_paths = []
165 for i, d in df.iterrows():
166 sess_path = Path(rel_sess_path).joinpath(d['rel_path'])
167 full_local_path = Path(self.globus.endpoints['local']['root_path']).joinpath(sess_path)
168 if not full_local_path.exists():
169 uuid = i
170 self.local_paths.append(full_local_path)
171 target_paths.append(sess_path)
172 source_paths.append(add_uuid_string(sess_path, uuid))
174 if len(target_paths) != 0:
175 ts = time()
176 for sp, tp in zip(source_paths, target_paths):
177 _logger.info(f'Downloading {sp} to {tp}')
178 self.globus.mv(f'flatiron_{self.lab}', 'local', source_paths, target_paths)
179 _logger.debug(f'Complete. Time elapsed {time() - ts}')
181 def uploadData(self, outputs, version, **kwargs):
182 """
183 Function to upload and register data of completed task
184 :param outputs: output files from task to register
185 :param version: ibllib version
186 :return: output info of registered datasets
187 """
188 versions = super().uploadData(outputs, version)
189 data_repo = get_local_data_repository(self.one.alyx)
190 return register_dataset(outputs, one=self.one, versions=versions, repository=data_repo, **kwargs)
192 def cleanUp(self):
193 """Clean up, remove the files that were downloaded from globus once task has completed."""
194 for file in self.local_paths:
195 os.unlink(file)
198class RemoteHttpDataHandler(DataHandler):
199 def __init__(self, session_path, signature, one=None):
200 """
201 Data handler for running tasks on remote compute node. Will download missing data via http using ONE
203 :param session_path: path to session
204 :param signature: input and output file signatures
205 :param one: ONE instance
206 """
207 super().__init__(session_path, signature, one=one)
209 def setUp(self):
210 """
211 Function to download necessary data to run tasks using ONE
212 :return:
213 """
214 df = super().getData()
215 self.one._check_filesystem(df)
217 def uploadData(self, outputs, version, **kwargs):
218 """
219 Function to upload and register data of completed task via FTP patcher
220 :param outputs: output files from task to register
221 :param version: ibllib version
222 :return: output info of registered datasets
223 """
224 versions = super().uploadData(outputs, version)
225 ftp_patcher = FTPPatcher(one=self.one)
226 return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user,
227 versions=versions, **kwargs)
230class RemoteAwsDataHandler(DataHandler):
231 def __init__(self, task, session_path, signature, one=None):
232 """
233 Data handler for running tasks on remote compute node. Will download missing data from private ibl s3 AWS data bucket
235 :param session_path: path to session
236 :param signature: input and output file signatures
237 :param one: ONE instance
238 """
239 super().__init__(session_path, signature, one=one)
240 self.task = task
242 self.local_paths = []
244 def setUp(self):
245 """Function to download necessary data to run tasks using AWS boto3."""
246 df = super().getData()
247 self.local_paths = self.one._download_aws(map(lambda x: x[1], df.iterrows()))
249 def uploadData(self, outputs, version, **kwargs):
250 """
251 Function to upload and register data of completed task via FTP patcher
252 :param outputs: output files from task to register
253 :param version: ibllib version
254 :return: output info of registered datasets
255 """
256 # Set up Globus
257 from one.remote.globus import Globus # noqa
258 self.globus = Globus(client_name='server')
259 self.lab = session_path_parts(self.session_path, as_dict=True)['lab']
260 if self.lab == 'cortexlab' and 'cortexlab' in self.one.alyx.base_url:
261 base_url = 'https://alyx.internationalbrainlab.org'
262 _logger.warning('Changing Alyx client to %s', base_url)
263 ac = AlyxClient(base_url=base_url)
264 else:
265 ac = self.one.alyx
266 self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=ac)
268 # register datasets
269 versions = super().uploadData(outputs, version)
270 response = register_dataset(outputs, one=self.one, server_only=True, versions=versions, **kwargs)
272 # upload directly via globus
273 source_paths = []
274 target_paths = []
275 collections = {}
277 for dset, out in zip(response, outputs):
278 assert Path(out).name == dset['name']
279 # set flag to false
280 fr = next(fr for fr in dset['file_records'] if 'flatiron' in fr['data_repository'])
281 collection = '/'.join(fr['relative_path'].split('/')[:-1])
282 if collection in collections.keys():
283 collections[collection].update({f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}})
284 else:
285 collections[collection] = {f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}}
287 # Set all exists status to false for server file records
288 self.one.alyx.rest('files', 'partial_update', id=fr['id'], data={'exists': False})
290 source_paths.append(out)
291 target_paths.append(add_uuid_string(fr['relative_path'], dset['id']))
293 if len(target_paths) != 0:
294 ts = time()
295 for sp, tp in zip(source_paths, target_paths):
296 _logger.info(f'Uploading {sp} to {tp}')
297 self.globus.mv('local', f'flatiron_{self.lab}', source_paths, target_paths)
298 _logger.debug(f'Complete. Time elapsed {time() - ts}')
300 for collection, files in collections.items():
301 globus_files = self.globus.ls(f'flatiron_{self.lab}', collection, remove_uuid=True, return_size=True)
302 file_names = [str(gl[0]) for gl in globus_files]
303 file_sizes = [gl[1] for gl in globus_files]
305 for name, details in files.items():
306 try:
307 idx = file_names.index(name)
308 size = file_sizes[idx]
309 if size == details['size']:
310 # update the file record if sizes match
311 self.one.alyx.rest('files', 'partial_update', id=details['fr_id'], data={'exists': True})
312 else:
313 _logger.warning(f'File {name} found on SDSC but sizes do not match')
314 except ValueError:
315 _logger.warning(f'File {name} not found on SDSC')
317 return response
319 # ftp_patcher = FTPPatcher(one=self.one)
320 # return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user,
321 # versions=versions, **kwargs)
323 def cleanUp(self):
324 """Clean up, remove the files that were downloaded from globus once task has completed."""
325 if self.task.status == 0:
326 for file in self.local_paths:
327 os.unlink(file)
330class RemoteGlobusDataHandler(DataHandler):
331 """
332 Data handler for running tasks on remote compute node. Will download missing data using globus
334 :param session_path: path to session
335 :param signature: input and output file signatures
336 :param one: ONE instance
337 """
338 def __init__(self, session_path, signature, one=None):
339 super().__init__(session_path, signature, one=one)
341 def setUp(self):
342 """Function to download necessary data to run tasks using globus."""
343 # TODO
344 pass
346 def uploadData(self, outputs, version, **kwargs):
347 """
348 Function to upload and register data of completed task via FTP patcher
349 :param outputs: output files from task to register
350 :param version: ibllib version
351 :return: output info of registered datasets
352 """
353 versions = super().uploadData(outputs, version)
354 ftp_patcher = FTPPatcher(one=self.one)
355 return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user,
356 versions=versions, **kwargs)
359class SDSCDataHandler(DataHandler):
360 """
361 Data handler for running tasks on SDSC compute node
363 :param session_path: path to session
364 :param signature: input and output file signatures
365 :param one: ONE instance
366 """
367 def __init__(self, task, session_path, signatures, one=None):
368 super().__init__(session_path, signatures, one=one)
369 self.task = task
371 def setUp(self):
372 """Function to create symlinks to necessary data to run tasks."""
373 df = super().getData()
375 SDSC_TMP = Path(SDSC_PATCH_PATH.joinpath(self.task.__class__.__name__))
376 for i, d in df.iterrows():
377 file_path = Path(d['session_path']).joinpath(d['rel_path'])
378 uuid = i
379 file_uuid = add_uuid_string(file_path, uuid)
380 file_link = SDSC_TMP.joinpath(file_path)
381 file_link.parent.mkdir(exist_ok=True, parents=True)
382 file_link.symlink_to(
383 Path(SDSC_ROOT_PATH.joinpath(file_uuid)))
385 self.task.session_path = SDSC_TMP.joinpath(d['session_path'])
387 def uploadData(self, outputs, version, **kwargs):
388 """
389 Function to upload and register data of completed task via SDSC patcher
390 :param outputs: output files from task to register
391 :param version: ibllib version
392 :return: output info of registered datasets
393 """
394 versions = super().uploadData(outputs, version)
395 sdsc_patcher = SDSCPatcher(one=self.one)
396 return sdsc_patcher.patch_datasets(outputs, dry=False, versions=versions, **kwargs)
398 def cleanUp(self):
399 """Function to clean up symlinks created to run task."""
400 assert SDSC_PATCH_PATH.parts[0:4] == self.task.session_path.parts[0:4]
401 shutil.rmtree(self.task.session_path)