Coverage for brainbox/io/one.py: 53%

747 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment.""" 

2from dataclasses import dataclass, field 

3import gc 

4import logging 

5import re 

6import os 

7from pathlib import Path 

8 

9import numpy as np 

10import pandas as pd 

11from scipy.interpolate import interp1d 

12import matplotlib.pyplot as plt 

13 

14from one.api import ONE, One 

15from one.alf.path import get_alf_path, full_path_parts 

16from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound 

17from one.alf import cache 

18import one.alf.io as alfio 

19from neuropixel import TIP_SIZE_UM, trace_header 

20import spikeglx 

21 

22import ibldsp.voltage 

23from ibldsp.waveform_extraction import WaveformsLoader 

24from iblutil.util import Bunch 

25from iblatlas.atlas import AllenAtlas, BrainRegions 

26from iblatlas import atlas 

27from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times 

28from ibllib.pipes import histology 

29from ibllib.pipes.ephys_alignment import EphysAlignment 

30from ibllib.plots import vertical_lines, Density 

31 

32import brainbox.plot 

33from brainbox.io.spikeglx import Streamer 

34from brainbox.ephys_plots import plot_brain_regions 

35from brainbox.metrics.single_units import quick_unit_metrics 

36from brainbox.behavior.wheel import interpolate_position, velocity_filtered 

37from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter 

38 

39_logger = logging.getLogger('ibllib') 

40 

41 

42SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths'] 

43CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids'] 

44WAVEFORMS_ATTRIBUTES = ['templates'] 

45 

46 

47def load_lfp(eid, one=None, dataset_types=None, **kwargs): 

48 """ 

49 TODO Verify works 

50 From an eid, hits the Alyx database and downloads the standard set of datasets 

51 needed for LFP 

52 :param eid: 

53 :param dataset_types: additional dataset types to add to the list 

54 :param open: if True, spikeglx readers are opened 

55 :return: spikeglx.Reader 

56 """ 

57 if dataset_types is None: 

58 dataset_types = [] 

59 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*'] 

60 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes] 

61 session_path = one.eid2path(eid) 

62 

63 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False) 

64 if ef.get('lf', None)] 

65 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles] 

66 

67 

68def _collection_filter_from_args(probe, spike_sorter=None): 

69 collection = f'alf/{probe}/{spike_sorter}' 1g

70 collection = collection.replace('None', '*') 1g

71 collection = collection.replace('/*', '*') 1g

72 collection = collection[:-1] if collection.endswith('/') else collection 1g

73 return collection 1g

74 

75 

76def _get_spike_sorting_collection(collections, pname): 

77 """ 

78 Filters a list or array of collections to get the relevant spike sorting dataset 

79 if there is a pykilosort, load it 

80 """ 

81 # 

82 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1gb

83 # otherwise, prefers the shortest 

84 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1gb

85 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1gb

86 return collection 1gb

87 

88 

89def _channels_alyx2bunch(chans): 

90 channels = Bunch({ 

91 'atlas_id': np.array([ch['brain_region'] for ch in chans]), 

92 'x': np.array([ch['x'] for ch in chans]) / 1e6, 

93 'y': np.array([ch['y'] for ch in chans]) / 1e6, 

94 'z': np.array([ch['z'] for ch in chans]) / 1e6, 

95 'axial_um': np.array([ch['axial'] for ch in chans]), 

96 'lateral_um': np.array([ch['lateral'] for ch in chans]) 

97 }) 

98 return channels 

99 

100 

101def _channels_traj2bunch(xyz_chans, brain_atlas): 

102 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans)) 

103 channels = { 

104 'x': xyz_chans[:, 0], 

105 'y': xyz_chans[:, 1], 

106 'z': xyz_chans[:, 2], 

107 'acronym': brain_regions['acronym'], 

108 'atlas_id': brain_regions['id'] 

109 } 

110 

111 return channels 

112 

113 

114def _channels_bunch2alf(channels): 

115 channels_ = { 1i

116 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6, 

117 'brainLocationIds_ccf_2017': channels['atlas_id'], 

118 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]} 

119 return channels_ 1i

120 

121 

122def _channels_alf2bunch(channels, brain_regions=None): 

123 # reformat the dictionary according to the standard that comes out of Alyx 

124 channels_ = { 1icbed

125 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6, 

126 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6, 

127 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6, 

128 'acronym': None, 

129 'atlas_id': channels['brainLocationIds_ccf_2017'], 

130 'axial_um': channels['localCoordinates'][:, 1], 

131 'lateral_um': channels['localCoordinates'][:, 0], 

132 } 

133 # here if we have some extra keys, they will carry over to the next dictionary 

134 for k in channels: 1icbed

135 if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: 1icbed

136 channels_[k] = channels[k] 1icbed

137 if brain_regions: 1icbed

138 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icbed

139 return channels_ 1icbed

140 

141 

142def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None, 

143 brain_regions=None): 

144 """ 

145 Generic function to load spike sorting according data using ONE. 

146 

147 Will try to load one spike sorting for any probe present for the eid matching the collection 

148 For each probe it will load a spike sorting: 

149 - if there is one version: loads this one 

150 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX) 

151 

152 Parameters 

153 ---------- 

154 eid : [str, UUID, Path, dict] 

155 Experiment session identifier; may be a UUID, URL, experiment reference string 

156 details dict or Path 

157 one : one.api.OneAlyx 

158 An instance of ONE (may be in 'local' mode) 

159 collection : str 

160 collection filter word - accepts wildcards - can be a combination of spike sorter and 

161 probe. See `ALF documentation`_ for details. 

162 revision : str 

163 A particular revision return (defaults to latest revision). See `ALF documentation`_ for 

164 details. 

165 return_channels : bool 

166 Defaults to False otherwise loads channels from disk 

167 

168 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components 

169 

170 Returns 

171 ------- 

172 spikes : dict of one.alf.io.AlfBunch 

173 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

174 session and spike sorter, with keys ('clusters', 'times') 

175 clusters : dict of one.alf.io.AlfBunch 

176 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

177 ('channels', 'depths', 'metrics') 

178 channels : dict of one.alf.io.AlfBunch 

179 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

180 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs 

181 non-lateralized. 

182 """ 

183 one = one or ONE() 1gb

184 # enumerate probes and load according to the name 

185 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1gb

186 if len(collections) == 0: 1gb

187 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

188 pnames = list(set(c.split('/')[1] for c in collections)) 1gb

189 spikes, clusters, channels = ({} for _ in range(3)) 1gb

190 

191 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1gb

192 

193 for pname in pnames: 1gb

194 probe_collection = _get_spike_sorting_collection(collections, pname) 1gb

195 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1gb

196 attribute=spike_attributes) 

197 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1g

198 attribute=cluster_attributes) 

199 if return_channels: 1g

200 channels = _load_channels_locations_from_disk( 

201 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions) 

202 return spikes, clusters, channels 

203 else: 

204 return spikes, clusters 1g

205 

206 

207def _get_attributes(dataset_types): 

208 if dataset_types is None: 1gb

209 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1gb

210 else: 

211 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 

212 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 

213 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 

214 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 

215 return spike_attributes, cluster_attributes 

216 

217 

218def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None): 

219 _logger.debug('loading spike sorting from disk') 

220 channels = Bunch({}) 

221 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) 

222 if len(collections) == 0: 

223 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

224 probes = list(set([c.split('/')[1] for c in collections])) 

225 for probe in probes: 

226 probe_collection = _get_spike_sorting_collection(collections, probe) 

227 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') 

228 # if the spike sorter has not aligned data, try and get the alignment available 

229 if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): 

230 aligned_channel_collections = one.list_collections( 

231 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision) 

232 if len(aligned_channel_collections) == 0: 

233 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}") 

234 continue 

235 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}") 

236 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe) 

237 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection) 

238 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe]) 

239 # only have to reformat channels if we were able to load coordinates from disk 

240 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) 

241 return channels 

242 

243 

244def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None): 

245 """ 

