Coverage for brainbox/io/one.py: 57%

768 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-07 14:26 +0100

1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment.""" 

2from dataclasses import dataclass, field 

3import gc 

4import logging 

5import re 

6import os 

7from pathlib import Path 

8from collections import defaultdict 

9 

10import numpy as np 

11import pandas as pd 

12from scipy.interpolate import interp1d 

13import matplotlib.pyplot as plt 

14 

15from one.api import ONE, One 

16from one.alf.path import get_alf_path, full_path_parts, filename_parts 

17from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound 

18from one.alf import cache 

19import one.alf.io as alfio 

20from neuropixel import TIP_SIZE_UM, trace_header 

21import spikeglx 

22 

23import ibldsp.voltage 

24from ibldsp.waveform_extraction import WaveformsLoader 

25from iblutil.util import Bunch 

26from iblatlas.atlas import AllenAtlas, BrainRegions 

27from iblatlas import atlas 

28from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times 

29from ibllib.pipes import histology 

30from ibllib.pipes.ephys_alignment import EphysAlignment 

31from ibllib.plots import vertical_lines, Density 

32 

33import brainbox.plot 

34from brainbox.io.spikeglx import Streamer 

35from brainbox.ephys_plots import plot_brain_regions 

36from brainbox.metrics.single_units import quick_unit_metrics 

37from brainbox.behavior.wheel import interpolate_position, velocity_filtered 

38from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter 

39 

40_logger = logging.getLogger('ibllib') 

41 

42 

43SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths'] 

44CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids'] 

45WAVEFORMS_ATTRIBUTES = ['templates'] 

46 

47 

48def load_lfp(eid, one=None, dataset_types=None, **kwargs): 

49 """ 

50 TODO Verify works 

51 From an eid, hits the Alyx database and downloads the standard set of datasets 

52 needed for LFP 

53 :param eid: 

54 :param dataset_types: additional dataset types to add to the list 

55 :param open: if True, spikeglx readers are opened 

56 :return: spikeglx.Reader 

57 """ 

58 if dataset_types is None: 

59 dataset_types = [] 

60 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*'] 

61 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes] 

62 session_path = one.eid2path(eid) 

63 

64 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False) 

65 if ef.get('lf', None)] 

66 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles] 

67 

68 

69def _collection_filter_from_args(probe, spike_sorter=None): 

70 collection = f'alf/{probe}/{spike_sorter}' 1g

71 collection = collection.replace('None', '*') 1g

72 collection = collection.replace('/*', '*') 1g

73 collection = collection[:-1] if collection.endswith('/') else collection 1g

74 return collection 1g

75 

76 

77def _get_spike_sorting_collection(collections, pname): 

78 """ 

79 Filters a list or array of collections to get the relevant spike sorting dataset 

80 if there is a pykilosort, load it 

81 """ 

82 # 

83 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1ga

84 # otherwise, prefers the shortest 

85 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1ga

86 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1ga

87 return collection 1ga

88 

89 

90def _channels_alyx2bunch(chans): 

91 channels = Bunch({ 

92 'atlas_id': np.array([ch['brain_region'] for ch in chans]), 

93 'x': np.array([ch['x'] for ch in chans]) / 1e6, 

94 'y': np.array([ch['y'] for ch in chans]) / 1e6, 

95 'z': np.array([ch['z'] for ch in chans]) / 1e6, 

96 'axial_um': np.array([ch['axial'] for ch in chans]), 

97 'lateral_um': np.array([ch['lateral'] for ch in chans]) 

98 }) 

99 return channels 

100 

101 

102def _channels_traj2bunch(xyz_chans, brain_atlas): 

103 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans)) 

104 channels = { 

105 'x': xyz_chans[:, 0], 

106 'y': xyz_chans[:, 1], 

107 'z': xyz_chans[:, 2], 

108 'acronym': brain_regions['acronym'], 

109 'atlas_id': brain_regions['id'] 

110 } 

111 

112 return channels 

113 

114 

115def _channels_bunch2alf(channels): 

116 channels_ = { 1i

117 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6, 

118 'brainLocationIds_ccf_2017': channels['atlas_id'], 

119 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]} 

120 return channels_ 1i

121 

122 

123def _channels_alf2bunch(channels, brain_regions=None): 

124 # reformat the dictionary according to the standard that comes out of Alyx 

125 channels_ = { 1icaed

126 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6, 

127 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6, 

128 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6, 

129 'acronym': None, 

130 'atlas_id': channels['brainLocationIds_ccf_2017'], 

131 'axial_um': channels['localCoordinates'][:, 1], 

132 'lateral_um': channels['localCoordinates'][:, 0], 

133 } 

134 # here if we have some extra keys, they will carry over to the next dictionary 

135 for k in channels: 1icaed

136 if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: 1icaed

137 channels_[k] = channels[k] 1icaed

138 if brain_regions: 1icaed

139 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icaed

140 return channels_ 1icaed

141 

142 

143def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None, 

144 brain_regions=None): 

145 """ 

146 Generic function to load spike sorting according data using ONE. 

147 

148 Will try to load one spike sorting for any probe present for the eid matching the collection 

149 For each probe it will load a spike sorting: 

150 - if there is one version: loads this one 

151 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX) 

152 

153 Parameters 

154 ---------- 

155 eid : [str, UUID, Path, dict] 

156 Experiment session identifier; may be a UUID, URL, experiment reference string 

157 details dict or Path 

158 one : one.api.OneAlyx 

159 An instance of ONE (may be in 'local' mode) 

160 collection : str 

161 collection filter word - accepts wildcards - can be a combination of spike sorter and 

162 probe. See `ALF documentation`_ for details. 

163 revision : str 

164 A particular revision return (defaults to latest revision). See `ALF documentation`_ for 

165 details. 

166 return_channels : bool 

167 Defaults to False otherwise loads channels from disk 

168 

169 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components 

170 

171 Returns 

172 ------- 

173 spikes : dict of one.alf.io.AlfBunch 

174 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

175 session and spike sorter, with keys ('clusters', 'times') 

176 clusters : dict of one.alf.io.AlfBunch 

177 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

178 ('channels', 'depths', 'metrics') 

179 channels : dict of one.alf.io.AlfBunch 

180 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

181 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs 

182 non-lateralized. 

183 """ 

184 one = one or ONE() 1ga

185 # enumerate probes and load according to the name 

186 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1ga

187 if len(collections) == 0: 1ga

188 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

189 pnames = list(set(c.split('/')[1] for c in collections)) 1ga

190 spikes, clusters, channels = ({} for _ in range(3)) 1ga

191 

192 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1ga

193 

194 for pname in pnames: 1ga

195 probe_collection = _get_spike_sorting_collection(collections, pname) 1ga

196 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1ga

197 attribute=spike_attributes, namespace='') 

198 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1ga

199 attribute=cluster_attributes, namespace='') 

200 if return_channels: 1ga

201 channels = _load_channels_locations_from_disk( 1a

202 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions) 

203 return spikes, clusters, channels 1a

204 else: 

205 return spikes, clusters 1g

206 

207 

208def _get_attributes(dataset_types): 

209 if dataset_types is None: 1ga

210 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1ga

211 else: 

212 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 

213 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 

214 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 

215 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 

216 return spike_attributes, cluster_attributes 

217 

218 

219def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None): 

220 _logger.debug('loading spike sorting from disk') 1a

221 channels = Bunch({}) 1a

222 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) 1a

223 if len(collections) == 0: 1a

224 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

225 probes = list(set([c.split('/')[1] for c in collections])) 1a

226 for probe in probes: 1a

227 probe_collection = _get_spike_sorting_collection(collections, probe) 1a

228 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') 1a

229 # if the spike sorter has not aligned data, try and get the alignment available 

230 if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): 1a

231 aligned_channel_collections = one.list_collections( 

232 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision) 

233 if len(aligned_channel_collections) == 0: 

234 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}") 

235 continue 

236 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}") 

237 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe) 

238 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection) 

239 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe]) 

240 # only have to reformat channels if we were able to load coordinates from disk 

241 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) 1a

242 return channels 1a

243 

244 

245def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None): 

246 """ 

