Coverage for brainbox/io/one.py: 56%

726 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment.""" 

2from dataclasses import dataclass, field 

3import gc 

4import logging 

5import re 

6import os 

7from pathlib import Path 

8 

9import numpy as np 

10import pandas as pd 

11from scipy.interpolate import interp1d 

12import matplotlib.pyplot as plt 

13 

14from one.api import ONE, One 

15from one.alf.files import get_alf_path, full_path_parts 

16from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound 

17from one.alf import cache 

18import one.alf.io as alfio 

19from neuropixel import TIP_SIZE_UM, trace_header 

20import spikeglx 

21 

22import ibldsp.voltage 

23from iblutil.util import Bunch 

24from iblatlas.atlas import AllenAtlas, BrainRegions 

25from iblatlas import atlas 

26from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times 

27from ibllib.pipes import histology 

28from ibllib.pipes.ephys_alignment import EphysAlignment 

29from ibllib.plots import vertical_lines, Density 

30 

31import brainbox.plot 

32from brainbox.io.spikeglx import Streamer 

33from brainbox.ephys_plots import plot_brain_regions 

34from brainbox.metrics.single_units import quick_unit_metrics 

35from brainbox.behavior.wheel import interpolate_position, velocity_filtered 

36from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter 

37 

38_logger = logging.getLogger('ibllib') 

39 

40 

41SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths'] 

42CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids'] 

43WAVEFORMS_ATTRIBUTES = ['templates'] 

44 

45 

46def load_lfp(eid, one=None, dataset_types=None, **kwargs): 

47 """ 

48 TODO Verify works 

49 From an eid, hits the Alyx database and downloads the standard set of datasets 

50 needed for LFP 

51 :param eid: 

52 :param dataset_types: additional dataset types to add to the list 

53 :param open: if True, spikeglx readers are opened 

54 :return: spikeglx.Reader 

55 """ 

56 if dataset_types is None: 

57 dataset_types = [] 

58 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*'] 

59 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes] 

60 session_path = one.eid2path(eid) 

61 

62 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False) 

63 if ef.get('lf', None)] 

64 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles] 

65 

66 

67def _collection_filter_from_args(probe, spike_sorter=None): 

68 collection = f'alf/{probe}/{spike_sorter}' 1g

69 collection = collection.replace('None', '*') 1g

70 collection = collection.replace('/*', '*') 1g

71 collection = collection[:-1] if collection.endswith('/') else collection 1g

72 return collection 1g

73 

74 

75def _get_spike_sorting_collection(collections, pname): 

76 """ 

77 Filters a list or array of collections to get the relevant spike sorting dataset 

78 if there is a pykilosort, load it 

79 """ 

80 # 

81 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1gb

82 # otherwise, prefers the shortest 

83 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1gb

84 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1gb

85 return collection 1gb

86 

87 

88def _channels_alyx2bunch(chans): 

89 channels = Bunch({ 

90 'atlas_id': np.array([ch['brain_region'] for ch in chans]), 

91 'x': np.array([ch['x'] for ch in chans]) / 1e6, 

92 'y': np.array([ch['y'] for ch in chans]) / 1e6, 

93 'z': np.array([ch['z'] for ch in chans]) / 1e6, 

94 'axial_um': np.array([ch['axial'] for ch in chans]), 

95 'lateral_um': np.array([ch['lateral'] for ch in chans]) 

96 }) 

97 return channels 

98 

99 

100def _channels_traj2bunch(xyz_chans, brain_atlas): 

101 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans)) 

102 channels = { 

103 'x': xyz_chans[:, 0], 

104 'y': xyz_chans[:, 1], 

105 'z': xyz_chans[:, 2], 

106 'acronym': brain_regions['acronym'], 

107 'atlas_id': brain_regions['id'] 

108 } 

109 

110 return channels 

111 

112 

113def _channels_bunch2alf(channels): 

114 channels_ = { 1i

115 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6, 

116 'brainLocationIds_ccf_2017': channels['atlas_id'], 

117 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]} 

118 return channels_ 1i

119 

120 

121def _channels_alf2bunch(channels, brain_regions=None): 

122 # reformat the dictionary according to the standard that comes out of Alyx 

123 channels_ = { 1icbed

124 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6, 

125 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6, 

126 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6, 

127 'acronym': None, 

128 'atlas_id': channels['brainLocationIds_ccf_2017'], 

129 'axial_um': channels['localCoordinates'][:, 1], 

130 'lateral_um': channels['localCoordinates'][:, 0], 

131 } 

132 # here if we have some extra keys, they will carry over to the next dictionary 

133 for k in channels: 1icbed

134 if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: 1icbed

135 channels_[k] = channels[k] 1icbed

136 if brain_regions: 1icbed

137 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icbed

138 return channels_ 1icbed

139 

140 

141def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None, 

142 brain_regions=None): 

143 """ 

144 Generic function to load spike sorting according data using ONE. 

145 

146 Will try to load one spike sorting for any probe present for the eid matching the collection 

147 For each probe it will load a spike sorting: 

148 - if there is one version: loads this one 

149 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX) 

150 

151 Parameters 

152 ---------- 

153 eid : [str, UUID, Path, dict] 

154 Experiment session identifier; may be a UUID, URL, experiment reference string 

155 details dict or Path 

156 one : one.api.OneAlyx 

157 An instance of ONE (may be in 'local' mode) 

158 collection : str 

159 collection filter word - accepts wildcards - can be a combination of spike sorter and 

160 probe. See `ALF documentation`_ for details. 

161 revision : str 

162 A particular revision return (defaults to latest revision). See `ALF documentation`_ for 

163 details. 

164 return_channels : bool 

165 Defaults to False otherwise loads channels from disk 

166 

167 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components 

168 

169 Returns 

170 ------- 

171 spikes : dict of one.alf.io.AlfBunch 

172 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

173 session and spike sorter, with keys ('clusters', 'times') 

174 clusters : dict of one.alf.io.AlfBunch 

175 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

176 ('channels', 'depths', 'metrics') 

177 channels : dict of one.alf.io.AlfBunch 

178 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

179 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs 

180 non-lateralized. 

181 """ 

182 one = one or ONE() 1gb

183 # enumerate probes and load according to the name 

184 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1gb

185 if len(collections) == 0: 1gb

186 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

187 pnames = list(set(c.split('/')[1] for c in collections)) 1gb

188 spikes, clusters, channels = ({} for _ in range(3)) 1gb

189 

190 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1gb

191 

192 for pname in pnames: 1gb

193 probe_collection = _get_spike_sorting_collection(collections, pname) 1gb

194 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1gb

195 attribute=spike_attributes) 

196 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1gb

197 attribute=cluster_attributes) 

198 if return_channels: 1gb

199 channels = _load_channels_locations_from_disk( 1b

200 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions) 

201 return spikes, clusters, channels 1b

202 else: 

203 return spikes, clusters 1g

204 

205 

206def _get_attributes(dataset_types): 

207 if dataset_types is None: 1gb

208 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1gb

209 else: 

210 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 

211 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 

212 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 

213 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 

214 return spike_attributes, cluster_attributes 

215 

216 

217def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None): 

218 _logger.debug('loading spike sorting from disk') 1b

219 channels = Bunch({}) 1b

220 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) 1b

221 if len(collections) == 0: 1b

222 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

223 probes = list(set([c.split('/')[1] for c in collections])) 1b

224 for probe in probes: 1b

225 probe_collection = _get_spike_sorting_collection(collections, probe) 1b

226 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') 1b

227 # if the spike sorter has not aligned data, try and get the alignment available 

228 if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): 1b

229 aligned_channel_collections = one.list_collections( 

230 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision) 

231 if len(aligned_channel_collections) == 0: 

232 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}") 

233 continue 

234 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}") 

235 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe) 

236 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection) 

237 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe]) 

238 # only have to reformat channels if we were able to load coordinates from disk 

239 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) 1b

240 return channels 1b

241 

242 

243def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None): 

244 """ 

