Coverage for brainbox/io/one.py: 56%

626 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment.""" 

2from dataclasses import dataclass, field 

3import gc 

4import logging 

5import os 

6from pathlib import Path 

7 

8 

9import numpy as np 

10import pandas as pd 

11from scipy.interpolate import interp1d 

12import matplotlib.pyplot as plt 

13 

14from one.api import ONE, One 

15import one.alf.io as alfio 

16from one.alf.files import get_alf_path 

17from one.alf.exceptions import ALFObjectNotFound 

18from one.alf import cache 

19from neuropixel import TIP_SIZE_UM, trace_header 

20import spikeglx 

21 

22from iblutil.util import Bunch 

23from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times 

24from iblatlas.atlas import AllenAtlas, BrainRegions 

25from iblatlas import atlas 

26from ibllib.pipes import histology 

27from ibllib.pipes.ephys_alignment import EphysAlignment 

28from ibllib.plots import vertical_lines 

29 

30import brainbox.plot 

31from brainbox.ephys_plots import plot_brain_regions 

32from brainbox.metrics.single_units import quick_unit_metrics 

33from brainbox.behavior.wheel import interpolate_position, velocity_filtered 

34from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter 

35 

36_logger = logging.getLogger('ibllib') 

37 

38 

39SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths'] 

40CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids'] 

41 

42 

43def load_lfp(eid, one=None, dataset_types=None, **kwargs): 

44 """ 

45 TODO Verify works 

46 From an eid, hits the Alyx database and downloads the standard set of datasets 

47 needed for LFP 

48 :param eid: 

49 :param dataset_types: additional dataset types to add to the list 

50 :param open: if True, spikeglx readers are opened 

51 :return: spikeglx.Reader 

52 """ 

53 if dataset_types is None: 

54 dataset_types = [] 

55 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*'] 

56 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes] 

57 session_path = one.eid2path(eid) 

58 

59 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False) 

60 if ef.get('lf', None)] 

61 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles] 

62 

63 

64def _collection_filter_from_args(probe, spike_sorter=None): 

65 collection = f'alf/{probe}/{spike_sorter}' 1g

66 collection = collection.replace('None', '*') 1g

67 collection = collection.replace('/*', '*') 1g

68 collection = collection[:-1] if collection.endswith('/') else collection 1g

69 return collection 1g

70 

71 

72def _get_spike_sorting_collection(collections, pname): 

73 """ 

74 Filters a list or array of collections to get the relevant spike sorting dataset 

75 if there is a pykilosort, load it 

76 """ 

77 # 

78 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1gb

79 # otherwise, prefers the shortest 

80 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1gb

81 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1gb

82 return collection 1gb

83 

84 

85def _channels_alyx2bunch(chans): 

86 channels = Bunch({ 

87 'atlas_id': np.array([ch['brain_region'] for ch in chans]), 

88 'x': np.array([ch['x'] for ch in chans]) / 1e6, 

89 'y': np.array([ch['y'] for ch in chans]) / 1e6, 

90 'z': np.array([ch['z'] for ch in chans]) / 1e6, 

91 'axial_um': np.array([ch['axial'] for ch in chans]), 

92 'lateral_um': np.array([ch['lateral'] for ch in chans]) 

93 }) 

94 return channels 

95 

96 

97def _channels_traj2bunch(xyz_chans, brain_atlas): 

98 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans)) 

99 channels = { 

100 'x': xyz_chans[:, 0], 

101 'y': xyz_chans[:, 1], 

102 'z': xyz_chans[:, 2], 

103 'acronym': brain_regions['acronym'], 

104 'atlas_id': brain_regions['id'] 

105 } 

106 

107 return channels 

108 

109 

110def _channels_bunch2alf(channels): 

111 channels_ = { 1i

112 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6, 

113 'brainLocationIds_ccf_2017': channels['atlas_id'], 

114 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]} 

115 return channels_ 1i

116 

117 

118def _channels_alf2bunch(channels, brain_regions=None): 

119 # reformat the dictionary according to the standard that comes out of Alyx 

120 channels_ = { 1icbdf

121 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6, 

122 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6, 

123 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6, 

124 'acronym': None, 

125 'atlas_id': channels['brainLocationIds_ccf_2017'], 

126 'axial_um': channels['localCoordinates'][:, 1], 

127 'lateral_um': channels['localCoordinates'][:, 0], 

128 } 

129 if brain_regions: 1icbdf

130 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icbdf

131 return channels_ 1icbdf

132 

133 

134def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None, 

135 brain_regions=None): 

136 """ 

137 Generic function to load spike sorting according data using ONE. 

138 

139 Will try to load one spike sorting for any probe present for the eid matching the collection 

140 For each probe it will load a spike sorting: 

141 - if there is one version: loads this one 

142 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX) 

143 

144 Parameters 

145 ---------- 

146 eid : [str, UUID, Path, dict] 

147 Experiment session identifier; may be a UUID, URL, experiment reference string 

148 details dict or Path 

149 one : one.api.OneAlyx 

150 An instance of ONE (may be in 'local' mode) 

151 collection : str 

152 collection filter word - accepts wildcards - can be a combination of spike sorter and 

153 probe. See `ALF documentation`_ for details. 

154 revision : str 

155 A particular revision return (defaults to latest revision). See `ALF documentation`_ for 

156 details. 

157 return_channels : bool 

158 Defaults to False otherwise loads channels from disk 

159 

160 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components 

161 

162 Returns 

163 ------- 

164 spikes : dict of one.alf.io.AlfBunch 

165 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

166 session and spike sorter, with keys ('clusters', 'times') 

167 clusters : dict of one.alf.io.AlfBunch 

168 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

169 ('channels', 'depths', 'metrics') 

170 channels : dict of one.alf.io.AlfBunch 

171 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

172 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs 

173 non-lateralized. 

