Coverage for brainbox/io/one.py: 57%

772 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-02 18:55 +0100

1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment.""" 

2from dataclasses import dataclass, field 

3import gc 

4import logging 

5import re 

6import os 

7from pathlib import Path 

8from collections import defaultdict 

9 

10import numpy as np 

11import pandas as pd 

12from scipy.interpolate import interp1d 

13import matplotlib.pyplot as plt 

14 

15from one.api import ONE, One 

16from one.alf.path import get_alf_path, full_path_parts, filename_parts 

17from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound 

18from one.alf import cache 

19import one.alf.io as alfio 

20from neuropixel import TIP_SIZE_UM, trace_header 

21import spikeglx 

22 

23import ibldsp.voltage 

24from ibldsp.waveform_extraction import WaveformsLoader 

25from iblutil.util import Bunch 

26from iblatlas.atlas import AllenAtlas, BrainRegions 

27from iblatlas import atlas 

28from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times 

29from ibllib.pipes import histology 

30from ibllib.pipes.ephys_alignment import EphysAlignment 

31from ibllib.plots import vertical_lines, Density 

32 

33import brainbox.plot 

34from brainbox.io.spikeglx import Streamer 

35from brainbox.ephys_plots import plot_brain_regions 

36from brainbox.metrics.single_units import quick_unit_metrics 

37from brainbox.behavior.wheel import interpolate_position, velocity_filtered 

38from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter 

39 

40_logger = logging.getLogger('ibllib') 

41 

42 

43SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths'] 

44CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids'] 

45WAVEFORMS_ATTRIBUTES = ['templates'] 

46 

47 

48def load_lfp(eid, one=None, dataset_types=None, **kwargs): 

49 """ 

50 TODO Verify works 

51 From an eid, hits the Alyx database and downloads the standard set of datasets 

52 needed for LFP 

53 :param eid: 

54 :param dataset_types: additional dataset types to add to the list 

55 :param open: if True, spikeglx readers are opened 

56 :return: spikeglx.Reader 

57 """ 

58 if dataset_types is None: 

59 dataset_types = [] 

60 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*'] 

61 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes] 

62 session_path = one.eid2path(eid) 

63 

64 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False) 

65 if ef.get('lf', None)] 

66 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles] 

67 

68 

69def _collection_filter_from_args(probe, spike_sorter=None): 

70 collection = f'alf/{probe}/{spike_sorter}' 1g

71 collection = collection.replace('None', '*') 1g

72 collection = collection.replace('/*', '*') 1g

73 collection = collection[:-1] if collection.endswith('/') else collection 1g

74 return collection 1g

75 

76 

77def _get_spike_sorting_collection(collections, pname): 

78 """ 

79 Filters a list or array of collections to get the relevant spike sorting dataset 

80 if there is a pykilosort, load it 

81 """ 

82 # 

83 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1ga

84 # otherwise, prefers the shortest 

85 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1ga

86 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1ga

87 return collection 1ga

88 

89 

90def _channels_alyx2bunch(chans): 

91 channels = Bunch({ 

92 'atlas_id': np.array([ch['brain_region'] for ch in chans]), 

93 'x': np.array([ch['x'] for ch in chans]) / 1e6, 

94 'y': np.array([ch['y'] for ch in chans]) / 1e6, 

95 'z': np.array([ch['z'] for ch in chans]) / 1e6, 

96 'axial_um': np.array([ch['axial'] for ch in chans]), 

97 'lateral_um': np.array([ch['lateral'] for ch in chans]) 

98 }) 

99 return channels 

100 

101 

102def _channels_traj2bunch(xyz_chans, brain_atlas): 

103 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans)) 

104 channels = { 

105 'x': xyz_chans[:, 0], 

106 'y': xyz_chans[:, 1], 

107 'z': xyz_chans[:, 2], 

108 'acronym': brain_regions['acronym'], 

109 'atlas_id': brain_regions['id'] 

110 } 

111 

112 return channels 

113 

114 

115def _channels_bunch2alf(channels): 

116 channels_ = { 1i

117 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6, 

118 'brainLocationIds_ccf_2017': channels['atlas_id'], 

119 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]} 

120 return channels_ 1i

121 

122 

123def _channels_alf2bunch(channels, brain_regions=None): 

124 # reformat the dictionary according to the standard that comes out of Alyx 

125 channels_ = { 1icaed

126 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6, 

127 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6, 

128 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6, 

129 'acronym': None, 

130 'atlas_id': channels['brainLocationIds_ccf_2017'], 

131 'axial_um': channels['localCoordinates'][:, 1], 

132 'lateral_um': channels['localCoordinates'][:, 0], 

133 } 

134 # here if we have some extra keys, they will carry over to the next dictionary 

135 for k in channels: 1icaed

136 if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: 1icaed

137 channels_[k] = channels[k] 1icaed

138 if brain_regions: 1icaed

139 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icaed

140 return channels_ 1icaed

141 

142 

143def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None, 

144 brain_regions=None): 

145 """ 

146 Generic function to load spike sorting according data using ONE. 

147 

148 Will try to load one spike sorting for any probe present for the eid matching the collection 

149 For each probe it will load a spike sorting: 

150 - if there is one version: loads this one 

151 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX) 

152 

153 Parameters 

154 ---------- 

155 eid : [str, UUID, Path, dict] 

156 Experiment session identifier; may be a UUID, URL, experiment reference string 

157 details dict or Path 

158 one : one.api.OneAlyx 

159 An instance of ONE (may be in 'local' mode) 

160 collection : str 

161 collection filter word - accepts wildcards - can be a combination of spike sorter and 

162 probe. See `ALF documentation`_ for details. 

163 revision : str 

164 A particular revision return (defaults to latest revision). See `ALF documentation`_ for 

165 details. 

166 return_channels : bool 

167 Defaults to False otherwise loads channels from disk 

168 

169 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components 

170 

171 Returns 

172 ------- 

173 spikes : dict of one.alf.io.AlfBunch 

174 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

175 session and spike sorter, with keys ('clusters', 'times') 

176 clusters : dict of one.alf.io.AlfBunch 

177 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

178 ('channels', 'depths', 'metrics') 

179 channels : dict of one.alf.io.AlfBunch 

180 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

181 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs 

182 non-lateralized. 

183 """ 

184 one = one or ONE() 1ga

185 # enumerate probes and load according to the name 

186 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1ga

187 if len(collections) == 0: 1ga

188 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

189 pnames = list(set(c.split('/')[1] for c in collections)) 1ga

190 spikes, clusters, channels = ({} for _ in range(3)) 1ga

191 

192 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1ga

193 

194 for pname in pnames: 1ga

195 probe_collection = _get_spike_sorting_collection(collections, pname) 1ga

196 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1ga

197 attribute=spike_attributes, namespace='') 

198 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1ga

199 attribute=cluster_attributes, namespace='') 

200 if return_channels: 1ga

201 channels = _load_channels_locations_from_disk( 1a

202 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions) 

203 return spikes, clusters, channels 1a

204 else: 

205 return spikes, clusters 1g

206 

207 

208def _get_attributes(dataset_types): 

209 if dataset_types is None: 1ga

210 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1ga

211 else: 

212 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 

213 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 

214 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 

215 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 

216 return spike_attributes, cluster_attributes 

217 

218 

219def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None): 

220 _logger.debug('loading spike sorting from disk') 1a

221 channels = Bunch({}) 1a

222 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) 1a

223 if len(collections) == 0: 1a

224 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

225 probes = list(set([c.split('/')[1] for c in collections])) 1a

226 for probe in probes: 1a

227 probe_collection = _get_spike_sorting_collection(collections, probe) 1a

228 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') 1a

229 # if the spike sorter has not aligned data, try and get the alignment available 

230 if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): 1a

231 aligned_channel_collections = one.list_collections( 

232 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision) 

233 if len(aligned_channel_collections) == 0: 

234 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}") 

235 continue 

236 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}") 

237 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe) 

238 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection) 

239 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe]) 

240 # only have to reformat channels if we were able to load coordinates from disk 

241 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) 1a

242 return channels 1a

243 

244 

245def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None): 

246 """ 