245 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto 

246 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field 

247 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts 

248 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys 

249 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017' 

250 OR 

251 'x', 'y', 'z', 'acronym', 'axial_um' 

252 those are the guide for the interpolation 

253 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates' 

254 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object 

255 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017 

256 if a brain region object is provided, outputts a dict with keys 

257 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um' 

258 :return: Bunch or dictionary of channels with brain coordinates keys 

259 """ 

260 NEUROPIXEL_VERSION = 1 1i

261 h = trace_header(version=NEUROPIXEL_VERSION) 1i

262 if channels is None: 1i

263 channels = {'localCoordinates': np.c_[h['x'], h['y']]} 

264 nch = channels['localCoordinates'].shape[0] 1i

265 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i

266 channels_aligned = _channels_bunch2alf(channels_aligned) 1i

267 if 'localCoordinates' in channels_aligned.keys(): 1i

268 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i

269 else: # this is a edge case for a few spike sorting sessions 

270 assert channels_aligned['mlapdv'].shape[0] == 384 

271 aligned_depths = h['y'] 

272 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i

273 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i

274 channels['mlapdv'] = np.zeros((nch, 3)) 1i

275 for i in np.arange(3): 1i

276 channels['mlapdv'][:, i] = np.interp( 1i

277 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv] 

278 # the brain locations have to be interpolated by nearest neighbour 

279 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i

280 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i

281 if brain_regions is not None: 1i

282 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i

283 else: 

284 return channels 1i

285 

286 

287def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False, 

288 brain_atlas=None, return_source=False): 

289 if not hasattr(one, 'alyx'): 1b

290 return {}, None 1b

291 _logger.debug(f"trying to load from traj {probe}") 

292 channels = Bunch() 

293 brain_atlas = brain_atlas or AllenAtlas 

294 # need to find the collection bruh 

295 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0] 

296 collection = _collection_filter_from_args(probe=probe) 

297 collections = one.list_collections(eid, filename='channels*', collection=collection, 

298 revision=revision) 

299 probe_collection = _get_spike_sorting_collection(collections, probe) 

300 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection) 

301 depths = chn_coords[:, 1] 

302 

303 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

304 get('tracing_exists', False) 

305 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

306 get('alignment_resolved', False) 

307 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

308 get('alignment_count', 0) 

309 

310 if tracing: 

311 xyz = np.array(insertion['json']['xyz_picks']) / 1e6 

312 if resolved: 

313 

314 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. ' 

315 f'Channel and cluster locations obtained from ephys aligned histology ' 

316 f'track.') 

317 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

318 provenance='Ephys aligned histology track')[0] 

319 align_key = insertion['json']['extended_qc']['alignment_stored'] 

320 feature = traj['json'][align_key][0] 

321 track = traj['json'][align_key][1] 

322 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

323 feature_prev=feature, 

324 brain_atlas=brain_atlas, speedy=True) 

325 chans = ephysalign.get_channel_locations(feature, track) 

326 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

327 source = 'resolved' 

328 elif counts > 0 and aligned: 

329 _logger.debug(f'Channel locations for {eid}/{probe} have not been ' 

330 f'resolved. However, alignment flag set to True so channel and cluster' 

331 f' locations will be obtained from latest available ephys aligned ' 

332 f'histology track.') 

333 # get the latest user aligned channels 

334 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

335 provenance='Ephys aligned histology track')[0] 

336 align_key = insertion['json']['extended_qc']['alignment_stored'] 

337 feature = traj['json'][align_key][0] 

338 track = traj['json'][align_key][1] 

339 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

340 feature_prev=feature, 

341 brain_atlas=brain_atlas, speedy=True) 

342 chans = ephysalign.get_channel_locations(feature, track) 

343 

344 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

345 source = 'aligned' 

346 else: 

347 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. ' 

348 f'Channel and cluster locations obtained from histology track.') 

349 # get the channels from histology tracing 

350 xyz = xyz[np.argsort(xyz[:, 2]), :] 

351 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) 

352 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

353 source = 'traced' 

354 channels[probe]['axial_um'] = chn_coords[:, 1] 

355 channels[probe]['lateral_um'] = chn_coords[:, 0] 

356 

357 else: 

358 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}') 

359 source = '' 

360 channels = None 

361 

362 if return_source: 

363 return channels, source 

364 else: 

365 return channels 

366 

367 

368def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None): 

369 """ 

370 Load the brain locations of each channel for a given session/probe 

371 

372 Parameters 

373 ---------- 

374 eid : [str, UUID, Path, dict] 

375 Experiment session identifier; may be a UUID, URL, experiment reference string 

376 details dict or Path 

377 probe : [str, list of str] 

378 The probe label(s), e.g. 'probe01' 

379 one : one.api.OneAlyx 

380 An instance of ONE (shouldn't be in 'local' mode) 

381 aligned : bool 

382 Whether to get the latest user aligned channel when not resolved or use histology track 

383 brain_atlas : iblatlas.BrainAtlas 

384 Brain atlas object (default: Allen atlas) 

385 Returns 

386 ------- 

387 dict of one.alf.io.AlfBunch 

388 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

389 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

390 optional: string 'resolved', 'aligned', 'traced' or '' 

391 """ 

392 one = one or ONE() 

393 brain_atlas = brain_atlas or AllenAtlas() 

394 if isinstance(eid, dict): 

395 ses = eid 

396 eid = ses['url'][-36:] 

397 else: 

398 eid = one.to_eid(eid) 

399 collection = _collection_filter_from_args(probe=probe) 

400 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection, 

401 brain_regions=brain_atlas.regions) 

402 incomplete_probes = [k for k in channels if 'x' not in channels[k]] 

403 for iprobe in incomplete_probes: 

404 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned, 

405 brain_atlas=brain_atlas, return_source=True) 

406 if channels_ is not None: 

407 channels[iprobe] = channels_[iprobe] 

408 return channels 

409 

410 

411def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

412 brain_regions=None, nested=True, collection=None, return_collection=False): 

413 """ 

414 From an eid, loads spikes and clusters for all probes 

415 The following set of dataset types are loaded: 

416 'clusters.channels', 

417 'clusters.depths', 

418 'clusters.metrics', 

419 'spikes.clusters', 

420 'spikes.times', 

421 'probes.description' 

422 :param eid: experiment UUID or pathlib.Path of the local session 

423 :param one: an instance of OneAlyx 

424 :param probe: name of probe to load in, if not given all probes for session will be loaded 

425 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

426 :param spike_sorter: name of the spike sorting you want to load (None for default) 

427 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00" 

428 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

429 :param nested: if a single probe is required, do not output a dictionary with the probe name as key 

430 :param return_collection: (False) if True, will return the collection used to load 

431 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe) 