174 """ 

175 one = one or ONE() 1gb

176 # enumerate probes and load according to the name 

177 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1gb

178 if len(collections) == 0: 1gb

179 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

180 pnames = list(set(c.split('/')[1] for c in collections)) 1gb

181 spikes, clusters, channels = ({} for _ in range(3)) 1gb

182 

183 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1gb

184 

185 for pname in pnames: 1gb

186 probe_collection = _get_spike_sorting_collection(collections, pname) 1gb

187 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1gb

188 attribute=spike_attributes) 

189 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1gb

190 attribute=cluster_attributes) 

191 if return_channels: 1gb

192 channels = _load_channels_locations_from_disk( 1b

193 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions) 

194 return spikes, clusters, channels 1b

195 else: 

196 return spikes, clusters 1g

197 

198 

199def _get_attributes(dataset_types): 

200 if dataset_types is None: 1gb

201 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1gb

202 else: 

203 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 

204 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 

205 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 

206 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 

207 return spike_attributes, cluster_attributes 

208 

209 

210def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None): 

211 _logger.debug('loading spike sorting from disk') 1b

212 channels = Bunch({}) 1b

213 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) 1b

214 if len(collections) == 0: 1b

215 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") 

216 probes = list(set([c.split('/')[1] for c in collections])) 1b

217 for probe in probes: 1b

218 probe_collection = _get_spike_sorting_collection(collections, probe) 1b

219 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') 1b

220 # if the spike sorter has not aligned data, try and get the alignment available 

221 if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): 1b

222 aligned_channel_collections = one.list_collections( 

223 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision) 

224 if len(aligned_channel_collections) == 0: 

225 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}") 

226 continue 

227 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}") 

228 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe) 

229 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection) 

230 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe]) 

231 # only have to reformat channels if we were able to load coordinates from disk 

232 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) 1b

233 return channels 1b

234 

235 

236def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None): 

237 """ 

238 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto 

239 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field 

240 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts 

241 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys 

242 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017' 

243 OR 

244 'x', 'y', 'z', 'acronym', 'axial_um' 

245 those are the guide for the interpolation 

246 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates' 

247 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object 

248 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017 

249 if a brain region object is provided, outputts a dict with keys 

250 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um' 

251 :return: Bunch or dictionary of channels with brain coordinates keys 

252 """ 

253 NEUROPIXEL_VERSION = 1 1i

254 h = trace_header(version=NEUROPIXEL_VERSION) 1i

255 if channels is None: 1i

256 channels = {'localCoordinates': np.c_[h['x'], h['y']]} 

257 nch = channels['localCoordinates'].shape[0] 1i

258 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i

259 channels_aligned = _channels_bunch2alf(channels_aligned) 1i

260 if 'localCoordinates' in channels_aligned.keys(): 1i

261 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i

262 else: # this is a edge case for a few spike sorting sessions 

263 assert channels_aligned['mlapdv'].shape[0] == 384 

264 aligned_depths = h['y'] 

265 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i

266 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i

267 channels['mlapdv'] = np.zeros((nch, 3)) 1i

268 for i in np.arange(3): 1i

269 channels['mlapdv'][:, i] = np.interp( 1i

270 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv] 

271 # the brain locations have to be interpolated by nearest neighbour 

272 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i

273 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i

274 if brain_regions is not None: 1i

275 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i

276 else: 

277 return channels 1i

278 

279 

280def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False, 

281 brain_atlas=None, return_source=False): 

282 if not hasattr(one, 'alyx'): 1b

283 return {}, None 1b

284 _logger.debug(f"trying to load from traj {probe}") 

285 channels = Bunch() 

286 brain_atlas = brain_atlas or AllenAtlas 

287 # need to find the collection bruh 

288 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0] 

289 collection = _collection_filter_from_args(probe=probe) 

290 collections = one.list_collections(eid, filename='channels*', collection=collection, 

291 revision=revision) 

292 probe_collection = _get_spike_sorting_collection(collections, probe) 

293 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection) 

294 depths = chn_coords[:, 1] 

295 

296 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

297 get('tracing_exists', False) 

298 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

299 get('alignment_resolved', False) 

300 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

301 get('alignment_count', 0) 

302 

303 if tracing: 

304 xyz = np.array(insertion['json']['xyz_picks']) / 1e6 

305 if resolved: 

306 

307 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. ' 

308 f'Channel and cluster locations obtained from ephys aligned histology ' 

309 f'track.') 

310 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

311 provenance='Ephys aligned histology track')[0] 

312 align_key = insertion['json']['extended_qc']['alignment_stored'] 

313 feature = traj['json'][align_key][0] 

314 track = traj['json'][align_key][1] 

315 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

316 feature_prev=feature, 

317 brain_atlas=brain_atlas, speedy=True) 

318 chans = ephysalign.get_channel_locations(feature, track) 

319 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

320 source = 'resolved' 

321 elif counts > 0 and aligned: 

322 _logger.debug(f'Channel locations for {eid}/{probe} have not been ' 

323 f'resolved. However, alignment flag set to True so channel and cluster' 

324 f' locations will be obtained from latest available ephys aligned ' 

325 f'histology track.') 

326 # get the latest user aligned channels 

327 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, 

328 provenance='Ephys aligned histology track')[0] 

329 align_key = insertion['json']['extended_qc']['alignment_stored'] 

330 feature = traj['json'][align_key][0] 

331 track = traj['json'][align_key][1] 

332 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

333 feature_prev=feature, 

334 brain_atlas=brain_atlas, speedy=True) 

335 chans = ephysalign.get_channel_locations(feature, track) 

336 

337 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

338 source = 'aligned' 

339 else: 

340 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. ' 

341 f'Channel and cluster locations obtained from histology track.') 

342 # get the channels from histology tracing 

343 xyz = xyz[np.argsort(xyz[:, 2]), :] 

344 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) 

345 channels[probe] = _channels_traj2bunch(chans, brain_atlas) 

346 source = 'traced' 

347 channels[probe]['axial_um'] = chn_coords[:, 1] 

348 channels[probe]['lateral_um'] = chn_coords[:, 0] 

349 

350 else: 

351 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}') 

352 source = '' 

353 channels = None 

354 

355 if return_source: 

356 return channels, source 

357 else: 

358 return channels 

359 

360 

361def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None): 

362 """ 

363 Load the brain locations of each channel for a given session/probe 

364 

365 Parameters 

366 ---------- 

367 eid : [str, UUID, Path, dict] 

368 Experiment session identifier; may be a UUID, URL, experiment reference string 

369 details dict or Path 

370 probe : [str, list of str] 

371 The probe label(s), e.g. 'probe01' 

372 one : one.api.OneAlyx 

