Coverage for ibllib/pipes/base_tasks.py: 97%

201 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1"""Abstract base classes for dynamic pipeline tasks.""" 

2import logging 

3from pathlib import Path 

4 

5from pkg_resources import parse_version 

6from one.webclient import no_cache 

7from iblutil.util import flatten 

8 

9from ibllib.pipes.tasks import Task 

10import ibllib.io.session_params as sess_params 

11from ibllib.qc.base import sign_off_dict, SIGN_OFF_CATEGORIES 

12from ibllib.io.raw_daq_loaders import load_timeline_sync_and_chmap 

13 

14_logger = logging.getLogger(__name__) 

15 

16 

17class DynamicTask(Task): 

18 

19 def __init__(self, session_path, **kwargs): 

20 super().__init__(session_path, **kwargs) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

21 self.session_params = self.read_params_file() 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

22 

23 # TODO Which should be default? 

24 # Sync collection 

25 self.sync_collection = self.get_sync_collection(kwargs.get('sync_collection', None)) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

26 # Sync type 

27 self.sync = self.get_sync(kwargs.get('sync', None)) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

28 # Sync extension 

29 self.sync_ext = self.get_sync_extension(kwargs.get('sync_ext', None)) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

30 # Sync namespace 

31 self.sync_namespace = self.get_sync_namespace(kwargs.get('sync_namespace', None)) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

32 

33 def get_sync_collection(self, sync_collection=None): 

34 return sync_collection if sync_collection else sess_params.get_sync_collection(self.session_params) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

35 

36 def get_sync(self, sync=None): 

37 return sync if sync else sess_params.get_sync_label(self.session_params) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

38 

39 def get_sync_extension(self, sync_ext=None): 

40 return sync_ext if sync_ext else sess_params.get_sync_extension(self.session_params) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

41 

42 def get_sync_namespace(self, sync_namespace=None): 

43 return sync_namespace if sync_namespace else sess_params.get_sync_namespace(self.session_params) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

44 

45 def get_protocol(self, protocol=None, task_collection=None): 

46 return protocol if protocol else sess_params.get_task_protocol(self.session_params, task_collection) 1cfrsWd

47 

48 def get_task_collection(self, collection=None): 

49 if not collection: 1201cfrvs3zWd

50 collection = sess_params.get_task_collection(self.session_params) 12cfrsd

51 # If inferring the collection from the experiment description, assert only one returned 

52 assert collection is None or isinstance(collection, str) or len(collection) == 1 1201cfrvs3zWd

53 return collection 1201cfrvs3zWd

54 

55 def get_device_collection(self, device, device_collection=None): 