247 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto 

248 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field 

249 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts 

250 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys 

251 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017' 

252 OR 

253 'x', 'y', 'z', 'acronym', 'axial_um' 

254 those are the guide for the interpolation 

255 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates' 

256 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object 

257 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017 

258 if a brain region object is provided, outputts a dict with keys 

259 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um' 

260 :return: Bunch or dictionary of channels with brain coordinates keys 

261 """ 

262 NEUROPIXEL_VERSION = 1 1i

263 h = trace_header(version=NEUROPIXEL_VERSION) 1i

264 if channels is None: 1i

265 channels = {'localCoordinates': np.c_[h['x'], h['y']]} 

266 nch = channels['localCoordinates'].shape[0] 1i

267 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i

268 channels_aligned = _channels_bunch2alf(channels_aligned) 1i

269 if 'localCoordinates' in channels_aligned.keys(): 1i

270 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i

271 else: # this is a edge case for a few spike sorting sessions 

272 assert channels_aligned['mlapdv'].shape[0] == 384 

273 aligned_depths = h['y'] 

274 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i

275 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i

276 channels['mlapdv'] = np.zeros((nch, 3)) 1i

277 for i in np.arange(3): 1i

278 channels['mlapdv'][:, i] = np.interp( 1i

279 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv] 

280 # the brain locations have to be interpolated by nearest neighbour 

281 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i

282 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i

283 if brain_regions is not None: 1i

284 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i

285 else: 

286 return channels 1i

287 

288 

289def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False, 

290 brain_atlas=None, return_source=False): 

291 if not hasattr(one, 'alyx'): 1a

292 return {}, None 1a

293 _logger.debug(f"trying to load from traj {probe}") 

294 channels = Bunch() 

295 brain_atlas = brain_atlas or AllenAtlas 

296 # need to find the collection bruh 

297 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0] 

298 collection = _collection_filter_from_args(probe=probe) 

299 collections = one.list_collections(eid, filename='channels*', collection=collection, 

300 revision=revision) 

301 probe_collection = _get_spike_sorting_collection(collections, probe) 

302 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection) 

303 depths = chn_coords[:, 1] 

304 

305 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

306 get('tracing_exists', False) 

307 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

308 get('alignment_resolved', False) 

309 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

310 get('alignment_count', 0) 

311 

312 if tracing: 

313 xyz = np.array(insertion['json']['xyz_picks']) / 1e6 

314 if resolved: 

315 

316 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. ' 

317 f'Channel and cluster locations obtained from ephys aligned histology ' 

318 f'track.') 

319 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

320 provenance='Ephys aligned histology track')[0] 

321 align_key = insertion['json']['extended_qc']['alignment_stored'] 

322 feature = traj['json'][align_key][0] 

323 track = traj['json'][align_key][1] 

324 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

325 feature_prev=feature, 

326 brain_atlas=brain_atlas, speedy=True) 

327 chans = ephysalign.get_channel_locations(feature, track) 

328 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

329 source = 'resolved' 

330 elif counts > 0 and aligned: 

331 _logger.debug(f'Channel locations for {eid}/{probe} have not been ' 

332 f'resolved. However, alignment flag set to True so channel and cluster' 

333 f' locations will be obtained from latest available ephys aligned ' 

334 f'histology track.') 

335 # get the latest user aligned channels 

336 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

337 provenance='Ephys aligned histology track')[0] 

338 align_key = insertion['json']['extended_qc']['alignment_stored'] 

339 feature = traj['json'][align_key][0] 

340 track = traj['json'][align_key][1] 

341 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

342 feature_prev=feature, 

343 brain_atlas=brain_atlas, speedy=True) 

344 chans = ephysalign.get_channel_locations(feature, track) 

345 

346 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

347 source = 'aligned' 

348 else: 

349 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. ' 

350 f'Channel and cluster locations obtained from histology track.') 

351 # get the channels from histology tracing 

352 xyz = xyz[np.argsort(xyz[:, 2]), :] 

353 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) 

354 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

355 source = 'traced' 

356 channels[probe]['axial_um'] = chn_coords[:, 1] 

357 channels[probe]['lateral_um'] = chn_coords[:, 0] 

358 

359 else: 

360 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}') 

361 source = '' 

362 channels = None 

363 

364 if return_source: 

365 return channels, source 

366 else: 

367 return channels 

368 

369 

370def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None): 

371 """ 

372 Load the brain locations of each channel for a given session/probe 

373 

374 Parameters 

375 ---------- 

376 eid : [str, UUID, Path, dict] 

377 Experiment session identifier; may be a UUID, URL, experiment reference string 

378 details dict or Path 

379 probe : [str, list of str] 

380 The probe label(s), e.g. 'probe01' 

381 one : one.api.OneAlyx 

382 An instance of ONE (shouldn't be in 'local' mode) 

383 aligned : bool 

384 Whether to get the latest user aligned channel when not resolved or use histology track 

385 brain_atlas : iblatlas.BrainAtlas 

386 Brain atlas object (default: Allen atlas) 

387 Returns 

388 ------- 

389 dict of one.alf.io.AlfBunch 

390 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

391 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

392 optional: string 'resolved', 'aligned', 'traced' or '' 

393 """ 

394 one = one or ONE() 

395 brain_atlas = brain_atlas or AllenAtlas() 

396 if isinstance(eid, dict): 

397 ses = eid 

398 eid = ses['url'][-36:] 

399 else: 

400 eid = one.to_eid(eid) 

401 collection = _collection_filter_from_args(probe=probe) 

402 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection, 

403 brain_regions=brain_atlas.regions) 

404 incomplete_probes = [k for k in channels if 'x' not in channels[k]] 

405 for iprobe in incomplete_probes: 

406 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned, 

407 brain_atlas=brain_atlas, return_source=True) 

408 if channels_ is not None: 

409 channels[iprobe] = channels_[iprobe] 

410 return channels 

411 

412 

413def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

414 brain_regions=None, nested=True, collection=None, return_collection=False): 

415 """ 

416 From an eid, loads spikes and clusters for all probes 

417 The following set of dataset types are loaded: 

418 'clusters.channels', 

419 'clusters.depths', 

420 'clusters.metrics', 

421 'spikes.clusters', 

422 'spikes.times', 

423 'probes.description' 

424 :param eid: experiment UUID or pathlib.Path of the local session 

425 :param one: an instance of OneAlyx 

426 :param probe: name of probe to load in, if not given all probes for session will be loaded 

427 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

428 :param spike_sorter: name of the spike sorting you want to load (None for default) 

429 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00" 

430 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

431 :param nested: if a single probe is required, do not output a dictionary with the probe name as key 

432 :param return_collection: (False) if True, will return the collection used to load 

433 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe) 

434 """ 

435 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.' 

436 'Use brainbox.io.one.SpikeSortingLoader instead') 

437 if collection is None: 

438 collection = _collection_filter_from_args(probe, spike_sorter) 

439 _logger.debug(f"load spike sorting with collection filter {collection}") 

440 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types, 

441 brain_regions=brain_regions) 

442 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True) 

443 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

444 if nested is False and len(spikes.keys()) == 1: 

445 k = list(spikes.keys())[0] 

446 channels = channels[k] 

447 clusters = clusters[k] 

448 spikes = spikes[k] 

449 if return_collection: 

450 return spikes, clusters, channels, collection 

451 else: 

452 return spikes, clusters, channels 

453 

454 

455def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

456 brain_regions=None, return_collection=False): 

457 """ 