373 An instance of ONE (shouldn't be in 'local' mode) 

374 aligned : bool 

375 Whether to get the latest user aligned channel when not resolved or use histology track 

376 brain_atlas : iblatlas.BrainAtlas 

377 Brain atlas object (default: Allen atlas) 

378 Returns 

379 ------- 

380 dict of one.alf.io.AlfBunch 

381 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

382 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

383 optional: string 'resolved', 'aligned', 'traced' or '' 

384 """ 

385 one = one or ONE() 

386 brain_atlas = brain_atlas or AllenAtlas() 

387 if isinstance(eid, dict): 

388 ses = eid 

389 eid = ses['url'][-36:] 

390 else: 

391 eid = one.to_eid(eid) 

392 collection = _collection_filter_from_args(probe=probe) 

393 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection, 

394 brain_regions=brain_atlas.regions) 

395 incomplete_probes = [k for k in channels if 'x' not in channels[k]] 

396 for iprobe in incomplete_probes: 

397 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned, 

398 brain_atlas=brain_atlas, return_source=True) 

399 if channels_ is not None: 

400 channels[iprobe] = channels_[iprobe] 

401 return channels 

402 

403 

404def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

405 brain_regions=None, nested=True, collection=None, return_collection=False): 

406 """ 

407 From an eid, loads spikes and clusters for all probes 

408 The following set of dataset types are loaded: 

409 'clusters.channels', 

410 'clusters.depths', 

411 'clusters.metrics', 

412 'spikes.clusters', 

413 'spikes.times', 

414 'probes.description' 

415 :param eid: experiment UUID or pathlib.Path of the local session 

416 :param one: an instance of OneAlyx 

417 :param probe: name of probe to load in, if not given all probes for session will be loaded 

418 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

419 :param spike_sorter: name of the spike sorting you want to load (None for default) 

420 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00" 

421 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

422 :param nested: if a single probe is required, do not output a dictionary with the probe name as key 

423 :param return_collection: (False) if True, will return the collection used to load 

424 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe) 

425 """ 

426 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.' 

427 'Use brainbox.io.one.SpikeSortingLoader instead') 

428 if collection is None: 

429 collection = _collection_filter_from_args(probe, spike_sorter) 

430 _logger.debug(f"load spike sorting with collection filter {collection}") 

431 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types, 

432 brain_regions=brain_regions) 

433 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True) 

434 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

435 if nested is False and len(spikes.keys()) == 1: 

436 k = list(spikes.keys())[0] 

437 channels = channels[k] 

438 clusters = clusters[k] 

439 spikes = spikes[k] 

440 if return_collection: 

441 return spikes, clusters, channels, collection 

442 else: 

443 return spikes, clusters, channels 

444 

445 

446def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, 

447 brain_regions=None, return_collection=False): 

448 """ 

449 From an eid, loads spikes and clusters for all probes 

450 The following set of dataset types are loaded: 

451 'clusters.channels', 

452 'clusters.depths', 

453 'clusters.metrics', 

454 'spikes.clusters', 

455 'spikes.times', 

456 'probes.description' 

457 :param eid: experiment UUID or pathlib.Path of the local session 

458 :param one: an instance of OneAlyx 

459 :param probe: name of probe to load in, if not given all probes for session will be loaded 

460 :param dataset_types: additional spikes/clusters objects to add to the standard default list 

461 :param spike_sorter: name of the spike sorting you want to load (None for default) 

462 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided 

463 :param return_collection:(bool - False) if True, returns the collection for loading the data 

464 :return: spikes, clusters (dict of bunch, 1 bunch per probe) 

465 """ 

466 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g

467 'Use brainbox.io.one.SpikeSortingLoader instead') 

468 collection = _collection_filter_from_args(probe, spike_sorter) 1g

469 _logger.debug(f"load spike sorting with collection filter {collection}") 1g

470 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g

471 return_channels=False, dataset_types=dataset_types, 

472 brain_regions=brain_regions) 

473 if return_collection: 1g

474 return spikes, clusters, collection 

475 else: 

476 return spikes, clusters 1g

477 

478 

479def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None, 

480 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False): 

481 """ 

482 For a given eid, get spikes, clusters and channels information, and merges clusters 

483 and channels information before returning all three variables. 

484 

485 Parameters 

486 ---------- 

487 eid : [str, UUID, Path, dict] 

488 Experiment session identifier; may be a UUID, URL, experiment reference string 

489 details dict or Path 

490 one : one.api.OneAlyx 