247 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto 

248 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field 

249 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts 

250 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys 

251 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017' 

252 OR 

253 'x', 'y', 'z', 'acronym', 'axial_um' 

254 those are the guide for the interpolation 

255 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates' 

256 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object 

257 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017 

258 if a brain region object is provided, outputts a dict with keys 

259 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um' 

260 :return: Bunch or dictionary of channels with brain coordinates keys 

261 """ 

262 NEUROPIXEL_VERSION = 1 1i

263 h = trace_header(version=NEUROPIXEL_VERSION) 1i

264 if channels is None: 1i

265 channels = {'localCoordinates': np.c_[h['x'], h['y']]} 

266 nch = channels['localCoordinates'].shape[0] 1i

267 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i

268 channels_aligned = _channels_bunch2alf(channels_aligned) 1i

269 if 'localCoordinates' in channels_aligned.keys(): 1i

270 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i

271 else: # this is a edge case for a few spike sorting sessions 

272 assert channels_aligned['mlapdv'].shape[0] == 384 

273 aligned_depths = h['y'] 

274 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i

275 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i

276 channels['mlapdv'] = np.zeros((nch, 3)) 1i

277 for i in np.arange(3): 1i

278 channels['mlapdv'][:, i] = np.interp( 1i

279 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv] 

280 # the brain locations have to be interpolated by nearest neighbour 

281 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i

282 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i

283 if brain_regions is not None: 1i

284 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i

285 else: 

286 return channels 1i

287 

288 

289def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False, 

290 brain_atlas=None, return_source=False): 

291 if not hasattr(one, 'alyx'): 1a

292 return {}, None 1a

293 _logger.debug(f"trying to load from traj {probe}") 

294 channels = Bunch() 

295 brain_atlas = brain_atlas or AllenAtlas 

296 # need to find the collection bruh 

297 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0] 

298 collection = _collection_filter_from_args(probe=probe) 

299 collections = one.list_collections(eid, filename='channels*', collection=collection, 

300 revision=revision) 

301 probe_collection = _get_spike_sorting_collection(collections, probe) 

302 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection) 

303 depths = chn_coords[:, 1] 

304 

305 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

306 get('tracing_exists', False) 

307 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

308 get('alignment_resolved', False) 

309 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

310 get('alignment_count', 0) 

311 

312 if tracing: 

313 xyz = np.array(insertion['json']['xyz_picks']) / 1e6 

314 if resolved: 

315 

316 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. ' 

317 f'Channel and cluster locations obtained from ephys aligned histology ' 

318 f'track.') 

319 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

320 provenance='Ephys aligned histology track')[0] 

321 align_key = insertion['json']['extended_qc']['alignment_stored'] 

322 feature = traj['json'][align_key][0] 

323 track = traj['json'][align_key][1] 

324 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

325 feature_prev=feature, 

326 brain_atlas=brain_atlas, speedy=True) 

327 chans = ephysalign.get_channel_locations(feature, track) 

328 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

329 source = 'resolved' 

330 elif counts > 0 and aligned: 

331 _logger.debug(f'Channel locations for {eid}/{probe} have not been ' 

332 f'resolved. However, alignment flag set to True so channel and cluster' 

333 f' locations will be obtained from latest available ephys aligned ' 

334 f'histology track.') 

335 # get the latest user aligned channels 

336 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

337 provenance='Ephys aligned histology track')[0] 

338 align_key = insertion['json']['extended_qc']['alignment_stored'] 

339 feature = traj['json'][align_key][0] 

340 track = traj['json'][align_key][1] 

341 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

342 feature_prev=feature, 

343 brain_atlas=brain_atlas, speedy=True) 

344 chans = ephysalign.get_channel_locations(feature, track) 

345 

346 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

347 source = 'aligned' 

348 else: 

349 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. ' 

350 f'Channel and cluster locations obtained from histology track.') 

351 # get the channels from histology tracing 

352 xyz = xyz[np.argsort(xyz[:, 2]), :] 

353 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) 

354 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

355 source = 'traced' 

356 channels[probe]['axial_um'] = chn_coords[:, 1] 

357 channels[probe]['lateral_um'] = chn_coords[:, 0] 

358 

359 else: 

360 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}') 

361 source = '' 

362 channels = None 

363 

364 if return_source: 

365 return channels, source 

366 else: 

367 return channels 

368 

369 

370def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None): 

371 """ 

372 Load the brain locations of each channel for a given session/probe 

373 

374 Parameters 

375 ---------- 

376 eid : [str, UUID, Path, dict] 

377 Experiment session identifier; may be a UUID, URL, experiment reference string 

378 details dict or Path 

379 probe : [str, list of str] 

380 The probe label(s), e.g. 'probe01' 

381 one : one.api.OneAlyx 

382 An instance of ONE (shouldn't be in 'local' mode) 

383 aligned : bool 

384 Whether to get the latest user aligned channel when not resolved or use histology track 

385 brain_atlas : iblatlas.BrainAtlas 

386 Brain atlas object (default: Allen atlas) 

387 Returns 

388 ------- 

389 dict of one.alf.io.AlfBunch 

390 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

391 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

392 optional: string 'resolved', 'aligned', 'traced' or '' 

393 """ 

394 one = one or ONE() 

395 brain_atlas = brain_atlas or AllenAtlas() 

396 if isinstance(eid, dict): 

397 ses = eid 

398 eid = ses['url'][-36:] 

399 else: 

400 eid = one.to_eid(eid) 

401 collection = _collection_filter_from_args(probe=probe) 

402 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection, 

403 brain_regions=brain_atlas.regions) 

404 incomplete_probes = [k for k in channels if 'x' not in channels[k]] 

405 for iprobe in incomplete_probes: 

406 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned, 

407 brain_atlas=brain_atlas, return_source=True) 

408 if channels_ is not None: 

409 channels[iprobe] = channels_[iprobe] 

410 return channels 

411 

412 

413def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

414 brain_regions=None, nested=True, collection=None, return_collection=False): 

415 """ 

416 From an eid, loads spikes and clusters for all probes 

417 The following set of dataset types are loaded: 

418 'clusters.channels', 

419 'clusters.depths', 

420 'clusters.metrics', 

421 'spikes.clusters', 

422 'spikes.times', 

423 'probes.description' 

424 :param eid: experiment UUID or pathlib.Path of the local session 

425 :param one: an instance of OneAlyx 

426 :param probe: name of probe to load in, if not given all probes for session will be loaded 

427 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

428 :param spike_sorter: name of the spike sorting you want to load (None for default) 

429 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00" 

430 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

431 :param nested: if a single probe is required, do not output a dictionary with the probe name as key 

432 :param return_collection: (False) if True, will return the collection used to load 

433 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe) 

434 """ 

435 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.' 

436 'Use brainbox.io.one.SpikeSortingLoader instead') 

437 if collection is None: 

438 collection = _collection_filter_from_args(probe, spike_sorter) 

439 _logger.debug(f"load spike sorting with collection filter {collection}") 

440 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types, 

441 brain_regions=brain_regions) 

442 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True) 

443 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

444 if nested is False and len(spikes.keys()) == 1: 

445 k = list(spikes.keys())[0] 

446 channels = channels[k] 

447 clusters = clusters[k] 

448 spikes = spikes[k] 

449 if return_collection: 

450 return spikes, clusters, channels, collection 

451 else: 

452 return spikes, clusters, channels 

453 

454 

455def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

456 brain_regions=None, return_collection=False): 

457 """ 