432 """ 

433 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.' 

434 'Use brainbox.io.one.SpikeSortingLoader instead') 

435 if collection is None: 

436 collection = _collection_filter_from_args(probe, spike_sorter) 

437 _logger.debug(f"load spike sorting with collection filter {collection}") 

438 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types, 

439 brain_regions=brain_regions) 

440 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True) 

441 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

442 if nested is False and len(spikes.keys()) == 1: 

443 k = list(spikes.keys())[0] 

444 channels = channels[k] 

445 clusters = clusters[k] 

446 spikes = spikes[k] 

447 if return_collection: 

448 return spikes, clusters, channels, collection 

449 else: 

450 return spikes, clusters, channels 

451 

452 

453def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

454 brain_regions=None, return_collection=False): 

455 """ 

456 From an eid, loads spikes and clusters for all probes 

457 The following set of dataset types are loaded: 

458 'clusters.channels', 

459 'clusters.depths', 

460 'clusters.metrics', 

461 'spikes.clusters', 

462 'spikes.times', 

463 'probes.description' 

464 :param eid: experiment UUID or pathlib.Path of the local session 

465 :param one: an instance of OneAlyx 

466 :param probe: name of probe to load in, if not given all probes for session will be loaded 

467 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

468 :param spike_sorter: name of the spike sorting you want to load (None for default) 

469 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

470 :param return_collection:(bool - False) if True, returns the collection for loading the data 

471 :return: spikes, clusters (dict of bunch, 1 bunch per probe) 

472 """ 

473 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g

474 'Use brainbox.io.one.SpikeSortingLoader instead') 

475 collection = _collection_filter_from_args(probe, spike_sorter) 1g

476 _logger.debug(f"load spike sorting with collection filter {collection}") 1g

477 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g

478 return_channels=False, dataset_types=dataset_types, 

479 brain_regions=brain_regions) 

480 if return_collection: 1g

481 return spikes, clusters, collection 

482 else: 

483 return spikes, clusters 1g

484 

485 

486def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None, 

487 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False): 

488 """ 

489 For a given eid, get spikes, clusters and channels information, and merges clusters 

490 and channels information before returning all three variables. 

491 

492 Parameters 

493 ---------- 

494 eid : [str, UUID, Path, dict] 

495 Experiment session identifier; may be a UUID, URL, experiment reference string 

496 details dict or Path 

497 one : one.api.OneAlyx 