246 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto 

247 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field 

248 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts 

249 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys 

250 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017' 

251 OR 

252 'x', 'y', 'z', 'acronym', 'axial_um' 

253 those are the guide for the interpolation 

254 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates' 

255 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object 

256 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017 

257 if a brain region object is provided, outputts a dict with keys 

258 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um' 

259 :return: Bunch or dictionary of channels with brain coordinates keys 

260 """ 

261 NEUROPIXEL_VERSION = 1 1i

262 h = trace_header(version=NEUROPIXEL_VERSION) 1i

263 if channels is None: 1i

264 channels = {'localCoordinates': np.c_[h['x'], h['y']]} 

265 nch = channels['localCoordinates'].shape[0] 1i

266 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i

267 channels_aligned = _channels_bunch2alf(channels_aligned) 1i

268 if 'localCoordinates' in channels_aligned.keys(): 1i

269 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i

270 else: # this is a edge case for a few spike sorting sessions 

271 assert channels_aligned['mlapdv'].shape[0] == 384 

272 aligned_depths = h['y'] 

273 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i

274 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i

275 channels['mlapdv'] = np.zeros((nch, 3)) 1i

276 for i in np.arange(3): 1i

277 channels['mlapdv'][:, i] = np.interp( 1i

278 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv] 

279 # the brain locations have to be interpolated by nearest neighbour 

280 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i

281 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i

282 if brain_regions is not None: 1i

283 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i

284 else: 

285 return channels 1i

286 

287 

288def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False, 

289 brain_atlas=None, return_source=False): 

290 if not hasattr(one, 'alyx'): 1b

291 return {}, None 1b

292 _logger.debug(f"trying to load from traj {probe}") 

293 channels = Bunch() 

294 brain_atlas = brain_atlas or AllenAtlas 

295 # need to find the collection bruh 

296 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0] 

297 collection = _collection_filter_from_args(probe=probe) 

298 collections = one.list_collections(eid, filename='channels*', collection=collection, 

299 revision=revision) 

300 probe_collection = _get_spike_sorting_collection(collections, probe) 

301 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection) 

302 depths = chn_coords[:, 1] 

303 

304 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

305 get('tracing_exists', False) 

306 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

307 get('alignment_resolved', False) 

308 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

309 get('alignment_count', 0) 

310 

311 if tracing: 

312 xyz = np.array(insertion['json']['xyz_picks']) / 1e6 

313 if resolved: 

314 

315 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. ' 

316 f'Channel and cluster locations obtained from ephys aligned histology ' 

317 f'track.') 

318 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

319 provenance='Ephys aligned histology track')[0] 

320 align_key = insertion['json']['extended_qc']['alignment_stored'] 

321 feature = traj['json'][align_key][0] 

322 track = traj['json'][align_key][1] 

323 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

324 feature_prev=feature, 

325 brain_atlas=brain_atlas, speedy=True) 

326 chans = ephysalign.get_channel_locations(feature, track) 

327 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

328 source = 'resolved' 

329 elif counts > 0 and aligned: 

330 _logger.debug(f'Channel locations for {eid}/{probe} have not been ' 

331 f'resolved. However, alignment flag set to True so channel and cluster' 

332 f' locations will be obtained from latest available ephys aligned ' 

333 f'histology track.') 

334 # get the latest user aligned channels 

335 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

336 provenance='Ephys aligned histology track')[0] 

337 align_key = insertion['json']['extended_qc']['alignment_stored'] 

338 feature = traj['json'][align_key][0] 

339 track = traj['json'][align_key][1] 

340 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

341 feature_prev=feature, 

342 brain_atlas=brain_atlas, speedy=True) 

343 chans = ephysalign.get_channel_locations(feature, track) 

344 

345 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

346 source = 'aligned' 

347 else: 

348 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. ' 

349 f'Channel and cluster locations obtained from histology track.') 

350 # get the channels from histology tracing 

351 xyz = xyz[np.argsort(xyz[:, 2]), :] 

352 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) 

353 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

354 source = 'traced' 

355 channels[probe]['axial_um'] = chn_coords[:, 1] 

356 channels[probe]['lateral_um'] = chn_coords[:, 0] 

357 

358 else: 

359 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}') 

360 source = '' 

361 channels = None 

362 

363 if return_source: 

364 return channels, source 

365 else: 

366 return channels 

367 

368 

369def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None): 

370 """ 

371 Load the brain locations of each channel for a given session/probe 

372 

373 Parameters 

374 ---------- 

375 eid : [str, UUID, Path, dict] 

376 Experiment session identifier; may be a UUID, URL, experiment reference string 

377 details dict or Path 

378 probe : [str, list of str] 

379 The probe label(s), e.g. 'probe01' 

380 one : one.api.OneAlyx 

381 An instance of ONE (shouldn't be in 'local' mode) 

382 aligned : bool 

383 Whether to get the latest user aligned channel when not resolved or use histology track 

384 brain_atlas : iblatlas.BrainAtlas 

385 Brain atlas object (default: Allen atlas) 

386 Returns 

387 ------- 

388 dict of one.alf.io.AlfBunch 

389 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

390 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

391 optional: string 'resolved', 'aligned', 'traced' or '' 

392 """ 

393 one = one or ONE() 

394 brain_atlas = brain_atlas or AllenAtlas() 

395 if isinstance(eid, dict): 

396 ses = eid 

397 eid = ses['url'][-36:] 

398 else: 

399 eid = one.to_eid(eid) 

400 collection = _collection_filter_from_args(probe=probe) 

401 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection, 

402 brain_regions=brain_atlas.regions) 

403 incomplete_probes = [k for k in channels if 'x' not in channels[k]] 

404 for iprobe in incomplete_probes: 

405 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned, 

406 brain_atlas=brain_atlas, return_source=True) 

407 if channels_ is not None: 

408 channels[iprobe] = channels_[iprobe] 

409 return channels 

410 

411 

412def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

413 brain_regions=None, nested=True, collection=None, return_collection=False): 

414 """ 

415 From an eid, loads spikes and clusters for all probes 

416 The following set of dataset types are loaded: 

417 'clusters.channels', 

418 'clusters.depths', 

419 'clusters.metrics', 

420 'spikes.clusters', 

421 'spikes.times', 

422 'probes.description' 

423 :param eid: experiment UUID or pathlib.Path of the local session 

424 :param one: an instance of OneAlyx 

425 :param probe: name of probe to load in, if not given all probes for session will be loaded 

426 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

427 :param spike_sorter: name of the spike sorting you want to load (None for default) 

428 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00" 

429 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

430 :param nested: if a single probe is required, do not output a dictionary with the probe name as key 

431 :param return_collection: (False) if True, will return the collection used to load 

432 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe) 