458 From an eid, loads spikes and clusters for all probes 

459 The following set of dataset types are loaded: 

460 'clusters.channels', 

461 'clusters.depths', 

462 'clusters.metrics', 

463 'spikes.clusters', 

464 'spikes.times', 

465 'probes.description' 

466 :param eid: experiment UUID or pathlib.Path of the local session 

467 :param one: an instance of OneAlyx 

468 :param probe: name of probe to load in, if not given all probes for session will be loaded 

469 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

470 :param spike_sorter: name of the spike sorting you want to load (None for default) 

471 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

472 :param return_collection:(bool - False) if True, returns the collection for loading the data 

473 :return: spikes, clusters (dict of bunch, 1 bunch per probe) 

474 """ 

475 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g

476 'Use brainbox.io.one.SpikeSortingLoader instead') 

477 collection = _collection_filter_from_args(probe, spike_sorter) 1g

478 _logger.debug(f"load spike sorting with collection filter {collection}") 1g

479 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g

480 return_channels=False, dataset_types=dataset_types, 

481 brain_regions=brain_regions) 

482 if return_collection: 1g

483 return spikes, clusters, collection 

484 else: 

485 return spikes, clusters 1g

486 

487 

488def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None, 

489 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False): 

490 """ 

491 For a given eid, get spikes, clusters and channels information, and merges clusters 

492 and channels information before returning all three variables. 

493 

494 Parameters 

495 ---------- 

496 eid : [str, UUID, Path, dict] 

497 Experiment session identifier; may be a UUID, URL, experiment reference string 

498 details dict or Path 

499 one : one.api.OneAlyx 