491 An instance of ONE (shouldn't be in 'local' mode) 

492 probe : [str, list of str] 

493 The probe label(s), e.g. 'probe01' 

494 aligned : bool 

495 Whether to get the latest user aligned channel when not resolved or use histology track 

496 dataset_types : list of str 

497 Optional additional spikes/clusters objects to add to the standard default list 

498 spike_sorter : str 

499 Name of the spike sorting you want to load (None for default which is pykilosort if it's 

500 available otherwise the default MATLAB kilosort) 

501 brain_atlas : iblatlas.atlas.BrainAtlas 

502 Brain atlas object (default: Allen atlas) 

503 return_collection: bool 

504 Returns an extra argument with the collection chosen 

505 

506 Returns 

507 ------- 

508 spikes : dict of one.alf.io.AlfBunch 

509 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

510 session and spike sorter, with keys ('clusters', 'times') 

511 clusters : dict of one.alf.io.AlfBunch 

512 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

513 ('channels', 'depths', 'metrics') 

514 channels : dict of one.alf.io.AlfBunch 

515 A dict with probe labels as keys, contains channel locations with keys ('acronym', 

516 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. 

517 """ 

518 # --- Get spikes and clusters data 

519 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 

520 'Use brainbox.io.one.SpikeSortingLoader instead') 

521 one = one or ONE() 

522 brain_atlas = brain_atlas or AllenAtlas() 

523 spikes, clusters, collection = load_spike_sorting( 

524 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True) 

525 # -- Get brain regions and assign to clusters 

526 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned, 

527 brain_atlas=brain_atlas) 

528 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) 

529 if nested is False and len(spikes.keys()) == 1: 

530 k = list(spikes.keys())[0] 

531 channels = channels[k] 

532 clusters = clusters[k] 

533 spikes = spikes[k] 

534 if return_collection: 

535 return spikes, clusters, channels, collection 

536 else: 

537 return spikes, clusters, channels 

538 

539 

540def load_ephys_session(eid, one=None): 

541 """ 

542 From an eid, hits the Alyx database and downloads a standard default set of dataset types 

543 From a local session Path (pathlib.Path), loads a standard default set of dataset types 

544 to perform analysis: 

545 'clusters.channels', 

546 'clusters.depths', 

547 'clusters.metrics', 

548 'spikes.clusters', 

549 'spikes.times', 

550 'probes.description' 

551 

552 Parameters 

553 ---------- 

554 eid : [str, UUID, Path, dict] 

555 Experiment session identifier; may be a UUID, URL, experiment reference string 

556 details dict or Path 

557 one : oneibl.one.OneAlyx, optional 

558 ONE object to use for loading. Will generate internal one if not used, by default None 

559 

560 Returns 

561 ------- 

562 spikes : dict of one.alf.io.AlfBunch 

563 A dict with probe labels as keys, contains bunch(es) of spike data for the provided 

564 session and spike sorter, with keys ('clusters', 'times') 

565 clusters : dict of one.alf.io.AlfBunch 

566 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys 

567 ('channels', 'depths', 'metrics') 

568 trials : one.alf.io.AlfBunch of numpy.ndarray 

569 The session trials data 

570 """ 

571 assert one 1g

572 spikes, clusters = load_spike_sorting(eid, one=one) 1g

573 trials = one.load_object(eid, 'trials') 1g

574 return spikes, clusters, trials 1g

575 

576 

577def _remove_old_clusters(session_path, probe): 

578 # gets clusters and spikes from a local session folder 

579 probe_path = session_path.joinpath('alf', probe) 

580 

581 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead 

582 cluster_file = probe_path.joinpath('clusters.metrics.csv') 

583 

584 if cluster_file.exists(): 

585 os.remove(cluster_file) 

586 _logger.info('Deleting old clusters.metrics.csv file') 

587 

588 

589def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None): 

590 """ 

591 Takes (default and any extra) values in given keys from channels and assign them to clusters. 

592 If channels does not contain any data, the new keys are added to clusters but left empty. 

593 

594 Parameters 

595 ---------- 

596 dic_clus : dict of one.alf.io.AlfBunch 

597 1 bunch per probe, containing cluster information 

598 channels : dict of one.alf.io.AlfBunch 

599 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates') 

600 keys_to_add_extra : list of str 

601 Any extra keys to load into channels bunches 

602 

603 Returns 

604 ------- 

605 dict of one.alf.io.AlfBunch 

606 clusters (1 bunch per probe) with new keys values. 

607 """ 

608 probe_labels = list(channels.keys()) # Convert dict_keys into list 

609 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um'] 

610 

611 if keys_to_add_extra is None: 

612 keys_to_add = keys_to_add_default 

613 else: 

614 # Append extra optional keys 

615 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default)) 

616 

617 for label in probe_labels: 

618 clu_ch = dic_clus[label]['channels'] 

619 for key in keys_to_add: 

620 try: 

621 assert key in channels[label].keys() # Check key is in channels 

622 ch_key = channels[label][key] 

623 nch_key = len(ch_key) if ch_key is not None else 0 

624 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index 

625 dic_clus[label][key] = ch_key[clu_ch] 

626 else: 

627 _logger.warning( 

628 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels' 

629 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.') 

630 dic_clus[label][key] = [] 

631 except AssertionError: 

632 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge') 

633 continue 

634 

635 return dic_clus 

636 

637 

638def load_passive_rfmap(eid, one=None): 

639 """ 

640 For a given eid load in the passive receptive field mapping protocol data 

641 

642 Parameters 

643 ---------- 

644 eid : [str, UUID, Path, dict] 

645 Experiment session identifier; may be a UUID, URL, experiment reference string 

646 details dict or Path 

647 one : oneibl.one.OneAlyx, optional 

648 An instance of ONE (may be in 'local' - offline - mode) 

649 

650 Returns 

651 ------- 

652 one.alf.io.AlfBunch 

653 Passive receptive field mapping data 

654 """ 

655 one = one or ONE() 

656 

657 # Load in the receptive field mapping data 

658 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf') 

659 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin', 

660 collection='raw_passive_data'), dtype="uint8") 

661 y_pix, x_pix = 15, 15 

662 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0]) 

663 rf_map['frames'] = frames 

664 

665 return rf_map 

666 

667 

668def load_wheel_reaction_times(eid, one=None): 

669 """ 

670 Return the calculated reaction times for session. Reaction times are defined as the time 

671 between the go cue (onset tone) and the onset of the first substantial wheel movement. A 

672 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the 

673 distance to threshold (~0.1 radians). 

674 

675 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if 

676 there was no detected movement withing the period, or when the goCue_times or feedback_times 

677 are nan. 

678 

679 Parameters 

680 ---------- 

681 eid : [str, UUID, Path, dict] 

682 Experiment session identifier; may be a UUID, URL, experiment reference string 

683 details dict or Path 

684 one : one.api.OneAlyx, optional 

685 one object to use for loading. Will generate internal one if not used, by default None 

686 

687 Returns 

688 ---------- 

689 array-like 

690 reaction times 

691 """ 

692 if one is None: 

693 one = ONE() 

694 

695 trials = one.load_object(eid, 'trials') 

696 # If already extracted, load and return 

697 if trials and 'firstMovement_times' in trials: 

698 return trials['firstMovement_times'] - trials['goCue_times'] 

699 # Otherwise load wheelMoves object and calculate 

700 moves = one.load_object(eid, 'wheelMoves') 

701 # Re-extract wheel moves if necessary 

702 if not moves or 'peakAmplitude' not in moves: 

703 wheel = one.load_object(eid, 'wheel') 

704 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 

705 assert trials and moves, 'unable to load trials and wheelMoves data' 

706 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials) 

707 return firstMove_times - trials['goCue_times'] 

708 

709 

710def load_iti(trials): 

711 """ 

712 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey 

713 screen commencing at stimulus off and lasting until the quiescent period at the start of the 

714 following trial. Note that the ITI for the first trial is the time between the first trial 

715 and the next, therefore the last value is NaN. 

716 

717 Parameters 

718 ---------- 

719 trials : one.alf.io.AlfBunch 

720 An ALF trials object containing the keys {'intervals', 'stimOff_times'}. 

721 

722 Returns 

723 ------- 

724 np.array 

725 An array of inter-trial intervals, the last value being NaN. 

726 """ 

727 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1n

728 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1n

729 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1n

730 

731 

732def load_channels_from_insertion(ins, depths=None, one=None, ba=None): 

733 

734 PROV_2_VAL = { 

735 'Resolved': 90, 

736 'Ephys aligned histology track': 70, 

737 'Histology track': 50, 

738 'Micro-manipulator': 30, 

739 'Planned': 10} 

740 

741 one = one or ONE() 

742 ba = ba or atlas.AllenAtlas() 

743 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id']) 

744 val = [PROV_2_VAL[tr['provenance']] for tr in traj] 

745 idx = np.argmax(val) 

746 traj = traj[idx] 

747 if depths is None: 

748 depths = trace_header(version=1)[:, 1] 

749 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator': 

750 ins = atlas.Insertion.from_dict(traj) 

751 # Deepest coordinate first 

752 xyz = np.c_[ins.tip, ins.entry].T 

753 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

754 TIP_SIZE_UM) / 1e6) 

755 else: 

756 xyz = np.array(ins['json']['xyz_picks']) / 1e6 

757 if traj['provenance'] == 'Histology track': 

758 xyz = xyz[np.argsort(xyz[:, 2]), :] 

759 xyz_channels = histology.interpolate_along_track(xyz, (depths + 

760 TIP_SIZE_UM) / 1e6) 

761 else: 

762 align_key = ins['json']['extended_qc']['alignment_stored'] 

763 feature = traj['json'][align_key][0] 

764 track = traj['json'][align_key][1] 

765 ephysalign = EphysAlignment(xyz, depths, track_prev=track, 

766 feature_prev=feature, 

767 brain_atlas=ba, speedy=True) 

768 xyz_channels = ephysalign.get_channel_locations(feature, track) 

769 return xyz_channels 

770 

771 

772@dataclass 

773class SpikeSortingLoader: 

774 """ 

775 Object that will load spike sorting data for a given probe insertion. 

776 This class can be instantiated in several manners 

777 - With Alyx database probe id: 

778 SpikeSortingLoader(pid=pid, one=one) 

779 - With Alyx database eic and probe name: 

780 SpikeSortingLoader(eid=eid, pname='probe00', one=one) 

781 - From a local session and probe name: 

782 SpikeSortingLoader(session_path=session_path, pname='probe00') 

783 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded. 

784 """ 

785 one: One = None 

786 atlas: None = None 

787 pid: str = None 

788 eid: str = '' 

789 pname: str = '' 

790 session_path: Path = '' 

791 # the following properties are the outcome of the post init function 

792 collections: list = None 

793 datasets: list = None # list of all datasets belonging to the session 

794 # the following properties are the outcome of a reading function 

795 files: dict = None 

796 collection: str = '' 

797 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced' 

798 spike_sorter: str = 'pykilosort' 

799 spike_sorting_path: Path = None 

800 _sync: dict = None 

801 

802 def __post_init__(self): 

803 # pid gets precedence 

804 if self.pid is not None: 1cbdf

805 try: 1f

806 self.eid, self.pname = self.one.pid2eid(self.pid) 1f

807 except NotImplementedError: 

808 if self.eid == '' or self.pname == '': 

809 raise IOError("Cannot infer session id and probe name from pid. " 

810 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.") 

811 self.session_path = self.one.eid2path(self.eid) 1f

812 # then eid / pname combination 

813 elif self.session_path is None or self.session_path == '': 1cbd

814 self.session_path = self.one.eid2path(self.eid) 1cbd

815 # fully local providing a session path 

816 else: 

817 if self.one: 

818 self.eid = self.one.to_eid(self.session_path) 

819 else: 

820 self.one = One(cache_dir=self.session_path.parents[2], mode='local') 

821 df_sessions = cache._make_sessions_df(self.session_path) 

822 self.one._cache['sessions'] = df_sessions.set_index('id') 

823 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False) 

824 self.eid = str(self.session_path.relative_to(self.session_path.parents[2])) 

825 # populates default properties 

826 self.collections = self.one.list_collections( 1cbdf

827 self.eid, filename='spikes*', collection=f"alf/{self.pname}*") 

828 self.datasets = self.one.list_datasets(self.eid) 1cbdf

829 if self.atlas is None: 1cbdf

830 self.atlas = AllenAtlas() 1cbd

831 self.files = {} 1cbdf

832 

833 @staticmethod 

834 def _get_attributes(dataset_types): 

835 """returns attributes to load for spikes and clusters objects""" 

836 if dataset_types is None: 1cbdf

837 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1cbdf

838 else: 

839 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1d

840 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1d

841 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1d

842 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1d

843 return spike_attributes, cluster_attributes 1d

844 

845 def _get_spike_sorting_collection(self, spike_sorter='pykilosort'): 

846 """ 

847 Filters a list or array of collections to get the relevant spike sorting dataset 

848 if there is a pykilosort, load it 

849 """ 

850 collection = next(filter(lambda c: c == f'alf/{self.pname}/{spike_sorter}', self.collections), None) 1cbdf

851 # otherwise, prefers the shortest 

852 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) 1cbdf

853 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") 1cbdf

854 return collection 1cbdf

855 

856 def load_spike_sorting_object(self, obj, *args, **kwargs): 

857 """ 

858 Loads an ALF object 

859 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

860 :param spike_sorter: (defaults to 'pykilosort') 

861 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

862 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

863 :param kwargs: additional arguments to be passed to one.api.One.load_object 

864 :param missing: 'raise' (default) or 'ignore' 

865 :return: 

866 """ 

867 self.download_spike_sorting_object(obj, *args, **kwargs) 

868 return alfio.load_object(self.files[obj]) 

869 

870 def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None, 

871 missing='raise', **kwargs): 

872 """ 

873 Downloads an ALF object 

874 :param obj: object name, str between 'spikes', 'clusters' or 'channels' 

875 :param spike_sorter: (defaults to 'pykilosort') 

876 :param dataset_types: list of extra dataset types, for example ['spikes.samples'] 

877 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' 

878 :param kwargs: additional arguments to be passed to one.api.One.load_object 

879 :param missing: 'raise' (default) or 'ignore' 

880 :return: 

881 """ 

882 if len(self.collections) == 0: 1cbdf

883 return {}, {}, {} 

884 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1cbdf

885 collection = collection or self.collection 1cbdf

886 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1cbdf

887 spike_attributes, cluster_attributes = self._get_attributes(dataset_types) 1cbdf

888 attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes} 1cbdf

889 try: 1cbdf

890 self.files[obj] = self.one.load_object( 1cbdf

891 self.eid, obj=obj, attribute=attributes.get(obj, None), 

892 collection=collection, download_only=True, **kwargs) 

893 except ALFObjectNotFound as e: 1cbd

894 if missing == 'raise': 1cbd

895 raise e 

896 

897 def download_spike_sorting(self, **kwargs): 

898 """ 

899 Downloads spikes, clusters and channels 

900 :param spike_sorter: (defaults to 'pykilosort') 

901 :param dataset_types: list of extra dataset types 

902 :return: 

903 """ 

904 for obj in ['spikes', 'clusters', 'channels']: 1cbdf

905 self.download_spike_sorting_object(obj=obj, **kwargs) 1cbdf

906 self.spike_sorting_path = self.files['spikes'][0].parent 1cbdf

907 

908 def load_channels(self, **kwargs): 

909 """ 

910 Loads channels 

911 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

912 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

913 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

914 - resolved: channel locations alignments have been agreed upon 

915 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

916 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

917 

918 :param spike_sorter: (defaults to 'pykilosort') 

919 :param dataset_types: list of extra dataset types 

920 :return: 

921 """ 

922 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting 

923 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1cbdf

924 if 'electrodeSites' in self.files: 1cbdf

925 channels = alfio.load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1f

926 else: # otherwise, we try to load the channel object from the spike sorting folder - this may not contain histology 

927 self.download_spike_sorting_object(obj='channels', **kwargs) 1cbd

928 channels = alfio.load_object(self.files['channels'], wildcards=self.one.wildcards) 1cbd

929 if 'brainLocationIds_ccf_2017' not in channels: 1cbdf

930 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1b

931 _channels, self.histology = _load_channel_locations_traj( 1b

932 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True) 

933 if _channels: 1b

934 channels = _channels[self.pname] 

935 else: 

936 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1cbdf

937 self.histology = 'alf' 1cbdf

938 return channels 1cbdf

939 

940 def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs): 

941 """ 

942 Loads spikes, clusters and channels 

943 

944 There could be several spike sorting collections, by default the loader will get the pykilosort collection 

945 

946 The channel locations can come from several sources, it will load the most advanced version of the histology available, 

947 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): 

948 - alf: the final version of channel locations, same as resolved with the difference that data is on file 

949 - resolved: channel locations alignments have been agreed upon 

950 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate 

951 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data 

952 

953 :param spike_sorter: (defaults to 'pykilosort') 

954 :param dataset_types: list of extra dataset types 

955 :return: 

956 """ 

957 if len(self.collections) == 0: 1cbdf

958 return {}, {}, {} 

959 self.files = {} 1cbdf

960 self.spike_sorter = spike_sorter 1cbdf

961 self.download_spike_sorting(spike_sorter=spike_sorter, **kwargs) 1cbdf

962 channels = self.load_channels(spike_sorter=spike_sorter, **kwargs) 1cbdf

963 clusters = alfio.load_object(self.files['clusters'], wildcards=self.one.wildcards) 1cbdf

964 spikes = alfio.load_object(self.files['spikes'], wildcards=self.one.wildcards) 1cbdf

965 

966 return spikes, clusters, channels 1cbdf

967 

968 @staticmethod 

969 def compute_metrics(spikes, clusters=None): 

970 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size 

971 metrics = pd.DataFrame(quick_unit_metrics( 

972 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc))) 

973 return metrics 

974 

975 @staticmethod 

976 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False): 

977 """ 

978 Merge the metrics and the channel information into the clusters dictionary 

979 :param spikes: 

980 :param clusters: 

981 :param channels: 

982 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used 

983 for clusters or analysis applications (defaults to None). 

984 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false) 