433 """ 

434 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.' 

435 'Use brainbox.io.one.SpikeSortingLoader instead') 

436 if collection is None: 

437 collection = _collection_filter_from_args(probe, spike_sorter) 

438 _logger.debug(f"load spike sorting with collection filter {collection}") 

439 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types, 

440 brain_regions=brain_regions) 

441 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True) 

442 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

443 if nested is False and len(spikes.keys()) == 1: 

444 k = list(spikes.keys())[0] 

445 channels = channels[k] 

446 clusters = clusters[k] 

447 spikes = spikes[k] 

448 if return_collection: 

449 return spikes, clusters, channels, collection 

450 else: 

451 return spikes, clusters, channels 

452 

453 

454def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

455 brain_regions=None, return_collection=False): 

456 """ 

457 From an eid, loads spikes and clusters for all probes 

458 The following set of dataset types are loaded: 

459 'clusters.channels', 

460 'clusters.depths', 

461 'clusters.metrics', 

462 'spikes.clusters', 

463 'spikes.times', 

464 'probes.description' 

465 :param eid: experiment UUID or pathlib.Path of the local session 

466 :param one: an instance of OneAlyx 

467 :param probe: name of probe to load in, if not given all probes for session will be loaded 

468 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

469 :param spike_sorter: name of the spike sorting you want to load (None for default) 

470 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

471 :param return_collection:(bool - False) if True, returns the collection for loading the data 

472 :return: spikes, clusters (dict of bunch, 1 bunch per probe) 

473 """ 

474 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g

475 'Use brainbox.io.one.SpikeSortingLoader instead') 

476 collection = _collection_filter_from_args(probe, spike_sorter) 1g

477 _logger.debug(f"load spike sorting with collection filter {collection}") 1g

478 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g

479 return_channels=False, dataset_types=dataset_types, 

480 brain_regions=brain_regions) 

481 if return_collection: 1g

482 return spikes, clusters, collection 

483 else: 

484 return spikes, clusters 1g

485 

486 

487def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None, 

488 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False): 

489 """ 

490 For a given eid, get spikes, clusters and channels information, and merges clusters 

491 and channels information before returning all three variables. 

492 

493 Parameters 

494 ---------- 

495 eid : [str, UUID, Path, dict] 

496 Experiment session identifier; may be a UUID, URL, experiment reference string 

497 details dict or Path 

498 one : one.api.OneAlyx 