458 From an eid, loads spikes and clusters for all probes 

459 The following set of dataset types are loaded: 

460 'clusters.channels', 

461 'clusters.depths', 

462 'clusters.metrics', 

463 'spikes.clusters', 

464 'spikes.times', 

465 'probes.description' 

466 :param eid: experiment UUID or pathlib.Path of the local session 

467 :param one: an instance of OneAlyx 

468 :param probe: name of probe to load in, if not given all probes for session will be loaded 

469 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

470 :param spike_sorter: name of the spike sorting you want to load (None for default) 

471 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

472 :param return_collection:(bool - False) if True, returns the collection for loading the data 

473 :return: spikes, clusters (dict of bunch, 1 bunch per probe) 

474 """ 

475 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g

476 'Use brainbox.io.one.SpikeSortingLoader instead') 

477 collection = _collection_filter_from_args(probe, spike_sorter) 1g

478 _logger.debug(f"load spike sorting with collection filter {collection}") 1g

479 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g

480 return_channels=False, dataset_types=dataset_types, 

481 brain_regions=brain_regions) 

482 if return_collection: 1g

483 return spikes, clusters, collection 

484 else: 

485 return spikes, clusters 1g

486 

487 

488def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None, 

489 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False): 

490 """ 

491 For a given eid, get spikes, clusters and channels information, and merges clusters 

492 and channels information before returning all three variables. 

493 

494 Parameters 

495 ---------- 

496 eid : [str, UUID, Path, dict] 

497 Experiment session identifier; may be a UUID, URL, experiment reference string 

498 details dict or Path 

499 one : one.api.OneAlyx 