985 :return: cluster dictionary containing metrics and histology 

986 """ 

987 if spikes == {}: 1bf

988 return 

989 nc = clusters['channels'].size 1bf

990 # recompute metrics if they are not available 

991 metrics = None 1bf

992 if 'metrics' in clusters: 1bf

993 metrics = clusters.pop('metrics') 1bf

994 if metrics.shape[0] != nc: 1bf

995 metrics = None 

996 if metrics is None or compute_metrics is True: 1bf

997 _logger.debug("recompute clusters metrics") 

998 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters) 

999 if isinstance(cache_dir, Path): 

1000 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt')) 

1001 for k in metrics.keys(): 1bf

1002 clusters[k] = metrics[k].to_numpy() 1bf

1003 for k in channels.keys(): 1bf

1004 clusters[k] = channels[k][clusters['channels']] 1bf

1005 if cache_dir is not None: 1bf

1006 _logger.debug(f'caching clusters metrics in {cache_dir}') 

1007 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt')) 

1008 return clusters 1bf

1009 

1010 @property 

1011 def url(self): 

1012 """Gets flatiron URL for the session""" 

1013 webclient = getattr(self.one, '_web_client', None) 

1014 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None 

1015 

1016 def _get_probe_info(self): 

1017 if self._sync is None: 1d

1018 timestamps = self.one.load_dataset( 1d

1019 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}') 

1020 try: 1d

1021 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1d

1022 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}')) 

1023 fs = spikeglx._get_fs_from_meta(ap_meta) 

1024 except ALFObjectNotFound: 1d

1025 ap_meta = None 1d

1026 fs = 30_000 1d

1027 self._sync = { 1d

1028 'timestamps': timestamps, 

1029 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'), 

1030 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'), 

1031 'ap_meta': ap_meta, 

1032 'fs': fs, 

1033 } 

1034 

1035 def timesprobe2times(self, values, direction='forward'): 

1036 self._get_probe_info() 

1037 if direction == 'forward': 

1038 return self._sync['forward'](values * self._sync['fs']) 

1039 elif direction == 'reverse': 

1040 return self._sync['reverse'](values) / self._sync['fs'] 

1041 

1042 def samples2times(self, values, direction='forward'): 

1043 """ 