499 An instance of ONE (shouldn't be in 'local' mode) 

500 probe : [str, list of str] 

501 The probe label(s), e.g. 'probe01' 

502 aligned : bool 

503 Whether to get the latest user aligned channel when not resolved or use histology track 

504 dataset_types : list of str 

505 Optional additional spikes/clusters objects to add to the standard default list 

506 spike_sorter : str 

507 Name of the spike sorting you want to load (None for default which is pykilosort if it's 

508 available otherwise the default MATLAB kilosort) 

509 brain_atlas : iblatlas.atlas.BrainAtlas 

510 Brain atlas object (default: Allen atlas) 

511 return_collection: bool 

512 Returns an extra argument with the collection chosen 

513 

514 Returns 

515 ------- 

516 spikes : dict of one.alf.io.AlfBunch 

517 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

518 session and spike sorter, with keys ('clusters', 'times') 

519 clusters : dict of one.alf.io.AlfBunch 

520 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

521 ('channels', 'depths', 'metrics') 

522 channels : dict of one.alf.io.AlfBunch 

523 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

524 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

525 """ 

526 # --- Get spikes and clusters data 

527 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 

528 'Use brainbox.io.one.SpikeSortingLoader instead') 

529 one = one or ONE() 

530 brain_atlas = brain_atlas or AllenAtlas() 

531 spikes, clusters, collection = load_spike_sorting( 

532 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True) 

533 # -- Get brain regions and assign to clusters 

534 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned, 

535 brain_atlas=brain_atlas) 

536 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

537 if nested is False and len(spikes.keys()) == 1: 

538 k = list(spikes.keys())[0] 

539 channels = channels[k] 

540 clusters = clusters[k] 

541 spikes = spikes[k] 

542 if return_collection: 

543 return spikes, clusters, channels, collection 

544 else: 

545 return spikes, clusters, channels 

546 

547 

548def load_ephys_session(eid, one=None): 

549 """ 

550 From an eid, hits the Alyx database and downloads a standard default set of dataset types 

551 From a local session Path (pathlib.Path), loads a standard default set of dataset types 

552 to perform analysis: 

553 'clusters.channels', 

554 'clusters.depths', 

555 'clusters.metrics', 

556 'spikes.clusters', 

557 'spikes.times', 

558 'probes.description' 

559 

560 Parameters 

561 ---------- 

562 eid : [str, UUID, Path, dict] 

563 Experiment session identifier; may be a UUID, URL, experiment reference string 

564 details dict or Path 

565 one : oneibl.one.OneAlyx, optional 

566 ONE object to use for loading. Will generate internal one if not used, by default None 

567 

568 Returns 

569 ------- 

570 spikes : dict of one.alf.io.AlfBunch 

571 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

572 session and spike sorter, with keys ('clusters', 'times') 

573 clusters : dict of one.alf.io.AlfBunch 

574 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

575 ('channels', 'depths', 'metrics') 

576 trials : one.alf.io.AlfBunch of numpy.ndarray 

577 The session trials data 

578 """ 

579 assert one 1g

580 spikes, clusters = load_spike_sorting(eid, one=one) 1g

581 trials = one.load_object(eid, 'trials') 1g

582 return spikes, clusters, trials 1g

583 

584 

585def _remove_old_clusters(session_path, probe): 

586 # gets clusters and spikes from a local session folder 

587 probe_path = session_path.joinpath('alf', probe) 

588 

589 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead 

590 cluster_file = probe_path.joinpath('clusters.metrics.csv') 

591 

592 if cluster_file.exists(): 

593 os.remove(cluster_file) 

594 _logger.info('Deleting old clusters.metrics.csv file') 

595 

596 

597def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None): 

598 """ 

599 Takes (default and any extra) values in given keys from channels and assign them to clusters. 

600 If channels does not contain any data, the new keys are added to clusters but left empty. 

601 

602 Parameters 

603 ---------- 

604 dic_clus : dict of one.alf.io.AlfBunch 

605 1 bunch per probe, containing cluster information 

606 channels : dict of one.alf.io.AlfBunch 

607 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates') 

608 keys_to_add_extra : list of str 

609 Any extra keys to load into channels bunches 

610 

611 Returns 

612 ------- 

613 dict of one.alf.io.AlfBunch 

614 clusters (1 bunch per probe) with new keys values. 

615 """ 

616 probe_labels = list(channels.keys()) # Convert dict_keys into list 

617 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um'] 

618 

619 if keys_to_add_extra is None: 

620 keys_to_add = keys_to_add_default 

621 else: 

622 # Append extra optional keys 

623 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default)) 

624 

625 for label in probe_labels: 

626 clu_ch = dic_clus[label]['channels'] 

627 for key in keys_to_add: 

628 try: 

629 assert key in channels[label].keys() # Check key is in channels 

630 ch_key = channels[label][key] 

631 nch_key = len(ch_key) if ch_key is not None else 0 

632 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index 

633 dic_clus[label][key] = ch_key[clu_ch] 

634 else: 

635 _logger.warning( 

636 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels' 

637 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.') 

638 dic_clus[label][key] = [] 

639 except AssertionError: 

640 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge') 

641 continue 

642 

643 return dic_clus 

644 

645 

646def load_passive_rfmap(eid, one=None): 

647 """ 

648 For a given eid load in the passive receptive field mapping protocol data 

649 

650 Parameters 

651 ---------- 

652 eid : [str, UUID, Path, dict] 

653 Experiment session identifier; may be a UUID, URL, experiment reference string 

654 details dict or Path 

655 one : oneibl.one.OneAlyx, optional 

656 An instance of ONE (may be in 'local' - offline - mode) 

657 

658 Returns 

659 ------- 

660 one.alf.io.AlfBunch 

661 Passive receptive field mapping data 

662 """ 

663 one = one or ONE() 

664 

665 # Load in the receptive field mapping data 

666 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf') 

667 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin', 

668 collection='raw_passive_data'), dtype="uint8") 

669 y_pix, x_pix = 15, 15 

670 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0]) 

671 rf_map['frames'] = frames 

672 

673 return rf_map 

674 

675 

676def load_wheel_reaction_times(eid, one=None): 

677 """ 

678 Return the calculated reaction times for session. Reaction times are defined as the time 

679 between the go cue (onset tone) and the onset of the first substantial wheel movement. A 

680 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the 

681 distance to threshold (~0.1 radians). 

682 

683 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if 

684 there was no detected movement withing the period, or when the goCue_times or feedback_times 

685 are nan. 

686 

687 Parameters 

688 ---------- 

689 eid : [str, UUID, Path, dict] 

690 Experiment session identifier; may be a UUID, URL, experiment reference string 

691 details dict or Path 

692 one : one.api.OneAlyx, optional 

693 one object to use for loading. Will generate internal one if not used, by default None 

694 

695 Returns 

696 ---------- 

697 array-like 

698 reaction times 

699 """ 

700 if one is None: 

701 one = ONE() 

702 

703 trials = one.load_object(eid, 'trials') 

704 # If already extracted, load and return 

705 if trials and 'firstMovement_times' in trials: 

706 return trials['firstMovement_times'] - trials['goCue_times'] 

707 # Otherwise load wheelMoves object and calculate 

708 moves = one.load_object(eid, 'wheelMoves') 

709 # Re-extract wheel moves if necessary 

710 if not moves or 'peakAmplitude' not in moves: 

711 wheel = one.load_object(eid, 'wheel') 

712 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 

713 assert trials and moves, 'unable to load trials and wheelMoves data' 

714 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials) 

715 return firstMove_times - trials['goCue_times'] 

716 

717 

718def load_iti(trials): 

719 """ 

720 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey 

721 screen commencing at stimulus off and lasting until the quiescent period at the start of the 

722 following trial. Note that the ITI for the first trial is the time between the first trial 

723 and the next, therefore the last value is NaN. 

724 

725 Parameters 

726 ---------- 

727 trials : one.alf.io.AlfBunch 

728 An ALF trials object containing the keys {'intervals', 'stimOff_times'}. 

729 

730 Returns 

731 ------- 

732 np.array 

733 An array of inter-trial intervals, the last value being NaN. 

734 """ 

735 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1o

736 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1o

737 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1o

738 

739 

740def load_channels_from_insertion(ins, depths=None, one=None, ba=None): 

741 

742 PROV_2_VAL = { 

743 'Resolved': 90, 

744 'Ephys aligned histology track': 70, 

745 'Histology track': 50, 

746 'Micro-manipulator': 30, 

747 'Planned': 10} 

748 

749 one = one or ONE() 

750 ba = ba or atlas.AllenAtlas() 

751 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id']) 

752 val = [PROV_2_VAL[tr['provenance']] for tr in traj] 

753 idx = np.argmax(val) 

754 traj = traj[idx] 

755 if depths is None: 

756 depths = trace_header(version=1)[:, 1] 

757 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator': 

758 ins = atlas.Insertion.from_dict(traj) 

759 # Deepest coordinate first 

760 xyz = np.c_[ins.tip, ins.entry].T 

761 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

762 TIP_SIZE_UM) / 1e6) 

763 else: 

764 xyz = np.array(ins['json']['xyz_picks']) / 1e6 

765 if traj['provenance'] == 'Histology track': 

766 xyz = xyz[np.argsort(xyz[:, 2]), :] 

767 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

768 TIP_SIZE_UM) / 1e6) 

769 else: 

770 align_key = ins['json']['extended_qc']['alignment_stored'] 

771 feature = traj['json'][align_key][0] 

772 track = traj['json'][align_key][1] 

773 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

774 feature_prev=feature, 

775 brain_atlas=ba, speedy=True) 

776 xyz_channels = ephysalign.get_channel_locations(feature, track) 

777 return xyz_channels 

778 

779 

780@dataclass 

781class SpikeSortingLoader: 

782 """ 

783 Object that will load spike sorting data for a given probe insertion. 

784 This class can be instantiated in several manners 

785 - With Alyx database probe id: 

786 SpikeSortingLoader(pid=pid, one=one) 

787 - With Alyx database eic and probe name: 

788 SpikeSortingLoader(eid=eid, pname='probe00', one=one) 

789 - From a local session and probe name: 

790 SpikeSortingLoader(session_path=session_path, pname='probe00') 

791 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded. 

792 """ 

793 one: One = None 

794 atlas: None = None 

795 pid: str = None 

796 eid: str = '' 

797 pname: str = '' 

798 session_path: Path = '' 

799 # the following properties are the outcome of the post init function 

800 collections: list = None 

801 datasets: list = None # list of all datasets belonging to the session 

802 # the following properties are the outcome of a reading function 

803 files: dict = None 

804 raw_data_files: list = None # list of raw ap and lf files corresponding to the recording 

805 collection: str = '' 

806 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced' 

807 spike_sorter: str = 'pykilosort' 

808 spike_sorting_path: Path = None 

809 _sync: dict = None 

810 

811 def __post_init__(self): 

812 # pid gets precedence 

813 if self.pid is not None: 1cbedk

814 try: 1dk

815 self.eid, self.pname = self.one.pid2eid(self.pid) 1dk

816 except NotImplementedError: 

817 if self.eid == '' or self.pname == '': 

818 raise IOError("Cannot infer session id and probe name from pid. " 

819 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.") 

820 self.session_path = self.one.eid2path(self.eid) 1dk

821 # then eid / pname combination 

822 elif self.session_path is None or self.session_path == '': 1cbe

823 self.session_path = self.one.eid2path(self.eid) 1cbe

824 # fully local providing a session path 

825 else: 

826 if self.one: 

827 self.eid = self.one.to_eid(self.session_path) 

828 else: 

829 self.one = One(cache_dir=self.session_path.parents[2], mode='local') 

830 df_sessions = cache._make_sessions_df(self.session_path) 

831 self.one._cache['sessions'] = df_sessions.set_index('id') 

832 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False) 

833 self.eid = str(self.session_path.relative_to(self.session_path.parents[2])) 

834 # populates default properties 

835 self.collections = self.one.list_collections( 1cbedk

836 self.eid, filename='spikes*', collection=f"alf/{self.pname}*") 

837 self.datasets = self.one.list_datasets(self.eid) 1cbedk

838 if self.atlas is None: 1cbedk

839 self.atlas = AllenAtlas() 1cbek

840 self.files = {} 1cbedk

841 self.raw_data_files = [] 1cbedk

842 

843 def _load_object(self, *args, **kwargs): 

844 """ 

845 This function is a wrapper around alfio.load_object that will remove the UUID in the 

846 filename if the object is on SDSC. 

847 """ 

848 remove_uuids = getattr(self.one, 'uuid_filenames', False) 1cbed

849 d = alfio.load_object(*args, **kwargs) 1cbed

850 if remove_uuids: 1cbed

851 # pops the UUID in the key names 

852 keys = list(d.keys()) 

853 for k in keys: 

854 d[k[:-37]] = d.pop(k) 

855 return d 1cbed

856 

857 @staticmethod 

858 def _get_attributes(dataset_types): 

859 """returns attributes to load for spikes and clusters objects""" 

860 dataset_types = [] if dataset_types is None else dataset_types 1cbed

861 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1cbed

862 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1cbed

863 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1cbed

864 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1cbed

865 waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] 1cbed

866 waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) 1cbed

867 return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} 1cbed

868 

869 def _get_spike_sorting_collection(self, spike_sorter=None): 

870 """ 

871 Filters a list or array of collections to get the relevant spike sorting dataset 

872 if there is a pykilosort, load it 

873 """ 

874 for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']): 1cbed

875 if sorter is None: 1cbed

876 continue 

877 if sorter == "": 1cbed

878 collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None) 1cbe

879 else: 

880 collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None) 1bd

881 if collection is not None: 1cbed

882 return collection 1cbed

883 # if none is found amongst the defaults, prefers the shortest 

884 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) 

885 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") 

886 return collection 

887 

888 def load_spike_sorting_object(self, obj, *args, **kwargs): 

889 """ 

890 Loads an ALF object 

891 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

892 :param spike_sorter: (defaults to 'pykilosort') 

893 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

894 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

895 :param kwargs: additional arguments to be passed to one.api.One.load_object 

896 :param missing: 'raise' (default) or 'ignore' 

897 :return: 

898 """ 

899 self.download_spike_sorting_object(obj, *args, **kwargs) 

900 return self._load_object(self.files[obj]) 

901 

902 def get_version(self, spike_sorter=None): 

903 spike_sorter = (spike_sorter or self.spike_sorter) or 'iblsorter' 

904 collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 

905 dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy') 

906 return dset[0]['version'] if len(dset) else 'unknown' 

907 

908 def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=None, collection=None, 

909 attribute=None, missing='raise', **kwargs): 

910 """ 

911 Downloads an ALF object 

912 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

913 :param spike_sorter: (defaults to 'pykilosort') 

914 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

915 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

916 :param kwargs: additional arguments to be passed to one.api.One.load_object 

917 :param attribute: list of attributes to load for the object 

918 :param missing: 'raise' (default) or 'ignore' 

919 :return: 

920 """ 

921 if spike_sorter is None: 1cbed

922 spike_sorter = self.spike_sorter if self.spike_sorter is not None else 'iblsorter' 1cbed

923 if len(self.collections) == 0: 1cbed

924 return {}, {}, {} 

925 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1cbed

926 collection = collection or self.collection 1cbed

927 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1cbed

928 attributes = self._get_attributes(dataset_types) 1cbed

929 try: 1cbed

930 self.files[obj] = self.one.load_object( 1cbed

931 self.eid, obj=obj, attribute=attributes.get(obj, None), 

932 collection=collection, download_only=True, **kwargs) 

933 except ALFObjectNotFound as e: 1cbe

934 if missing == 'raise': 1cbe

935 raise e 

936 

937 def download_spike_sorting(self, objects=None, **kwargs): 

938 """ 

939 Downloads spikes, clusters and channels 

940 :param spike_sorter: (defaults to 'pykilosort') 

941 :param dataset_types: list of extra dataset types 

942 :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels'] 

943 :return: 

944 """ 

945 objects = ['spikes', 'clusters', 'channels'] if objects is None else objects 1cbed

946 for obj in objects: 1cbed

947 self.download_spike_sorting_object(obj=obj, **kwargs) 1cbed

948 self.spike_sorting_path = self.files['clusters'][0].parent 1cbed

949 

950 def download_raw_electrophysiology(self, band='ap'): 

951 """ 

952 Downloads raw electrophysiology data files on local disk. 

953 :param band: "ap" (default) or "lf" for LFP band 

954 :return: list of raw data files full paths (ch, meta and cbin files) 

955 """ 

956 raw_data_files = [] 

957 for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']: 

958 try: 

959 # FIXME: this will fail if multiple LFP segments are found 

960 raw_data_files.append(self.one.load_dataset( 

961 self.eid, 

962 download_only=True, 

963 collection=f'raw_ephys_data/{self.pname}', 

964 dataset=suffix, 

965 check_hash=False, 

966 )) 

967 except ALFObjectNotFound: 

968 _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}") 

969 self.raw_data_files = list(set(self.raw_data_files + raw_data_files)) 

970 return raw_data_files 

971 

972 def raw_electrophysiology(self, stream=True, band='ap', **kwargs): 

973 """ 

974 Returns a reader for the raw electrophysiology data 

975 By default it is a streamer object, but if stream is False, it will return a spikeglx.Reader after having 

976 downloaded the raw data file if necessary 

977 :param stream: 

978 :param band: 

979 :param kwargs: 

980 :return: 

981 """ 

982 if stream: 1k

983 return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) 1k

984 else: 

985 raw_data_files = self.download_raw_electrophysiology(band=band) 

986 cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None) 

987 if cbin_file is not None: 

988 return spikeglx.Reader(cbin_file) 

989 

990 def download_raw_waveforms(self, **kwargs): 

991 """ 

992 Downloads raw waveforms extracted from sorting to local disk. 

993 """ 

994 _logger.debug(f"loading waveforms from {self.collection}") 

995 return self.one.load_object( 

996 id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"], 

997 collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs 

998 ) 

999 

1000 def raw_waveforms(self, **kwargs): 

1001 wf_paths = self.download_raw_waveforms(**kwargs) 

1002 return WaveformsLoader(wf_paths[0].parent) 

1003 

1004 def load_channels(self, **kwargs): 

1005 """ 

1006 Loads channels 

1007 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

1008 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

1009 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

1010 - resolved: channel locations alignments have been agreed upon 

1011 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

1012 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

1013 

1014 :param spike_sorter: (defaults to 'pykilosort') 

1015 :param dataset_types: list of extra dataset types 

1016 :return: 

1017 """ 

1018 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting 

1019 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1cbed

1020 self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) 1cbed

1021 channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) 1cbed

1022 if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails 1cbed

1023 esites = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1d

1024 if alfio.check_dimensions(esites) != 0: 1d

1025 esites = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 

1026 esites['rawInd'] = np.arange(esites[list(esites.keys())[0]].shape[0]) 

1027 if 'brainLocationIds_ccf_2017' not in channels: 1cbed

1028 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1b

1029 _channels, self.histology = _load_channel_locations_traj( 1b

1030 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True) 

1031 if _channels: 1b

1032 channels = _channels[self.pname] 

1033 else: 

1034 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1cbed

1035 self.histology = 'alf' 1cbed

1036 return Bunch(channels) 1cbed

1037 

1038 def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, **kwargs): 

1039 """ 

1040 Loads spikes, clusters and channels 

1041 

1042 There could be several spike sorting collections, by default the loader will get the pykilosort collection 

1043 

1044 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

1045 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

1046 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

1047 - resolved: channel locations alignments have been agreed upon 

1048 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

1049 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

1050 

1051 :param spike_sorter: (defaults to 'pykilosort') 

1052 :param revision: for example "2024-05-06", (defaults to None): 

1053 :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one 

1054 :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates'] 

1055 :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table 

1056 :param kwargs: additional arguments to be passed to one.api.One.load_object 

1057 :return: 

1058 """ 

1059 if len(self.collections) == 0: 1cbed

1060 return {}, {}, {} 

1061 self.files = {} 1cbed

1062 self.spike_sorter = spike_sorter 1cbed

1063 self.revision = revision 1cbed

1064 objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None 1cbed

1065 self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) 1cbed

1066 channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs) 1cbed

1067 clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) 1cbed

1068 if good_units: 1cbed

1069 spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards) 

1070 else: 

1071 spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) 1cbed

1072 if enforce_version: 1cbed

1073 self._assert_version_consistency() 

1074 return spikes, clusters, channels 1cbed

1075 

1076 def _assert_version_consistency(self): 

1077 """ 

1078 Makes sure the state of the spike sorting object matches the files downloaded 

1079 :return: None 

1080 """ 

1081 for k in ['spikes', 'clusters', 'channels', 'passingSpikes']: 

1082 for fn in self.files.get(k, []): 

1083 if self.spike_sorter: 

1084 assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \ 

1085 f"You required strict version {self.spike_sorter}, {fn} does not match" 

1086 if self.revision: 

1087 assert full_path_parts(fn)[5] == self.revision, \ 

1088 f"You required strict revision {self.revision}, {fn} does not match" 

1089 

1090 @staticmethod 

1091 def compute_metrics(spikes, clusters=None): 

1092 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size 

1093 metrics = pd.DataFrame(quick_unit_metrics( 

1094 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc))) 

1095 return metrics 

1096 

1097 @staticmethod 

1098 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False): 

1099 """ 

1100 Merge the metrics and the channel information into the clusters dictionary 

1101 :param spikes: 

1102 :param clusters: 

1103 :param channels: 

1104 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used 

1105 for clusters or analysis applications (defaults to None). 

1106 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false) 