56 if device_collection: 1a9{2!#$01cjkghfrxvswLMNOPQRKSTXY%'ZiUV(3z456W78)*e+,-bdy

57 return device_collection 1a9{2!#$01cjkghfrxvswLMNOPQRKSTXY%'ZiUV(3z456W78)*e+,-bdy

58 collection_map = sess_params.get_collections(self.session_params['devices']) 1{

59 return collection_map.get(device) 1{

60 

61 def read_params_file(self): 

62 params = sess_params.read_params(self.session_path) 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

63 

64 if params is None: 1auA9qt2!#$01DmlEnFGoHIpcjkghfrxvswLMNOPQRKST/:;=?XY%'ZiUV(B3z@[]^_`C456W78)*e+,-Jbdy

65 return {} 1auA9q2!#$01DmlEnFGoHIpLMNOPQRKST/:;=?XY%'(@[]^_`C456W78)*e+,-

66 

67 # TODO figure out the best way 

68 # if params is None and self.one: 

69 # # Try to read params from alyx or try to download params file 

70 # params = self.one.load_dataset(self.one.path2eid(self.session_path), 'params.yml') 

71 # params = self.one.alyx.rest() 

72 

73 return params 1atcjkghfrxvswZiUVB3zJbdy

74 

75 

76class BehaviourTask(DynamicTask): 

77 

78 def __init__(self, session_path, **kwargs): 

79 super().__init__(session_path, **kwargs) 1aDmlEnFGoHIpcjkghfrxvswBJbdy

80 

81 self.collection = self.get_task_collection(kwargs.get('collection', None)) 1aDmlEnFGoHIpcjkghfrxvswBJbdy

82 # Task type (protocol) 

83 self.protocol = self.get_protocol(kwargs.get('protocol', None), task_collection=self.collection) 1aDmlEnFGoHIpcjkghfrxvswBJbdy

84 

85 self.protocol_number = self.get_protocol_number(kwargs.get('protocol_number'), task_protocol=self.protocol) 1aDmlEnFGoHIpcjkghfrxvswBJbdy

86 

87 self.output_collection = 'alf' 1aDmlEnFGoHIpcjkghfrxvswBJbdy

88 # Do not use kwargs.get('number', None) -- this will return None if number is 0 

89 if self.protocol_number is not None: 1aDmlEnFGoHIpcjkghfrxvswBJbdy

90 self.output_collection += f'/task_{self.protocol_number:02}' 1agfxJdy

91 

92 def get_protocol(self, protocol=None, task_collection=None): 

93 return protocol if protocol else sess_params.get_task_protocol(self.session_params, task_collection) 1aDmlEnFGoHIpcjkghfrxvswBJbdy

94 

95 def get_task_collection(self, collection=None): 

96 if not collection: 1aDmlEnFGoHIpcjkghfrxvswBJbdy

97 collection = sess_params.get_task_collection(self.session_params) 

98 # If inferring the collection from the experiment description, assert only one returned 

99 assert collection is None or isinstance(collection, str) or len(collection) == 1 1aDmlEnFGoHIpcjkghfrxvswBJbdy

100 return collection 1aDmlEnFGoHIpcjkghfrxvswBJbdy

101 

102 def get_protocol_number(self, number=None, task_protocol=None): 

103 if number is None: # Do not use "if not number" as that will return True if number is 0 1aDmlEnFGoHIpcjkghfrxvswBJbdy

104 number = sess_params.get_task_protocol_number(self.session_params, task_protocol) 1aDmlEnFGoHIpcjkghrvswBbd

105 # If inferring the number from the experiment description, assert only one returned (or something went wrong) 

106 assert number is None or isinstance(number, int) 1aDmlEnFGoHIpcjkghfrxvswBJbdy

107 return number 1aDmlEnFGoHIpcjkghfrxvswBJbdy

108 

109 @staticmethod 

110 def _spacer_support(settings): 

111 """ 

112 Spacer support was introduced in v7.1 for iblrig v7 and v8.0.1 in v8. 

113 

114 Parameters 

115 ---------- 

116 settings : dict 

117 The task settings dict. 

118 

119 Returns 

120 ------- 

121 bool 

122 True if task spacers are to be expected. 

123 """ 

124 v = parse_version 1}B

125 version = v(settings.get('IBLRIG_VERSION_TAG')) 1}B

126 return version not in (v('100.0.0'), v('8.0.0')) and version >= v('7.1.0') 1}B

127 

128 

129class VideoTask(DynamicTask): 

130 

131 def __init__(self, session_path, cameras, **kwargs): 

132 super().__init__(session_path, cameras=cameras, **kwargs) 1acjkghfrxvsw456W78bdy

133 self.cameras = cameras 1acjkghfrxvsw456W78bdy

134 self.device_collection = self.get_device_collection('cameras', kwargs.get('device_collection', 'raw_video_data')) 1acjkghfrxvsw456W78bdy

135 # self.collection = self.get_task_collection(kwargs.get('collection', None)) 

136 

137 

138class AudioTask(DynamicTask): 

139 

140 def __init__(self, session_path, **kwargs): 

141 super().__init__(session_path, **kwargs) 1a$01cjkghfrvswbd

142 self.device_collection = self.get_device_collection('microphone', kwargs.get('device_collection', 'raw_behavior_data')) 1a$01cjkghfrvswbd

143 

144 

145class EphysTask(DynamicTask): 

146 

147 def __init__(self, session_path, **kwargs): 

148 super().__init__(session_path, **kwargs) 1ajkghfLMNOPQRKSTb

149 

150 self.pname = self.get_pname(kwargs.get('pname', None)) 1ajkghfLMNOPQRKSTb

151 self.nshanks, self.pextra = self.get_nshanks(kwargs.get('nshanks', None)) 1ajkghfLMNOPQRKSTb

152 self.device_collection = self.get_device_collection('neuropixel', kwargs.get('device_collection', 'raw_ephys_data')) 1ajkghfLMNOPQRKSTb

153 

154 def get_pname(self, pname): 

155 # pname can be a list or a string 

156 pname = self.kwargs.get('pname', pname) 1ajkghfLMNOPQRKSTb

157 

158 return pname 1ajkghfLMNOPQRKSTb

159 

160 def get_nshanks(self, nshanks=None): 

161 nshanks = self.kwargs.get('nshanks', nshanks) 1ajkghfLMNOPQRKSTb

162 if nshanks is not None: 1ajkghfLMNOPQRKSTb

163 pextra = [chr(97 + int(shank)) for shank in range(nshanks)] 1hfRK

164 else: 

165 pextra = [] 1ajkghfLMNOPQKSTb

166 

167 return nshanks, pextra 1ajkghfLMNOPQRKSTb

168 

169 

170class WidefieldTask(DynamicTask): 

171 def __init__(self, session_path, **kwargs): 

172 super().__init__(session_path, **kwargs) 1a!#w)*e+,-

173 

174 self.device_collection = self.get_device_collection('widefield', kwargs.get('device_collection', 'raw_widefield_data')) 1a!#w)*e+,-

175 

176 

177class MesoscopeTask(DynamicTask): 

178 def __init__(self, session_path, **kwargs): 

179 super().__init__(session_path, **kwargs) 1a9xXY%'ZiUV(y

180 

181 self.device_collection = self.get_device_collection( 1a9xXY%'ZiUV(y

182 'mesoscope', kwargs.get('device_collection', 'raw_imaging_data_[0-9]*')) 

183 

184 def get_signatures(self, **kwargs): 

185 """ 

186 From the template signature of the task, create the exact list of inputs and outputs to expect based on the 

187 available device collection folders 

188 

189 Necessary because we don't know in advance how many device collection folders ("imaging bouts") to expect 

190 """ 

191 self.session_path = Path(self.session_path) 1|XYZiUV

192 # Glob for all device collection (raw imaging data) folders 

193 raw_imaging_folders = [p.name for p in self.session_path.glob(self.device_collection)] 1|XYZiUV

194 # For all inputs and outputs that are part of the device collection, expand to one file per folder 

195 # All others keep unchanged 

196 self.input_files = [(sig[0], sig[1].replace(self.device_collection, folder), sig[2]) 1|XYZiUV

197 for folder in raw_imaging_folders for sig in self.signature['input_files']] 

198 self.output_files = [(sig[0], sig[1].replace(self.device_collection, folder), sig[2]) 1|XYZiUV

199 for folder in raw_imaging_folders for sig in self.signature['output_files']] 

200 

201 def load_sync(self): 

202 """ 

203 Load the sync and channel map. 

204 

205 This method may be expanded to support other raw DAQ data formats. 

206 

207 Returns 

208 ------- 

209 one.alf.io.AlfBunch 

210 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses 

211 and the corresponding channel numbers. 

212 dict 

213 A map of channel names and their corresponding indices. 

214 """ 

215 alf_path = self.session_path / self.sync_collection 1UV

216 if self.get_sync_namespace() == 'timeline': 1UV

217 # Load the sync and channel map from the raw DAQ data 

218 sync, chmap = load_timeline_sync_and_chmap(alf_path) 1UV

219 else: 

220 raise NotImplementedError 

221 return sync, chmap 1UV

222 

223 

224class RegisterRawDataTask(DynamicTask): 

225 """ 

226 Base register raw task. 

227 To rename files 

228 1. input and output must have the same length 

229 2. output files must have full filename 

230 """ 

231 

232 priority = 100 

233 job_size = 'small' 

234 

235 def rename_files(self, symlink_old=False): 

236 

237 # If either no inputs or no outputs are given, we don't do any renaming 

238 if not all(map(len, (self.input_files, self.output_files))): 1At.mlnopcizCebd

239 return 1AtmlnopczCbd

240 

241 # Otherwise we need to make sure there is one to one correspondence for renaming files 

242 assert len(self.input_files) == len(self.output_files) 1A.ie

243 

244 for before, after in zip(self.input_files, self.output_files): 1A.ie

245 old_file, old_collection, required = before 1A.ie

246 old_path = self.session_path.joinpath(old_collection).glob(old_file) 1A.ie

247 old_path = next(old_path, None) 1A.ie

248 # if the file doesn't exist and it is not required we are okay to continue 

249 if not old_path: 1A.ie

250 if required: 1A

251 raise FileNotFoundError(str(old_file)) 1A

252 else: 

253 continue 

254 

255 new_file, new_collection, _ = after 1A.ie

256 new_path = self.session_path.joinpath(new_collection, new_file) 1A.ie

257 if old_path == new_path: 1A.ie

258 continue 1i

259 new_path.parent.mkdir(parents=True, exist_ok=True) 1A.e

260 _logger.debug('%s -> %s', old_path.relative_to(self.session_path), new_path.relative_to(self.session_path)) 1A.e

261 old_path.replace(new_path) 1A.e

262 if symlink_old: 1A.e

263 old_path.symlink_to(new_path) 1e

264 

265 def register_snapshots(self, unlink=False, collection=None): 

266 """ 

267 Register any photos in the snapshots folder to the session. Typically imaging users will 

268 take numerous photos for reference. Supported extensions: .jpg, .jpeg, .png, .tif, .tiff 

269 

270 If a .txt file with the same name exists in the same location, the contents will be added 

271 to the note text. 

272 

273 Parameters 

274 ---------- 

275 unlink : bool 

276 If true, files are deleted after upload. 

277 collection : str, list, optional 

278 Location of 'snapshots' folder relative to the session path. If None, uses 

279 'device_collection' attribute (if exists) or root session path. 

280 

281 Returns 

282 ------- 

283 list of dict 

284 The newly registered Alyx notes. 

285 """ 

286 collection = getattr(self, 'device_collection', None) if collection is None else collection 1uqe

287 collection = collection or '' # If not defined, use no collection 1uqe

288 if collection and '*' in collection: 1uqe

289 collection = [p.name for p in self.session_path.glob(collection)] 1q

290 # Check whether folders on disk contain '*'; this is to stop an infinite recursion 

291 assert not any('*' in c for c in collection), 'folders containing asterisks not supported' 1q

292 # If more that one collection exists, register snapshots in each collection 

293 if collection and not isinstance(collection, str): 1uqe

294 return flatten(filter(None, [self.register_snapshots(unlink, c) for c in collection])) 1q

295 snapshots_path = self.session_path.joinpath(*filter(None, (collection, 'snapshots'))) 1uqe

296 if not snapshots_path.exists(): 1uqe

297 return 1e

298 

299 eid = self.one.path2eid(self.session_path, query_type='remote') 1uq

300 if not eid: 1uq

301 _logger.warning('Failed to upload snapshots: session not found on Alyx') 

302 return 

303 note = dict(user=self.one.alyx.user, content_type='session', object_id=eid, text='') 1uq

304 

305 notes = [] 1uq

306 exts = ('.jpg', '.jpeg', '.png', '.tif', '.tiff') 1uq

307 for snapshot in filter(lambda x: x.suffix.lower() in exts, snapshots_path.glob('*.*')): 1uq

308 _logger.debug('Uploading "%s"...', snapshot.relative_to(self.session_path)) 1uq

309 if snapshot.with_suffix('.txt').exists(): 1uq

310 with open(snapshot.with_suffix('.txt'), 'r') as txt_file: 1q

311 note['text'] = txt_file.read().strip() 1q

312 else: 

313 note['text'] = '' 1uq

314 with open(snapshot, 'rb') as img_file: 1uq

315 files = {'image': img_file} 1uq

316 notes.append(self.one.alyx.rest('notes', 'create', data=note, files=files)) 1uq

317 if unlink: 1uq

318 snapshot.unlink() 1u

319 # If nothing else in the snapshots folder, delete the folder 

320 if unlink and next(snapshots_path.rglob('*'), None) is None: 1uq

321 snapshots_path.rmdir() 1u

322 _logger.info('%i snapshots uploaded to Alyx', len(notes)) 1uq

323 return notes 1uq

324 

325 def _run(self, **kwargs): 

326 self.rename_files(**kwargs) 1tmlnopcizCebd

327 out_files = [] 1tmlnopcizCebd

328 n_required = 0 1tmlnopcizCebd

329 for file_sig in self.output_files: 1tmlnopcizCebd

330 file_name, collection, required = file_sig 1tmlnopcizCebd

331 n_required += required 1tmlnopcizCebd

332 file_path = self.session_path.joinpath(collection).glob(file_name) 1tmlnopcizCebd

333 file_path = next(file_path, None) 1tmlnopcizCebd

334 if not file_path and not required: 1tmlnopcizCebd

335 continue 1lcC

336 elif not file_path and required: 1tmlnopcizCebd

337 _logger.error(f'expected {file_sig} missing') 

338 else: 

339 out_files.append(file_path) 1tmlnopcizCebd

340 

341 if len(out_files) < n_required: 1tmlnopcizCebd

342 self.status = -1 

343 

344 return out_files 1tmlnopcizCebd

345 

346 

347class ExperimentDescriptionRegisterRaw(RegisterRawDataTask): 

348 """dict of list: custom sign off keys corresponding to specific devices""" 

349 sign_off_categories = SIGN_OFF_CATEGORIES 

350 

351 @property 

352 def signature(self): 

353 signature = { 1tcbd

354 'input_files': [], 

355 'output_files': [('*experiment.description.yaml', '', True)] 

356 } 

357 return signature 1tcbd

358 

359 def _run(self, **kwargs): 

360 # Register experiment description file 

361 out_files = super(ExperimentDescriptionRegisterRaw, self)._run(**kwargs) 1tcbd

362 if not self.one.offline and self.status == 0: 1tcbd

363 with no_cache(self.one.alyx): # Ensure we don't load the cached JSON response 1tcbd

364 eid = self.one.path2eid(self.session_path, query_type='remote') 1tcbd

365 exp_dec = sess_params.read_params(out_files[0]) 1tcbd

366 data = sign_off_dict(exp_dec, sign_off_categories=self.sign_off_categories) 1tcbd

367 self.one.alyx.json_field_update('sessions', eid, data=data) 1tcbd

368 return out_files 1tcbd