498 An instance of ONE (shouldn't be in 'local' mode) 

499 probe : [str, list of str] 

500 The probe label(s), e.g. 'probe01' 

501 aligned : bool 

502 Whether to get the latest user aligned channel when not resolved or use histology track 

503 dataset_types : list of str 

504 Optional additional spikes/clusters objects to add to the standard default list 

505 spike_sorter : str 

506 Name of the spike sorting you want to load (None for default which is pykilosort if it's 

507 available otherwise the default MATLAB kilosort) 

508 brain_atlas : iblatlas.atlas.BrainAtlas 

509 Brain atlas object (default: Allen atlas) 

510 return_collection: bool 

511 Returns an extra argument with the collection chosen 

512 

513 Returns 

514 ------- 

515 spikes : dict of one.alf.io.AlfBunch 

516 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

517 session and spike sorter, with keys ('clusters', 'times') 

518 clusters : dict of one.alf.io.AlfBunch 

519 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

520 ('channels', 'depths', 'metrics') 

521 channels : dict of one.alf.io.AlfBunch 

522 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

523 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

524 """ 

525 # --- Get spikes and clusters data 

526 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 

527 'Use brainbox.io.one.SpikeSortingLoader instead') 

528 one = one or ONE() 

529 brain_atlas = brain_atlas or AllenAtlas() 

530 spikes, clusters, collection = load_spike_sorting( 

531 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True) 

532 # -- Get brain regions and assign to clusters 

533 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned, 

534 brain_atlas=brain_atlas) 

535 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

536 if nested is False and len(spikes.keys()) == 1: 

537 k = list(spikes.keys())[0] 

538 channels = channels[k] 

539 clusters = clusters[k] 

540 spikes = spikes[k] 

541 if return_collection: 

542 return spikes, clusters, channels, collection 

543 else: 

544 return spikes, clusters, channels 

545 

546 

547def load_ephys_session(eid, one=None): 

548 """ 

549 From an eid, hits the Alyx database and downloads a standard default set of dataset types 

550 From a local session Path (pathlib.Path), loads a standard default set of dataset types 

551 to perform analysis: 

552 'clusters.channels', 

553 'clusters.depths', 

554 'clusters.metrics', 

555 'spikes.clusters', 

556 'spikes.times', 

557 'probes.description' 

558 

559 Parameters 

560 ---------- 

561 eid : [str, UUID, Path, dict] 

562 Experiment session identifier; may be a UUID, URL, experiment reference string 

563 details dict or Path 

564 one : oneibl.one.OneAlyx, optional 

565 ONE object to use for loading. Will generate internal one if not used, by default None 

566 

567 Returns 

568 ------- 

569 spikes : dict of one.alf.io.AlfBunch 

570 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

571 session and spike sorter, with keys ('clusters', 'times') 

572 clusters : dict of one.alf.io.AlfBunch 

573 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

574 ('channels', 'depths', 'metrics') 

575 trials : one.alf.io.AlfBunch of numpy.ndarray 

576 The session trials data 

577 """ 

578 assert one 1g

579 spikes, clusters = load_spike_sorting(eid, one=one) 1g

580 trials = one.load_object(eid, 'trials') 1g

581 return spikes, clusters, trials 1g

582 

583 

584def _remove_old_clusters(session_path, probe): 

585 # gets clusters and spikes from a local session folder 

586 probe_path = session_path.joinpath('alf', probe) 

587 

588 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead 

589 cluster_file = probe_path.joinpath('clusters.metrics.csv') 

590 

591 if cluster_file.exists(): 

592 os.remove(cluster_file) 

593 _logger.info('Deleting old clusters.metrics.csv file') 

594 

595 

596def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None): 

597 """ 

598 Takes (default and any extra) values in given keys from channels and assign them to clusters. 

599 If channels does not contain any data, the new keys are added to clusters but left empty. 

600 

601 Parameters 

602 ---------- 

603 dic_clus : dict of one.alf.io.AlfBunch 

604 1 bunch per probe, containing cluster information 

605 channels : dict of one.alf.io.AlfBunch 

606 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates') 

607 keys_to_add_extra : list of str 

608 Any extra keys to load into channels bunches 

609 

610 Returns 

611 ------- 

612 dict of one.alf.io.AlfBunch 

613 clusters (1 bunch per probe) with new keys values. 

614 """ 

615 probe_labels = list(channels.keys()) # Convert dict_keys into list 

616 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um'] 

617 

618 if keys_to_add_extra is None: 

619 keys_to_add = keys_to_add_default 

620 else: 

621 # Append extra optional keys 

622 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default)) 

623 

624 for label in probe_labels: 

625 clu_ch = dic_clus[label]['channels'] 

626 for key in keys_to_add: 

627 try: 

628 assert key in channels[label].keys() # Check key is in channels 

629 ch_key = channels[label][key] 

630 nch_key = len(ch_key) if ch_key is not None else 0 

631 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index 

632 dic_clus[label][key] = ch_key[clu_ch] 

633 else: 

634 _logger.warning( 

635 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels' 

636 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.') 

637 dic_clus[label][key] = [] 

638 except AssertionError: 

639 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge') 

640 continue 

641 

642 return dic_clus 

643 

644 

645def load_passive_rfmap(eid, one=None): 

646 """ 

647 For a given eid load in the passive receptive field mapping protocol data 

648 

649 Parameters 

650 ---------- 

651 eid : [str, UUID, Path, dict] 

652 Experiment session identifier; may be a UUID, URL, experiment reference string 

653 details dict or Path 

654 one : oneibl.one.OneAlyx, optional 

655 An instance of ONE (may be in 'local' - offline - mode) 

656 

657 Returns 

658 ------- 

659 one.alf.io.AlfBunch 

660 Passive receptive field mapping data 

661 """ 

662 one = one or ONE() 

663 

664 # Load in the receptive field mapping data 

665 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf') 

666 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin', 

667 collection='raw_passive_data'), dtype="uint8") 

668 y_pix, x_pix = 15, 15 

669 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0]) 

670 rf_map['frames'] = frames 

671 

672 return rf_map 

673 

674 

675def load_wheel_reaction_times(eid, one=None): 

676 """ 

677 Return the calculated reaction times for session. Reaction times are defined as the time 

678 between the go cue (onset tone) and the onset of the first substantial wheel movement. A 

679 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the 

680 distance to threshold (~0.1 radians). 

681 

682 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if 

683 there was no detected movement withing the period, or when the goCue_times or feedback_times 

684 are nan. 

685 

686 Parameters 

687 ---------- 

688 eid : [str, UUID, Path, dict] 

689 Experiment session identifier; may be a UUID, URL, experiment reference string 

690 details dict or Path 

691 one : one.api.OneAlyx, optional 

692 one object to use for loading. Will generate internal one if not used, by default None 

693 

694 Returns 

695 ---------- 

696 array-like 

697 reaction times 

698 """ 

699 if one is None: 

700 one = ONE() 

701 

702 trials = one.load_object(eid, 'trials') 

703 # If already extracted, load and return 

704 if trials and 'firstMovement_times' in trials: 

705 return trials['firstMovement_times'] - trials['goCue_times'] 

706 # Otherwise load wheelMoves object and calculate 

707 moves = one.load_object(eid, 'wheelMoves') 

708 # Re-extract wheel moves if necessary 

709 if not moves or 'peakAmplitude' not in moves: 

710 wheel = one.load_object(eid, 'wheel') 

711 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 

712 assert trials and moves, 'unable to load trials and wheelMoves data' 

713 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials) 

714 return firstMove_times - trials['goCue_times'] 

715 

716 

717def load_iti(trials): 

718 """ 

719 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey 

720 screen commencing at stimulus off and lasting until the quiescent period at the start of the 

721 following trial. Note that the ITI for the first trial is the time between the first trial 

722 and the next, therefore the last value is NaN. 

723 

724 Parameters 

725 ---------- 

726 trials : one.alf.io.AlfBunch 

727 An ALF trials object containing the keys {'intervals', 'stimOff_times'}. 

728 

729 Returns 

730 ------- 

731 np.array 

732 An array of inter-trial intervals, the last value being NaN. 

733 """ 

734 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1o

735 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1o

736 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1o

737 

738 

739def load_channels_from_insertion(ins, depths=None, one=None, ba=None): 

740 

741 PROV_2_VAL = { 

742 'Resolved': 90, 

743 'Ephys aligned histology track': 70, 

744 'Histology track': 50, 

745 'Micro-manipulator': 30, 

746 'Planned': 10} 

747 

748 one = one or ONE() 

749 ba = ba or atlas.AllenAtlas() 

750 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id']) 

751 val = [PROV_2_VAL[tr['provenance']] for tr in traj] 

752 idx = np.argmax(val) 

753 traj = traj[idx] 

754 if depths is None: 

755 depths = trace_header(version=1)[:, 1] 

756 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator': 

757 ins = atlas.Insertion.from_dict(traj) 

758 # Deepest coordinate first 

759 xyz = np.c_[ins.tip, ins.entry].T 

760 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

761 TIP_SIZE_UM) / 1e6) 

762 else: 

763 xyz = np.array(ins['json']['xyz_picks']) / 1e6 

764 if traj['provenance'] == 'Histology track': 

765 xyz = xyz[np.argsort(xyz[:, 2]), :] 

766 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

767 TIP_SIZE_UM) / 1e6) 

768 else: 

769 align_key = ins['json']['extended_qc']['alignment_stored'] 

770 feature = traj['json'][align_key][0] 

771 track = traj['json'][align_key][1] 

772 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

773 feature_prev=feature, 

774 brain_atlas=ba, speedy=True) 

775 xyz_channels = ephysalign.get_channel_locations(feature, track) 

776 return xyz_channels 

777 

778 

779@dataclass 

780class SpikeSortingLoader: 

781 """ 

782 Object that will load spike sorting data for a given probe insertion. 

783 This class can be instantiated in several manners 

784 - With Alyx database probe id: 

785 SpikeSortingLoader(pid=pid, one=one) 

786 - With Alyx database eic and probe name: 

787 SpikeSortingLoader(eid=eid, pname='probe00', one=one) 

788 - From a local session and probe name: 

789 SpikeSortingLoader(session_path=session_path, pname='probe00') 

790 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded. 

791 """ 

792 one: One = None 

793 atlas: None = None 

794 pid: str = None 

795 eid: str = '' 

796 pname: str = '' 

797 session_path: Path = '' 

798 # the following properties are the outcome of the post init function 

799 collections: list = None 

800 datasets: list = None # list of all datasets belonging to the session 

801 # the following properties are the outcome of a reading function 

802 files: dict = None 

803 raw_data_files: list = None # list of raw ap and lf files corresponding to the recording 

804 collection: str = '' 

805 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced' 

806 spike_sorter: str = 'pykilosort' 

807 spike_sorting_path: Path = None 

808 _sync: dict = None 

809 

810 def __post_init__(self): 

811 # pid gets precedence 

812 if self.pid is not None: 1cbedk

813 try: 1dk

814 self.eid, self.pname = self.one.pid2eid(self.pid) 1dk

815 except NotImplementedError: 

816 if self.eid == '' or self.pname == '': 

817 raise IOError("Cannot infer session id and probe name from pid. " 

818 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.") 

819 self.session_path = self.one.eid2path(self.eid) 1dk

820 # then eid / pname combination 

821 elif self.session_path is None or self.session_path == '': 1cbe

822 self.session_path = self.one.eid2path(self.eid) 1cbe

823 # fully local providing a session path 

824 else: 

825 if self.one: 

826 self.eid = self.one.to_eid(self.session_path) 

827 else: 

828 self.one = One(cache_dir=self.session_path.parents[2], mode='local') 

829 df_sessions = cache._make_sessions_df(self.session_path) 

830 self.one._cache['sessions'] = df_sessions.set_index('id') 

831 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False) 

832 self.eid = str(self.session_path.relative_to(self.session_path.parents[2])) 

833 # populates default properties 

834 self.collections = self.one.list_collections( 1cbedk

835 self.eid, filename='spikes*', collection=f"alf/{self.pname}*") 

836 self.datasets = self.one.list_datasets(self.eid) 1cbedk

837 if self.atlas is None: 1cbedk

838 self.atlas = AllenAtlas() 1cbek

839 self.files = {} 1cbedk

840 self.raw_data_files = [] 1cbedk

841 

842 def _load_object(self, *args, **kwargs): 

843 """ 

844 This function is a wrapper around alfio.load_object that will remove the UUID in the 

845 filename if the object is on SDSC. 

846 """ 

847 remove_uuids = getattr(self.one, 'uuid_filenames', False) 1cbed

848 d = alfio.load_object(*args, **kwargs) 1cbed

849 if remove_uuids: 1cbed

850 # pops the UUID in the key names 

851 keys = list(d.keys()) 

852 for k in keys: 

853 d[k[:-37]] = d.pop(k) 

854 return d 1cbed

855 

856 @staticmethod 

857 def _get_attributes(dataset_types): 

858 """returns attributes to load for spikes and clusters objects""" 

859 dataset_types = [] if dataset_types is None else dataset_types 1cbed

860 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1cbed

861 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1cbed

862 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1cbed

863 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1cbed

864 waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] 1cbed

865 waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) 1cbed

866 return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} 1cbed

867 

868 def _get_spike_sorting_collection(self, spike_sorter='pykilosort'): 

869 """ 

870 Filters a list or array of collections to get the relevant spike sorting dataset 

871 if there is a pykilosort, load it 

872 """ 

873 collection = next(filter(lambda c: c == f'alf/{self.pname}/{spike_sorter}', self.collections), None) 1cbed

874 # otherwise, prefers the shortest 

875 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) 1cbed

876 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") 1cbed

877 return collection 1cbed

878 

879 def load_spike_sorting_object(self, obj, *args, **kwargs): 

880 """ 

881 Loads an ALF object 

882 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

883 :param spike_sorter: (defaults to 'pykilosort') 

884 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

885 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

886 :param kwargs: additional arguments to be passed to one.api.One.load_object 

887 :param missing: 'raise' (default) or 'ignore' 

888 :return: 

889 """ 

890 self.download_spike_sorting_object(obj, *args, **kwargs) 

891 return self._load_object(self.files[obj]) 

892 

893 def get_version(self, spike_sorter='pykilosort'): 

894 collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 

895 dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy') 

896 return dset[0]['version'] if len(dset) else 'unknown' 

897 

898 def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None, 

899 attribute=None, missing='raise', **kwargs): 

900 """ 

901 Downloads an ALF object 

902 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

903 :param spike_sorter: (defaults to 'pykilosort') 

904 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

905 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

906 :param kwargs: additional arguments to be passed to one.api.One.load_object 

907 :param attribute: list of attributes to load for the object 

908 :param missing: 'raise' (default) or 'ignore' 

909 :return: 

910 """ 

911 if len(self.collections) == 0: 1cbed

912 return {}, {}, {} 

913 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1cbed

914 collection = collection or self.collection 1cbed

915 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1cbed

916 attributes = self._get_attributes(dataset_types) 1cbed

917 try: 1cbed

918 self.files[obj] = self.one.load_object( 1cbed

919 self.eid, obj=obj, attribute=attributes.get(obj, None), 

920 collection=collection, download_only=True, **kwargs) 

921 except ALFObjectNotFound as e: 1cbe

922 if missing == 'raise': 1cbe

923 raise e 

924 

925 def download_spike_sorting(self, objects=None, **kwargs): 

926 """ 

927 Downloads spikes, clusters and channels 

928 :param spike_sorter: (defaults to 'pykilosort') 

929 :param dataset_types: list of extra dataset types 

930 :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels'] 

931 :return: 

932 """ 

933 objects = ['spikes', 'clusters', 'channels'] if objects is None else objects 1cbed

934 for obj in objects: 1cbed

935 self.download_spike_sorting_object(obj=obj, **kwargs) 1cbed

936 self.spike_sorting_path = self.files['clusters'][0].parent 1cbed

937 

938 def download_raw_electrophysiology(self, band='ap'): 

939 """ 

940 Downloads raw electrophysiology data files on local disk. 

941 :param band: "ap" (default) or "lf" for LFP band 

942 :return: list of raw data files full paths (ch, meta and cbin files) 

943 """ 

944 raw_data_files = [] 1b

945 for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']: 1b

946 try: 1b

947 # FIXME: this will fail if multiple LFP segments are found 

948 raw_data_files.append(self.one.load_dataset( 1b

949 self.eid, 

950 download_only=True, 

951 collection=f'raw_ephys_data/{self.pname}', 

952 dataset=suffix, 

953 check_hash=False, 

954 )) 

955 except ALFObjectNotFound: 1b

956 _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}") 1b

957 self.raw_data_files = list(set(self.raw_data_files + raw_data_files)) 1b

958 return raw_data_files 1b

959 

960 def raw_electrophysiology(self, stream=True, band='ap', **kwargs): 

961 """ 

962 Returns a reader for the raw electrophysiology data 

963 By default it is a streamer object, but if stream is False, it will return a spikeglx.Reader after having 

964 downloaded the raw data file if necessary 

965 :param stream: 

966 :param band: 

967 :param kwargs: 

968 :return: 

969 """ 

970 if stream: 1k

971 return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) 1k

972 else: 

973 raw_data_files = self.download_raw_electrophysiology(band=band) 

974 cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None) 

975 if cbin_file is not None: 

976 return spikeglx.Reader(cbin_file) 

977 

978 def load_channels(self, **kwargs): 

979 """ 

980 Loads channels 

981 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

982 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

983 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

984 - resolved: channel locations alignments have been agreed upon 

985 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

986 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

987 

988 :param spike_sorter: (defaults to 'pykilosort') 

989 :param dataset_types: list of extra dataset types 

990 :return: 

991 """ 

992 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting 

993 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1cbed

994 self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) 1cbed

995 channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) 1cbed

996 if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails 1cbed

997 channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1d

998 if 'brainLocationIds_ccf_2017' not in channels: 1cbed

999 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1b

1000 _channels, self.histology = _load_channel_locations_traj( 1b

1001 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True) 

1002 if _channels: 1b

1003 channels = _channels[self.pname] 

1004 else: 

1005 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1cbed

1006 self.histology = 'alf' 1cbed

1007 return Bunch(channels) 1cbed

1008 

1009 def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs): 

1010 """ 

1011 Loads spikes, clusters and channels 

1012 

1013 There could be several spike sorting collections, by default the loader will get the pykilosort collection 

1014 

1015 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

1016 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

1017 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

1018 - resolved: channel locations alignments have been agreed upon 

1019 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

1020 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

1021 

1022 :param spike_sorter: (defaults to 'pykilosort') 

1023 :param revision: for example "2024-05-06", (defaults to None): 

1024 :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one 

1025 :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates'] 

1026 :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table 

1027 :param kwargs: additional arguments to be passed to one.api.One.load_object 

1028 :return: 

1029 """ 

1030 if len(self.collections) == 0: 1cbed

1031 return {}, {}, {} 

1032 self.files = {} 1cbed

1033 self.spike_sorter = spike_sorter 1cbed

1034 self.revision = revision 1cbed

1035 objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None 1cbed

1036 self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) 1cbed

1037 channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs) 1cbed

1038 clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) 1cbed

1039 if good_units: 1cbed

1040 spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards) 

1041 else: 

1042 spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) 1cbed

1043 if enforce_version: 1cbed

1044 self._assert_version_consistency() 1cbed

1045 return spikes, clusters, channels 1cbed

1046 

1047 def _assert_version_consistency(self): 

1048 """ 

1049 Makes sure the state of the spike sorting object matches the files downloaded 

1050 :return: None 

1051 """ 

1052 for k in ['spikes', 'clusters', 'channels', 'passingSpikes']: 1cbed

1053 for fn in self.files.get(k, []): 1cbed

1054 if self.spike_sorter: 1cbed

1055 assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \ 1bd

1056 f"You required strict version {self.spike_sorter}, {fn} does not match" 

1057 if self.revision: 1cbed

1058 assert fn.relative_to(self.session_path).parts[3] == f"#{self.revision}#", \ 

1059 f"You required strict revision {self.revision}, {fn} does not match" 

1060 

1061 @staticmethod 

1062 def compute_metrics(spikes, clusters=None): 

1063 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size 

1064 metrics = pd.DataFrame(quick_unit_metrics( 

1065 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc))) 

1066 return metrics 

1067 

1068 @staticmethod 

1069 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False): 

1070 """ 

1071 Merge the metrics and the channel information into the clusters dictionary 

1072 :param spikes: 

1073 :param clusters: 

1074 :param channels: 

1075 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used 

1076 for clusters or analysis applications (defaults to None). 

1077 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false) 