500 An instance of ONE (shouldn't be in 'local' mode) 

501 probe : [str, list of str] 

502 The probe label(s), e.g. 'probe01' 

503 aligned : bool 

504 Whether to get the latest user aligned channel when not resolved or use histology track 

505 dataset_types : list of str 

506 Optional additional spikes/clusters objects to add to the standard default list 

507 spike_sorter : str 

508 Name of the spike sorting you want to load (None for default which is pykilosort if it's 

509 available otherwise the default MATLAB kilosort) 

510 brain_atlas : iblatlas.atlas.BrainAtlas 

511 Brain atlas object (default: Allen atlas) 

512 return_collection: bool 

513 Returns an extra argument with the collection chosen 

514 

515 Returns 

516 ------- 

517 spikes : dict of one.alf.io.AlfBunch 

518 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

519 session and spike sorter, with keys ('clusters', 'times') 

520 clusters : dict of one.alf.io.AlfBunch 

521 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

522 ('channels', 'depths', 'metrics') 

523 channels : dict of one.alf.io.AlfBunch 

524 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

525 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

526 """ 

527 # --- Get spikes and clusters data 

528 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 

529 'Use brainbox.io.one.SpikeSortingLoader instead') 

530 one = one or ONE() 

531 brain_atlas = brain_atlas or AllenAtlas() 

532 spikes, clusters, collection = load_spike_sorting( 

533 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True) 

534 # -- Get brain regions and assign to clusters 

535 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned, 

536 brain_atlas=brain_atlas) 

537 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

538 if nested is False and len(spikes.keys()) == 1: 

539 k = list(spikes.keys())[0] 

540 channels = channels[k] 

541 clusters = clusters[k] 

542 spikes = spikes[k] 

543 if return_collection: 

544 return spikes, clusters, channels, collection 

545 else: 

546 return spikes, clusters, channels 

547 

548 

549def load_ephys_session(eid, one=None): 

550 """ 

551 From an eid, hits the Alyx database and downloads a standard default set of dataset types 

552 From a local session Path (pathlib.Path), loads a standard default set of dataset types 

553 to perform analysis: 

554 'clusters.channels', 

555 'clusters.depths', 

556 'clusters.metrics', 

557 'spikes.clusters', 

558 'spikes.times', 

559 'probes.description' 

560 

561 Parameters 

562 ---------- 

563 eid : [str, UUID, Path, dict] 

564 Experiment session identifier; may be a UUID, URL, experiment reference string 

565 details dict or Path 

566 one : oneibl.one.OneAlyx, optional 

567 ONE object to use for loading. Will generate internal one if not used, by default None 

568 

569 Returns 

570 ------- 

571 spikes : dict of one.alf.io.AlfBunch 

572 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

573 session and spike sorter, with keys ('clusters', 'times') 

574 clusters : dict of one.alf.io.AlfBunch 

575 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

576 ('channels', 'depths', 'metrics') 

577 trials : one.alf.io.AlfBunch of numpy.ndarray 

578 The session trials data 

579 """ 

580 assert one 1g

581 spikes, clusters = load_spike_sorting(eid, one=one) 1g

582 trials = one.load_object(eid, 'trials') 1g

583 return spikes, clusters, trials 1g

584 

585 

586def _remove_old_clusters(session_path, probe): 

587 # gets clusters and spikes from a local session folder 

588 probe_path = session_path.joinpath('alf', probe) 

589 

590 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead 

591 cluster_file = probe_path.joinpath('clusters.metrics.csv') 

592 

593 if cluster_file.exists(): 

594 os.remove(cluster_file) 

595 _logger.info('Deleting old clusters.metrics.csv file') 

596 

597 

598def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None): 

599 """ 

600 Takes (default and any extra) values in given keys from channels and assign them to clusters. 

601 If channels does not contain any data, the new keys are added to clusters but left empty. 

602 

603 Parameters 

604 ---------- 

605 dic_clus : dict of one.alf.io.AlfBunch 

606 1 bunch per probe, containing cluster information 

607 channels : dict of one.alf.io.AlfBunch 

608 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates') 

609 keys_to_add_extra : list of str 

610 Any extra keys to load into channels bunches 

611 

612 Returns 

613 ------- 

614 dict of one.alf.io.AlfBunch 

615 clusters (1 bunch per probe) with new keys values. 

616 """ 

617 probe_labels = list(channels.keys()) # Convert dict_keys into list 

618 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um'] 

619 

620 if keys_to_add_extra is None: 

621 keys_to_add = keys_to_add_default 

622 else: 

623 # Append extra optional keys 

624 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default)) 

625 

626 for label in probe_labels: 

627 clu_ch = dic_clus[label]['channels'] 

628 for key in keys_to_add: 

629 try: 

630 assert key in channels[label].keys() # Check key is in channels 

631 ch_key = channels[label][key] 

632 nch_key = len(ch_key) if ch_key is not None else 0 

633 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index 

634 dic_clus[label][key] = ch_key[clu_ch] 

635 else: 

636 _logger.warning( 

637 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels' 

638 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.') 

639 dic_clus[label][key] = [] 

640 except AssertionError: 

641 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge') 

642 continue 

643 

644 return dic_clus 

645 

646 

647def load_passive_rfmap(eid, one=None): 

648 """ 

649 For a given eid load in the passive receptive field mapping protocol data 

650 

651 Parameters 

652 ---------- 

653 eid : [str, UUID, Path, dict] 

654 Experiment session identifier; may be a UUID, URL, experiment reference string 

655 details dict or Path 

656 one : oneibl.one.OneAlyx, optional 

657 An instance of ONE (may be in 'local' - offline - mode) 

658 

659 Returns 

660 ------- 

661 one.alf.io.AlfBunch 

662 Passive receptive field mapping data 

663 """ 

664 one = one or ONE() 

665 

666 # Load in the receptive field mapping data 

667 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf') 

668 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin', 

669 collection='raw_passive_data'), dtype="uint8") 

670 y_pix, x_pix = 15, 15 

671 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0]) 

672 rf_map['frames'] = frames 

673 

674 return rf_map 

675 

676 

677def load_wheel_reaction_times(eid, one=None): 

678 """ 

679 Return the calculated reaction times for session. Reaction times are defined as the time 

680 between the go cue (onset tone) and the onset of the first substantial wheel movement. A 

681 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the 

682 distance to threshold (~0.1 radians). 

683 

684 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if 

685 there was no detected movement withing the period, or when the goCue_times or feedback_times 

686 are nan. 

687 

688 Parameters 

689 ---------- 

690 eid : [str, UUID, Path, dict] 

691 Experiment session identifier; may be a UUID, URL, experiment reference string 

692 details dict or Path 

693 one : one.api.OneAlyx, optional 

694 one object to use for loading. Will generate internal one if not used, by default None 

695 

696 Returns 

697 ---------- 

698 array-like 

699 reaction times 

700 """ 

701 if one is None: 

702 one = ONE() 

703 

704 trials = one.load_object(eid, 'trials') 

705 # If already extracted, load and return 

706 if trials and 'firstMovement_times' in trials: 

707 return trials['firstMovement_times'] - trials['goCue_times'] 

708 # Otherwise load wheelMoves object and calculate 

709 moves = one.load_object(eid, 'wheelMoves') 

710 # Re-extract wheel moves if necessary 

711 if not moves or 'peakAmplitude' not in moves: 

712 wheel = one.load_object(eid, 'wheel') 

713 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 

714 assert trials and moves, 'unable to load trials and wheelMoves data' 

715 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials) 

716 return firstMove_times - trials['goCue_times'] 

717 

718 

719def load_iti(trials): 

720 """ 

721 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey 

722 screen commencing at stimulus off and lasting until the quiescent period at the start of the 

723 following trial. Note that the ITI for the first trial is the time between the first trial 

724 and the next, therefore the last value is NaN. 

725 

726 Parameters 

727 ---------- 

728 trials : one.alf.io.AlfBunch 

729 An ALF trials object containing the keys {'intervals', 'stimOff_times'}. 

730 

731 Returns 

732 ------- 

733 np.array 

734 An array of inter-trial intervals, the last value being NaN. 

735 """ 

736 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1o

737 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1o

738 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1o

739 

740 

741def load_channels_from_insertion(ins, depths=None, one=None, ba=None): 

742 

743 PROV_2_VAL = { 

744 'Resolved': 90, 

745 'Ephys aligned histology track': 70, 

746 'Histology track': 50, 

747 'Micro-manipulator': 30, 

748 'Planned': 10} 

749 

750 one = one or ONE() 

751 ba = ba or atlas.AllenAtlas() 

752 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id']) 

753 val = [PROV_2_VAL[tr['provenance']] for tr in traj] 

754 idx = np.argmax(val) 

755 traj = traj[idx] 

756 if depths is None: 

757 depths = trace_header(version=1)[:, 1] 

758 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator': 

759 ins = atlas.Insertion.from_dict(traj) 

760 # Deepest coordinate first 

761 xyz = np.c_[ins.tip, ins.entry].T 

762 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

763 TIP_SIZE_UM) / 1e6) 

764 else: 

765 xyz = np.array(ins['json']['xyz_picks']) / 1e6 

766 if traj['provenance'] == 'Histology track': 

767 xyz = xyz[np.argsort(xyz[:, 2]), :] 

768 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

769 TIP_SIZE_UM) / 1e6) 

770 else: 

771 align_key = ins['json']['extended_qc']['alignment_stored'] 

772 feature = traj['json'][align_key][0] 

773 track = traj['json'][align_key][1] 

774 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

775 feature_prev=feature, 

776 brain_atlas=ba, speedy=True) 

777 xyz_channels = ephysalign.get_channel_locations(feature, track) 

778 return xyz_channels 

779 

780 

781@dataclass 

782class SpikeSortingLoader: 

783 """ 

784 Object that will load spike sorting data for a given probe insertion. 

785 This class can be instantiated in several manners 

786 - With Alyx database probe id: 

787 SpikeSortingLoader(pid=pid, one=one) 

788 - With Alyx database eic and probe name: 

789 SpikeSortingLoader(eid=eid, pname='probe00', one=one) 

790 - From a local session and probe name: 

791 SpikeSortingLoader(session_path=session_path, pname='probe00') 

792 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded. 

793 """ 

794 one: One = None 

795 atlas: None = None 

796 pid: str = None 

797 eid: str = '' 

798 pname: str = '' 

799 session_path: Path = '' 

800 # the following properties are the outcome of the post init function 

801 collections: list = None 

802 datasets: list = None # list of all datasets belonging to the session 

803 # the following properties are the outcome of a reading function 

804 files: dict = None 

805 raw_data_files: list = None # list of raw ap and lf files corresponding to the recording 

806 collection: str = '' 

807 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced' 

808 spike_sorter: str = 'pykilosort' 

809 spike_sorting_path: Path = None 

810 _sync: dict = None 

811 

812 def __post_init__(self): 

813 # pid gets precedence 

814 if self.pid is not None: 1caedk

815 try: 1dk

816 self.eid, self.pname = self.one.pid2eid(self.pid) 1dk

817 except NotImplementedError: 

818 if self.eid == '' or self.pname == '': 

819 raise IOError("Cannot infer session id and probe name from pid. " 

820 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.") 

821 self.session_path = self.one.eid2path(self.eid) 1dk

822 # then eid / pname combination 

823 elif self.session_path is None or self.session_path == '': 1cae

824 self.session_path = self.one.eid2path(self.eid) 1cae

825 # fully local providing a session path 

826 else: 

827 if self.one: 

828 self.eid = self.one.to_eid(self.session_path) 

829 else: 

830 self.one = One(cache_dir=self.session_path.parents[2], mode='local') 

831 df_sessions = cache._make_sessions_df(self.session_path) 

832 self.one._cache['sessions'] = df_sessions.set_index('id') 

833 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False) 

834 self.eid = str(self.session_path.relative_to(self.session_path.parents[2])) 

835 # populates default properties 

836 self.collections = self.one.list_collections( 1caedk

837 self.eid, filename='spikes*', collection=f"alf/{self.pname}*") 

838 self.datasets = self.one.list_datasets(self.eid) 1caedk

839 if self.atlas is None: 1caedk

840 self.atlas = AllenAtlas() 1caek

841 self.files = {} 1caedk

842 self.raw_data_files = [] 1caedk

843 

844 def _load_object(self, *args, **kwargs): 

845 """ 

846 This function is a wrapper around alfio.load_object that will remove the UUID in the 

847 filename if the object is on SDSC. 

848 """ 

849 remove_uuids = getattr(self.one, 'uuid_filenames', False) 1caed

850 d = alfio.load_object(*args, **kwargs) 1caed

851 if remove_uuids: 1caed

852 # pops the UUID in the key names 

853 keys = list(d.keys()) 

854 for k in keys: 

855 d[k[:-37]] = d.pop(k) 

856 return d 1caed

857 

858 @staticmethod 

859 def _get_attributes(dataset_types): 

860 """returns attributes to load for spikes and clusters objects""" 

861 dataset_types = [] if dataset_types is None else dataset_types 1caed

862 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1caed

863 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1caed

864 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1caed

865 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1caed

866 waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] 1caed

867 waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) 1caed

868 return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} 1caed

869 

870 def _get_spike_sorting_collection(self, spike_sorter=None): 

871 """ 

872 Filters a list or array of collections to get the relevant spike sorting dataset 

873 if there is a pykilosort, load it 

874 """ 

875 for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']): 1caed

876 if sorter is None: 1caed

877 continue 

878 if sorter == "": 1caed

879 collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None) 1cae

880 else: 

881 collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None) 1ad

882 if collection is not None: 1caed

883 return collection 1caed

884 # if none is found amongst the defaults, prefers the shortest 

885 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) 

886 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") 

887 return collection 

888 

889 def load_spike_sorting_object(self, obj, *args, **kwargs): 

890 """ 

891 Loads an ALF object 

892 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

893 :param spike_sorter: (defaults to 'pykilosort') 

894 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

895 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

896 :param kwargs: additional arguments to be passed to one.api.One.load_object 

897 :param missing: 'raise' (default) or 'ignore' 

898 :return: 

899 """ 

900 self.download_spike_sorting_object(obj, *args, **kwargs) 

901 return self._load_object(self.files[obj]) 

902 

903 def get_version(self, spike_sorter=None): 

904 spike_sorter = (spike_sorter or self.spike_sorter) or 'iblsorter' 

905 collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 

906 dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy') 

907 return dset[0]['version'] if len(dset) else 'unknown' 

908 

909 def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=None, collection=None, 

910 attribute=None, missing='raise', **kwargs): 

911 """ 

912 Downloads an ALF object 

913 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

914 :param spike_sorter: (defaults to 'pykilosort') 

915 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

916 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

917 :param kwargs: additional arguments to be passed to one.api.One.load_object 

918 :param attribute: list of attributes to load for the object 

919 :param missing: 'raise' (default) or 'ignore' 

920 :return: 

921 """ 

922 if spike_sorter is None: 1caed

923 spike_sorter = self.spike_sorter if self.spike_sorter is not None else 'iblsorter' 1caed

924 if len(self.collections) == 0: 1caed

925 return {}, {}, {} 

926 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1caed

927 collection = collection or self.collection 1caed

928 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1caed

929 attributes = self._get_attributes(dataset_types) 1caed

930 try: 1caed

931 self.files[obj] = self.one.load_object( 1caed

932 self.eid, obj=obj, attribute=attributes.get(obj, None), 

933 collection=collection, download_only=True, **kwargs) 

934 except ALFObjectNotFound as e: 1cae

935 if missing == 'raise': 1cae

936 raise e 

937 

938 def download_spike_sorting(self, objects=None, **kwargs): 

939 """ 

940 Downloads spikes, clusters and channels 

941 :param spike_sorter: (defaults to 'pykilosort') 

942 :param dataset_types: list of extra dataset types 

943 :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels'] 

944 :return: 

945 """ 

946 objects = ['spikes', 'clusters', 'channels'] if objects is None else objects 1caed

947 for obj in objects: 1caed

948 self.download_spike_sorting_object(obj=obj, **kwargs) 1caed

949 self.spike_sorting_path = self.files['clusters'][0].parent 1caed

950 

951 def download_raw_electrophysiology(self, band='ap'): 

952 """ 

953 Downloads raw electrophysiology data files on local disk. 

954 :param band: "ap" (default) or "lf" for LFP band 

955 :return: list of raw data files full paths (ch, meta and cbin files) 

956 """ 

957 raw_data_files = [] 1a

958 for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']: 1a

959 try: 1a

960 # FIXME: this will fail if multiple LFP segments are found 

961 raw_data_files.append(self.one.load_dataset( 1a

962 self.eid, 

963 download_only=True, 

964 collection=f'raw_ephys_data/{self.pname}', 

965 dataset=suffix, 

966 check_hash=False, 

967 )) 

968 except ALFObjectNotFound: 1a

969 _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}") 1a

970 self.raw_data_files = list(set(self.raw_data_files + raw_data_files)) 1a

971 return raw_data_files 1a

972 

973 def raw_electrophysiology(self, stream=True, band='ap', **kwargs): 

974 """ 

975 Returns a reader for the raw electrophysiology data 

976 By default it is a streamer object, but if stream is False, it will return a spikeglx.Reader after having 

977 downloaded the raw data file if necessary 

978 :param stream: 

979 :param band: 

980 :param kwargs: 

981 :return: 

982 """ 

983 if stream: 1k

984 return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) 1k

985 else: 

986 raw_data_files = self.download_raw_electrophysiology(band=band) 

987 cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None) 

988 if cbin_file is not None: 

989 return spikeglx.Reader(cbin_file) 

990 

991 def download_raw_waveforms(self, **kwargs): 

992 """ 

993 Downloads raw waveforms extracted from sorting to local disk. 

994 """ 

995 _logger.debug(f"loading waveforms from {self.collection}") 

996 return self.one.load_object( 

997 id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"], 

998 collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs 

999 ) 

1000 

1001 def raw_waveforms(self, **kwargs): 

1002 wf_paths = self.download_raw_waveforms(**kwargs) 

1003 return WaveformsLoader(wf_paths[0].parent) 

1004 

1005 def load_channels(self, **kwargs): 

1006 """ 

1007 Loads channels 

1008 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

1009 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

1010 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

1011 - resolved: channel locations alignments have been agreed upon 

1012 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

1013 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

1014 

1015 :param spike_sorter: (defaults to 'pykilosort') 

1016 :param dataset_types: list of extra dataset types 

1017 :return: 

1018 """ 

1019 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting 

1020 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1caed

1021 self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) 1caed

1022 channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) 1caed

1023 if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails 1caed

1024 channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1d

1025 if alfio.check_dimensions(channels) != 0: 1d

1026 channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 

1027 channels['rawInd'] = np.arange(channels[list(channels.keys())[0]].shape[0]) 

1028 if 'brainLocationIds_ccf_2017' not in channels: 1caed

1029 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1a

1030 _channels, self.histology = _load_channel_locations_traj( 1a

1031 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True) 

1032 if _channels: 1a

1033 channels = _channels[self.pname] 

1034 else: 

1035 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1caed

1036 self.histology = 'alf' 1caed

1037 return Bunch(channels) 1caed

1038 

1039 @staticmethod 

1040 def filter_files_by_namespace(all_files, namespace): 

1041 

1042 # Create dict for each file with available namespaces, no namespce is stored under the key None 

1043 namespace_files = defaultdict(dict) 1caed

1044 available_namespaces = [] 1caed

1045 for file in all_files: 1caed

1046 fparts = filename_parts(file.name, as_dict=True) 1caed

1047 fname = f"{fparts['object']}.{fparts['attribute']}" 1caed

1048 nspace = fparts['namespace'] 1caed

1049 available_namespaces.append(nspace) 1caed

1050 namespace_files[fname][nspace] = file 1caed

1051 

1052 if namespace not in set(available_namespaces): 1caed

1053 _logger.info(f'Could not find manual curation results for {namespace}, returning default' 1a

1054 f' non manually curated spikesorting data') 

1055 

1056 # Return the files with the chosen namespace. 

1057 files = [f.get(namespace, f.get(None, None)) for f in namespace_files.values()] 1caed

1058 # remove any None files 

1059 files = [f for f in files if f] 1caed

1060 return files 1caed

1061 

1062 def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, 

1063 namespace=None, **kwargs): 

1064 """ 

1065 Loads spikes, clusters and channels 

1066 

1067 There could be several spike sorting collections, by default the loader will get the pykilosort collection 

1068 

1069 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

1070 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

1071 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

1072 - resolved: channel locations alignments have been agreed upon 

1073 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

1074 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

1075 

1076 :param spike_sorter: (defaults to 'pykilosort') 

1077 :param revision: for example "2024-05-06", (defaults to None): 

1078 :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one 

1079 :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates'] 

1080 :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table 

1081 :param namespace: None, if given will load the manually curated spikesorting with the given namespace, 

1082 e.g to load '_av_.clusters.depths use namespace='av' 

1083 :param kwargs: additional arguments to be passed to one.api.One.load_object 

1084 :return: 

1085 """ 

1086 if len(self.collections) == 0: 1caed

1087 return {}, {}, {} 

1088 self.files = {} 1caed

1089 self.spike_sorter = spike_sorter 1caed

1090 self.revision = revision 1caed

1091 

1092 if good_units and namespace is not None: 1caed

1093 _logger.info('Good units table does not exist for manually curated spike sorting. Pass in namespace with' 1a

1094 'good_units=False and filter the spikes post hoc by the good clusters.') 

1095 return [None] * 3 1a

1096 objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None 1caed

1097 self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) 1caed

1098 channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs) 1caed

1099 self.files['clusters'] = self.filter_files_by_namespace(self.files['clusters'], namespace) 1caed

1100 clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) 1caed

1101 

1102 if good_units: 1caed

1103 spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards) 

1104 else: 

1105 self.files['spikes'] = self.filter_files_by_namespace(self.files['spikes'], namespace) 1caed

1106 spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) 1caed

1107 if enforce_version: 1caed

1108 self._assert_version_consistency() 

1109 return spikes, clusters, channels 1caed

1110 

1111 def _assert_version_consistency(self): 

1112 """ 

1113 Makes sure the state of the spike sorting object matches the files downloaded 

1114 :return: None 

1115 """ 

1116 for k in ['spikes', 'clusters', 'channels', 'passingSpikes']: 

1117 for fn in self.files.get(k, []): 

1118 if self.spike_sorter: 

1119 assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \ 

1120 f"You required strict version {self.spike_sorter}, {fn} does not match" 

1121 if self.revision: 

1122 assert full_path_parts(fn)[5] == self.revision, \ 

1123 f"You required strict revision {self.revision}, {fn} does not match" 

1124 

1125 @staticmethod 

1126 def compute_metrics(spikes, clusters=None): 

1127 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size 

1128 metrics = pd.DataFrame(quick_unit_metrics( 

1129 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc))) 

1130 return metrics 

1131 

1132 @staticmethod 

1133 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False): 

1134 """ 

1135 Merge the metrics and the channel information into the clusters dictionary 

1136 :param spikes: 

1137 :param clusters: 

1138 :param channels: 

1139 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used 

1140 for clusters or analysis applications (defaults to None). 

1141 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false) 