500 An instance of ONE (shouldn't be in 'local' mode) 

501 probe : [str, list of str] 

502 The probe label(s), e.g. 'probe01' 

503 aligned : bool 

504 Whether to get the latest user aligned channel when not resolved or use histology track 

505 dataset_types : list of str 

506 Optional additional spikes/clusters objects to add to the standard default list 

507 spike_sorter : str 

508 Name of the spike sorting you want to load (None for default which is pykilosort if it's 

509 available otherwise the default MATLAB kilosort) 

510 brain_atlas : iblatlas.atlas.BrainAtlas 

511 Brain atlas object (default: Allen atlas) 

512 return_collection: bool 

513 Returns an extra argument with the collection chosen 

514 

515 Returns 

516 ------- 

517 spikes : dict of one.alf.io.AlfBunch 

518 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

519 session and spike sorter, with keys ('clusters', 'times') 

520 clusters : dict of one.alf.io.AlfBunch 

521 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

522 ('channels', 'depths', 'metrics') 

523 channels : dict of one.alf.io.AlfBunch 

524 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

525 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

526 """ 

527 # --- Get spikes and clusters data 

528 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 

529 'Use brainbox.io.one.SpikeSortingLoader instead') 

530 one = one or ONE() 

531 brain_atlas = brain_atlas or AllenAtlas() 

532 spikes, clusters, collection = load_spike_sorting( 

533 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True) 

534 # -- Get brain regions and assign to clusters 

535 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned, 

536 brain_atlas=brain_atlas) 

537 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

538 if nested is False and len(spikes.keys()) == 1: 

539 k = list(spikes.keys())[0] 

540 channels = channels[k] 

541 clusters = clusters[k] 

542 spikes = spikes[k] 

543 if return_collection: 

544 return spikes, clusters, channels, collection 

545 else: 

546 return spikes, clusters, channels 

547 

548 

549def load_ephys_session(eid, one=None): 

550 """ 

551 From an eid, hits the Alyx database and downloads a standard default set of dataset types 

552 From a local session Path (pathlib.Path), loads a standard default set of dataset types 

553 to perform analysis: 

554 'clusters.channels', 

555 'clusters.depths', 

556 'clusters.metrics', 

557 'spikes.clusters', 

558 'spikes.times', 

559 'probes.description' 

560 

561 Parameters 

562 ---------- 

563 eid : [str, UUID, Path, dict] 

564 Experiment session identifier; may be a UUID, URL, experiment reference string 

565 details dict or Path 

566 one : oneibl.one.OneAlyx, optional 

567 ONE object to use for loading. Will generate internal one if not used, by default None 

568 

569 Returns 

570 ------- 

571 spikes : dict of one.alf.io.AlfBunch 

572 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

573 session and spike sorter, with keys ('clusters', 'times') 

574 clusters : dict of one.alf.io.AlfBunch 

575 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

576 ('channels', 'depths', 'metrics') 

577 trials : one.alf.io.AlfBunch of numpy.ndarray 

578 The session trials data 

579 """ 

580 assert one 1g

581 spikes, clusters = load_spike_sorting(eid, one=one) 1g

582 trials = one.load_object(eid, 'trials') 1g

583 return spikes, clusters, trials 1g

584 

585 

586def _remove_old_clusters(session_path, probe): 

587 # gets clusters and spikes from a local session folder 

588 probe_path = session_path.joinpath('alf', probe) 

589 

590 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead 

591 cluster_file = probe_path.joinpath('clusters.metrics.csv') 

592 

593 if cluster_file.exists(): 

594 os.remove(cluster_file) 

595 _logger.info('Deleting old clusters.metrics.csv file') 

596 

597 

598def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None): 

599 """ 

600 Takes (default and any extra) values in given keys from channels and assign them to clusters. 

601 If channels does not contain any data, the new keys are added to clusters but left empty. 

602 

603 Parameters 

604 ---------- 

605 dic_clus : dict of one.alf.io.AlfBunch 

606 1 bunch per probe, containing cluster information 

607 channels : dict of one.alf.io.AlfBunch 

608 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates') 

609 keys_to_add_extra : list of str 

610 Any extra keys to load into channels bunches 

611 

612 Returns 

613 ------- 

614 dict of one.alf.io.AlfBunch 

615 clusters (1 bunch per probe) with new keys values. 

616 """ 

617 probe_labels = list(channels.keys()) # Convert dict_keys into list 

618 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um'] 

619 

620 if keys_to_add_extra is None: 

621 keys_to_add = keys_to_add_default 

622 else: 

623 # Append extra optional keys 

624 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default)) 

625 

626 for label in probe_labels: 

627 clu_ch = dic_clus[label]['channels'] 

628 for key in keys_to_add: 

629 try: 

630 assert key in channels[label].keys() # Check key is in channels 

631 ch_key = channels[label][key] 

632 nch_key = len(ch_key) if ch_key is not None else 0 

633 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index 

634 dic_clus[label][key] = ch_key[clu_ch] 

635 else: 

636 _logger.warning( 

637 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels' 

638 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.') 

639 dic_clus[label][key] = [] 

640 except AssertionError: 

641 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge') 

642 continue 

643 

644 return dic_clus 

645 

646 

647def load_passive_rfmap(eid, one=None): 

648 """ 

649 For a given eid load in the passive receptive field mapping protocol data 

650 

651 Parameters 

652 ---------- 

653 eid : [str, UUID, Path, dict] 

654 Experiment session identifier; may be a UUID, URL, experiment reference string 

655 details dict or Path 

656 one : oneibl.one.OneAlyx, optional 

657 An instance of ONE (may be in 'local' - offline - mode) 

658 

659 Returns 

660 ------- 

661 one.alf.io.AlfBunch 

662 Passive receptive field mapping data 

663 """ 

664 one = one or ONE() 

665 

666 # Load in the receptive field mapping data 

667 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf') 

668 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin', 

669 collection='raw_passive_data'), dtype="uint8") 

670 y_pix, x_pix = 15, 15 

671 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0]) 

672 rf_map['frames'] = frames 

673 

674 return rf_map 

675 

676 

677def load_wheel_reaction_times(eid, one=None): 

678 """ 

679 Return the calculated reaction times for session. Reaction times are defined as the time 

680 between the go cue (onset tone) and the onset of the first substantial wheel movement. A 

681 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the 

682 distance to threshold (~0.1 radians). 

683 

684 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if 

685 there was no detected movement withing the period, or when the goCue_times or feedback_times 

686 are nan. 

687 

688 Parameters 

689 ---------- 

690 eid : [str, UUID, Path, dict] 

691 Experiment session identifier; may be a UUID, URL, experiment reference string 

692 details dict or Path 

693 one : one.api.OneAlyx, optional 

694 one object to use for loading. Will generate internal one if not used, by default None 

695 

696 Returns 

697 ---------- 

698 array-like 

699 reaction times 

700 """ 

701 if one is None: 

702 one = ONE() 

703 

704 trials = one.load_object(eid, 'trials') 

705 # If already extracted, load and return 

706 if trials and 'firstMovement_times' in trials: 

707 return trials['firstMovement_times'] - trials['goCue_times'] 

708 # Otherwise load wheelMoves object and calculate 

709 moves = one.load_object(eid, 'wheelMoves') 

710 # Re-extract wheel moves if necessary 

711 if not moves or 'peakAmplitude' not in moves: 

712 wheel = one.load_object(eid, 'wheel') 

713 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 

714 assert trials and moves, 'unable to load trials and wheelMoves data' 

715 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials) 

716 return firstMove_times - trials['goCue_times'] 

717 

718 

719def load_iti(trials): 

720 """ 

721 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey 

722 screen commencing at stimulus off and lasting until the quiescent period at the start of the 

723 following trial. Note that the ITI for the first trial is the time between the first trial 

724 and the next, therefore the last value is NaN. 

725 

726 Parameters 

727 ---------- 

728 trials : one.alf.io.AlfBunch 

729 An ALF trials object containing the keys {'intervals', 'stimOff_times'}. 

730 

731 Returns 

732 ------- 

733 np.array 

734 An array of inter-trial intervals, the last value being NaN. 

735 """ 

736 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1o

737 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1o

738 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1o

739 

740 

741def load_channels_from_insertion(ins, depths=None, one=None, ba=None): 

742 

743 PROV_2_VAL = { 

744 'Resolved': 90, 

745 'Ephys aligned histology track': 70, 

746 'Histology track': 50, 

747 'Micro-manipulator': 30, 

748 'Planned': 10} 

749 

750 one = one or ONE() 

751 ba = ba or atlas.AllenAtlas() 

752 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id']) 

753 val = [PROV_2_VAL[tr['provenance']] for tr in traj] 

754 idx = np.argmax(val) 

755 traj = traj[idx] 

756 if depths is None: 

757 depths = trace_header(version=1)[:, 1] 

758 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator': 

759 ins = atlas.Insertion.from_dict(traj) 

760 # Deepest coordinate first 

761 xyz = np.c_[ins.tip, ins.entry].T 

762 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

763 TIP_SIZE_UM) / 1e6) 

764 else: 

765 xyz = np.array(ins['json']['xyz_picks']) / 1e6 

766 if traj['provenance'] == 'Histology track': 

767 xyz = xyz[np.argsort(xyz[:, 2]), :] 

768 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

769 TIP_SIZE_UM) / 1e6) 

770 else: 

771 align_key = ins['json']['extended_qc']['alignment_stored'] 

772 feature = traj['json'][align_key][0] 

773 track = traj['json'][align_key][1] 

774 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

775 feature_prev=feature, 

776 brain_atlas=ba, speedy=True) 

777 xyz_channels = ephysalign.get_channel_locations(feature, track) 

778 return xyz_channels 

779 

780 

781@dataclass 

782class SpikeSortingLoader: 

783 """ 

784 Object that will load spike sorting data for a given probe insertion. 

785 This class can be instantiated in several manners 

786 - With Alyx database probe id: 

787 SpikeSortingLoader(pid=pid, one=one) 

788 - With Alyx database eic and probe name: 

789 SpikeSortingLoader(eid=eid, pname='probe00', one=one) 

790 - From a local session and probe name: 

791 SpikeSortingLoader(session_path=session_path, pname='probe00') 

792 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded. 

793 """ 

794 one: One = None 

795 atlas: None = None 

796 pid: str = None 

797 eid: str = '' 

798 pname: str = '' 

799 session_path: Path = '' 

800 # the following properties are the outcome of the post init function 

801 collections: list = None 

802 datasets: list = None # list of all datasets belonging to the session 

803 # the following properties are the outcome of a reading function 

804 files: dict = None 

805 raw_data_files: list = None # list of raw ap and lf files corresponding to the recording 

806 collection: str = '' 

807 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced' 

808 spike_sorter: str = 'pykilosort' 

809 spike_sorting_path: Path = None 

810 _sync: dict = None 

811 revision: str = None 

812 

813 def __post_init__(self): 

814 # pid gets precedence 

815 if self.pid is not None: 1caedk

816 try: 1dk

817 self.eid, self.pname = self.one.pid2eid(self.pid) 1dk

818 except NotImplementedError: 

819 if self.eid == '' or self.pname == '': 

820 raise IOError("Cannot infer session id and probe name from pid. " 

821 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.") 

822 self.session_path = self.one.eid2path(self.eid) 1dk

823 # then eid / pname combination 

824 elif self.session_path is None or self.session_path == '': 1cae

825 self.session_path = self.one.eid2path(self.eid) 1cae

826 # fully local providing a session path 

827 else: 

828 if self.one: 

829 self.eid = self.one.to_eid(self.session_path) 

830 else: 

831 self.one = One(cache_dir=self.session_path.parents[2], mode='local') 

832 df_sessions = cache._make_sessions_df(self.session_path) 

833 self.one._cache['sessions'] = df_sessions.set_index('id') 

834 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False) 

835 self.eid = str(self.session_path.relative_to(self.session_path.parents[2])) 

836 # populates default properties 

837 self.collections = self.one.list_collections( 1caedk

838 self.eid, filename='spikes*', collection=f"alf/{self.pname}*") 

839 self.datasets = self.one.list_datasets(self.eid) 1caedk

840 if self.atlas is None: 1caedk

841 self.atlas = AllenAtlas() 1caek

842 self.files = {} 1caedk

843 self.raw_data_files = [] 1caedk

844 

845 def _load_object(self, *args, **kwargs): 

846 """ 

847 This function is a wrapper around alfio.load_object that will remove the UUID in the 

848 filename if the object is on SDSC. 

849 """ 

850 remove_uuids = getattr(self.one, 'uuid_filenames', False) 1caed

851 d = alfio.load_object(*args, **kwargs) 1caed

852 if remove_uuids: 1caed

853 # pops the UUID in the key names 

854 keys = list(d.keys()) 

855 for k in keys: 

856 d[k[:-37]] = d.pop(k) 

857 return d 1caed

858 

859 @staticmethod 

860 def _get_attributes(dataset_types): 

861 """returns attributes to load for spikes and clusters objects""" 

862 dataset_types = [] if dataset_types is None else dataset_types 1caed

863 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1caed

864 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1caed

865 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1caed

866 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1caed

867 waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] 1caed

868 waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) 1caed

869 return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} 1caed

870 

871 def _get_spike_sorting_collection(self, spike_sorter=None): 

872 """ 

873 Filters a list or array of collections to get the relevant spike sorting dataset 

874 if there is a pykilosort, load it 

875 """ 

876 for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']): 1caed

877 if sorter is None: 1caed

878 continue 

879 if sorter == "": 1caed

880 collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None) 1cae

881 else: 

882 collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None) 1ad

883 if collection is not None: 1caed

884 return collection 1caed

885 # if none is found amongst the defaults, prefers the shortest 

886 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) 

887 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") 

888 return collection 

889 

890 def load_spike_sorting_object(self, obj, *args, revision=None, **kwargs): 

891 """ 

892 Loads an ALF object 

893 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

894 :param spike_sorter: (defaults to 'pykilosort') 

895 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

896 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

897 :param kwargs: additional arguments to be passed to one.api.One.load_object 

898 :param missing: 'raise' (default) or 'ignore' 

899 :param revision: the dataset revision to load 

900 :return: 

901 """ 

902 revision = revision if revision is not None else self.revision 

903 self.download_spike_sorting_object(obj, *args, revision=revision, **kwargs) 

904 return self._load_object(self.files[obj]) 

905 

906 def get_version(self, spike_sorter=None): 

907 spike_sorter = (spike_sorter or self.spike_sorter) or 'iblsorter' 

908 collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 

909 dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy') 

910 return dset[0]['version'] if len(dset) else 'unknown' 

911 

912 def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=None, collection=None, 

913 attribute=None, missing='raise', revision=None, **kwargs): 

914 """ 

915 Downloads an ALF object 

916 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

917 :param spike_sorter: (defaults to 'pykilosort') 

918 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

919 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

920 :param kwargs: additional arguments to be passed to one.api.One.load_object 

921 :param attribute: list of attributes to load for the object 

922 :param missing: 'raise' (default) or 'ignore' 

923 :param revision: the dataset revision to load 

924 :return: 

925 """ 

926 revision = revision if revision is not None else self.revision 1caed

927 if spike_sorter is None: 1caed

928 spike_sorter = self.spike_sorter if self.spike_sorter is not None else 'iblsorter' 1caed

929 if len(self.collections) == 0: 1caed

930 return {}, {}, {} 

931 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1caed

932 collection = collection or self.collection 1caed

933 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1caed

934 attributes = self._get_attributes(dataset_types) 1caed

935 try: 1caed

936 self.files[obj] = self.one.load_object( 1caed

937 self.eid, obj=obj, attribute=attributes.get(obj, None), 

938 collection=collection, download_only=True, revision=revision, **kwargs) 

939 except ALFObjectNotFound as e: 1cae

940 if missing == 'raise': 1cae

941 raise e 

942 

943 def download_spike_sorting(self, objects=None, **kwargs): 

944 """ 

945 Downloads spikes, clusters and channels 

946 :param spike_sorter: (defaults to 'pykilosort') 

947 :param dataset_types: list of extra dataset types 

948 :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels'] 

949 :return: 

950 """ 

951 objects = ['spikes', 'clusters', 'channels'] if objects is None else objects 1caed

952 for obj in objects: 1caed

953 self.download_spike_sorting_object(obj=obj, **kwargs) 1caed

954 self.spike_sorting_path = self.files['clusters'][0].parent 1caed

955 

956 def download_raw_electrophysiology(self, band='ap'): 

957 """ 

958 Downloads raw electrophysiology data files on local disk. 

959 :param band: "ap" (default) or "lf" for LFP band 

960 :return: list of raw data files full paths (ch, meta and cbin files) 

961 """ 

962 raw_data_files = [] 1a

963 for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']: 1a

964 try: 1a

965 # FIXME: this will fail if multiple LFP segments are found 

966 raw_data_files.append(self.one.load_dataset( 1a

967 self.eid, 

968 download_only=True, 

969 collection=f'raw_ephys_data/{self.pname}', 

970 dataset=suffix, 

971 check_hash=False, 

972 )) 

973 except ALFObjectNotFound: 1a

974 _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}") 1a

975 self.raw_data_files = list(set(self.raw_data_files + raw_data_files)) 1a

976 return raw_data_files 1a

977 

978 def raw_electrophysiology(self, stream=True, band='ap', **kwargs): 

979 """ 

980 Returns a reader for the raw electrophysiology data 

981 By default it is a streamer object, but if stream is False, it will return a spikeglx.Reader after having 

982 downloaded the raw data file if necessary 

983 :param stream: 

984 :param band: 

985 :param kwargs: 

986 :return: 

987 """ 

988 if stream: 1k

989 return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) 1k

990 else: 

991 raw_data_files = self.download_raw_electrophysiology(band=band) 

992 cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None) 

993 if cbin_file is not None: 

994 return spikeglx.Reader(cbin_file) 

995 

996 def download_raw_waveforms(self, **kwargs): 

997 """ 

998 Downloads raw waveforms extracted from sorting to local disk. 

999 """ 

1000 _logger.debug(f"loading waveforms from {self.collection}") 

1001 return self.one.load_object( 

1002 id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"], 

1003 collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs 

1004 ) 

1005 

1006 def raw_waveforms(self, **kwargs): 

1007 wf_paths = self.download_raw_waveforms(**kwargs) 

1008 return WaveformsLoader(wf_paths[0].parent) 

1009 

1010 def load_channels(self, **kwargs): 

1011 """ 

1012 Loads channels 

1013 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

1014 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

1015 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

1016 - resolved: channel locations alignments have been agreed upon 

1017 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

1018 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

1019 

1020 :param spike_sorter: (defaults to 'pykilosort') 

1021 :param dataset_types: list of extra dataset types 

1022 :return: 

1023 """ 

1024 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting 

1025 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1caed

1026 self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) 1caed

1027 channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) 1caed

1028 if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails 1caed

1029 channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1d

1030 if alfio.check_dimensions(channels) != 0: 1d

1031 channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 

1032 channels['rawInd'] = np.arange(channels[list(channels.keys())[0]].shape[0]) 

1033 if 'brainLocationIds_ccf_2017' not in channels: 1caed

1034 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1a

1035 _channels, self.histology = _load_channel_locations_traj( 1a

1036 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True) 

1037 if _channels: 1a

1038 channels = _channels[self.pname] 

1039 else: 

1040 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1caed

1041 self.histology = 'alf' 1caed

1042 return Bunch(channels) 1caed

1043 

1044 @staticmethod 

1045 def filter_files_by_namespace(all_files, namespace): 

1046 

1047 # Create dict for each file with available namespaces, no namespce is stored under the key None 

1048 namespace_files = defaultdict(dict) 1caed

1049 available_namespaces = [] 1caed

1050 for file in all_files: 1caed

1051 fparts = filename_parts(file.name, as_dict=True) 1caed

1052 fname = f"{fparts['object']}.{fparts['attribute']}" 1caed

1053 nspace = fparts['namespace'] 1caed

1054 available_namespaces.append(nspace) 1caed

1055 namespace_files[fname][nspace] = file 1caed

1056 

1057 if namespace not in set(available_namespaces): 1caed

1058 _logger.info(f'Could not find manual curation results for {namespace}, returning default' 1a

1059 f' non manually curated spikesorting data') 

1060 

1061 # Return the files with the chosen namespace. 

1062 files = [f.get(namespace, f.get(None, None)) for f in namespace_files.values()] 1caed

1063 # remove any None files 

1064 files = [f for f in files if f] 1caed

1065 return files 1caed

1066 

1067 def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, 

1068 namespace=None, **kwargs): 

1069 """ 

1070 Loads spikes, clusters and channels 

1071 

1072 There could be several spike sorting collections, by default the loader will get the pykilosort collection 

1073 

1074 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

1075 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

1076 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

1077 - resolved: channel locations alignments have been agreed upon 

1078 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

1079 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

1080 

1081 :param spike_sorter: (defaults to 'pykilosort') 

1082 :param revision: for example "2024-05-06", (defaults to None): 

1083 :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one 

1084 :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates'] 

1085 :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table 

1086 :param namespace: None, if given will load the manually curated spikesorting with the given namespace, 

1087 e.g to load '_av_.clusters.depths use namespace='av' 

1088 :param kwargs: additional arguments to be passed to one.api.One.load_object 

1089 :return: 

1090 """ 

1091 if len(self.collections) == 0: 1caed

1092 return {}, {}, {} 

1093 self.files = {} 1caed

1094 self.spike_sorter = spike_sorter 1caed

1095 self.revision = revision 1caed

1096 

1097 if good_units and namespace is not None: 1caed

1098 _logger.info('Good units table does not exist for manually curated spike sorting. Pass in namespace with' 1a

1099 'good_units=False and filter the spikes post hoc by the good clusters.') 

1100 return [None] * 3 1a

1101 objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None 1caed

1102 self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) 1caed

1103 channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs) 1caed

1104 self.files['clusters'] = self.filter_files_by_namespace(self.files['clusters'], namespace) 1caed

1105 clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) 1caed

1106 

1107 if good_units: 1caed

1108 spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards) 

1109 else: 

1110 self.files['spikes'] = self.filter_files_by_namespace(self.files['spikes'], namespace) 1caed

1111 spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) 1caed

1112 if enforce_version: 1caed

1113 self._assert_version_consistency() 

1114 return spikes, clusters, channels 1caed

1115 

1116 def _assert_version_consistency(self): 

1117 """ 

1118 Makes sure the state of the spike sorting object matches the files downloaded 

1119 :return: None 

1120 """ 

1121 for k in ['spikes', 'clusters', 'channels', 'passingSpikes']: 

1122 for fn in self.files.get(k, []): 

1123 if self.spike_sorter: 

1124 assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \ 

1125 f"You required strict version {self.spike_sorter}, {fn} does not match" 

1126 if self.revision: 

1127 assert full_path_parts(fn)[5] == self.revision, \ 

1128 f"You required strict revision {self.revision}, {fn} does not match" 

1129 

1130 @staticmethod 

1131 def compute_metrics(spikes, clusters=None): 

1132 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size 

1133 metrics = pd.DataFrame(quick_unit_metrics( 

1134 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc))) 

1135 return metrics 

1136 

1137 @staticmethod 

1138 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False): 

1139 """ 

1140 Merge the metrics and the channel information into the clusters dictionary 

1141 :param spikes: 

1142 :param clusters: 

1143 :param channels: 

1144 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used 

1145 for clusters or analysis applications (defaults to None). 

1146 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false) 

