Coverage for ibllib/pipes/base_tasks.py: 97%
201 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1"""Abstract base classes for dynamic pipeline tasks."""
2import logging
3from pathlib import Path
5from pkg_resources import parse_version
6from one.webclient import no_cache
7from iblutil.util import flatten
9from ibllib.pipes.tasks import Task
10import ibllib.io.session_params as sess_params
11from ibllib.qc.base import sign_off_dict, SIGN_OFF_CATEGORIES
12from ibllib.io.raw_daq_loaders import load_timeline_sync_and_chmap
14_logger = logging.getLogger(__name__)
17class DynamicTask(Task):
19 def __init__(self, session_path, **kwargs):
20 super().__init__(session_path, **kwargs) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
21 self.session_params = self.read_params_file() 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
23 # TODO Which should be default?
24 # Sync collection
25 self.sync_collection = self.get_sync_collection(kwargs.get('sync_collection', None)) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
26 # Sync type
27 self.sync = self.get_sync(kwargs.get('sync', None)) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
28 # Sync extension
29 self.sync_ext = self.get_sync_extension(kwargs.get('sync_ext', None)) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
30 # Sync namespace
31 self.sync_namespace = self.get_sync_namespace(kwargs.get('sync_namespace', None)) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
33 def get_sync_collection(self, sync_collection=None):
34 return sync_collection if sync_collection else sess_params.get_sync_collection(self.session_params) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
36 def get_sync(self, sync=None):
37 return sync if sync else sess_params.get_sync_label(self.session_params) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
39 def get_sync_extension(self, sync_ext=None):
40 return sync_ext if sync_ext else sess_params.get_sync_extension(self.session_params) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
42 def get_sync_namespace(self, sync_namespace=None):
43 return sync_namespace if sync_namespace else sess_params.get_sync_namespace(self.session_params) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
45 def get_protocol(self, protocol=None, task_collection=None):
46 return protocol if protocol else sess_params.get_task_protocol(self.session_params, task_collection) 1cfrsWd
48 def get_task_collection(self, collection=None):
49 if not collection: 1201cfrvs3zWd
50 collection = sess_params.get_task_collection(self.session_params) 12cfrsd
51 # If inferring the collection from the experiment description, assert only one returned
52 assert collection is None or isinstance(collection, str) or len(collection) == 1 1201cfrvs3zWd
53 return collection 1201cfrvs3zWd
55 def get_device_collection(self, device, device_collection=None):
56 if device_collection: 1a9{2!#$01cjkghfrxvswLMNOPQRKSTXY%'ZiUV(3z456W78)*e+,-bdy
57 return device_collection 1a9{2!#$01cjkghfrxvswLMNOPQRKSTXY%'ZiUV(3z456W78)*e+,-bdy
58 collection_map = sess_params.get_collections(self.session_params['devices']) 1{
59 return collection_map.get(device) 1{
61 def read_params_file(self):
62 params = sess_params.read_params(self.session_path) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
64 if params is None: 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy
65 return {} 1auA9q2!#$01DmlEnFGoHIpLMNOPQRKST/:;=?XY%'(@[]^_`C456W78)*e+,-
67 # TODO figure out the best way
68 # if params is None and self.one:
69 # # Try to read params from alyx or try to download params file
70 # params = self.one.load_dataset(self.one.path2eid(self.session_path), 'params.yml')
71 # params = self.one.alyx.rest()
73 return params 1atcjkghfrxvswZiUVB3zJbdy
76class BehaviourTask(DynamicTask):
78 def __init__(self, session_path, **kwargs):
79 super().__init__(session_path, **kwargs) 1aDmlEnFGoHIpcjkghfrxvswBJbdy
81 self.collection = self.get_task_collection(kwargs.get('collection', None)) 1aDmlEnFGoHIpcjkghfrxvswBJbdy
82 # Task type (protocol)
83 self.protocol = self.get_protocol(kwargs.get('protocol', None), task_collection=self.collection) 1aDmlEnFGoHIpcjkghfrxvswBJbdy
85 self.protocol_number = self.get_protocol_number(kwargs.get('protocol_number'), task_protocol=self.protocol) 1aDmlEnFGoHIpcjkghfrxvswBJbdy
87 self.output_collection = 'alf' 1aDmlEnFGoHIpcjkghfrxvswBJbdy
88 # Do not use kwargs.get('number', None) -- this will return None if number is 0
89 if self.protocol_number is not None: 1aDmlEnFGoHIpcjkghfrxvswBJbdy
90 self.output_collection += f'/task_{self.protocol_number:02}' 1agfxJdy
92 def get_protocol(self, protocol=None, task_collection=None):
93 return protocol if protocol else sess_params.get_task_protocol(self.session_params, task_collection) 1aDmlEnFGoHIpcjkghfrxvswBJbdy
95 def get_task_collection(self, collection=None):
96 if not collection: 1aDmlEnFGoHIpcjkghfrxvswBJbdy
97 collection = sess_params.get_task_collection(self.session_params)
98 # If inferring the collection from the experiment description, assert only one returned
99 assert collection is None or isinstance(collection, str) or len(collection) == 1 1aDmlEnFGoHIpcjkghfrxvswBJbdy
100 return collection 1aDmlEnFGoHIpcjkghfrxvswBJbdy
102 def get_protocol_number(self, number=None, task_protocol=None):
103 if number is None: # Do not use "if not number" as that will return True if number is 0 1aDmlEnFGoHIpcjkghfrxvswBJbdy
104 number = sess_params.get_task_protocol_number(self.session_params, task_protocol) 1aDmlEnFGoHIpcjkghrvswBbd
105 # If inferring the number from the experiment description, assert only one returned (or something went wrong)
106 assert number is None or isinstance(number, int) 1aDmlEnFGoHIpcjkghfrxvswBJbdy
107 return number 1aDmlEnFGoHIpcjkghfrxvswBJbdy
109 @staticmethod
110 def _spacer_support(settings):
111 """
112 Spacer support was introduced in v7.1 for iblrig v7 and v8.0.1 in v8.
114 Parameters
115 ----------
116 settings : dict
117 The task settings dict.
119 Returns
120 -------
121 bool
122 True if task spacers are to be expected.
123 """
124 v = parse_version 1}B
125 version = v(settings.get('IBLRIG_VERSION_TAG')) 1}B
126 return version not in (v('100.0.0'), v('8.0.0')) and version >= v('7.1.0') 1}B
129class VideoTask(DynamicTask):
131 def __init__(self, session_path, cameras, **kwargs):
132 super().__init__(session_path, cameras=cameras, **kwargs) 1acjkghfrxvsw456W78bdy
133 self.cameras = cameras 1acjkghfrxvsw456W78bdy
134 self.device_collection = self.get_device_collection('cameras', kwargs.get('device_collection', 'raw_video_data')) 1acjkghfrxvsw456W78bdy
135 # self.collection = self.get_task_collection(kwargs.get('collection', None))
138class AudioTask(DynamicTask):
140 def __init__(self, session_path, **kwargs):
141 super().__init__(session_path, **kwargs) 1a$01cjkghfrvswbd
142 self.device_collection = self.get_device_collection('microphone', kwargs.get('device_collection', 'raw_behavior_data')) 1a$01cjkghfrvswbd
145class EphysTask(DynamicTask):
147 def __init__(self, session_path, **kwargs):
148 super().__init__(session_path, **kwargs) 1ajkghfLMNOPQRKSTb
150 self.pname = self.get_pname(kwargs.get('pname', None)) 1ajkghfLMNOPQRKSTb
151 self.nshanks, self.pextra = self.get_nshanks(kwargs.get('nshanks', None)) 1ajkghfLMNOPQRKSTb
152 self.device_collection = self.get_device_collection('neuropixel', kwargs.get('device_collection', 'raw_ephys_data')) 1ajkghfLMNOPQRKSTb
154 def get_pname(self, pname):
155 # pname can be a list or a string
156 pname = self.kwargs.get('pname', pname) 1ajkghfLMNOPQRKSTb
158 return pname 1ajkghfLMNOPQRKSTb
160 def get_nshanks(self, nshanks=None):
161 nshanks = self.kwargs.get('nshanks', nshanks) 1ajkghfLMNOPQRKSTb
162 if nshanks is not None: 1ajkghfLMNOPQRKSTb
163 pextra = [chr(97 + int(shank)) for shank in range(nshanks)] 1hfRK
164 else:
165 pextra = [] 1ajkghfLMNOPQKSTb
167 return nshanks, pextra 1ajkghfLMNOPQRKSTb
170class WidefieldTask(DynamicTask):
171 def __init__(self, session_path, **kwargs):
172 super().__init__(session_path, **kwargs) 1a!#w)*e+,-
174 self.device_collection = self.get_device_collection('widefield', kwargs.get('device_collection', 'raw_widefield_data')) 1a!#w)*e+,-
177class MesoscopeTask(DynamicTask):
178 def __init__(self, session_path, **kwargs):
179 super().__init__(session_path, **kwargs) 1a9xXY%'ZiUV(y
181 self.device_collection = self.get_device_collection( 1a9xXY%'ZiUV(y
182 'mesoscope', kwargs.get('device_collection', 'raw_imaging_data_[0-9]*'))
184 def get_signatures(self, **kwargs):
185 """
186 From the template signature of the task, create the exact list of inputs and outputs to expect based on the
187 available device collection folders
189 Necessary because we don't know in advance how many device collection folders ("imaging bouts") to expect
190 """
191 self.session_path = Path(self.session_path) 1|XYZiUV
192 # Glob for all device collection (raw imaging data) folders
193 raw_imaging_folders = [p.name for p in self.session_path.glob(self.device_collection)] 1|XYZiUV
194 # For all inputs and outputs that are part of the device collection, expand to one file per folder
195 # All others keep unchanged
196 self.input_files = [(sig[0], sig[1].replace(self.device_collection, folder), sig[2]) 1|XYZiUV
197 for folder in raw_imaging_folders for sig in self.signature['input_files']]
198 self.output_files = [(sig[0], sig[1].replace(self.device_collection, folder), sig[2]) 1|XYZiUV
199 for folder in raw_imaging_folders for sig in self.signature['output_files']]
201 def load_sync(self):
202 """
203 Load the sync and channel map.
205 This method may be expanded to support other raw DAQ data formats.
207 Returns
208 -------
209 one.alf.io.AlfBunch
210 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
211 and the corresponding channel numbers.
212 dict
213 A map of channel names and their corresponding indices.
214 """
215 alf_path = self.session_path / self.sync_collection 1UV
216 if self.get_sync_namespace() == 'timeline': 1UV
217 # Load the sync and channel map from the raw DAQ data
218 sync, chmap = load_timeline_sync_and_chmap(alf_path) 1UV
219 else:
220 raise NotImplementedError
221 return sync, chmap 1UV
224class RegisterRawDataTask(DynamicTask):
225 """
226 Base register raw task.
227 To rename files
228 1. input and output must have the same length
229 2. output files must have full filename
230 """
232 priority = 100
233 job_size = 'small'
235 def rename_files(self, symlink_old=False):
237 # If either no inputs or no outputs are given, we don't do any renaming
238 if not all(map(len, (self.input_files, self.output_files))): 1At.mlnopcizCebd
239 return 1AtmlnopczCbd
241 # Otherwise we need to make sure there is one to one correspondence for renaming files
242 assert len(self.input_files) == len(self.output_files) 1A.ie
244 for before, after in zip(self.input_files, self.output_files): 1A.ie
245 old_file, old_collection, required = before 1A.ie
246 old_path = self.session_path.joinpath(old_collection).glob(old_file) 1A.ie
247 old_path = next(old_path, None) 1A.ie
248 # if the file doesn't exist and it is not required we are okay to continue
249 if not old_path: 1A.ie
250 if required: 1A
251 raise FileNotFoundError(str(old_file)) 1A
252 else:
253 continue
255 new_file, new_collection, _ = after 1A.ie
256 new_path = self.session_path.joinpath(new_collection, new_file) 1A.ie
257 if old_path == new_path: 1A.ie
258 continue 1i
259 new_path.parent.mkdir(parents=True, exist_ok=True) 1A.e
260 _logger.debug('%s -> %s', old_path.relative_to(self.session_path), new_path.relative_to(self.session_path)) 1A.e
261 old_path.replace(new_path) 1A.e
262 if symlink_old: 1A.e
263 old_path.symlink_to(new_path) 1e
265 def register_snapshots(self, unlink=False, collection=None):
266 """
267 Register any photos in the snapshots folder to the session. Typically imaging users will
268 take numerous photos for reference. Supported extensions: .jpg, .jpeg, .png, .tif, .tiff
270 If a .txt file with the same name exists in the same location, the contents will be added
271 to the note text.
273 Parameters
274 ----------
275 unlink : bool
276 If true, files are deleted after upload.
277 collection : str, list, optional
278 Location of 'snapshots' folder relative to the session path. If None, uses
279 'device_collection' attribute (if exists) or root session path.
281 Returns
282 -------
283 list of dict
284 The newly registered Alyx notes.
285 """
286 collection = getattr(self, 'device_collection', None) if collection is None else collection 1uqe
287 collection = collection or '' # If not defined, use no collection 1uqe
288 if collection and '*' in collection: 1uqe
289 collection = [p.name for p in self.session_path.glob(collection)] 1q
290 # Check whether folders on disk contain '*'; this is to stop an infinite recursion
291 assert not any('*' in c for c in collection), 'folders containing asterisks not supported' 1q
292 # If more that one collection exists, register snapshots in each collection
293 if collection and not isinstance(collection, str): 1uqe
294 return flatten(filter(None, [self.register_snapshots(unlink, c) for c in collection])) 1q
295 snapshots_path = self.session_path.joinpath(*filter(None, (collection, 'snapshots'))) 1uqe
296 if not snapshots_path.exists(): 1uqe
297 return 1e
299 eid = self.one.path2eid(self.session_path, query_type='remote') 1uq
300 if not eid: 1uq
301 _logger.warning('Failed to upload snapshots: session not found on Alyx')
302 return
303 note = dict(user=self.one.alyx.user, content_type='session', object_id=eid, text='') 1uq
305 notes = [] 1uq
306 exts = ('.jpg', '.jpeg', '.png', '.tif', '.tiff') 1uq
307 for snapshot in filter(lambda x: x.suffix.lower() in exts, snapshots_path.glob('*.*')): 1uq
308 _logger.debug('Uploading "%s"...', snapshot.relative_to(self.session_path)) 1uq
309 if snapshot.with_suffix('.txt').exists(): 1uq
310 with open(snapshot.with_suffix('.txt'), 'r') as txt_file: 1q
311 note['text'] = txt_file.read().strip() 1q
312 else:
313 note['text'] = '' 1uq
314 with open(snapshot, 'rb') as img_file: 1uq
315 files = {'image': img_file} 1uq
316 notes.append(self.one.alyx.rest('notes', 'create', data=note, files=files)) 1uq
317 if unlink: 1uq
318 snapshot.unlink() 1u
319 # If nothing else in the snapshots folder, delete the folder
320 if unlink and next(snapshots_path.rglob('*'), None) is None: 1uq
321 snapshots_path.rmdir() 1u
322 _logger.info('%i snapshots uploaded to Alyx', len(notes)) 1uq
323 return notes 1uq
325 def _run(self, **kwargs):
326 self.rename_files(**kwargs) 1tmlnopcizCebd
327 out_files = [] 1tmlnopcizCebd
328 n_required = 0 1tmlnopcizCebd
329 for file_sig in self.output_files: 1tmlnopcizCebd
330 file_name, collection, required = file_sig 1tmlnopcizCebd
331 n_required += required 1tmlnopcizCebd
332 file_path = self.session_path.joinpath(collection).glob(file_name) 1tmlnopcizCebd
333 file_path = next(file_path, None) 1tmlnopcizCebd
334 if not file_path and not required: 1tmlnopcizCebd
335 continue 1lcC
336 elif not file_path and required: 1tmlnopcizCebd
337 _logger.error(f'expected {file_sig} missing')
338 else:
339 out_files.append(file_path) 1tmlnopcizCebd
341 if len(out_files) < n_required: 1tmlnopcizCebd
342 self.status = -1
344 return out_files 1tmlnopcizCebd
347class ExperimentDescriptionRegisterRaw(RegisterRawDataTask):
348 """dict of list: custom sign off keys corresponding to specific devices"""
349 sign_off_categories = SIGN_OFF_CATEGORIES
351 @property
352 def signature(self):
353 signature = { 1tcbd
354 'input_files': [],
355 'output_files': [('*experiment.description.yaml', '', True)]
356 }
357 return signature 1tcbd
359 def _run(self, **kwargs):
360 # Register experiment description file
361 out_files = super(ExperimentDescriptionRegisterRaw, self)._run(**kwargs) 1tcbd
362 if not self.one.offline and self.status == 0: 1tcbd
363 with no_cache(self.one.alyx): # Ensure we don't load the cached JSON response 1tcbd
364 eid = self.one.path2eid(self.session_path, query_type='remote') 1tcbd
365 exp_dec = sess_params.read_params(out_files[0]) 1tcbd
366 data = sign_off_dict(exp_dec, sign_off_categories=self.sign_off_categories) 1tcbd
367 self.one.alyx.json_field_update('sessions', eid, data=data) 1tcbd
368 return out_files 1tcbd