1142 :return: cluster dictionary containing metrics and histology 

1143 """ 

1144 if spikes == {}: 1ad

1145 return 

1146 nc = clusters['channels'].size 1ad

1147 # recompute metrics if they are not available 

1148 metrics = None 1ad

1149 if 'metrics' in clusters: 1ad

1150 metrics = clusters.pop('metrics') 1ad

1151 if metrics.shape[0] != nc: 1ad

1152 metrics = None 

1153 if metrics is None or compute_metrics is True: 1ad

1154 _logger.debug("recompute clusters metrics") 

1155 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters) 

1156 if isinstance(cache_dir, Path): 

1157 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt')) 

1158 for k in metrics.keys(): 1ad

1159 clusters[k] = metrics[k].to_numpy() 1ad

1160 for k in channels.keys(): 1ad

1161 clusters[k] = channels[k][clusters['channels']] 1ad

1162 if cache_dir is not None: 1ad

1163 _logger.debug(f'caching clusters metrics in {cache_dir}') 

1164 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt')) 

1165 return clusters 1ad

1166 

1167 @property 

1168 def url(self): 

1169 """Gets flatiron URL for the session""" 

1170 webclient = getattr(self.one, '_web_client', None) 

1171 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None 

1172 

1173 def _get_probe_info(self): 

1174 if self._sync is None: 1e

1175 timestamps = self.one.load_dataset( 1e

1176 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}') 

1177 _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks 1e

1178 self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}') 

1179 try: 1e

1180 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1e

1181 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}')) 

1182 fs = spikeglx._get_fs_from_meta(ap_meta) 1e

1183 except ALFObjectNotFound: 

1184 ap_meta = None 

1185 fs = 30_000 

1186 self._sync = { 1e

1187 'timestamps': timestamps, 

1188 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'), 

1189 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'), 

1190 'ap_meta': ap_meta, 

1191 'fs': fs, 

1192 } 

1193 

1194 def timesprobe2times(self, values, direction='forward'): 

1195 self._get_probe_info() 

1196 if direction == 'forward': 

1197 return self._sync['forward'](values * self._sync['fs']) 

1198 elif direction == 'reverse': 

1199 return self._sync['reverse'](values) / self._sync['fs'] 

1200 

1201 def samples2times(self, values, direction='forward'): 

1202 """ 