1078 :return: cluster dictionary containing metrics and histology 

1079 """ 

1080 if spikes == {}: 1bd

1081 return 

1082 nc = clusters['channels'].size 1bd

1083 # recompute metrics if they are not available 

1084 metrics = None 1bd

1085 if 'metrics' in clusters: 1bd

1086 metrics = clusters.pop('metrics') 1bd

1087 if metrics.shape[0] != nc: 1bd

1088 metrics = None 

1089 if metrics is None or compute_metrics is True: 1bd

1090 _logger.debug("recompute clusters metrics") 

1091 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters) 

1092 if isinstance(cache_dir, Path): 

1093 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt')) 

1094 for k in metrics.keys(): 1bd

1095 clusters[k] = metrics[k].to_numpy() 1bd

1096 for k in channels.keys(): 1bd

1097 clusters[k] = channels[k][clusters['channels']] 1bd

1098 if cache_dir is not None: 1bd

1099 _logger.debug(f'caching clusters metrics in {cache_dir}') 

1100 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt')) 

1101 return clusters 1bd

1102 

1103 @property 

1104 def url(self): 

1105 """Gets flatiron URL for the session""" 

1106 webclient = getattr(self.one, '_web_client', None) 

1107 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None 

1108 

1109 def _get_probe_info(self): 

1110 if self._sync is None: 1e

1111 timestamps = self.one.load_dataset( 1e

1112 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}') 

1113 _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks 1e

1114 self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}') 

1115 try: 1e

1116 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1e

1117 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}')) 

1118 fs = spikeglx._get_fs_from_meta(ap_meta) 1e

1119 except ALFObjectNotFound: 

1120 ap_meta = None 

1121 fs = 30_000 

1122 self._sync = { 1e

1123 'timestamps': timestamps, 

1124 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'), 

1125 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'), 

1126 'ap_meta': ap_meta, 

1127 'fs': fs, 

1128 } 

1129 

1130 def timesprobe2times(self, values, direction='forward'): 

1131 self._get_probe_info() 

1132 if direction == 'forward': 

1133 return self._sync['forward'](values * self._sync['fs']) 

1134 elif direction == 'reverse': 

1135 return self._sync['reverse'](values) / self._sync['fs'] 

1136 

1137 def samples2times(self, values, direction='forward'): 

1138 """ 