1044 Converts ephys sample values to session main clock seconds 

1045 :param values: numpy array of times in seconds or samples to resync 

1046 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse' 

1047 (seconds main time to samples probe time) 

1048 :return: 

1049 """ 

1050 self._get_probe_info() 1d

1051 return self._sync[direction](values) 1d

1052 

1053 @property 

1054 def pid2ref(self): 

1055 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1cb

1056 

1057 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, **kwargs): 

1058 """ 

1059 :param spikes: spikes dictionary or Bunch 

1060 :param channels: channels dictionary or Bunch. 

1061 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png". 

1062 Otherwise, plot. 

1063 :param br: brain regions object (optional) 

1064 :param label: label for saved image (optional, default="raster") 

1065 :param time_series: timeseries dictionary for behavioral event times (optional) 

1066 :param **kwargs: kwargs passed to `driftmap()` (optional) 

1067 :return: 

1068 """ 

1069 br = br or BrainRegions() 1c

1070 time_series = time_series or {} 1c

1071 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c

1072 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col') 

1073 axs[0, 1].set_axis_off() 1c

1074 # axs[0, 0].set_xticks([]) 

1075 if kwargs is None: 1c

1076 # set default raster plot parameters 

1077 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5} 

1078 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c

1079 title_str = f"{self.pid2ref}, {self.pid} \n" \ 1c

1080 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters" 

1081 axs[0, 0].title.set_text(title_str) 1c

1082 for k, ts in time_series.items(): 1c

1083 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0]) 

1084 if 'atlas_id' in channels: 1c

1085 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c

1086 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology) 

1087 axs[1, 0].set_ylim(0, 3800) 1c

1088 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c

1089 fig.tight_layout() 1c

1090 

1091 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c

1092 if 'drift' in self.files: 1c

1093 drift = alfio.load_object(self.files['drift'], wildcards=self.one.wildcards) 

1094 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5) 

1095 

1096 if save_dir is not None: 1c

1097 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) 

1098 fig.savefig(png_file) 

1099 plt.close(fig) 

1100 gc.collect() 

1101 else: 

1102 return fig, axs 1c

1103 

1104 

1105@dataclass 

1106class SessionLoader: 

1107 """ 