1203 Converts ephys sample values to session main clock seconds 

1204 :param values: numpy array of times in seconds or samples to resync 

1205 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse' 

1206 (seconds main time to samples probe time) 

1207 :return: 

1208 """ 

1209 self._get_probe_info() 1e

1210 return self._sync[direction](values) 1e

1211 

1212 @property 

1213 def pid2ref(self): 

1214 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1ca

1215 

1216 def _default_plot_title(self, spikes): 

1217 title = f"{self.pid2ref}, {self.pid} \n" \ 1c

1218 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters" 

1219 return title 1c

1220 

1221 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, 

1222 drift=None, title=None, **kwargs): 

1223 """ 

1224 :param spikes: spikes dictionary or Bunch 

1225 :param channels: channels dictionary or Bunch. 

1226 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png". 

1227 Otherwise, plot. 

1228 :param br: brain regions object (optional) 

1229 :param label: label for saved image (optional, default="raster") 

1230 :param time_series: timeseries dictionary for behavioral event times (optional) 

1231 :param **kwargs: kwargs passed to `driftmap()` (optional) 

1232 :return: 

1233 """ 

1234 br = br or BrainRegions() 1c

1235 time_series = time_series or {} 1c

1236 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c

1237 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col') 

1238 axs[0, 1].set_axis_off() 1c

1239 # axs[0, 0].set_xticks([]) 

1240 if kwargs is None: 1c

1241 # set default raster plot parameters 

1242 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5} 

1243 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c

1244 if title is None: 1c

1245 title = self._default_plot_title(spikes) 1c

1246 axs[0, 0].title.set_text(title) 1c

1247 for k, ts in time_series.items(): 1c

1248 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0]) 

1249 if 'atlas_id' in channels: 1c

1250 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c

1251 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology) 

1252 axs[1, 0].set_ylim(0, 3800) 1c

1253 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c

1254 fig.tight_layout() 1c

1255 

1256 if drift is None: 1c

1257 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c

1258 if 'drift' in self.files: 1c

1259 drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards) 

1260 if isinstance(drift, dict): 1c

1261 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5) 

1262 axs[0, 0].set(ylim=[-15, 15]) 

1263 

1264 if save_dir is not None: 1c

1265 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) 

1266 fig.savefig(png_file) 

1267 plt.close(fig) 

1268 gc.collect() 

1269 else: 

1270 return fig, axs 1c

1271 

1272 def plot_rawdata_snippet(self, sr, spikes, clusters, t0, 

1273 channels=None, 

1274 br: BrainRegions = None, 

1275 save_dir=None, 

1276 label='raster', 

1277 gain=-93, 

1278 title=None): 

1279 

1280 # compute the raw data offset and destripe, we take 400ms around t0 

1281 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs)) 

1282 raw = sr[first_sample:last_sample, :-sr.nsync].T 

1283 channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True 

1284 destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels) 

1285 # filter out the spikes according to good/bad clusters and to the time slice 

1286 spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample])) 

1287 ss = spikes['samples'][spike_sel] 

1288 sc = clusters['channels'][spikes['clusters'][spike_sel]] 

1289 sok = clusters['label'][spikes['clusters'][spike_sel]] == 1 

1290 if title is None: 

1291 title = self._default_plot_title(spikes) 

1292 # display the raw data snippet with spikes overlaid 

1293 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col') 

1294 Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s') 

1295 axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5) 

1296 axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5) 

1297 axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035]) 

1298 # adds the channel locations if available 

1299 if (channels is not None) and ('atlas_id' in channels): 

1300 br = br or BrainRegions() 

1301 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 

1302 brain_regions=br, display=True, ax=axs[1], title=self.histology) 

1303 axs[1].get_yaxis().set_visible(False) 

1304 fig.tight_layout() 

1305 

1306 if save_dir is not None: 

1307 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) 

1308 fig.savefig(png_file) 

1309 plt.close(fig) 

1310 gc.collect() 

1311 else: 

1312 return fig, axs 

1313 

1314 

1315@dataclass 

1316class SessionLoader: 

1317 """ 

