Coverage for brainbox/io/one.py: 56%
626 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
2from dataclasses import dataclass, field
3import gc
4import logging
5import os
6from pathlib import Path
9import numpy as np
10import pandas as pd
11from scipy.interpolate import interp1d
12import matplotlib.pyplot as plt
14from one.api import ONE, One
15import one.alf.io as alfio
16from one.alf.files import get_alf_path
17from one.alf.exceptions import ALFObjectNotFound
18from one.alf import cache
19from neuropixel import TIP_SIZE_UM, trace_header
20import spikeglx
22from iblutil.util import Bunch
23from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
24from iblatlas.atlas import AllenAtlas, BrainRegions
25from iblatlas import atlas
26from ibllib.pipes import histology
27from ibllib.pipes.ephys_alignment import EphysAlignment
28from ibllib.plots import vertical_lines
30import brainbox.plot
31from brainbox.ephys_plots import plot_brain_regions
32from brainbox.metrics.single_units import quick_unit_metrics
33from brainbox.behavior.wheel import interpolate_position, velocity_filtered
34from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter
36_logger = logging.getLogger('ibllib')
39SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
40CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']
43def load_lfp(eid, one=None, dataset_types=None, **kwargs):
44 """
45 TODO Verify works
46 From an eid, hits the Alyx database and downloads the standard set of datasets
47 needed for LFP
48 :param eid:
49 :param dataset_types: additional dataset types to add to the list
50 :param open: if True, spikeglx readers are opened
51 :return: spikeglx.Reader
52 """
53 if dataset_types is None:
54 dataset_types = []
55 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*']
56 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes]
57 session_path = one.eid2path(eid)
59 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False)
60 if ef.get('lf', None)]
61 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles]
64def _collection_filter_from_args(probe, spike_sorter=None):
65 collection = f'alf/{probe}/{spike_sorter}' 1g
66 collection = collection.replace('None', '*') 1g
67 collection = collection.replace('/*', '*') 1g
68 collection = collection[:-1] if collection.endswith('/') else collection 1g
69 return collection 1g
72def _get_spike_sorting_collection(collections, pname):
73 """
74 Filters a list or array of collections to get the relevant spike sorting dataset
75 if there is a pykilosort, load it
76 """
77 #
78 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1gb
79 # otherwise, prefers the shortest
80 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1gb
81 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1gb
82 return collection 1gb
85def _channels_alyx2bunch(chans):
86 channels = Bunch({
87 'atlas_id': np.array([ch['brain_region'] for ch in chans]),
88 'x': np.array([ch['x'] for ch in chans]) / 1e6,
89 'y': np.array([ch['y'] for ch in chans]) / 1e6,
90 'z': np.array([ch['z'] for ch in chans]) / 1e6,
91 'axial_um': np.array([ch['axial'] for ch in chans]),
92 'lateral_um': np.array([ch['lateral'] for ch in chans])
93 })
94 return channels
97def _channels_traj2bunch(xyz_chans, brain_atlas):
98 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans))
99 channels = {
100 'x': xyz_chans[:, 0],
101 'y': xyz_chans[:, 1],
102 'z': xyz_chans[:, 2],
103 'acronym': brain_regions['acronym'],
104 'atlas_id': brain_regions['id']
105 }
107 return channels
110def _channels_bunch2alf(channels):
111 channels_ = { 1i
112 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6,
113 'brainLocationIds_ccf_2017': channels['atlas_id'],
114 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]}
115 return channels_ 1i
118def _channels_alf2bunch(channels, brain_regions=None):
119 # reformat the dictionary according to the standard that comes out of Alyx
120 channels_ = { 1icbdf
121 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6,
122 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6,
123 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6,
124 'acronym': None,
125 'atlas_id': channels['brainLocationIds_ccf_2017'],
126 'axial_um': channels['localCoordinates'][:, 1],
127 'lateral_um': channels['localCoordinates'][:, 0],
128 }
129 if brain_regions: 1icbdf
130 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icbdf
131 return channels_ 1icbdf
134def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None,
135 brain_regions=None):
136 """
137 Generic function to load spike sorting according data using ONE.
139 Will try to load one spike sorting for any probe present for the eid matching the collection
140 For each probe it will load a spike sorting:
141 - if there is one version: loads this one
142 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX)
144 Parameters
145 ----------
146 eid : [str, UUID, Path, dict]
147 Experiment session identifier; may be a UUID, URL, experiment reference string
148 details dict or Path
149 one : one.api.OneAlyx
150 An instance of ONE (may be in 'local' mode)
151 collection : str
152 collection filter word - accepts wildcards - can be a combination of spike sorter and
153 probe. See `ALF documentation`_ for details.
154 revision : str
155 A particular revision return (defaults to latest revision). See `ALF documentation`_ for
156 details.
157 return_channels : bool
158 Defaults to False otherwise loads channels from disk
160 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components
162 Returns
163 -------
164 spikes : dict of one.alf.io.AlfBunch
165 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
166 session and spike sorter, with keys ('clusters', 'times')
167 clusters : dict of one.alf.io.AlfBunch
168 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
169 ('channels', 'depths', 'metrics')
170 channels : dict of one.alf.io.AlfBunch
171 A dict with probe labels as keys, contains channel locations with keys ('acronym',
172 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs
173 non-lateralized.
174 """
175 one = one or ONE() 1gb
176 # enumerate probes and load according to the name
177 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1gb
178 if len(collections) == 0: 1gb
179 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
180 pnames = list(set(c.split('/')[1] for c in collections)) 1gb
181 spikes, clusters, channels = ({} for _ in range(3)) 1gb
183 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1gb
185 for pname in pnames: 1gb
186 probe_collection = _get_spike_sorting_collection(collections, pname) 1gb
187 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1gb
188 attribute=spike_attributes)
189 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1gb
190 attribute=cluster_attributes)
191 if return_channels: 1gb
192 channels = _load_channels_locations_from_disk( 1b
193 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions)
194 return spikes, clusters, channels 1b
195 else:
196 return spikes, clusters 1g
199def _get_attributes(dataset_types):
200 if dataset_types is None: 1gb
201 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1gb
202 else:
203 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
204 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
205 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
206 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
207 return spike_attributes, cluster_attributes
210def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None):
211 _logger.debug('loading spike sorting from disk') 1b
212 channels = Bunch({}) 1b
213 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) 1b
214 if len(collections) == 0: 1b
215 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
216 probes = list(set([c.split('/')[1] for c in collections])) 1b
217 for probe in probes: 1b
218 probe_collection = _get_spike_sorting_collection(collections, probe) 1b
219 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') 1b
220 # if the spike sorter has not aligned data, try and get the alignment available
221 if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): 1b
222 aligned_channel_collections = one.list_collections(
223 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision)
224 if len(aligned_channel_collections) == 0:
225 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}")
226 continue
227 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}")
228 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe)
229 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection)
230 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe])
231 # only have to reformat channels if we were able to load coordinates from disk
232 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) 1b
233 return channels 1b
236def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None):
237 """
238 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
239 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
240 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
241 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys
242 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017'
243 OR
244 'x', 'y', 'z', 'acronym', 'axial_um'
245 those are the guide for the interpolation
246 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates'
247 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object
248 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017
249 if a brain region object is provided, outputts a dict with keys
250 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um'
251 :return: Bunch or dictionary of channels with brain coordinates keys
252 """
253 NEUROPIXEL_VERSION = 1 1i
254 h = trace_header(version=NEUROPIXEL_VERSION) 1i
255 if channels is None: 1i
256 channels = {'localCoordinates': np.c_[h['x'], h['y']]}
257 nch = channels['localCoordinates'].shape[0] 1i
258 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i
259 channels_aligned = _channels_bunch2alf(channels_aligned) 1i
260 if 'localCoordinates' in channels_aligned.keys(): 1i
261 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i
262 else: # this is a edge case for a few spike sorting sessions
263 assert channels_aligned['mlapdv'].shape[0] == 384
264 aligned_depths = h['y']
265 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i
266 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i
267 channels['mlapdv'] = np.zeros((nch, 3)) 1i
268 for i in np.arange(3): 1i
269 channels['mlapdv'][:, i] = np.interp( 1i
270 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv]
271 # the brain locations have to be interpolated by nearest neighbour
272 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i
273 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i
274 if brain_regions is not None: 1i
275 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i
276 else:
277 return channels 1i
280def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
281 brain_atlas=None, return_source=False):
282 if not hasattr(one, 'alyx'): 1b
283 return {}, None 1b
284 _logger.debug(f"trying to load from traj {probe}")
285 channels = Bunch()
286 brain_atlas = brain_atlas or AllenAtlas
287 # need to find the collection bruh
288 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0]
289 collection = _collection_filter_from_args(probe=probe)
290 collections = one.list_collections(eid, filename='channels*', collection=collection,
291 revision=revision)
292 probe_collection = _get_spike_sorting_collection(collections, probe)
293 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection)
294 depths = chn_coords[:, 1]
296 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
297 get('tracing_exists', False)
298 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
299 get('alignment_resolved', False)
300 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
301 get('alignment_count', 0)
303 if tracing:
304 xyz = np.array(insertion['json']['xyz_picks']) / 1e6
305 if resolved:
307 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. '
308 f'Channel and cluster locations obtained from ephys aligned histology '
309 f'track.')
310 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
311 provenance='Ephys aligned histology track')[0]
312 align_key = insertion['json']['extended_qc']['alignment_stored']
313 feature = traj['json'][align_key][0]
314 track = traj['json'][align_key][1]
315 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
316 feature_prev=feature,
317 brain_atlas=brain_atlas, speedy=True)
318 chans = ephysalign.get_channel_locations(feature, track)
319 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
320 source = 'resolved'
321 elif counts > 0 and aligned:
322 _logger.debug(f'Channel locations for {eid}/{probe} have not been '
323 f'resolved. However, alignment flag set to True so channel and cluster'
324 f' locations will be obtained from latest available ephys aligned '
325 f'histology track.')
326 # get the latest user aligned channels
327 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
328 provenance='Ephys aligned histology track')[0]
329 align_key = insertion['json']['extended_qc']['alignment_stored']
330 feature = traj['json'][align_key][0]
331 track = traj['json'][align_key][1]
332 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
333 feature_prev=feature,
334 brain_atlas=brain_atlas, speedy=True)
335 chans = ephysalign.get_channel_locations(feature, track)
337 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
338 source = 'aligned'
339 else:
340 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. '
341 f'Channel and cluster locations obtained from histology track.')
342 # get the channels from histology tracing
343 xyz = xyz[np.argsort(xyz[:, 2]), :]
344 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6)
345 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
346 source = 'traced'
347 channels[probe]['axial_um'] = chn_coords[:, 1]
348 channels[probe]['lateral_um'] = chn_coords[:, 0]
350 else:
351 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}')
352 source = ''
353 channels = None
355 if return_source:
356 return channels, source
357 else:
358 return channels
361def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None):
362 """
363 Load the brain locations of each channel for a given session/probe
365 Parameters
366 ----------
367 eid : [str, UUID, Path, dict]
368 Experiment session identifier; may be a UUID, URL, experiment reference string
369 details dict or Path
370 probe : [str, list of str]
371 The probe label(s), e.g. 'probe01'
372 one : one.api.OneAlyx
373 An instance of ONE (shouldn't be in 'local' mode)
374 aligned : bool
375 Whether to get the latest user aligned channel when not resolved or use histology track
376 brain_atlas : iblatlas.BrainAtlas
377 Brain atlas object (default: Allen atlas)
378 Returns
379 -------
380 dict of one.alf.io.AlfBunch
381 A dict with probe labels as keys, contains channel locations with keys ('acronym',
382 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
383 optional: string 'resolved', 'aligned', 'traced' or ''
384 """
385 one = one or ONE()
386 brain_atlas = brain_atlas or AllenAtlas()
387 if isinstance(eid, dict):
388 ses = eid
389 eid = ses['url'][-36:]
390 else:
391 eid = one.to_eid(eid)
392 collection = _collection_filter_from_args(probe=probe)
393 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection,
394 brain_regions=brain_atlas.regions)
395 incomplete_probes = [k for k in channels if 'x' not in channels[k]]
396 for iprobe in incomplete_probes:
397 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned,
398 brain_atlas=brain_atlas, return_source=True)
399 if channels_ is not None:
400 channels[iprobe] = channels_[iprobe]
401 return channels
404def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
405 brain_regions=None, nested=True, collection=None, return_collection=False):
406 """
407 From an eid, loads spikes and clusters for all probes
408 The following set of dataset types are loaded:
409 'clusters.channels',
410 'clusters.depths',
411 'clusters.metrics',
412 'spikes.clusters',
413 'spikes.times',
414 'probes.description'
415 :param eid: experiment UUID or pathlib.Path of the local session
416 :param one: an instance of OneAlyx
417 :param probe: name of probe to load in, if not given all probes for session will be loaded
418 :param dataset_types: additional spikes/clusters objects to add to the standard default list
419 :param spike_sorter: name of the spike sorting you want to load (None for default)
420 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00"
421 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
422 :param nested: if a single probe is required, do not output a dictionary with the probe name as key
423 :param return_collection: (False) if True, will return the collection used to load
424 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
425 """
426 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.'
427 'Use brainbox.io.one.SpikeSortingLoader instead')
428 if collection is None:
429 collection = _collection_filter_from_args(probe, spike_sorter)
430 _logger.debug(f"load spike sorting with collection filter {collection}")
431 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types,
432 brain_regions=brain_regions)
433 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True)
434 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
435 if nested is False and len(spikes.keys()) == 1:
436 k = list(spikes.keys())[0]
437 channels = channels[k]
438 clusters = clusters[k]
439 spikes = spikes[k]
440 if return_collection:
441 return spikes, clusters, channels, collection
442 else:
443 return spikes, clusters, channels
446def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
447 brain_regions=None, return_collection=False):
448 """
449 From an eid, loads spikes and clusters for all probes
450 The following set of dataset types are loaded:
451 'clusters.channels',
452 'clusters.depths',
453 'clusters.metrics',
454 'spikes.clusters',
455 'spikes.times',
456 'probes.description'
457 :param eid: experiment UUID or pathlib.Path of the local session
458 :param one: an instance of OneAlyx
459 :param probe: name of probe to load in, if not given all probes for session will be loaded
460 :param dataset_types: additional spikes/clusters objects to add to the standard default list
461 :param spike_sorter: name of the spike sorting you want to load (None for default)
462 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
463 :param return_collection:(bool - False) if True, returns the collection for loading the data
464 :return: spikes, clusters (dict of bunch, 1 bunch per probe)
465 """
466 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g
467 'Use brainbox.io.one.SpikeSortingLoader instead')
468 collection = _collection_filter_from_args(probe, spike_sorter) 1g
469 _logger.debug(f"load spike sorting with collection filter {collection}") 1g
470 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g
471 return_channels=False, dataset_types=dataset_types,
472 brain_regions=brain_regions)
473 if return_collection: 1g
474 return spikes, clusters, collection
475 else:
476 return spikes, clusters 1g
479def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None,
480 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False):
481 """
482 For a given eid, get spikes, clusters and channels information, and merges clusters
483 and channels information before returning all three variables.
485 Parameters
486 ----------
487 eid : [str, UUID, Path, dict]
488 Experiment session identifier; may be a UUID, URL, experiment reference string
489 details dict or Path
490 one : one.api.OneAlyx
491 An instance of ONE (shouldn't be in 'local' mode)
492 probe : [str, list of str]
493 The probe label(s), e.g. 'probe01'
494 aligned : bool
495 Whether to get the latest user aligned channel when not resolved or use histology track
496 dataset_types : list of str
497 Optional additional spikes/clusters objects to add to the standard default list
498 spike_sorter : str
499 Name of the spike sorting you want to load (None for default which is pykilosort if it's
500 available otherwise the default MATLAB kilosort)
501 brain_atlas : iblatlas.atlas.BrainAtlas
502 Brain atlas object (default: Allen atlas)
503 return_collection: bool
504 Returns an extra argument with the collection chosen
506 Returns
507 -------
508 spikes : dict of one.alf.io.AlfBunch
509 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
510 session and spike sorter, with keys ('clusters', 'times')
511 clusters : dict of one.alf.io.AlfBunch
512 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
513 ('channels', 'depths', 'metrics')
514 channels : dict of one.alf.io.AlfBunch
515 A dict with probe labels as keys, contains channel locations with keys ('acronym',
516 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
517 """
518 # --- Get spikes and clusters data
519 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
520 'Use brainbox.io.one.SpikeSortingLoader instead')
521 one = one or ONE()
522 brain_atlas = brain_atlas or AllenAtlas()
523 spikes, clusters, collection = load_spike_sorting(
524 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True)
525 # -- Get brain regions and assign to clusters
526 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned,
527 brain_atlas=brain_atlas)
528 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
529 if nested is False and len(spikes.keys()) == 1:
530 k = list(spikes.keys())[0]
531 channels = channels[k]
532 clusters = clusters[k]
533 spikes = spikes[k]
534 if return_collection:
535 return spikes, clusters, channels, collection
536 else:
537 return spikes, clusters, channels
540def load_ephys_session(eid, one=None):
541 """
542 From an eid, hits the Alyx database and downloads a standard default set of dataset types
543 From a local session Path (pathlib.Path), loads a standard default set of dataset types
544 to perform analysis:
545 'clusters.channels',
546 'clusters.depths',
547 'clusters.metrics',
548 'spikes.clusters',
549 'spikes.times',
550 'probes.description'
552 Parameters
553 ----------
554 eid : [str, UUID, Path, dict]
555 Experiment session identifier; may be a UUID, URL, experiment reference string
556 details dict or Path
557 one : oneibl.one.OneAlyx, optional
558 ONE object to use for loading. Will generate internal one if not used, by default None
560 Returns
561 -------
562 spikes : dict of one.alf.io.AlfBunch
563 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
564 session and spike sorter, with keys ('clusters', 'times')
565 clusters : dict of one.alf.io.AlfBunch
566 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
567 ('channels', 'depths', 'metrics')
568 trials : one.alf.io.AlfBunch of numpy.ndarray
569 The session trials data
570 """
571 assert one 1g
572 spikes, clusters = load_spike_sorting(eid, one=one) 1g
573 trials = one.load_object(eid, 'trials') 1g
574 return spikes, clusters, trials 1g
577def _remove_old_clusters(session_path, probe):
578 # gets clusters and spikes from a local session folder
579 probe_path = session_path.joinpath('alf', probe)
581 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead
582 cluster_file = probe_path.joinpath('clusters.metrics.csv')
584 if cluster_file.exists():
585 os.remove(cluster_file)
586 _logger.info('Deleting old clusters.metrics.csv file')
589def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
590 """
591 Takes (default and any extra) values in given keys from channels and assign them to clusters.
592 If channels does not contain any data, the new keys are added to clusters but left empty.
594 Parameters
595 ----------
596 dic_clus : dict of one.alf.io.AlfBunch
597 1 bunch per probe, containing cluster information
598 channels : dict of one.alf.io.AlfBunch
599 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates')
600 keys_to_add_extra : list of str
601 Any extra keys to load into channels bunches
603 Returns
604 -------
605 dict of one.alf.io.AlfBunch
606 clusters (1 bunch per probe) with new keys values.
607 """
608 probe_labels = list(channels.keys()) # Convert dict_keys into list
609 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um']
611 if keys_to_add_extra is None:
612 keys_to_add = keys_to_add_default
613 else:
614 # Append extra optional keys
615 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default))
617 for label in probe_labels:
618 clu_ch = dic_clus[label]['channels']
619 for key in keys_to_add:
620 try:
621 assert key in channels[label].keys() # Check key is in channels
622 ch_key = channels[label][key]
623 nch_key = len(ch_key) if ch_key is not None else 0
624 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index
625 dic_clus[label][key] = ch_key[clu_ch]
626 else:
627 _logger.warning(
628 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels'
629 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.')
630 dic_clus[label][key] = []
631 except AssertionError:
632 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge')
633 continue
635 return dic_clus
638def load_passive_rfmap(eid, one=None):
639 """
640 For a given eid load in the passive receptive field mapping protocol data
642 Parameters
643 ----------
644 eid : [str, UUID, Path, dict]
645 Experiment session identifier; may be a UUID, URL, experiment reference string
646 details dict or Path
647 one : oneibl.one.OneAlyx, optional
648 An instance of ONE (may be in 'local' - offline - mode)
650 Returns
651 -------
652 one.alf.io.AlfBunch
653 Passive receptive field mapping data
654 """
655 one = one or ONE()
657 # Load in the receptive field mapping data
658 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf')
659 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin',
660 collection='raw_passive_data'), dtype="uint8")
661 y_pix, x_pix = 15, 15
662 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0])
663 rf_map['frames'] = frames
665 return rf_map
668def load_wheel_reaction_times(eid, one=None):
669 """
670 Return the calculated reaction times for session. Reaction times are defined as the time
671 between the go cue (onset tone) and the onset of the first substantial wheel movement. A
672 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the
673 distance to threshold (~0.1 radians).
675 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if
676 there was no detected movement withing the period, or when the goCue_times or feedback_times
677 are nan.
679 Parameters
680 ----------
681 eid : [str, UUID, Path, dict]
682 Experiment session identifier; may be a UUID, URL, experiment reference string
683 details dict or Path
684 one : one.api.OneAlyx, optional
685 one object to use for loading. Will generate internal one if not used, by default None
687 Returns
688 ----------
689 array-like
690 reaction times
691 """
692 if one is None:
693 one = ONE()
695 trials = one.load_object(eid, 'trials')
696 # If already extracted, load and return
697 if trials and 'firstMovement_times' in trials:
698 return trials['firstMovement_times'] - trials['goCue_times']
699 # Otherwise load wheelMoves object and calculate
700 moves = one.load_object(eid, 'wheelMoves')
701 # Re-extract wheel moves if necessary
702 if not moves or 'peakAmplitude' not in moves:
703 wheel = one.load_object(eid, 'wheel')
704 moves = extract_wheel_moves(wheel['timestamps'], wheel['position'])
705 assert trials and moves, 'unable to load trials and wheelMoves data'
706 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials)
707 return firstMove_times - trials['goCue_times']
710def load_iti(trials):
711 """
712 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey
713 screen commencing at stimulus off and lasting until the quiescent period at the start of the
714 following trial. Note that the ITI for the first trial is the time between the first trial
715 and the next, therefore the last value is NaN.
717 Parameters
718 ----------
719 trials : one.alf.io.AlfBunch
720 An ALF trials object containing the keys {'intervals', 'stimOff_times'}.
722 Returns
723 -------
724 np.array
725 An array of inter-trial intervals, the last value being NaN.
726 """
727 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1n
728 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1n
729 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1n
732def load_channels_from_insertion(ins, depths=None, one=None, ba=None):
734 PROV_2_VAL = {
735 'Resolved': 90,
736 'Ephys aligned histology track': 70,
737 'Histology track': 50,
738 'Micro-manipulator': 30,
739 'Planned': 10}
741 one = one or ONE()
742 ba = ba or atlas.AllenAtlas()
743 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id'])
744 val = [PROV_2_VAL[tr['provenance']] for tr in traj]
745 idx = np.argmax(val)
746 traj = traj[idx]
747 if depths is None:
748 depths = trace_header(version=1)[:, 1]
749 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator':
750 ins = atlas.Insertion.from_dict(traj)
751 # Deepest coordinate first
752 xyz = np.c_[ins.tip, ins.entry].T
753 xyz_channels = histology.interpolate_along_track(xyz, (depths +
754 TIP_SIZE_UM) / 1e6)
755 else:
756 xyz = np.array(ins['json']['xyz_picks']) / 1e6
757 if traj['provenance'] == 'Histology track':
758 xyz = xyz[np.argsort(xyz[:, 2]), :]
759 xyz_channels = histology.interpolate_along_track(xyz, (depths +
760 TIP_SIZE_UM) / 1e6)
761 else:
762 align_key = ins['json']['extended_qc']['alignment_stored']
763 feature = traj['json'][align_key][0]
764 track = traj['json'][align_key][1]
765 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
766 feature_prev=feature,
767 brain_atlas=ba, speedy=True)
768 xyz_channels = ephysalign.get_channel_locations(feature, track)
769 return xyz_channels
772@dataclass
773class SpikeSortingLoader:
774 """
775 Object that will load spike sorting data for a given probe insertion.
776 This class can be instantiated in several manners
777 - With Alyx database probe id:
778 SpikeSortingLoader(pid=pid, one=one)
779 - With Alyx database eic and probe name:
780 SpikeSortingLoader(eid=eid, pname='probe00', one=one)
781 - From a local session and probe name:
782 SpikeSortingLoader(session_path=session_path, pname='probe00')
783 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded.
784 """
785 one: One = None
786 atlas: None = None
787 pid: str = None
788 eid: str = ''
789 pname: str = ''
790 session_path: Path = ''
791 # the following properties are the outcome of the post init function
792 collections: list = None
793 datasets: list = None # list of all datasets belonging to the session
794 # the following properties are the outcome of a reading function
795 files: dict = None
796 collection: str = ''
797 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced'
798 spike_sorter: str = 'pykilosort'
799 spike_sorting_path: Path = None
800 _sync: dict = None
802 def __post_init__(self):
803 # pid gets precedence
804 if self.pid is not None: 1cbdf
805 try: 1f
806 self.eid, self.pname = self.one.pid2eid(self.pid) 1f
807 except NotImplementedError:
808 if self.eid == '' or self.pname == '':
809 raise IOError("Cannot infer session id and probe name from pid. "
810 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.")
811 self.session_path = self.one.eid2path(self.eid) 1f
812 # then eid / pname combination
813 elif self.session_path is None or self.session_path == '': 1cbd
814 self.session_path = self.one.eid2path(self.eid) 1cbd
815 # fully local providing a session path
816 else:
817 if self.one:
818 self.eid = self.one.to_eid(self.session_path)
819 else:
820 self.one = One(cache_dir=self.session_path.parents[2], mode='local')
821 df_sessions = cache._make_sessions_df(self.session_path)
822 self.one._cache['sessions'] = df_sessions.set_index('id')
823 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False)
824 self.eid = str(self.session_path.relative_to(self.session_path.parents[2]))
825 # populates default properties
826 self.collections = self.one.list_collections( 1cbdf
827 self.eid, filename='spikes*', collection=f"alf/{self.pname}*")
828 self.datasets = self.one.list_datasets(self.eid) 1cbdf
829 if self.atlas is None: 1cbdf
830 self.atlas = AllenAtlas() 1cbd
831 self.files = {} 1cbdf
833 @staticmethod
834 def _get_attributes(dataset_types):
835 """returns attributes to load for spikes and clusters objects"""
836 if dataset_types is None: 1cbdf
837 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1cbdf
838 else:
839 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1d
840 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1d
841 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1d
842 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1d
843 return spike_attributes, cluster_attributes 1d
845 def _get_spike_sorting_collection(self, spike_sorter='pykilosort'):
846 """
847 Filters a list or array of collections to get the relevant spike sorting dataset
848 if there is a pykilosort, load it
849 """
850 collection = next(filter(lambda c: c == f'alf/{self.pname}/{spike_sorter}', self.collections), None) 1cbdf
851 # otherwise, prefers the shortest
852 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) 1cbdf
853 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") 1cbdf
854 return collection 1cbdf
856 def load_spike_sorting_object(self, obj, *args, **kwargs):
857 """
858 Loads an ALF object
859 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
860 :param spike_sorter: (defaults to 'pykilosort')
861 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
862 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
863 :param kwargs: additional arguments to be passed to one.api.One.load_object
864 :param missing: 'raise' (default) or 'ignore'
865 :return:
866 """
867 self.download_spike_sorting_object(obj, *args, **kwargs)
868 return alfio.load_object(self.files[obj])
870 def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None,
871 missing='raise', **kwargs):
872 """
873 Downloads an ALF object
874 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
875 :param spike_sorter: (defaults to 'pykilosort')
876 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
877 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
878 :param kwargs: additional arguments to be passed to one.api.One.load_object
879 :param missing: 'raise' (default) or 'ignore'
880 :return:
881 """
882 if len(self.collections) == 0: 1cbdf
883 return {}, {}, {}
884 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1cbdf
885 collection = collection or self.collection 1cbdf
886 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1cbdf
887 spike_attributes, cluster_attributes = self._get_attributes(dataset_types) 1cbdf
888 attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes} 1cbdf
889 try: 1cbdf
890 self.files[obj] = self.one.load_object( 1cbdf
891 self.eid, obj=obj, attribute=attributes.get(obj, None),
892 collection=collection, download_only=True, **kwargs)
893 except ALFObjectNotFound as e: 1cbd
894 if missing == 'raise': 1cbd
895 raise e
897 def download_spike_sorting(self, **kwargs):
898 """
899 Downloads spikes, clusters and channels
900 :param spike_sorter: (defaults to 'pykilosort')
901 :param dataset_types: list of extra dataset types
902 :return:
903 """
904 for obj in ['spikes', 'clusters', 'channels']: 1cbdf
905 self.download_spike_sorting_object(obj=obj, **kwargs) 1cbdf
906 self.spike_sorting_path = self.files['spikes'][0].parent 1cbdf
908 def load_channels(self, **kwargs):
909 """
910 Loads channels
911 The channel locations can come from several sources, it will load the most advanced version of the histology available,
912 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
913 - alf: the final version of channel locations, same as resolved with the difference that data is on file
914 - resolved: channel locations alignments have been agreed upon
915 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
916 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
918 :param spike_sorter: (defaults to 'pykilosort')
919 :param dataset_types: list of extra dataset types
920 :return:
921 """
922 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
923 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1cbdf
924 if 'electrodeSites' in self.files: 1cbdf
925 channels = alfio.load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1f
926 else: # otherwise, we try to load the channel object from the spike sorting folder - this may not contain histology
927 self.download_spike_sorting_object(obj='channels', **kwargs) 1cbd
928 channels = alfio.load_object(self.files['channels'], wildcards=self.one.wildcards) 1cbd
929 if 'brainLocationIds_ccf_2017' not in channels: 1cbdf
930 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1b
931 _channels, self.histology = _load_channel_locations_traj( 1b
932 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True)
933 if _channels: 1b
934 channels = _channels[self.pname]
935 else:
936 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1cbdf
937 self.histology = 'alf' 1cbdf
938 return channels 1cbdf
940 def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs):
941 """
942 Loads spikes, clusters and channels
944 There could be several spike sorting collections, by default the loader will get the pykilosort collection
946 The channel locations can come from several sources, it will load the most advanced version of the histology available,
947 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
948 - alf: the final version of channel locations, same as resolved with the difference that data is on file
949 - resolved: channel locations alignments have been agreed upon
950 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
951 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
953 :param spike_sorter: (defaults to 'pykilosort')
954 :param dataset_types: list of extra dataset types
955 :return:
956 """
957 if len(self.collections) == 0: 1cbdf
958 return {}, {}, {}
959 self.files = {} 1cbdf
960 self.spike_sorter = spike_sorter 1cbdf
961 self.download_spike_sorting(spike_sorter=spike_sorter, **kwargs) 1cbdf
962 channels = self.load_channels(spike_sorter=spike_sorter, **kwargs) 1cbdf
963 clusters = alfio.load_object(self.files['clusters'], wildcards=self.one.wildcards) 1cbdf
964 spikes = alfio.load_object(self.files['spikes'], wildcards=self.one.wildcards) 1cbdf
966 return spikes, clusters, channels 1cbdf
968 @staticmethod
969 def compute_metrics(spikes, clusters=None):
970 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size
971 metrics = pd.DataFrame(quick_unit_metrics(
972 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc)))
973 return metrics
975 @staticmethod
976 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False):
977 """
978 Merge the metrics and the channel information into the clusters dictionary
979 :param spikes:
980 :param clusters:
981 :param channels:
982 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used
983 for clusters or analysis applications (defaults to None).
984 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false)
985 :return: cluster dictionary containing metrics and histology
986 """
987 if spikes == {}: 1bf
988 return
989 nc = clusters['channels'].size 1bf
990 # recompute metrics if they are not available
991 metrics = None 1bf
992 if 'metrics' in clusters: 1bf
993 metrics = clusters.pop('metrics') 1bf
994 if metrics.shape[0] != nc: 1bf
995 metrics = None
996 if metrics is None or compute_metrics is True: 1bf
997 _logger.debug("recompute clusters metrics")
998 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters)
999 if isinstance(cache_dir, Path):
1000 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt'))
1001 for k in metrics.keys(): 1bf
1002 clusters[k] = metrics[k].to_numpy() 1bf
1003 for k in channels.keys(): 1bf
1004 clusters[k] = channels[k][clusters['channels']] 1bf
1005 if cache_dir is not None: 1bf
1006 _logger.debug(f'caching clusters metrics in {cache_dir}')
1007 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt'))
1008 return clusters 1bf
1010 @property
1011 def url(self):
1012 """Gets flatiron URL for the session"""
1013 webclient = getattr(self.one, '_web_client', None)
1014 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None
1016 def _get_probe_info(self):
1017 if self._sync is None: 1d
1018 timestamps = self.one.load_dataset( 1d
1019 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}')
1020 try: 1d
1021 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1d
1022 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}'))
1023 fs = spikeglx._get_fs_from_meta(ap_meta)
1024 except ALFObjectNotFound: 1d
1025 ap_meta = None 1d
1026 fs = 30_000 1d
1027 self._sync = { 1d
1028 'timestamps': timestamps,
1029 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'),
1030 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'),
1031 'ap_meta': ap_meta,
1032 'fs': fs,
1033 }
1035 def timesprobe2times(self, values, direction='forward'):
1036 self._get_probe_info()
1037 if direction == 'forward':
1038 return self._sync['forward'](values * self._sync['fs'])
1039 elif direction == 'reverse':
1040 return self._sync['reverse'](values) / self._sync['fs']
1042 def samples2times(self, values, direction='forward'):
1043 """
1044 Converts ephys sample values to session main clock seconds
1045 :param values: numpy array of times in seconds or samples to resync
1046 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse'
1047 (seconds main time to samples probe time)
1048 :return:
1049 """
1050 self._get_probe_info() 1d
1051 return self._sync[direction](values) 1d
1053 @property
1054 def pid2ref(self):
1055 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1cb
1057 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, **kwargs):
1058 """
1059 :param spikes: spikes dictionary or Bunch
1060 :param channels: channels dictionary or Bunch.
1061 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png".
1062 Otherwise, plot.
1063 :param br: brain regions object (optional)
1064 :param label: label for saved image (optional, default="raster")
1065 :param time_series: timeseries dictionary for behavioral event times (optional)
1066 :param **kwargs: kwargs passed to `driftmap()` (optional)
1067 :return:
1068 """
1069 br = br or BrainRegions() 1c
1070 time_series = time_series or {} 1c
1071 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c
1072 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col')
1073 axs[0, 1].set_axis_off() 1c
1074 # axs[0, 0].set_xticks([])
1075 if kwargs is None: 1c
1076 # set default raster plot parameters
1077 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5}
1078 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c
1079 title_str = f"{self.pid2ref}, {self.pid} \n" \ 1c
1080 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters"
1081 axs[0, 0].title.set_text(title_str) 1c
1082 for k, ts in time_series.items(): 1c
1083 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0])
1084 if 'atlas_id' in channels: 1c
1085 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c
1086 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology)
1087 axs[1, 0].set_ylim(0, 3800) 1c
1088 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c
1089 fig.tight_layout() 1c
1091 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c
1092 if 'drift' in self.files: 1c
1093 drift = alfio.load_object(self.files['drift'], wildcards=self.one.wildcards)
1094 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5)
1096 if save_dir is not None: 1c
1097 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1098 fig.savefig(png_file)
1099 plt.close(fig)
1100 gc.collect()
1101 else:
1102 return fig, axs 1c
1105@dataclass
1106class SessionLoader:
1107 """
1108 Object to load session data for a give session in the recommended way.
1110 Parameters
1111 ----------
1112 one: one.api.ONE instance
1113 Can be in remote or local mode (required)
1114 session_path: string or pathlib.Path
1115 The absolute path to the session (one of session_path or eid is required)
1116 eid: string
1117 database UUID of the session (one of session_path or eid is required)
1119 If both are provided, session_path takes precedence over eid.
1121 Examples
1122 --------
1123 1) Load all available session data for one session:
1124 >>> from one.api import ONE
1125 >>> from brainbox.io.one import SessionLoader
1126 >>> one = ONE()
1127 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/')
1128 # Object is initiated, but no data is loaded as you can see in the data_info attribute
1129 >>> sess_loader.data_info
1130 name is_loaded
1131 0 trials False
1132 1 wheel False
1133 2 pose False
1134 3 motion_energy False
1135 4 pupil False
1137 # Loading all available session data, the data_info attribute now shows which data has been loaded
1138 >>> sess_loader.load_session_data()
1139 >>> sess_loader.data_info
1140 name is_loaded
1141 0 trials True
1142 1 wheel True
1143 2 pose True
1144 3 motion_energy True
1145 4 pupil False
1147 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g.
1148 >>> type(sess_loader.trials)
1149 pandas.core.frame.DataFrame
1150 >>> sess_loader.trials.shape
1151 (626, 18)
1152 # Each data comes with its own timestamps in a column called 'times'
1153 >>> sess_loader.wheel['times']
1154 0 0.134286
1155 1 0.135286
1156 2 0.136286
1157 3 0.137286
1158 4 0.138286
1159 ...
1160 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera.
1161 # The dataframes of all cameras are collected in a dictionary
1162 >>> type(sess_loader.pose)
1163 dict
1164 >>> sess_loader.pose.keys()
1165 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera'])
1166 >>> sess_loader.pose['bodyCamera'].columns
1167 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object')
1168 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading
1169 functions:
1170 >>> sess_loader.load_wheel(sampling_rate=100)
1171 """
1172 one: One = None
1173 session_path: Path = ''
1174 eid: str = ''
1175 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1176 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1177 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1178 pose: dict = field(default_factory=dict, repr=False)
1179 motion_energy: dict = field(default_factory=dict, repr=False)
1180 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1182 def __post_init__(self):
1183 """
1184 Function that runs automatically after initiation of the dataclass attributes.
1185 Checks for required inputs, sets session_path and eid, creates data_info table.
1186 """
1187 if self.one is None: 1ae
1188 raise ValueError("An input to one is required. If not connection to a database is desired, it can be "
1189 "a fully local instance of One.")
1190 # If session path is given, takes precedence over eid
1191 if self.session_path is not None and self.session_path != '': 1ae
1192 self.eid = self.one.to_eid(self.session_path) 1ae
1193 self.session_path = Path(self.session_path) 1ae
1194 # Providing no session path, try to infer from eid
1195 else:
1196 if self.eid is not None and self.eid != '':
1197 self.session_path = self.one.eid2path(self.eid)
1198 else:
1199 raise ValueError("If no session path is given, eid is required.")
1201 data_names = [ 1ae
1202 'trials',
1203 'wheel',
1204 'pose',
1205 'motion_energy',
1206 'pupil'
1207 ]
1208 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1ae
1210 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False):
1211 """
1212 Function to load available session data into the SessionLoader object. Input parameters allow to control which
1213 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input
1214 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored
1215 in SessionLoader.data_info
1217 Parameters
1218 ----------
1219 trials: boolean
1220 Whether to load all trials data into SessionLoader.trials, default is True
1221 wheel: boolean
1222 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True
1223 pose: boolean
1224 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose,
1225 default is True
1226 motion_energy: boolean
1227 Whether to load motion energy data (whisker pad for left/right camera, body for body camera)
1228 into SessionLoader.motion_energy, default is True
1229 pupil: boolean
1230 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil,
1231 default is True
1232 reload: boolean
1233 Whether to reload data that has already been loaded into this SessionLoader object, default is False
1234 """
1235 load_df = self.data_info.copy() 1e
1236 load_df['to_load'] = [ 1e
1237 trials,
1238 wheel,
1239 pose,
1240 motion_energy,
1241 pupil
1242 ]
1243 load_df['load_func'] = [ 1e
1244 self.load_trials,
1245 self.load_wheel,
1246 self.load_pose,
1247 self.load_motion_energy,
1248 self.load_pupil
1249 ]
1251 for idx, row in load_df.iterrows(): 1e
1252 if row['to_load'] is False: 1e
1253 _logger.debug(f"Not loading {row['name']} data, set to False.")
1254 elif row['is_loaded'] is True and reload is False: 1e
1255 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1e
1256 else:
1257 try: 1e
1258 _logger.info(f"Loading {row['name']} data") 1e
1259 row['load_func']() 1e
1260 self.data_info.loc[idx, 'is_loaded'] = True 1e
1261 except BaseException as e:
1262 _logger.warning(f"Could not load {row['name']} data.")
1263 _logger.debug(e)
1265 def load_trials(self):
1266 """
1267 Function to load trials data into SessionLoader.trials
1268 """
1269 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex
1270 self.one.wildcards = False 1em
1271 self.trials = self.one.load_object(self.eid, 'trials', collection='alf', attribute=r'(?!itiDuration).*').to_df() 1em
1272 self.one.wildcards = True 1em
1273 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1em
1275 def load_wheel(self, fs=1000, corner_frequency=20, order=8):
1276 """
1277 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
1278 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
1279 a Butterworth low-pass filter is applied.
1281 Parameters
1282 ----------
1283 fs: int, float
1284 Sampling frequency for the wheel position, default is 1000 Hz
1285 corner_frequency: int, float
1286 Corner frequency of Butterworth low-pass filter, default is 20
1287 order: int, float
1288 Order of Butterworth low_pass filter, default is 8
1289 """
1290 wheel_raw = self.one.load_object(self.eid, 'wheel') 1el
1291 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1el
1292 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
1293 # resample the wheel position and compute velocity, acceleration
1294 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1el
1295 self.wheel['position'], self.wheel['times'] = interpolate_position( 1el
1296 wheel_raw['timestamps'], wheel_raw['position'], freq=fs)
1297 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1el
1298 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order)
1299 self.wheel = self.wheel.apply(np.float32) 1el
1300 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1el
1302 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
1303 """
1304 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a
1305 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas
1306 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera.
1308 Parameters
1309 ----------
1310 likelihood_thr: float
1311 The position of each tracked body part come with a likelihood of that estimate for each time point.
1312 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set
1313 likelihood_thr=1. Default is 0.9
1314 views: list
1315 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1316 """
1317 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1318 self.pose = {} 1khe
1319 for view in views: 1khe
1320 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times']) 1khe
1321 # Double check if video timestamps are correct length or can be fixed
1322 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1khe
1323 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1khe
1324 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1khe
1325 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1khe
1327 def load_motion_energy(self, views=['left', 'right', 'body']):
1328 """
1329 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a
1330 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are
1331 pandas Dataframes with the timestamps and motion energy data.
1332 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad
1333 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the
1334 body (bodyMotionEnergy).
1336 Parameters
1337 ----------
1338 views: list
1339 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1340 """
1341 names = {'left': 'whiskerMotionEnergy', 1je
1342 'right': 'whiskerMotionEnergy',
1343 'body': 'bodyMotionEnergy'}
1344 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1345 self.motion_energy = {} 1je
1346 for view in views: 1je
1347 me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times']) 1je
1348 # Double check if video timestamps are correct length or can be fixed
1349 times_fixed, motion_energy = self._check_video_timestamps( 1je
1350 view, me_raw['times'], me_raw['ROIMotionEnergy'])
1351 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1je
1352 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1je
1353 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1je
1355 def load_licks(self):
1356 """
1357 Not yet implemented
1358 """
1359 pass
1361 def load_pupil(self, snr_thresh=5.):
1362 """
1363 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil.
1365 Parameters
1366 ----------
1367 snr_thresh: float
1368 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data
1369 will be considered unusable and will be discarded.
1370 """
1371 # Try to load from features
1372 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features']) 1he
1373 if 'features' in feat_raw.keys(): 1he
1374 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features'])
1375 self.pupil = feats.copy()
1376 self.pupil.insert(0, 'times', times_fixed)
1378 # If unavailable compute on the fly
1379 else:
1380 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1he
1381 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1he
1382 and 'leftCamera' in self.pose.keys()):
1383 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt
1384 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1he
1385 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1he
1386 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1he
1387 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1he
1388 else:
1389 self.load_pose(views=['left'], likelihood_thr=0.9)
1390 dlc_thr = self.pose['leftCamera'].copy()
1392 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1he
1393 try: 1he
1394 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1he
1395 except BaseException as e:
1396 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. "
1397 "Saving all NaNs for pupilDiameter_smooth.")
1398 _logger.debug(e)
1399 self.pupil['pupilDiameter_smooth'] = np.nan
1401 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1he
1402 good_idxs = np.where( 1he
1403 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0]
1404 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1he
1405 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs])))
1406 if snr < snr_thresh: 1he
1407 self.pupil = pd.DataFrame() 1h
1408 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h
1410 def _check_video_timestamps(self, view, video_timestamps, video_data):
1411 """
1412 Helper function to check for the length of the video frames vs video timestamps and fix in case
1413 timestamps are longer than video frames.
1414 """
1415 # If camera times are shorter than video data, or empty, no current fix
1416 if video_timestamps.shape[0] < video_data.shape[0]: 1jkhe
1417 if video_timestamps.shape[0] == 0:
1418 msg = f'Camera times empty for {view}Camera.'
1419 else:
1420 msg = f'Camera times are shorter than video data for {view}Camera.'
1421 _logger.warning(msg)
1422 raise ValueError(msg)
1423 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video.
1424 # This is because the first few frames are sometimes not recorded. We can remove the first few
1425 # timestamps in this case
1426 elif video_timestamps.shape[0] > video_data.shape[0]: 1jkhe
1427 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1jkhe
1428 return video_timestamps_fixed, video_data 1jkhe
1429 else:
1430 return video_timestamps, video_data
1433class EphysSessionLoader(SessionLoader):
1434 """
1435 Spike sorting enhanced version of SessionLoader
1436 Loads spike sorting data for all probes in the session, in the self.ephys dict
1437 >>> EphysSessionLoader(eid=eid, one=one)
1438 To select for a specific probe
1439 >>> EphysSessionLoader(eid=eid, one=one, pid=pid)
1440 """
1441 def __init__(self, *args, pname=None, pid=None, **kwargs):
1442 """
1443 Needs an active connection in order to get the list of insertions in the session
1444 :param args:
1445 :param kwargs:
1446 """
1447 super().__init__(*args, **kwargs)
1448 # if necessary, restrict the query
1449 qargs = {} if pname is None else {'name': pname}
1450 qargs = qargs or ({} if pid is None else {'id': pid})
1451 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs)
1452 self.ephys = {}
1453 for ins in insertions:
1454 self.ephys[ins['name']] = {}
1455 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one)
1457 def load_session_data(self, *args, **kwargs):
1458 super().load_session_data(*args, **kwargs)
1459 self.load_spike_sorting()
1461 def load_spike_sorting(self, pnames=None):
1462 pnames = pnames or list(self.ephys.keys())
1463 for pname in pnames:
1464 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting()
1465 self.ephys[pname]['spikes'] = spikes
1466 self.ephys[pname]['clusters'] = clusters
1467 self.ephys[pname]['channels'] = channels
1469 @property
1470 def probes(self):
1471 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}