1147 :return: cluster dictionary containing metrics and histology 

1148 """ 

1149 if spikes == {}: 1ad

1150 return 

1151 nc = clusters['channels'].size 1ad

1152 # recompute metrics if they are not available 

1153 metrics = None 1ad

1154 if 'metrics' in clusters: 1ad

1155 metrics = clusters.pop('metrics') 1ad

1156 if metrics.shape[0] != nc: 1ad

1157 metrics = None 

1158 if metrics is None or compute_metrics is True: 1ad

1159 _logger.debug("recompute clusters metrics") 

1160 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters) 

1161 if isinstance(cache_dir, Path): 

1162 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt')) 

1163 for k in metrics.keys(): 1ad

1164 clusters[k] = metrics[k].to_numpy() 1ad

1165 for k in channels.keys(): 1ad

1166 clusters[k] = channels[k][clusters['channels']] 1ad

1167 if cache_dir is not None: 1ad

1168 _logger.debug(f'caching clusters metrics in {cache_dir}') 

1169 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt')) 

1170 return clusters 1ad

1171 

1172 @property 

1173 def url(self): 

1174 """Gets flatiron URL for the session""" 

1175 webclient = getattr(self.one, '_web_client', None) 

1176 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None 

1177 

1178 def _get_probe_info(self, revision=None): 

1179 revision = revision if revision is not None else self.revision 1e

1180 if self._sync is None: 1e

1181 timestamps = self.one.load_dataset( 1e

1182 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision) 

1183 _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks 1e

1184 self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision) 

1185 try: 1e

1186 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1e

1187 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}')) 

1188 fs = spikeglx._get_fs_from_meta(ap_meta) 1e

1189 except ALFObjectNotFound: 

1190 ap_meta = None 

1191 fs = 30_000 

1192 self._sync = { 1e

1193 'timestamps': timestamps, 

1194 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'), 

1195 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'), 

1196 'ap_meta': ap_meta, 

1197 'fs': fs, 

1198 } 

1199 

1200 def timesprobe2times(self, values, direction='forward'): 

1201 self._get_probe_info() 

1202 if direction == 'forward': 

1203 return self._sync['forward'](values * self._sync['fs']) 

1204 elif direction == 'reverse': 

1205 return self._sync['reverse'](values) / self._sync['fs'] 

1206 

1207 def samples2times(self, values, direction='forward'): 

1208 """ 