1139 Converts ephys sample values to session main clock seconds 

1140 :param values: numpy array of times in seconds or samples to resync 

1141 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse' 

1142 (seconds main time to samples probe time) 

1143 :return: 

1144 """ 

1145 self._get_probe_info() 1e

1146 return self._sync[direction](values) 1e

1147 

1148 @property 

1149 def pid2ref(self): 

1150 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1cb

1151 

1152 def _default_plot_title(self, spikes): 

1153 title = f"{self.pid2ref}, {self.pid} \n" \ 1c

1154 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters" 

1155 return title 1c

1156 

1157 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, 

1158 drift=None, title=None, **kwargs): 

1159 """ 

1160 :param spikes: spikes dictionary or Bunch 

1161 :param channels: channels dictionary or Bunch. 

1162 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png". 

1163 Otherwise, plot. 

1164 :param br: brain regions object (optional) 

1165 :param label: label for saved image (optional, default="raster") 

1166 :param time_series: timeseries dictionary for behavioral event times (optional) 

1167 :param **kwargs: kwargs passed to `driftmap()` (optional) 

1168 :return: 

1169 """ 

1170 br = br or BrainRegions() 1c

1171 time_series = time_series or {} 1c

1172 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c

1173 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col') 

1174 axs[0, 1].set_axis_off() 1c

1175 # axs[0, 0].set_xticks([]) 

1176 if kwargs is None: 1c

1177 # set default raster plot parameters 

1178 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5} 

1179 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c

1180 if title is None: 1c

1181 title = self._default_plot_title(spikes) 1c

1182 axs[0, 0].title.set_text(title) 1c

1183 for k, ts in time_series.items(): 1c

1184 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0]) 

1185 if 'atlas_id' in channels: 1c

1186 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c

1187 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology) 

1188 axs[1, 0].set_ylim(0, 3800) 1c

1189 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c

1190 fig.tight_layout() 1c

1191 

1192 if drift is None: 1c

1193 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c

1194 if 'drift' in self.files: 1c

1195 drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards) 

1196 if isinstance(drift, dict): 1c

1197 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5) 

1198 axs[0, 0].set(ylim=[-15, 15]) 

1199 

1200 if save_dir is not None: 1c

1201 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) 

1202 fig.savefig(png_file) 

1203 plt.close(fig) 

1204 gc.collect() 

1205 else: 

1206 return fig, axs 1c

1207 

1208 def plot_rawdata_snippet(self, sr, spikes, clusters, t0, 

1209 channels=None, 

1210 br: BrainRegions = None, 

1211 save_dir=None, 

1212 label='raster', 

1213 gain=-93, 

1214 title=None): 

1215 

1216 # compute the raw data offset and destripe, we take 400ms around t0 

1217 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs)) 

1218 raw = sr[first_sample:last_sample, :-sr.nsync].T 

1219 channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True 

1220 destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels) 

1221 # filter out the spikes according to good/bad clusters and to the time slice 

1222 spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample])) 

1223 ss = spikes['samples'][spike_sel] 

1224 sc = clusters['channels'][spikes['clusters'][spike_sel]] 

1225 sok = clusters['label'][spikes['clusters'][spike_sel]] == 1 

1226 if title is None: 

1227 title = self._default_plot_title(spikes) 

1228 # display the raw data snippet with spikes overlaid 

1229 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col') 

1230 Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s') 

1231 axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5) 

1232 axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5) 

1233 axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035]) 

1234 # adds the channel locations if available 

1235 if (channels is not None) and ('atlas_id' in channels): 

1236 br = br or BrainRegions() 

1237 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 

1238 brain_regions=br, display=True, ax=axs[1], title=self.histology) 

1239 axs[1].get_yaxis().set_visible(False) 

1240 fig.tight_layout() 

1241 

1242 if save_dir is not None: 

1243 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) 

1244 fig.savefig(png_file) 

1245 plt.close(fig) 

1246 gc.collect() 

1247 else: 

1248 return fig, axs 

1249 

1250 

1251@dataclass 

1252class SessionLoader: 

1253 """ 