1318 Object to load session data for a give session in the recommended way. 

1319 

1320 Parameters 

1321 ---------- 

1322 one: one.api.ONE instance 

1323 Can be in remote or local mode (required) 

1324 session_path: string or pathlib.Path 

1325 The absolute path to the session (one of session_path or eid is required) 

1326 eid: string 

1327 database UUID of the session (one of session_path or eid is required) 

1328 

1329 If both are provided, session_path takes precedence over eid. 

1330 

1331 Examples 

1332 -------- 

1333 1) Load all available session data for one session: 

1334 >>> from one.api import ONE 

1335 >>> from brainbox.io.one import SessionLoader 

1336 >>> one = ONE() 

1337 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/') 

1338 # Object is initiated, but no data is loaded as you can see in the data_info attribute 

1339 >>> sess_loader.data_info 

1340 name is_loaded 

1341 0 trials False 

1342 1 wheel False 

1343 2 pose False 

1344 3 motion_energy False 

1345 4 pupil False 

1346 

1347 # Loading all available session data, the data_info attribute now shows which data has been loaded 

1348 >>> sess_loader.load_session_data() 

1349 >>> sess_loader.data_info 

1350 name is_loaded 

1351 0 trials True 

1352 1 wheel True 

1353 2 pose True 

1354 3 motion_energy True 

1355 4 pupil False 

1356 

1357 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g. 

1358 >>> type(sess_loader.trials) 

1359 pandas.core.frame.DataFrame 

1360 >>> sess_loader.trials.shape 

1361 (626, 18) 

1362 # Each data comes with its own timestamps in a column called 'times' 

1363 >>> sess_loader.wheel['times'] 

1364 0 0.134286 

1365 1 0.135286 

1366 2 0.136286 

1367 3 0.137286 

1368 4 0.138286 

1369 ... 

1370 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera. 

1371 # The dataframes of all cameras are collected in a dictionary 

1372 >>> type(sess_loader.pose) 

1373 dict 

1374 >>> sess_loader.pose.keys() 

1375 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera']) 

1376 >>> sess_loader.pose['bodyCamera'].columns 

1377 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object') 

1378 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading 

1379 functions: 

1380 >>> sess_loader.load_wheel(sampling_rate=100) 

1381 """ 

1382 one: One = None 

1383 session_path: Path = '' 

1384 eid: str = '' 

1385 revision: str = '' 

1386 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1387 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1388 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1389 pose: dict = field(default_factory=dict, repr=False) 

1390 motion_energy: dict = field(default_factory=dict, repr=False) 

1391 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1392 

1393 def __post_init__(self): 

1394 """ 

1395 Function that runs automatically after initiation of the dataclass attributes. 

1396 Checks for required inputs, sets session_path and eid, creates data_info table. 

1397 """ 

1398 if self.one is None: 1bf

1399 raise ValueError("An input to one is required. If not connection to a database is desired, it can be " 

1400 "a fully local instance of One.") 

1401 # If session path is given, takes precedence over eid 

1402 if self.session_path is not None and self.session_path != '': 1bf

1403 self.eid = self.one.to_eid(self.session_path) 1bf

1404 self.session_path = Path(self.session_path) 1bf

1405 # Providing no session path, try to infer from eid 

1406 else: 

1407 if self.eid is not None and self.eid != '': 

1408 self.session_path = self.one.eid2path(self.eid) 

1409 else: 

1410 raise ValueError("If no session path is given, eid is required.") 

1411 

1412 data_names = [ 1bf

1413 'trials', 

1414 'wheel', 

1415 'pose', 

1416 'motion_energy', 

1417 'pupil' 

1418 ] 

1419 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1bf

1420 

1421 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False): 

1422 """ 

1423 Function to load available session data into the SessionLoader object. Input parameters allow to control which 

1424 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input 

1425 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored 

1426 in SessionLoader.data_info 

1427 

1428 Parameters 

1429 ---------- 

1430 trials: boolean 

1431 Whether to load all trials data into SessionLoader.trials, default is True 

1432 wheel: boolean 

1433 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True 

1434 pose: boolean 

1435 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose, 

1436 default is True 

1437 motion_energy: boolean 

1438 Whether to load motion energy data (whisker pad for left/right camera, body for body camera) 

1439 into SessionLoader.motion_energy, default is True 

1440 pupil: boolean 

1441 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil, 

1442 default is True 

1443 reload: boolean 

1444 Whether to reload data that has already been loaded into this SessionLoader object, default is False 

1445 """ 

1446 load_df = self.data_info.copy() 1f

1447 load_df['to_load'] = [ 1f

1448 trials, 

1449 wheel, 

1450 pose, 

1451 motion_energy, 

1452 pupil 

1453 ] 

1454 load_df['load_func'] = [ 1f

1455 self.load_trials, 

1456 self.load_wheel, 

1457 self.load_pose, 

1458 self.load_motion_energy, 

1459 self.load_pupil 

1460 ] 

1461 

1462 for idx, row in load_df.iterrows(): 1f

1463 if row['to_load'] is False: 1f

1464 _logger.debug(f"Not loading {row['name']} data, set to False.") 

1465 elif row['is_loaded'] is True and reload is False: 1f

1466 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1f

1467 else: 

1468 try: 1f

1469 _logger.info(f"Loading {row['name']} data") 1f

1470 row['load_func']() 1f

1471 self.data_info.loc[idx, 'is_loaded'] = True 1f

1472 except BaseException as e: 

1473 _logger.warning(f"Could not load {row['name']} data.") 

1474 _logger.debug(e) 

1475 

1476 def _find_behaviour_collection(self, obj): 

1477 """ 

1478 Function to find the trial or wheel collection 

1479 

1480 Parameters 

1481 ---------- 

1482 obj: str 

1483 Alf object to load, either 'trials' or 'wheel' 

1484 """ 

1485 dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy' 1fnj

1486 dsets = self.one.list_datasets(self.eid, dataset) 1fnj

1487 if len(dsets) == 0: 1fnj

1488 return 'alf' 1fn

1489 else: 

1490 collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets] 1fj

1491 if len(set(collections)) == 1: 1fj

1492 return collections[0] 1fj

1493 else: 

1494 _logger.error(f'Multiple collections found {collections}. Specify collection when loading, ' 

1495 f'e.g sl.load_{obj}(collection="{collections[0]}")') 

1496 raise ALFMultipleCollectionsFound 

1497 

1498 def load_trials(self, collection=None): 

1499 """ 