1108 Object to load session data for a give session in the recommended way. 

1109 

1110 Parameters 

1111 ---------- 

1112 one: one.api.ONE instance 

1113 Can be in remote or local mode (required) 

1114 session_path: string or pathlib.Path 

1115 The absolute path to the session (one of session_path or eid is required) 

1116 eid: string 

1117 database UUID of the session (one of session_path or eid is required) 

1118 

1119 If both are provided, session_path takes precedence over eid. 

1120 

1121 Examples 

1122 -------- 

1123 1) Load all available session data for one session: 

1124 >>> from one.api import ONE 

1125 >>> from brainbox.io.one import SessionLoader 

1126 >>> one = ONE() 

1127 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/') 

1128 # Object is initiated, but no data is loaded as you can see in the data_info attribute 

1129 >>> sess_loader.data_info 

1130 name is_loaded 

1131 0 trials False 

1132 1 wheel False 

1133 2 pose False 

1134 3 motion_energy False 

1135 4 pupil False 

1136 

1137 # Loading all available session data, the data_info attribute now shows which data has been loaded 

1138 >>> sess_loader.load_session_data() 

1139 >>> sess_loader.data_info 

1140 name is_loaded 

1141 0 trials True 

1142 1 wheel True 

1143 2 pose True 

1144 3 motion_energy True 

1145 4 pupil False 

1146 

1147 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g. 

1148 >>> type(sess_loader.trials) 

1149 pandas.core.frame.DataFrame 

1150 >>> sess_loader.trials.shape 

1151 (626, 18) 

1152 # Each data comes with its own timestamps in a column called 'times' 

1153 >>> sess_loader.wheel['times'] 

1154 0 0.134286 

1155 1 0.135286 

1156 2 0.136286 

1157 3 0.137286 

1158 4 0.138286 

1159 ... 

1160 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera. 

1161 # The dataframes of all cameras are collected in a dictionary 

1162 >>> type(sess_loader.pose) 

1163 dict 

1164 >>> sess_loader.pose.keys() 

1165 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera']) 

1166 >>> sess_loader.pose['bodyCamera'].columns 

1167 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object') 

1168 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading 

1169 functions: 

1170 >>> sess_loader.load_wheel(sampling_rate=100) 

1171 """ 

1172 one: One = None 

1173 session_path: Path = '' 

1174 eid: str = '' 

1175 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1176 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1177 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1178 pose: dict = field(default_factory=dict, repr=False) 

1179 motion_energy: dict = field(default_factory=dict, repr=False) 

1180 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) 

1181 

1182 def __post_init__(self): 

1183 """ 

1184 Function that runs automatically after initiation of the dataclass attributes. 

1185 Checks for required inputs, sets session_path and eid, creates data_info table. 

1186 """ 

1187 if self.one is None: 1ae

1188 raise ValueError("An input to one is required. If not connection to a database is desired, it can be " 

1189 "a fully local instance of One.") 

1190 # If session path is given, takes precedence over eid 

1191 if self.session_path is not None and self.session_path != '': 1ae

1192 self.eid = self.one.to_eid(self.session_path) 1ae

1193 self.session_path = Path(self.session_path) 1ae

1194 # Providing no session path, try to infer from eid 

1195 else: 

1196 if self.eid is not None and self.eid != '': 

1197 self.session_path = self.one.eid2path(self.eid) 

1198 else: 

1199 raise ValueError("If no session path is given, eid is required.") 

1200 

1201 data_names = [ 1ae

1202 'trials', 

1203 'wheel', 

1204 'pose', 

1205 'motion_energy', 

1206 'pupil' 

1207 ] 

1208 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1ae

1209 

1210 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False): 

1211 """ 

1212 Function to load available session data into the SessionLoader object. Input parameters allow to control which 

1213 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input 

1214 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored 

1215 in SessionLoader.data_info 

1216 

1217 Parameters 

1218 ---------- 

1219 trials: boolean 

1220 Whether to load all trials data into SessionLoader.trials, default is True 

1221 wheel: boolean 

1222 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True 

1223 pose: boolean 

1224 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose, 

1225 default is True 

1226 motion_energy: boolean 

1227 Whether to load motion energy data (whisker pad for left/right camera, body for body camera) 

1228 into SessionLoader.motion_energy, default is True 

1229 pupil: boolean 

1230 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil, 

1231 default is True 

1232 reload: boolean 

1233 Whether to reload data that has already been loaded into this SessionLoader object, default is False 

1234 """ 

1235 load_df = self.data_info.copy() 1e

1236 load_df['to_load'] = [ 1e

1237 trials, 

1238 wheel, 

1239 pose, 

1240 motion_energy, 

1241 pupil 

1242 ] 

1243 load_df['load_func'] = [ 1e

1244 self.load_trials, 

1245 self.load_wheel, 

1246 self.load_pose, 

1247 self.load_motion_energy, 

1248 self.load_pupil 

1249 ] 

1250 

1251 for idx, row in load_df.iterrows(): 1e

1252 if row['to_load'] is False: 1e

1253 _logger.debug(f"Not loading {row['name']} data, set to False.") 

1254 elif row['is_loaded'] is True and reload is False: 1e

1255 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1e

1256 else: 

1257 try: 1e

1258 _logger.info(f"Loading {row['name']} data") 1e

1259 row['load_func']() 1e

1260 self.data_info.loc[idx, 'is_loaded'] = True 1e

1261 except BaseException as e: 

1262 _logger.warning(f"Could not load {row['name']} data.") 

1263 _logger.debug(e) 

1264 

1265 def load_trials(self): 

1266 """ 

1267 Function to load trials data into SessionLoader.trials 

1268 """ 

1269 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex 

1270 self.one.wildcards = False 1em

1271 self.trials = self.one.load_object(self.eid, 'trials', collection='alf', attribute=r'(?!itiDuration).*').to_df() 1em

1272 self.one.wildcards = True 1em

1273 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1em

1274 

1275 def load_wheel(self, fs=1000, corner_frequency=20, order=8): 

1276 """ 

1277 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position 

1278 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which 

1279 a Butterworth low-pass filter is applied. 

1280 

1281 Parameters 

1282 ---------- 

1283 fs: int, float 

1284 Sampling frequency for the wheel position, default is 1000 Hz 