1254 Object to load session data for a give session in the recommended way. 

1255 

1256 Parameters 

1257 ---------- 

1258 one: one.api.ONE instance 

1259 Can be in remote or local mode (required) 

1260 session_path: string or pathlib.Path 

1261 The absolute path to the session (one of session_path or eid is required) 

1262 eid: string 

1263 database UUID of the session (one of session_path or eid is required) 

1264 

1265 If both are provided, session_path takes precedence over eid. 

1266 

1267 Examples 

1268 -------- 

1269 1) Load all available session data for one session: 

1270 >>> from one.api import ONE 

1271 >>> from brainbox.io.one import SessionLoader 

1272 >>> one = ONE() 

1273 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/') 

1274 # Object is initiated, but no data is loaded as you can see in the data_info attribute 

1275 >>> sess_loader.data_info 

1276 name is_loaded 

1277 0 trials False 

1278 1 wheel False 

1279 2 pose False 

1280 3 motion_energy False 

1281 4 pupil False 

1282 

1283 # Loading all available session data, the data_info attribute now shows which data has been loaded 

1284 >>> sess_loader.load_session_data() 

1285 >>> sess_loader.data_info 

1286 name is_loaded 

1287 0 trials True 

1288 1 wheel True 

1289 2 pose True 

1290 3 motion_energy True 

1291 4 pupil False 

1292 

1293 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g. 

1294 >>> type(sess_loader.trials) 

1295 pandas.core.frame.DataFrame 

1296 >>> sess_loader.trials.shape 

1297 (626, 18) 

1298 # Each data comes with its own timestamps in a column called 'times' 

1299 >>> sess_loader.wheel['times'] 

1300 0 0.134286 

1301 1 0.135286 

1302 2 0.136286 

1303 3 0.137286 

1304 4 0.138286 

1305 ... 

1306 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera. 

1307 # The dataframes of all cameras are collected in a dictionary 

1308 >>> type(sess_loader.pose) 

1309 dict 

1310 >>> sess_loader.pose.keys() 

1311 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera']) 

1312 >>> sess_loader.pose['bodyCamera'].columns 

1313 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object') 

1314 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading 

1315 functions: 

1316 >>> sess_loader.load_wheel(sampling_rate=100) 

1317 """ 

1318 one: One = None 

1319 session_path: Path = '' 

1320 eid: str = '' 

1321 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1322 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1323 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1324 pose: dict = field(default_factory=dict, repr=False) 

1325 motion_energy: dict = field(default_factory=dict, repr=False) 

1326 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1327 

1328 def __post_init__(self): 

1329 """ 

1330 Function that runs automatically after initiation of the dataclass attributes. 

1331 Checks for required inputs, sets session_path and eid, creates data_info table. 

1332 """ 

1333 if self.one is None: 1af

1334 raise ValueError("An input to one is required. If not connection to a database is desired, it can be " 

1335 "a fully local instance of One.") 

1336 # If session path is given, takes precedence over eid 

1337 if self.session_path is not None and self.session_path != '': 1af

1338 self.eid = self.one.to_eid(self.session_path) 1af

1339 self.session_path = Path(self.session_path) 1af

1340 # Providing no session path, try to infer from eid 

1341 else: 

1342 if self.eid is not None and self.eid != '': 

1343 self.session_path = self.one.eid2path(self.eid) 

1344 else: 

1345 raise ValueError("If no session path is given, eid is required.") 

1346 

1347 data_names = [ 1af

1348 'trials', 

1349 'wheel', 

1350 'pose', 

1351 'motion_energy', 

1352 'pupil' 

1353 ] 

1354 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1af

1355 

1356 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False): 

1357 """ 

1358 Function to load available session data into the SessionLoader object. Input parameters allow to control which 

1359 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input 

1360 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored 

1361 in SessionLoader.data_info 

1362 

1363 Parameters 

1364 ---------- 

1365 trials: boolean 

1366 Whether to load all trials data into SessionLoader.trials, default is True 

1367 wheel: boolean 

1368 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True 

1369 pose: boolean 

1370 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose, 

1371 default is True 

1372 motion_energy: boolean 

1373 Whether to load motion energy data (whisker pad for left/right camera, body for body camera) 

1374 into SessionLoader.motion_energy, default is True 

1375 pupil: boolean 

1376 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil, 

1377 default is True 

1378 reload: boolean 

1379 Whether to reload data that has already been loaded into this SessionLoader object, default is False 

1380 """ 

1381 load_df = self.data_info.copy() 1f

1382 load_df['to_load'] = [ 1f

1383 trials, 

1384 wheel, 

1385 pose, 

1386 motion_energy, 

1387 pupil 

1388 ] 

1389 load_df['load_func'] = [ 1f

1390 self.load_trials, 

1391 self.load_wheel, 

1392 self.load_pose, 

1393 self.load_motion_energy, 

1394 self.load_pupil 

1395 ] 

1396 

1397 for idx, row in load_df.iterrows(): 1f

1398 if row['to_load'] is False: 1f

1399 _logger.debug(f"Not loading {row['name']} data, set to False.") 

1400 elif row['is_loaded'] is True and reload is False: 1f

1401 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1f

1402 else: 

1403 try: 1f

1404 _logger.info(f"Loading {row['name']} data") 1f

1405 row['load_func']() 1f

1406 self.data_info.loc[idx, 'is_loaded'] = True 1f

1407 except BaseException as e: 

1408 _logger.warning(f"Could not load {row['name']} data.") 

1409 _logger.debug(e) 

1410 

1411 def _find_behaviour_collection(self, obj): 

1412 """ 

1413 Function to find the trial or wheel collection 

1414 

1415 Parameters 

1416 ---------- 

1417 obj: str 

1418 Alf object to load, either 'trials' or 'wheel' 

1419 """ 

1420 dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy' 1fnj

1421 dsets = self.one.list_datasets(self.eid, dataset) 1fnj

1422 if len(dsets) == 0: 1fnj

1423 return 'alf' 1fn

1424 else: 

1425 collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets] 1fj

1426 if len(set(collections)) == 1: 1fj

1427 return collections[0] 1fj

1428 else: 

1429 _logger.error(f'Multiple collections found {collections}. Specify collection when loading, ' 

1430 f'e.g sl.load_{obj}(collection="{collections[0]}")') 

1431 raise ALFMultipleCollectionsFound 

1432 

1433 def load_trials(self, collection=None): 

1434 """ 

1435 Function to load trials data into SessionLoader.trials 

1436 

1437 Parameters 

1438 ---------- 

1439 collection: str 

1440 Alf collection of trials data 