1500 Function to load trials data into SessionLoader.trials 

1501 

1502 Parameters 

1503 ---------- 

1504 collection: str 

1505 Alf collection of trials data 

1506 """ 

1507 

1508 if not collection: 1fn

1509 collection = self._find_behaviour_collection('trials') 1fn

1510 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex 

1511 self.one.wildcards = False 1fn

1512 self.trials = self.one.load_object( 1fn

1513 self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df() 

1514 self.one.wildcards = True 1fn

1515 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1fn

1516 

1517 def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None): 

1518 """ 

1519 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position 

1520 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which 

1521 a Butterworth low-pass filter is applied. 

1522 

1523 Parameters 

1524 ---------- 

1525 fs: int, float 

1526 Sampling frequency for the wheel position, default is 1000 Hz 

1527 corner_frequency: int, float 

1528 Corner frequency of Butterworth low-pass filter, default is 20 

1529 order: int, float 

1530 Order of Butterworth low_pass filter, default is 8 

1531 collection: str 

1532 Alf collection of wheel data 

1533 """ 

1534 if not collection: 1fj

1535 collection = self._find_behaviour_collection('wheel') 1fj

1536 wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection, revision=self.revision or None) 1fj

1537 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1fj

1538 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps") 

1539 # resample the wheel position and compute velocity, acceleration 

1540 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1fj

1541 self.wheel['position'], self.wheel['times'] = interpolate_position( 1fj

1542 wheel_raw['timestamps'], wheel_raw['position'], freq=fs) 

1543 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1fj

1544 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order) 

1545 self.wheel = self.wheel.apply(np.float32) 1fj

1546 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1fj

1547 

1548 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']): 

1549 """ 

1550 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a 

1551 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas 

1552 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera. 

1553 

1554 Parameters 

1555 ---------- 

1556 likelihood_thr: float 

1557 The position of each tracked body part come with a likelihood of that estimate for each time point. 

1558 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set 

1559 likelihood_thr=1. Default is 0.9 

1560 views: list 

1561 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1562 """ 

1563 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1564 self.pose = {} 1mhf

1565 for view in views: 1mhf

1566 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'], revision=self.revision or None) 1mhf

1567 # Double check if video timestamps are correct length or can be fixed 

1568 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1mhf

1569 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1mhf

1570 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1mhf

1571 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1mhf

1572 

1573 def load_motion_energy(self, views=['left', 'right', 'body']): 

1574 """ 

1575 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a 

1576 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are 

1577 pandas Dataframes with the timestamps and motion energy data. 

1578 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad 

1579 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the 

1580 body (bodyMotionEnergy). 

1581 

1582 Parameters 

1583 ---------- 

1584 views: list 

1585 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1586 """ 

1587 names = {'left': 'whiskerMotionEnergy', 1lf

1588 'right': 'whiskerMotionEnergy', 

1589 'body': 'bodyMotionEnergy'} 

1590 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1591 self.motion_energy = {} 1lf

1592 for view in views: 1lf

1593 me_raw = self.one.load_object( 1lf

1594 self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None) 

1595 # Double check if video timestamps are correct length or can be fixed 

1596 times_fixed, motion_energy = self._check_video_timestamps( 1lf

1597 view, me_raw['times'], me_raw['ROIMotionEnergy']) 

1598 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1lf

1599 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1lf

1600 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1lf

1601 

1602 def load_licks(self): 

1603 """ 

1604 Not yet implemented 

1605 """ 

1606 pass 

1607 

1608 def load_pupil(self, snr_thresh=5.): 

1609 """ 

1610 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil. 

1611 

1612 Parameters 

1613 ---------- 

1614 snr_thresh: float 

1615 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data 

1616 will be considered unusable and will be discarded. 

1617 """ 

1618 # Try to load from features 

1619 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'], revision=self.revision or None) 1hf

1620 if 'features' in feat_raw.keys(): 1hf

1621 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features']) 

1622 self.pupil = feats.copy() 

1623 self.pupil.insert(0, 'times', times_fixed) 

1624 

1625 # If unavailable compute on the fly 

1626 else: 

1627 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1hf

1628 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1hf

1629 and 'leftCamera' in self.pose.keys()): 

1630 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt 

1631 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1hf

1632 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1hf

1633 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1hf

1634 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1hf

1635 else: 

1636 self.load_pose(views=['left'], likelihood_thr=0.9) 

1637 dlc_thr = self.pose['leftCamera'].copy() 

1638 

1639 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1hf

1640 try: 1hf

1641 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1hf

1642 except BaseException as e: 

1643 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. " 

1644 "Saving all NaNs for pupilDiameter_smooth.") 

1645 _logger.debug(e) 

1646 self.pupil['pupilDiameter_smooth'] = np.nan 

1647 

1648 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1hf

1649 good_idxs = np.where( 1hf

1650 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] 

1651 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1hf

1652 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]))) 

1653 if snr < snr_thresh: 1hf

1654 self.pupil = pd.DataFrame() 1h

1655 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h

1656 

1657 def _check_video_timestamps(self, view, video_timestamps, video_data): 

1658 """ 

1659 Helper function to check for the length of the video frames vs video timestamps and fix in case 

1660 timestamps are longer than video frames. 

1661 """ 

1662 # If camera times are shorter than video data, or empty, no current fix 

1663 if video_timestamps.shape[0] < video_data.shape[0]: 1lmhf

1664 if video_timestamps.shape[0] == 0: 

1665 msg = f'Camera times empty for {view}Camera.' 

1666 else: 

1667 msg = f'Camera times are shorter than video data for {view}Camera.' 

1668 _logger.warning(msg) 

1669 raise ValueError(msg) 

1670 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video. 

1671 # This is because the first few frames are sometimes not recorded. We can remove the first few 

1672 # timestamps in this case 

1673 elif video_timestamps.shape[0] > video_data.shape[0]: 1lmhf

1674 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1lmhf

1675 return video_timestamps_fixed, video_data 1lmhf

1676 else: 

1677 return video_timestamps, video_data 

1678 

1679 

1680class EphysSessionLoader(SessionLoader): 

1681 """ 

1682 Spike sorting enhanced version of SessionLoader 

1683 Loads spike sorting data for all probes in the session, in the self.ephys dict 

1684 >>> EphysSessionLoader(eid=eid, one=one) 

1685 To select for a specific probe 

1686 >>> EphysSessionLoader(eid=eid, one=one, pid=pid) 

1687 """ 

1688 def __init__(self, *args, pname=None, pid=None, **kwargs): 

1689 """ 

1690 Needs an active connection in order to get the list of insertions in the session 

1691 :param args: 

1692 :param kwargs: 

1693 """ 

1694 super().__init__(*args, **kwargs) 

1695 # if necessary, restrict the query 

1696 qargs = {} if pname is None else {'name': pname} 

1697 qargs = qargs or ({} if pid is None else {'id': pid}) 

1698 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs) 

1699 self.ephys = {} 

1700 for ins in insertions: 

1701 self.ephys[ins['name']] = {} 

1702 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one) 

1703 

1704 def load_session_data(self, *args, **kwargs): 

1705 super().load_session_data(*args, **kwargs) 

1706 self.load_spike_sorting() 

1707 

1708 def load_spike_sorting(self, pnames=None): 

1709 pnames = pnames or list(self.ephys.keys()) 

1710 for pname in pnames: 

1711 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting() 

1712 self.ephys[pname]['spikes'] = spikes 

1713 self.ephys[pname]['clusters'] = clusters 

1714 self.ephys[pname]['channels'] = channels 

1715 

1716 @property 

1717 def probes(self): 

1718 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}