1285 corner_frequency: int, float 

1286 Corner frequency of Butterworth low-pass filter, default is 20 

1287 order: int, float 

1288 Order of Butterworth low_pass filter, default is 8 

1289 """ 

1290 wheel_raw = self.one.load_object(self.eid, 'wheel') 1el

1291 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1el

1292 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps") 

1293 # resample the wheel position and compute velocity, acceleration 

1294 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1el

1295 self.wheel['position'], self.wheel['times'] = interpolate_position( 1el

1296 wheel_raw['timestamps'], wheel_raw['position'], freq=fs) 

1297 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1el

1298 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order) 

1299 self.wheel = self.wheel.apply(np.float32) 1el

1300 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1el

1301 

1302 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']): 

1303 """ 

1304 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a 

1305 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas 

1306 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera. 

1307 

1308 Parameters 

1309 ---------- 

1310 likelihood_thr: float 

1311 The position of each tracked body part come with a likelihood of that estimate for each time point. 

1312 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set 

1313 likelihood_thr=1. Default is 0.9 

1314 views: list 

1315 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1316 """ 

1317 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1318 self.pose = {} 1khe

1319 for view in views: 1khe

1320 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times']) 1khe

1321 # Double check if video timestamps are correct length or can be fixed 

1322 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1khe

1323 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1khe

1324 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1khe

1325 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1khe

1326 

1327 def load_motion_energy(self, views=['left', 'right', 'body']): 

1328 """ 

1329 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a 

1330 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are 

1331 pandas Dataframes with the timestamps and motion energy data. 

1332 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad 

1333 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the 

1334 body (bodyMotionEnergy). 

1335 

1336 Parameters 

1337 ---------- 

1338 views: list 

1339 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} 

1340 """ 

1341 names = {'left': 'whiskerMotionEnergy', 1je

1342 'right': 'whiskerMotionEnergy', 

1343 'body': 'bodyMotionEnergy'} 

1344 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger 

1345 self.motion_energy = {} 1je

1346 for view in views: 1je

1347 me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times']) 1je

1348 # Double check if video timestamps are correct length or can be fixed 

1349 times_fixed, motion_energy = self._check_video_timestamps( 1je

1350 view, me_raw['times'], me_raw['ROIMotionEnergy']) 

1351 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1je

1352 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1je

1353 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1je

1354 

1355 def load_licks(self): 

1356 """ 

1357 Not yet implemented 

1358 """ 

1359 pass 

1360 

1361 def load_pupil(self, snr_thresh=5.): 

1362 """ 

1363 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil. 

1364 

1365 Parameters 

1366 ---------- 

1367 snr_thresh: float 

1368 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data 

1369 will be considered unusable and will be discarded. 

1370 """ 

1371 # Try to load from features 

1372 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features']) 1he

1373 if 'features' in feat_raw.keys(): 1he

1374 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features']) 

1375 self.pupil = feats.copy() 

1376 self.pupil.insert(0, 'times', times_fixed) 

1377 

1378 # If unavailable compute on the fly 

1379 else: 

1380 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1he

1381 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1he

1382 and 'leftCamera' in self.pose.keys()): 

1383 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt 

1384 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1he

1385 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1he

1386 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1he

1387 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1he

1388 else: 

1389 self.load_pose(views=['left'], likelihood_thr=0.9) 

1390 dlc_thr = self.pose['leftCamera'].copy() 

1391 

1392 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1he

1393 try: 1he

1394 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1he

1395 except BaseException as e: 

1396 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. " 

1397 "Saving all NaNs for pupilDiameter_smooth.") 

1398 _logger.debug(e) 

1399 self.pupil['pupilDiameter_smooth'] = np.nan 

1400 

1401 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1he

1402 good_idxs = np.where( 1he

1403 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] 

1404 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1he

1405 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]))) 

1406 if snr < snr_thresh: 1he

1407 self.pupil = pd.DataFrame() 1h

1408 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h

1409 

1410 def _check_video_timestamps(self, view, video_timestamps, video_data): 

1411 """ 

1412 Helper function to check for the length of the video frames vs video timestamps and fix in case 

1413 timestamps are longer than video frames. 

1414 """ 

1415 # If camera times are shorter than video data, or empty, no current fix 

1416 if video_timestamps.shape[0] < video_data.shape[0]: 1jkhe

1417 if video_timestamps.shape[0] == 0: 

1418 msg = f'Camera times empty for {view}Camera.' 

1419 else: 

1420 msg = f'Camera times are shorter than video data for {view}Camera.' 

1421 _logger.warning(msg) 

1422 raise ValueError(msg) 

1423 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video. 

1424 # This is because the first few frames are sometimes not recorded. We can remove the first few 

1425 # timestamps in this case 

1426 elif video_timestamps.shape[0] > video_data.shape[0]: 1jkhe

1427 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1jkhe

1428 return video_timestamps_fixed, video_data 1jkhe

1429 else: 

1430 return video_timestamps, video_data 

1431 

1432 

1433class EphysSessionLoader(SessionLoader): 

1434 """ 

1435 Spike sorting enhanced version of SessionLoader 

1436 Loads spike sorting data for all probes in the session, in the self.ephys dict 

1437 >>> EphysSessionLoader(eid=eid, one=one) 

1438 To select for a specific probe 

1439 >>> EphysSessionLoader(eid=eid, one=one, pid=pid) 

1440 """ 

1441 def __init__(self, *args, pname=None, pid=None, **kwargs): 

1442 """ 

1443 Needs an active connection in order to get the list of insertions in the session 

1444 :param args: 

1445 :param kwargs: 

1446 """ 

1447 super().__init__(*args, **kwargs) 

1448 # if necessary, restrict the query 

1449 qargs = {} if pname is None else {'name': pname} 

1450 qargs = qargs or ({} if pid is None else {'id': pid}) 

1451 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs) 

1452 self.ephys = {} 

1453 for ins in insertions: 

1454 self.ephys[ins['name']] = {} 

1455 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one) 

1456 

1457 def load_session_data(self, *args, **kwargs): 

1458 super().load_session_data(*args, **kwargs) 

1459 self.load_spike_sorting() 

1460 

1461 def load_spike_sorting(self, pnames=None): 

1462 pnames = pnames or list(self.ephys.keys()) 

1463 for pname in pnames: 

1464 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting() 

1465 self.ephys[pname]['spikes'] = spikes 

1466 self.ephys[pname]['clusters'] = clusters 

1467 self.ephys[pname]['channels'] = channels 

1468 

1469 @property 

1470 def probes(self): 

1471 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}