1107 :return: cluster dictionary containing metrics and histology 

1108 """ 

1109 if spikes == {}: 1bd

1110 return 

1111 nc = clusters['channels'].size 1bd

1112 # recompute metrics if they are not available 

1113 metrics = None 1bd

1114 if 'metrics' in clusters: 1bd

1115 metrics = clusters.pop('metrics') 1bd

1116 if metrics.shape[0] != nc: 1bd

1117 metrics = None 

1118 if metrics is None or compute_metrics is True: 1bd

1119 _logger.debug("recompute clusters metrics") 

1120 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters) 

1121 if isinstance(cache_dir, Path): 

1122 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt')) 

1123 for k in metrics.keys(): 1bd

1124 clusters[k] = metrics[k].to_numpy() 1bd

1125 for k in channels.keys(): 1bd

1126 clusters[k] = channels[k][clusters['channels']] 1bd

1127 if cache_dir is not None: 1bd

1128 _logger.debug(f'caching clusters metrics in {cache_dir}') 

1129 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt')) 

1130 return clusters 1bd

1131 

1132 @property 

1133 def url(self): 

1134 """Gets flatiron URL for the session""" 

1135 webclient = getattr(self.one, '_web_client', None) 

1136 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None 

1137 

1138 def _get_probe_info(self): 

1139 if self._sync is None: 1e

1140 timestamps = self.one.load_dataset( 1e

1141 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}') 

1142 _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks 1e

1143 self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}') 

1144 try: 1e

1145 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1e

1146 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}')) 

1147 fs = spikeglx._get_fs_from_meta(ap_meta) 1e

1148 except ALFObjectNotFound: 

1149 ap_meta = None 

1150 fs = 30_000 

1151 self._sync = { 1e

1152 'timestamps': timestamps, 

1153 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'), 

1154 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'), 

1155 'ap_meta': ap_meta, 

1156 'fs': fs, 

1157 } 

1158 

1159 def timesprobe2times(self, values, direction='forward'): 

1160 self._get_probe_info() 

1161 if direction == 'forward': 

1162 return self._sync['forward'](values * self._sync['fs']) 

1163 elif direction == 'reverse': 

1164 return self._sync['reverse'](values) / self._sync['fs'] 

1165 

1166 def samples2times(self, values, direction='forward'): 

1167 """ 