1209 Converts ephys sample values to session main clock seconds 

1210 :param values: numpy array of times in seconds or samples to resync 

1211 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse' 

1212 (seconds main time to samples probe time) 

1213 :return: 

1214 """ 

1215 self._get_probe_info() 1e

1216 return self._sync[direction](values) 1e

1217 

1218 @property 

1219 def pid2ref(self): 

1220 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1ca

1221 

1222 def _default_plot_title(self, spikes): 

1223 title = f"{self.pid2ref}, {self.pid} \n" \ 1c

1224 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters" 

1225 return title 1c

1226 

1227 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, 

1228 drift=None, title=None, **kwargs): 

1229 """ 

1230 :param spikes: spikes dictionary or Bunch 

1231 :param channels: channels dictionary or Bunch. 

1232 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png". 

1233 Otherwise, plot. 

1234 :param br: brain regions object (optional) 

1235 :param label: label for saved image (optional, default="raster") 

1236 :param time_series: timeseries dictionary for behavioral event times (optional) 

1237 :param **kwargs: kwargs passed to `driftmap()` (optional) 

1238 :return: 

1239 """ 

1240 br = br or BrainRegions() 1c

1241 time_series = time_series or {} 1c

1242 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c

1243 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col') 

1244 axs[0, 1].set_axis_off() 1c

1245 # axs[0, 0].set_xticks([]) 

1246 if kwargs is None: 1c

1247 # set default raster plot parameters 

1248 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5} 

1249 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c

1250 if title is None: 1c

1251 title = self._default_plot_title(spikes) 1c

1252 axs[0, 0].title.set_text(title) 1c

1253 for k, ts in time_series.items(): 1c

1254 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0]) 

1255 if 'atlas_id' in channels: 1c

1256 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c

1257 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology) 

1258 axs[1, 0].set_ylim(0, 3800) 1c

1259 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c

1260 fig.tight_layout() 1c

1261 

1262 if drift is None: 1c

1263 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c

1264 if 'drift' in self.files: 1c

1265 drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards) 

1266 if isinstance(drift, dict): 1c

1267 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5) 

1268 axs[0, 0].set(ylim=[-15, 15]) 

1269 

1270 if save_dir is not None: 1c

1271 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) 

1272 fig.savefig(png_file) 

1273 plt.close(fig) 

1274 gc.collect() 

1275 else: 

1276 return fig, axs 1c

1277 

1278 def plot_rawdata_snippet(self, sr, spikes, clusters, t0, 

1279 channels=None, 

1280 br: BrainRegions = None, 

1281 save_dir=None, 

1282 label='raster', 

1283 gain=-93, 

1284 title=None): 

1285 

1286 # compute the raw data offset and destripe, we take 400ms around t0 

1287 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs)) 

1288 raw = sr[first_sample:last_sample, :-sr.nsync].T 

1289 channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True 

1290 destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels) 

1291 # filter out the spikes according to good/bad clusters and to the time slice 

1292 spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample])) 

1293 ss = spikes['samples'][spike_sel] 

1294 sc = clusters['channels'][spikes['clusters'][spike_sel]] 

1295 sok = clusters['label'][spikes['clusters'][spike_sel]] == 1 

1296 if title is None: 

1297 title = self._default_plot_title(spikes) 

1298 # display the raw data snippet with spikes overlaid 

1299 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col') 

1300 Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s') 

1301 axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5) 

1302 axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5) 

1303 axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035]) 

1304 # adds the channel locations if available 

1305 if (channels is not None) and ('atlas_id' in channels): 

1306 br = br or BrainRegions() 

1307 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 

1308 brain_regions=br, display=True, ax=axs[1], title=self.histology) 

1309 axs[1].get_yaxis().set_visible(False) 

1310 fig.tight_layout() 

1311 

1312 if save_dir is not None: 

1313 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) 

1314 fig.savefig(png_file) 

1315 plt.close(fig) 

1316 gc.collect() 

1317 else: 

1318 return fig, axs 

1319 

1320 

1321@dataclass 

1322class SessionLoader: 

1323 """ 