1441 """ 

1442 

1443 if not collection: 1fn

1444 collection = self._find_behaviour_collection('trials') 1fn

1445 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex 

1446 self.one.wildcards = False 1fn

1447 self.trials = self.one.load_object( 1fn

1448 self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*').to_df() 

1449 self.one.wildcards = True 1fn

1450 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1fn

1451 

1452 def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None): 

1453 """ 

1454 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position 

1455 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which 

1456 a Butterworth low-pass filter is applied. 

1457 

1458 Parameters 

1459 ---------- 

1460 fs: int, float 

1461 Sampling frequency for the wheel position, default is 1000 Hz 

1462 corner_frequency: int, float 

1463 Corner frequency of Butterworth low-pass filter, default is 20 

1464 order: int, float 

1465 Order of Butterworth low_pass filter, default is 8 

1466 collection: str 

1467 Alf collection of wheel data 

1468 """ 

1469 if not collection: 1fj

1470 collection = self._find_behaviour_collection('wheel') 1fj

1471 wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection) 1fj

1472 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1fj

1473 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps") 

1474 # resample the wheel position and compute velocity, acceleration 

1475 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1fj

1476 self.wheel['position'], self.wheel['times'] = interpolate_position( 1fj

1477 wheel_raw['timestamps'], wheel_raw['position'], freq=fs) 

1478 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1fj

1479 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order) 

1480 self.wheel = self.wheel.apply(np.float32) 1fj

1481 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1fj

1482 

1483 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']): 

1484 """ 

1485 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a 

1486 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas 

1487 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera. 

1488 

1489 Parameters 

1490 ---------- 

1491 likelihood_thr: float 

1492 The position of each tracked body part come with a likelihood of that estimate for each time point. 

1493 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set 

1494 likelihood_thr=1. Default is 0.9 

1495 views: list 

1496 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1497 """ 

1498 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1499 self.pose = {} 1mhf

1500 for view in views: 1mhf

1501 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times']) 1mhf

1502 # Double check if video timestamps are correct length or can be fixed 

1503 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1mhf

1504 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1mhf

1505 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1mhf

1506 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1mhf

1507 

1508 def load_motion_energy(self, views=['left', 'right', 'body']): 

1509 """ 

1510 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a 

1511 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are 

1512 pandas Dataframes with the timestamps and motion energy data. 

1513 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad 

1514 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the 

1515 body (bodyMotionEnergy). 

1516 

1517 Parameters 

1518 ---------- 

1519 views: list 

1520 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1521 """ 

1522 names = {'left': 'whiskerMotionEnergy', 1lf

1523 'right': 'whiskerMotionEnergy', 

1524 'body': 'bodyMotionEnergy'} 

1525 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1526 self.motion_energy = {} 1lf

1527 for view in views: 1lf

1528 me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times']) 1lf

1529 # Double check if video timestamps are correct length or can be fixed 

1530 times_fixed, motion_energy = self._check_video_timestamps( 1lf

1531 view, me_raw['times'], me_raw['ROIMotionEnergy']) 

1532 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1lf

1533 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1lf

1534 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1lf

1535 

1536 def load_licks(self): 

1537 """ 

1538 Not yet implemented 

1539 """ 

1540 pass 

1541 

1542 def load_pupil(self, snr_thresh=5.): 

1543 """ 

1544 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil. 

1545 

1546 Parameters 

1547 ---------- 

1548 snr_thresh: float 

1549 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data 

1550 will be considered unusable and will be discarded. 

1551 """ 

1552 # Try to load from features 

1553 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features']) 1hf

1554 if 'features' in feat_raw.keys(): 1hf

1555 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features']) 

1556 self.pupil = feats.copy() 

1557 self.pupil.insert(0, 'times', times_fixed) 

1558 

1559 # If unavailable compute on the fly 

1560 else: 

1561 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1hf

1562 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1hf

1563 and 'leftCamera' in self.pose.keys()): 

1564 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt 

1565 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1hf

1566 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1hf

1567 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1hf

1568 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1hf

1569 else: 

1570 self.load_pose(views=['left'], likelihood_thr=0.9) 

1571 dlc_thr = self.pose['leftCamera'].copy() 

1572 

1573 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1hf

1574 try: 1hf

1575 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1hf

1576 except BaseException as e: 

1577 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. " 

1578 "Saving all NaNs for pupilDiameter_smooth.") 

1579 _logger.debug(e) 

1580 self.pupil['pupilDiameter_smooth'] = np.nan 

1581 

1582 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1hf

1583 good_idxs = np.where( 1hf

1584 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] 

1585 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1hf

1586 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]))) 

1587 if snr < snr_thresh: 1hf

1588 self.pupil = pd.DataFrame() 1h

1589 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h

1590 

1591 def _check_video_timestamps(self, view, video_timestamps, video_data): 

1592 """ 

1593 Helper function to check for the length of the video frames vs video timestamps and fix in case 

1594 timestamps are longer than video frames. 

1595 """ 

1596 # If camera times are shorter than video data, or empty, no current fix 

1597 if video_timestamps.shape[0] < video_data.shape[0]: 1lmhf

1598 if video_timestamps.shape[0] == 0: 

1599 msg = f'Camera times empty for {view}Camera.' 

1600 else: 

1601 msg = f'Camera times are shorter than video data for {view}Camera.' 

1602 _logger.warning(msg) 

1603 raise ValueError(msg) 

1604 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video. 

1605 # This is because the first few frames are sometimes not recorded. We can remove the first few 

1606 # timestamps in this case 

1607 elif video_timestamps.shape[0] > video_data.shape[0]: 1lmhf

1608 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1lmhf

1609 return video_timestamps_fixed, video_data 1lmhf

1610 else: 

1611 return video_timestamps, video_data 

1612 

1613 

1614class EphysSessionLoader(SessionLoader): 

1615 """ 

1616 Spike sorting enhanced version of SessionLoader 

1617 Loads spike sorting data for all probes in the session, in the self.ephys dict 

1618 >>> EphysSessionLoader(eid=eid, one=one) 

1619 To select for a specific probe 

1620 >>> EphysSessionLoader(eid=eid, one=one, pid=pid) 

1621 """ 

1622 def __init__(self, *args, pname=None, pid=None, **kwargs): 

1623 """ 

1624 Needs an active connection in order to get the list of insertions in the session 

1625 :param args: 

1626 :param kwargs: 

1627 """ 

1628 super().__init__(*args, **kwargs) 

1629 # if necessary, restrict the query 

1630 qargs = {} if pname is None else {'name': pname} 

1631 qargs = qargs or ({} if pid is None else {'id': pid}) 

1632 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs) 

1633 self.ephys = {} 

1634 for ins in insertions: 

1635 self.ephys[ins['name']] = {} 

1636 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one) 

1637 

1638 def load_session_data(self, *args, **kwargs): 

1639 super().load_session_data(*args, **kwargs) 

1640 self.load_spike_sorting() 

1641 

1642 def load_spike_sorting(self, pnames=None): 

1643 pnames = pnames or list(self.ephys.keys()) 

1644 for pname in pnames: 

1645 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting() 

1646 self.ephys[pname]['spikes'] = spikes 

1647 self.ephys[pname]['clusters'] = clusters 

1648 self.ephys[pname]['channels'] = channels 

1649 

1650 @property 

1651 def probes(self): 

1652 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}