1168 Converts ephys sample values to session main clock seconds 

1169 :param values: numpy array of times in seconds or samples to resync 

1170 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse' 

1171 (seconds main time to samples probe time) 

1172 :return: 

1173 """ 

1174 self._get_probe_info() 1e

1175 return self._sync[direction](values) 1e

1176 

1177 @property 

1178 def pid2ref(self): 

1179 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1cb

1180 

1181 def _default_plot_title(self, spikes): 

1182 title = f"{self.pid2ref}, {self.pid} \n" \ 1c

1183 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters" 

1184 return title 1c

1185 

1186 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, 

1187 drift=None, title=None, **kwargs): 

1188 """ 

1189 :param spikes: spikes dictionary or Bunch 

1190 :param channels: channels dictionary or Bunch. 

1191 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png". 

1192 Otherwise, plot. 

1193 :param br: brain regions object (optional) 

1194 :param label: label for saved image (optional, default="raster") 

1195 :param time_series: timeseries dictionary for behavioral event times (optional) 

1196 :param **kwargs: kwargs passed to `driftmap()` (optional) 

1197 :return: 

1198 """ 

1199 br = br or BrainRegions() 1c

1200 time_series = time_series or {} 1c

1201 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c

1202 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col') 

1203 axs[0, 1].set_axis_off() 1c

1204 # axs[0, 0].set_xticks([]) 

1205 if kwargs is None: 1c

1206 # set default raster plot parameters 

1207 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5} 

1208 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c

1209 if title is None: 1c

1210 title = self._default_plot_title(spikes) 1c

1211 axs[0, 0].title.set_text(title) 1c

1212 for k, ts in time_series.items(): 1c

1213 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0]) 

1214 if 'atlas_id' in channels: 1c

1215 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c

1216 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology) 

1217 axs[1, 0].set_ylim(0, 3800) 1c

1218 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c

1219 fig.tight_layout() 1c

1220 

1221 if drift is None: 1c

1222 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c

1223 if 'drift' in self.files: 1c

1224 drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards) 

1225 if isinstance(drift, dict): 1c

1226 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5) 

1227 axs[0, 0].set(ylim=[-15, 15]) 

1228 

1229 if save_dir is not None: 1c

1230 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) 

1231 fig.savefig(png_file) 

1232 plt.close(fig) 

1233 gc.collect() 

1234 else: 

1235 return fig, axs 1c

1236 

1237 def plot_rawdata_snippet(self, sr, spikes, clusters, t0, 

1238 channels=None, 

1239 br: BrainRegions = None, 

1240 save_dir=None, 

1241 label='raster', 

1242 gain=-93, 

1243 title=None): 

1244 

1245 # compute the raw data offset and destripe, we take 400ms around t0 

1246 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs)) 

1247 raw = sr[first_sample:last_sample, :-sr.nsync].T 

1248 channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True 

1249 destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels) 

1250 # filter out the spikes according to good/bad clusters and to the time slice 

1251 spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample])) 

1252 ss = spikes['samples'][spike_sel] 

1253 sc = clusters['channels'][spikes['clusters'][spike_sel]] 

1254 sok = clusters['label'][spikes['clusters'][spike_sel]] == 1 

1255 if title is None: 

1256 title = self._default_plot_title(spikes) 

1257 # display the raw data snippet with spikes overlaid 

1258 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col') 

1259 Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s') 

1260 axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5) 

1261 axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5) 

1262 axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035]) 

1263 # adds the channel locations if available 

1264 if (channels is not None) and ('atlas_id' in channels): 

1265 br = br or BrainRegions() 

1266 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 

1267 brain_regions=br, display=True, ax=axs[1], title=self.histology) 

1268 axs[1].get_yaxis().set_visible(False) 

1269 fig.tight_layout() 

1270 

1271 if save_dir is not None: 

1272 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) 

1273 fig.savefig(png_file) 

1274 plt.close(fig) 

1275 gc.collect() 

1276 else: 

1277 return fig, axs 

1278 

1279 

1280@dataclass 

1281class SessionLoader: 

1282 """ 