1324 Object to load session data for a give session in the recommended way. 

1325 

1326 Parameters 

1327 ---------- 

1328 one: one.api.ONE instance 

1329 Can be in remote or local mode (required) 

1330 session_path: string or pathlib.Path 

1331 The absolute path to the session (one of session_path or eid is required) 

1332 eid: string 

1333 database UUID of the session (one of session_path or eid is required) 

1334 

1335 If both are provided, session_path takes precedence over eid. 

1336 

1337 Examples 

1338 -------- 

1339 1) Load all available session data for one session: 

1340 >>> from one.api import ONE 

1341 >>> from brainbox.io.one import SessionLoader 

1342 >>> one = ONE() 

1343 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/') 

1344 # Object is initiated, but no data is loaded as you can see in the data_info attribute 

1345 >>> sess_loader.data_info 

1346 name is_loaded 

1347 0 trials False 

1348 1 wheel False 

1349 2 pose False 

1350 3 motion_energy False 

1351 4 pupil False 

1352 

1353 # Loading all available session data, the data_info attribute now shows which data has been loaded 

1354 >>> sess_loader.load_session_data() 

1355 >>> sess_loader.data_info 

1356 name is_loaded 

1357 0 trials True 

1358 1 wheel True 

1359 2 pose True 

1360 3 motion_energy True 

1361 4 pupil False 

1362 

1363 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g. 

1364 >>> type(sess_loader.trials) 

1365 pandas.core.frame.DataFrame 

1366 >>> sess_loader.trials.shape 

1367 (626, 18) 

1368 # Each data comes with its own timestamps in a column called 'times' 

1369 >>> sess_loader.wheel['times'] 

1370 0 0.134286 

1371 1 0.135286 

1372 2 0.136286 

1373 3 0.137286 

1374 4 0.138286 

1375 ... 

1376 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera. 

1377 # The dataframes of all cameras are collected in a dictionary 

1378 >>> type(sess_loader.pose) 

1379 dict 

1380 >>> sess_loader.pose.keys() 

1381 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera']) 

1382 >>> sess_loader.pose['bodyCamera'].columns 

1383 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object') 

1384 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading 

1385 functions: 

1386 >>> sess_loader.load_wheel(sampling_rate=100) 

1387 """ 

1388 one: One = None 

1389 session_path: Path = '' 

1390 eid: str = '' 

1391 revision: str = '' 

1392 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1393 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1394 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1395 pose: dict = field(default_factory=dict, repr=False) 

1396 motion_energy: dict = field(default_factory=dict, repr=False) 

1397 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1398 

1399 def __post_init__(self): 

1400 """ 

1401 Function that runs automatically after initiation of the dataclass attributes. 

1402 Checks for required inputs, sets session_path and eid, creates data_info table. 

1403 """ 

1404 if self.one is None: 1bf

1405 raise ValueError("An input to one is required. If not connection to a database is desired, it can be " 

1406 "a fully local instance of One.") 

1407 # If session path is given, takes precedence over eid 

1408 if self.session_path is not None and self.session_path != '': 1bf

1409 self.eid = self.one.to_eid(self.session_path) 1bf

1410 self.session_path = Path(self.session_path) 1bf

1411 # Providing no session path, try to infer from eid 

1412 else: 

1413 if self.eid is not None and self.eid != '': 

1414 self.session_path = self.one.eid2path(self.eid) 

1415 else: 

1416 raise ValueError("If no session path is given, eid is required.") 

1417 

1418 data_names = [ 1bf

1419 'trials', 

1420 'wheel', 

1421 'pose', 

1422 'motion_energy', 

1423 'pupil' 

1424 ] 

1425 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1bf

1426 

1427 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False): 

1428 """ 

1429 Function to load available session data into the SessionLoader object. Input parameters allow to control which 

1430 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input 

1431 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored 

1432 in SessionLoader.data_info 

1433 

1434 Parameters 

1435 ---------- 

1436 trials: boolean 

1437 Whether to load all trials data into SessionLoader.trials, default is True 

1438 wheel: boolean 

1439 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True 

1440 pose: boolean 

1441 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose, 

1442 default is True 

1443 motion_energy: boolean 

1444 Whether to load motion energy data (whisker pad for left/right camera, body for body camera) 

1445 into SessionLoader.motion_energy, default is True 

1446 pupil: boolean 

1447 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil, 

1448 default is True 

1449 reload: boolean 

1450 Whether to reload data that has already been loaded into this SessionLoader object, default is False 

1451 """ 

1452 load_df = self.data_info.copy() 1f

1453 load_df['to_load'] = [ 1f

1454 trials, 

1455 wheel, 

1456 pose, 

1457 motion_energy, 

1458 pupil 

1459 ] 

1460 load_df['load_func'] = [ 1f

1461 self.load_trials, 

1462 self.load_wheel, 

1463 self.load_pose, 

1464 self.load_motion_energy, 

1465 self.load_pupil 

1466 ] 

1467 

1468 for idx, row in load_df.iterrows(): 1f

1469 if row['to_load'] is False: 1f

1470 _logger.debug(f"Not loading {row['name']} data, set to False.") 

1471 elif row['is_loaded'] is True and reload is False: 1f

1472 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1f

1473 else: 

1474 try: 1f

1475 _logger.info(f"Loading {row['name']} data") 1f

1476 row['load_func']() 1f

1477 self.data_info.loc[idx, 'is_loaded'] = True 1f

1478 except BaseException as e: 

1479 _logger.warning(f"Could not load {row['name']} data.") 

1480 _logger.debug(e) 

1481 

1482 def _find_behaviour_collection(self, obj): 

1483 """ 

1484 Function to find the trial or wheel collection 

1485 

1486 Parameters 

1487 ---------- 

1488 obj: str 

1489 Alf object to load, either 'trials' or 'wheel' 

1490 """ 

1491 dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy' 1fnj

1492 dsets = self.one.list_datasets(self.eid, dataset) 1fnj

1493 if len(dsets) == 0: 1fnj

1494 return 'alf' 1fn

1495 else: 

1496 collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets] 1fj

1497 if len(set(collections)) == 1: 1fj

1498 return collections[0] 1fj

1499 else: 

1500 _logger.error(f'Multiple collections found {collections}. Specify collection when loading, ' 

1501 f'e.g sl.load_{obj}(collection="{collections[0]}")') 

1502 raise ALFMultipleCollectionsFound 

1503 

1504 def load_trials(self, collection=None): 

1505 """ 