1283 Object to load session data for a give session in the recommended way. 

1284 

1285 Parameters 

1286 ---------- 

1287 one: one.api.ONE instance 

1288 Can be in remote or local mode (required) 

1289 session_path: string or pathlib.Path 

1290 The absolute path to the session (one of session_path or eid is required) 

1291 eid: string 

1292 database UUID of the session (one of session_path or eid is required) 

1293 

1294 If both are provided, session_path takes precedence over eid. 

1295 

1296 Examples 

1297 -------- 

1298 1) Load all available session data for one session: 

1299 >>> from one.api import ONE 

1300 >>> from brainbox.io.one import SessionLoader 

1301 >>> one = ONE() 

1302 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/') 

1303 # Object is initiated, but no data is loaded as you can see in the data_info attribute 

1304 >>> sess_loader.data_info 

1305 name is_loaded 

1306 0 trials False 

1307 1 wheel False 

1308 2 pose False 

1309 3 motion_energy False 

1310 4 pupil False 

1311 

1312 # Loading all available session data, the data_info attribute now shows which data has been loaded 

1313 >>> sess_loader.load_session_data() 

1314 >>> sess_loader.data_info 

1315 name is_loaded 

1316 0 trials True 

1317 1 wheel True 

1318 2 pose True 

1319 3 motion_energy True 

1320 4 pupil False 

1321 

1322 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g. 

1323 >>> type(sess_loader.trials) 

1324 pandas.core.frame.DataFrame 

1325 >>> sess_loader.trials.shape 

1326 (626, 18) 

1327 # Each data comes with its own timestamps in a column called 'times' 

1328 >>> sess_loader.wheel['times'] 

1329 0 0.134286 

1330 1 0.135286 

1331 2 0.136286 

1332 3 0.137286 

1333 4 0.138286 

1334 ... 

1335 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera. 

1336 # The dataframes of all cameras are collected in a dictionary 

1337 >>> type(sess_loader.pose) 

1338 dict 

1339 >>> sess_loader.pose.keys() 

1340 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera']) 

1341 >>> sess_loader.pose['bodyCamera'].columns 

1342 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object') 

1343 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading 

1344 functions: 

1345 >>> sess_loader.load_wheel(sampling_rate=100) 

1346 """ 

1347 one: One = None 

1348 session_path: Path = '' 

1349 eid: str = '' 

1350 revision: str = '' 

1351 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1352 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1353 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1354 pose: dict = field(default_factory=dict, repr=False) 

1355 motion_energy: dict = field(default_factory=dict, repr=False) 

1356 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1357 

1358 def __post_init__(self): 

1359 """ 

1360 Function that runs automatically after initiation of the dataclass attributes. 

1361 Checks for required inputs, sets session_path and eid, creates data_info table. 

1362 """ 

1363 if self.one is None: 1af

1364 raise ValueError("An input to one is required. If not connection to a database is desired, it can be " 

1365 "a fully local instance of One.") 

1366 # If session path is given, takes precedence over eid 

1367 if self.session_path is not None and self.session_path != '': 1af

1368 self.eid = self.one.to_eid(self.session_path) 1af

1369 self.session_path = Path(self.session_path) 1af

1370 # Providing no session path, try to infer from eid 

1371 else: 

1372 if self.eid is not None and self.eid != '': 

1373 self.session_path = self.one.eid2path(self.eid) 

1374 else: 

1375 raise ValueError("If no session path is given, eid is required.") 

1376 

1377 data_names = [ 1af

1378 'trials', 

1379 'wheel', 

1380 'pose', 

1381 'motion_energy', 

1382 'pupil' 

1383 ] 

1384 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1af

1385 

1386 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False): 

1387 """ 

1388 Function to load available session data into the SessionLoader object. Input parameters allow to control which 

1389 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input 

1390 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored 

1391 in SessionLoader.data_info 

1392 

1393 Parameters 

1394 ---------- 

1395 trials: boolean 

1396 Whether to load all trials data into SessionLoader.trials, default is True 

1397 wheel: boolean 

1398 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True 

1399 pose: boolean 

1400 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose, 

1401 default is True 

1402 motion_energy: boolean 

1403 Whether to load motion energy data (whisker pad for left/right camera, body for body camera) 

1404 into SessionLoader.motion_energy, default is True 

1405 pupil: boolean 

1406 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil, 

1407 default is True 

1408 reload: boolean 

1409 Whether to reload data that has already been loaded into this SessionLoader object, default is False 

1410 """ 

1411 load_df = self.data_info.copy() 1f

1412 load_df['to_load'] = [ 1f

1413 trials, 

1414 wheel, 

1415 pose, 

1416 motion_energy, 

1417 pupil 

1418 ] 

1419 load_df['load_func'] = [ 1f

1420 self.load_trials, 

1421 self.load_wheel, 

1422 self.load_pose, 

1423 self.load_motion_energy, 

1424 self.load_pupil 

1425 ] 

1426 

1427 for idx, row in load_df.iterrows(): 1f

1428 if row['to_load'] is False: 1f

1429 _logger.debug(f"Not loading {row['name']} data, set to False.") 

1430 elif row['is_loaded'] is True and reload is False: 1f

1431 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1f

1432 else: 

1433 try: 1f

1434 _logger.info(f"Loading {row['name']} data") 1f

1435 row['load_func']() 1f

1436 self.data_info.loc[idx, 'is_loaded'] = True 1f

1437 except BaseException as e: 

1438 _logger.warning(f"Could not load {row['name']} data.") 

1439 _logger.debug(e) 

1440 

1441 def _find_behaviour_collection(self, obj): 

1442 """ 

1443 Function to find the trial or wheel collection 

1444 

1445 Parameters 

1446 ---------- 

1447 obj: str 

1448 Alf object to load, either 'trials' or 'wheel' 

1449 """ 

1450 dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy' 1fnj

1451 dsets = self.one.list_datasets(self.eid, dataset) 1fnj

1452 if len(dsets) == 0: 1fnj

1453 return 'alf' 1fn

1454 else: 

1455 collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets] 1fj

1456 if len(set(collections)) == 1: 1fj

1457 return collections[0] 1fj

1458 else: 

1459 _logger.error(f'Multiple collections found {collections}. Specify collection when loading, ' 

1460 f'e.g sl.load_{obj}(collection="{collections[0]}")') 

1461 raise ALFMultipleCollectionsFound 

1462 

1463 def load_trials(self, collection=None): 

1464 """ 

1465 Function to load trials data into SessionLoader.trials 

1466 

1467 Parameters 

1468 ---------- 

1469 collection: str 

1470 Alf collection of trials data 