1506 Function to load trials data into SessionLoader.trials 

1507 

1508 Parameters 

1509 ---------- 

1510 collection: str 

1511 Alf collection of trials data 

1512 """ 

1513 

1514 if not collection: 1fn

1515 collection = self._find_behaviour_collection('trials') 1fn

1516 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex 

1517 self.one.wildcards = False 1fn

1518 self.trials = self.one.load_object( 1fn

1519 self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df() 

1520 self.one.wildcards = True 1fn

1521 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1fn

1522 

1523 def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None): 

1524 """ 

1525 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position 

1526 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which 

1527 a Butterworth low-pass filter is applied. 

1528 

1529 Parameters 

1530 ---------- 

1531 fs: int, float 

1532 Sampling frequency for the wheel position, default is 1000 Hz 

1533 corner_frequency: int, float 

1534 Corner frequency of Butterworth low-pass filter, default is 20 

1535 order: int, float 

1536 Order of Butterworth low_pass filter, default is 8 

1537 collection: str 

1538 Alf collection of wheel data 

1539 """ 

1540 if not collection: 1fj

1541 collection = self._find_behaviour_collection('wheel') 1fj

1542 wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection, revision=self.revision or None) 1fj

1543 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1fj

1544 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps") 

1545 # resample the wheel position and compute velocity, acceleration 

1546 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1fj

1547 self.wheel['position'], self.wheel['times'] = interpolate_position( 1fj

1548 wheel_raw['timestamps'], wheel_raw['position'], freq=fs) 

1549 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1fj

1550 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order) 

1551 self.wheel = self.wheel.apply(np.float32) 1fj

1552 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1fj

1553 

1554 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']): 

1555 """ 

1556 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a 

1557 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas 

1558 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera. 

1559 

1560 Parameters 

1561 ---------- 

1562 likelihood_thr: float 

1563 The position of each tracked body part come with a likelihood of that estimate for each time point. 

1564 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set 

1565 likelihood_thr=1. Default is 0.9 

1566 views: list 

1567 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1568 """ 

1569 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1570 self.pose = {} 1mhf

1571 for view in views: 1mhf

1572 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'], revision=self.revision or None) 1mhf

1573 # Double check if video timestamps are correct length or can be fixed 

1574 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1mhf

1575 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1mhf

1576 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1mhf

1577 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1mhf

1578 

1579 def load_motion_energy(self, views=['left', 'right', 'body']): 

1580 """ 

1581 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a 

1582 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are 

1583 pandas Dataframes with the timestamps and motion energy data. 

1584 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad 

1585 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the 

1586 body (bodyMotionEnergy). 

1587 

1588 Parameters 

1589 ---------- 

1590 views: list 

1591 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1592 """ 

1593 names = {'left': 'whiskerMotionEnergy', 1lf

1594 'right': 'whiskerMotionEnergy', 

1595 'body': 'bodyMotionEnergy'} 

1596 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1597 self.motion_energy = {} 1lf

1598 for view in views: 1lf

1599 me_raw = self.one.load_object( 1lf

1600 self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None) 

1601 # Double check if video timestamps are correct length or can be fixed 

1602 times_fixed, motion_energy = self._check_video_timestamps( 1lf

1603 view, me_raw['times'], me_raw['ROIMotionEnergy']) 

1604 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1lf

1605 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1lf

1606 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1lf

1607 

1608 def load_licks(self): 

1609 """ 

1610 Not yet implemented 

1611 """ 

1612 pass 

1613 

1614 def load_pupil(self, snr_thresh=5.): 

1615 """ 

1616 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil. 

1617 

1618 Parameters 

1619 ---------- 

1620 snr_thresh: float 

1621 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data 

1622 will be considered unusable and will be discarded. 

1623 """ 

1624 # Try to load from features 

1625 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'], revision=self.revision or None) 1hf

1626 if 'features' in feat_raw.keys(): 1hf

1627 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features']) 

1628 self.pupil = feats.copy() 

1629 self.pupil.insert(0, 'times', times_fixed) 

1630 

1631 # If unavailable compute on the fly 

1632 else: 

1633 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1hf

1634 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1hf

1635 and 'leftCamera' in self.pose.keys()): 

1636 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt 

1637 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1hf

1638 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1hf

1639 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1hf

1640 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1hf

1641 else: 

1642 self.load_pose(views=['left'], likelihood_thr=0.9) 

1643 dlc_thr = self.pose['leftCamera'].copy() 

1644 

1645 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1hf

1646 try: 1hf

1647 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1hf

1648 except BaseException as e: 

1649 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. " 

1650 "Saving all NaNs for pupilDiameter_smooth.") 

1651 _logger.debug(e) 

1652 self.pupil['pupilDiameter_smooth'] = np.nan 

1653 

1654 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1hf

1655 good_idxs = np.where( 1hf

1656 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] 

1657 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1hf

1658 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]))) 

1659 if snr < snr_thresh: 1hf

1660 self.pupil = pd.DataFrame() 1h

1661 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h

1662 

1663 def _check_video_timestamps(self, view, video_timestamps, video_data): 

1664 """ 

1665 Helper function to check for the length of the video frames vs video timestamps and fix in case 

1666 timestamps are longer than video frames. 

1667 """ 

1668 # If camera times are shorter than video data, or empty, no current fix 

1669 if video_timestamps.shape[0] < video_data.shape[0]: 1lmhf

1670 if video_timestamps.shape[0] == 0: 

1671 msg = f'Camera times empty for {view}Camera.' 

1672 else: 

1673 msg = f'Camera times are shorter than video data for {view}Camera.' 

1674 _logger.warning(msg) 

1675 raise ValueError(msg) 

1676 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video. 

1677 # This is because the first few frames are sometimes not recorded. We can remove the first few 

1678 # timestamps in this case 

1679 elif video_timestamps.shape[0] > video_data.shape[0]: 1lmhf

1680 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1lmhf

1681 return video_timestamps_fixed, video_data 1lmhf

1682 else: 

1683 return video_timestamps, video_data 

1684 

1685 

1686class EphysSessionLoader(SessionLoader): 

1687 """ 

1688 Spike sorting enhanced version of SessionLoader 

1689 Loads spike sorting data for all probes in the session, in the self.ephys dict 

1690 >>> EphysSessionLoader(eid=eid, one=one) 

1691 To select for a specific probe 

1692 >>> EphysSessionLoader(eid=eid, one=one, pid=pid) 

1693 """ 

1694 def __init__(self, *args, pname=None, pid=None, **kwargs): 

1695 """ 

1696 Needs an active connection in order to get the list of insertions in the session 

1697 :param args: 

1698 :param kwargs: 

1699 """ 

1700 super().__init__(*args, **kwargs) 

1701 # if necessary, restrict the query 

1702 qargs = {} if pname is None else {'name': pname} 

1703 qargs = qargs or ({} if pid is None else {'id': pid}) 

1704 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs) 

1705 self.ephys = {} 

1706 for ins in insertions: 

1707 self.ephys[ins['name']] = {} 

1708 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one) 

1709 

1710 def load_session_data(self, *args, **kwargs): 

1711 super().load_session_data(*args, **kwargs) 

1712 self.load_spike_sorting() 

1713 

1714 def load_spike_sorting(self, pnames=None): 

1715 pnames = pnames or list(self.ephys.keys()) 

1716 for pname in pnames: 

1717 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting() 

1718 self.ephys[pname]['spikes'] = spikes 

1719 self.ephys[pname]['clusters'] = clusters 

1720 self.ephys[pname]['channels'] = channels 

1721 

1722 @property 

1723 def probes(self): 

1724 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}