1471 """ 

1472 

1473 if not collection: 1fn

1474 collection = self._find_behaviour_collection('trials') 1fn

1475 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex 

1476 self.one.wildcards = False 1fn

1477 self.trials = self.one.load_object( 1fn

1478 self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df() 

1479 self.one.wildcards = True 1fn

1480 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1fn

1481 

1482 def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None): 

1483 """ 

1484 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position 

1485 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which 

1486 a Butterworth low-pass filter is applied. 

1487 

1488 Parameters 

1489 ---------- 

1490 fs: int, float 

1491 Sampling frequency for the wheel position, default is 1000 Hz 

1492 corner_frequency: int, float 

1493 Corner frequency of Butterworth low-pass filter, default is 20 

1494 order: int, float 

1495 Order of Butterworth low_pass filter, default is 8 

1496 collection: str 

1497 Alf collection of wheel data 

1498 """ 

1499 if not collection: 1fj

1500 collection = self._find_behaviour_collection('wheel') 1fj

1501 wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection, revision=self.revision or None) 1fj

1502 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1fj

1503 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps") 

1504 # resample the wheel position and compute velocity, acceleration 

1505 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1fj

1506 self.wheel['position'], self.wheel['times'] = interpolate_position( 1fj

1507 wheel_raw['timestamps'], wheel_raw['position'], freq=fs) 

1508 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1fj

1509 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order) 

1510 self.wheel = self.wheel.apply(np.float32) 1fj

1511 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1fj

1512 

1513 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']): 

1514 """ 

1515 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a 

1516 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas 

1517 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera. 

1518 

1519 Parameters 

1520 ---------- 

1521 likelihood_thr: float 

1522 The position of each tracked body part come with a likelihood of that estimate for each time point. 

1523 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set 

1524 likelihood_thr=1. Default is 0.9 

1525 views: list 

1526 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1527 """ 

1528 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1529 self.pose = {} 1mhf

1530 for view in views: 1mhf

1531 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'], revision=self.revision or None) 1mhf

1532 # Double check if video timestamps are correct length or can be fixed 

1533 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1mhf

1534 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1mhf

1535 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1mhf

1536 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1mhf

1537 

1538 def load_motion_energy(self, views=['left', 'right', 'body']): 

1539 """ 

1540 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a 

1541 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are 

1542 pandas Dataframes with the timestamps and motion energy data. 

1543 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad 

1544 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the 

1545 body (bodyMotionEnergy). 

1546 

1547 Parameters 

1548 ---------- 

1549 views: list 

1550 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1551 """ 

1552 names = {'left': 'whiskerMotionEnergy', 1lf

1553 'right': 'whiskerMotionEnergy', 

1554 'body': 'bodyMotionEnergy'} 

1555 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1556 self.motion_energy = {} 1lf

1557 for view in views: 1lf

1558 me_raw = self.one.load_object( 1lf

1559 self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None) 

1560 # Double check if video timestamps are correct length or can be fixed 

1561 times_fixed, motion_energy = self._check_video_timestamps( 1lf

1562 view, me_raw['times'], me_raw['ROIMotionEnergy']) 

1563 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1lf

1564 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1lf

1565 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1lf

1566 

1567 def load_licks(self): 

1568 """ 

1569 Not yet implemented 

1570 """ 

1571 pass 

1572 

1573 def load_pupil(self, snr_thresh=5.): 

1574 """ 

1575 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil. 

1576 

1577 Parameters 

1578 ---------- 

1579 snr_thresh: float 

1580 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data 

1581 will be considered unusable and will be discarded. 

1582 """ 

1583 # Try to load from features 

1584 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'], revision=self.revision or None) 1hf

1585 if 'features' in feat_raw.keys(): 1hf

1586 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features']) 

1587 self.pupil = feats.copy() 

1588 self.pupil.insert(0, 'times', times_fixed) 

1589 

1590 # If unavailable compute on the fly 

1591 else: 

1592 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1hf

1593 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1hf

1594 and 'leftCamera' in self.pose.keys()): 

1595 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt 

1596 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1hf

1597 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1hf

1598 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1hf

1599 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1hf

1600 else: 

1601 self.load_pose(views=['left'], likelihood_thr=0.9) 

1602 dlc_thr = self.pose['leftCamera'].copy() 

1603 

1604 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1hf

1605 try: 1hf

1606 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1hf

1607 except BaseException as e: 

1608 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. " 

1609 "Saving all NaNs for pupilDiameter_smooth.") 

1610 _logger.debug(e) 

1611 self.pupil['pupilDiameter_smooth'] = np.nan 

1612 

1613 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1hf

1614 good_idxs = np.where( 1hf

1615 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] 

1616 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1hf

1617 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]))) 

1618 if snr < snr_thresh: 1hf

1619 self.pupil = pd.DataFrame() 1h

1620 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h

1621 

1622 def _check_video_timestamps(self, view, video_timestamps, video_data): 

1623 """ 

1624 Helper function to check for the length of the video frames vs video timestamps and fix in case 

1625 timestamps are longer than video frames. 

1626 """ 

1627 # If camera times are shorter than video data, or empty, no current fix 

1628 if video_timestamps.shape[0] < video_data.shape[0]: 1lmhf

1629 if video_timestamps.shape[0] == 0: 

1630 msg = f'Camera times empty for {view}Camera.' 

1631 else: 

1632 msg = f'Camera times are shorter than video data for {view}Camera.' 

1633 _logger.warning(msg) 

1634 raise ValueError(msg) 

1635 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video. 

1636 # This is because the first few frames are sometimes not recorded. We can remove the first few 

1637 # timestamps in this case 

1638 elif video_timestamps.shape[0] > video_data.shape[0]: 1lmhf

1639 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1lmhf

1640 return video_timestamps_fixed, video_data 1lmhf

1641 else: 

1642 return video_timestamps, video_data 

1643 

1644 

1645class EphysSessionLoader(SessionLoader): 

1646 """ 

1647 Spike sorting enhanced version of SessionLoader 

1648 Loads spike sorting data for all probes in the session, in the self.ephys dict 

1649 >>> EphysSessionLoader(eid=eid, one=one) 

1650 To select for a specific probe 

1651 >>> EphysSessionLoader(eid=eid, one=one, pid=pid) 

1652 """ 

1653 def __init__(self, *args, pname=None, pid=None, **kwargs): 

1654 """ 

1655 Needs an active connection in order to get the list of insertions in the session 

1656 :param args: 

1657 :param kwargs: 

1658 """ 

1659 super().__init__(*args, **kwargs) 

1660 # if necessary, restrict the query 

1661 qargs = {} if pname is None else {'name': pname} 

1662 qargs = qargs or ({} if pid is None else {'id': pid}) 

1663 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs) 

1664 self.ephys = {} 

1665 for ins in insertions: 

1666 self.ephys[ins['name']] = {} 

1667 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one) 

1668 

1669 def load_session_data(self, *args, **kwargs): 

1670 super().load_session_data(*args, **kwargs) 

1671 self.load_spike_sorting() 

1672 

1673 def load_spike_sorting(self, pnames=None): 

1674 pnames = pnames or list(self.ephys.keys()) 

1675 for pname in pnames: 

1676 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting() 

1677 self.ephys[pname]['spikes'] = spikes 

1678 self.ephys[pname]['clusters'] = clusters 

1679 self.ephys[pname]['channels'] = channels 

1680 

1681 @property 

1682 def probes(self): 

1683 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}