Coverage for brainbox/io/one.py: 53%
747 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
2from dataclasses import dataclass, field
3import gc
4import logging
5import re
6import os
7from pathlib import Path
9import numpy as np
10import pandas as pd
11from scipy.interpolate import interp1d
12import matplotlib.pyplot as plt
14from one.api import ONE, One
15from one.alf.path import get_alf_path, full_path_parts
16from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound
17from one.alf import cache
18import one.alf.io as alfio
19from neuropixel import TIP_SIZE_UM, trace_header
20import spikeglx
22import ibldsp.voltage
23from ibldsp.waveform_extraction import WaveformsLoader
24from iblutil.util import Bunch
25from iblatlas.atlas import AllenAtlas, BrainRegions
26from iblatlas import atlas
27from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
28from ibllib.pipes import histology
29from ibllib.pipes.ephys_alignment import EphysAlignment
30from ibllib.plots import vertical_lines, Density
32import brainbox.plot
33from brainbox.io.spikeglx import Streamer
34from brainbox.ephys_plots import plot_brain_regions
35from brainbox.metrics.single_units import quick_unit_metrics
36from brainbox.behavior.wheel import interpolate_position, velocity_filtered
37from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter
39_logger = logging.getLogger('ibllib')
42SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
43CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']
44WAVEFORMS_ATTRIBUTES = ['templates']
47def load_lfp(eid, one=None, dataset_types=None, **kwargs):
48 """
49 TODO Verify works
50 From an eid, hits the Alyx database and downloads the standard set of datasets
51 needed for LFP
52 :param eid:
53 :param dataset_types: additional dataset types to add to the list
54 :param open: if True, spikeglx readers are opened
55 :return: spikeglx.Reader
56 """
57 if dataset_types is None:
58 dataset_types = []
59 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*']
60 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes]
61 session_path = one.eid2path(eid)
63 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False)
64 if ef.get('lf', None)]
65 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles]
68def _collection_filter_from_args(probe, spike_sorter=None):
69 collection = f'alf/{probe}/{spike_sorter}' 1g
70 collection = collection.replace('None', '*') 1g
71 collection = collection.replace('/*', '*') 1g
72 collection = collection[:-1] if collection.endswith('/') else collection 1g
73 return collection 1g
76def _get_spike_sorting_collection(collections, pname):
77 """
78 Filters a list or array of collections to get the relevant spike sorting dataset
79 if there is a pykilosort, load it
80 """
81 #
82 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1gb
83 # otherwise, prefers the shortest
84 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1gb
85 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1gb
86 return collection 1gb
89def _channels_alyx2bunch(chans):
90 channels = Bunch({
91 'atlas_id': np.array([ch['brain_region'] for ch in chans]),
92 'x': np.array([ch['x'] for ch in chans]) / 1e6,
93 'y': np.array([ch['y'] for ch in chans]) / 1e6,
94 'z': np.array([ch['z'] for ch in chans]) / 1e6,
95 'axial_um': np.array([ch['axial'] for ch in chans]),
96 'lateral_um': np.array([ch['lateral'] for ch in chans])
97 })
98 return channels
101def _channels_traj2bunch(xyz_chans, brain_atlas):
102 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans))
103 channels = {
104 'x': xyz_chans[:, 0],
105 'y': xyz_chans[:, 1],
106 'z': xyz_chans[:, 2],
107 'acronym': brain_regions['acronym'],
108 'atlas_id': brain_regions['id']
109 }
111 return channels
114def _channels_bunch2alf(channels):
115 channels_ = { 1i
116 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6,
117 'brainLocationIds_ccf_2017': channels['atlas_id'],
118 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]}
119 return channels_ 1i
122def _channels_alf2bunch(channels, brain_regions=None):
123 # reformat the dictionary according to the standard that comes out of Alyx
124 channels_ = { 1icbed
125 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6,
126 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6,
127 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6,
128 'acronym': None,
129 'atlas_id': channels['brainLocationIds_ccf_2017'],
130 'axial_um': channels['localCoordinates'][:, 1],
131 'lateral_um': channels['localCoordinates'][:, 0],
132 }
133 # here if we have some extra keys, they will carry over to the next dictionary
134 for k in channels: 1icbed
135 if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: 1icbed
136 channels_[k] = channels[k] 1icbed
137 if brain_regions: 1icbed
138 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icbed
139 return channels_ 1icbed
142def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None,
143 brain_regions=None):
144 """
145 Generic function to load spike sorting according data using ONE.
147 Will try to load one spike sorting for any probe present for the eid matching the collection
148 For each probe it will load a spike sorting:
149 - if there is one version: loads this one
150 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX)
152 Parameters
153 ----------
154 eid : [str, UUID, Path, dict]
155 Experiment session identifier; may be a UUID, URL, experiment reference string
156 details dict or Path
157 one : one.api.OneAlyx
158 An instance of ONE (may be in 'local' mode)
159 collection : str
160 collection filter word - accepts wildcards - can be a combination of spike sorter and
161 probe. See `ALF documentation`_ for details.
162 revision : str
163 A particular revision return (defaults to latest revision). See `ALF documentation`_ for
164 details.
165 return_channels : bool
166 Defaults to False otherwise loads channels from disk
168 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components
170 Returns
171 -------
172 spikes : dict of one.alf.io.AlfBunch
173 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
174 session and spike sorter, with keys ('clusters', 'times')
175 clusters : dict of one.alf.io.AlfBunch
176 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
177 ('channels', 'depths', 'metrics')
178 channels : dict of one.alf.io.AlfBunch
179 A dict with probe labels as keys, contains channel locations with keys ('acronym',
180 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs
181 non-lateralized.
182 """
183 one = one or ONE() 1gb
184 # enumerate probes and load according to the name
185 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1gb
186 if len(collections) == 0: 1gb
187 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
188 pnames = list(set(c.split('/')[1] for c in collections)) 1gb
189 spikes, clusters, channels = ({} for _ in range(3)) 1gb
191 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1gb
193 for pname in pnames: 1gb
194 probe_collection = _get_spike_sorting_collection(collections, pname) 1gb
195 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1gb
196 attribute=spike_attributes)
197 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1g
198 attribute=cluster_attributes)
199 if return_channels: 1g
200 channels = _load_channels_locations_from_disk(
201 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions)
202 return spikes, clusters, channels
203 else:
204 return spikes, clusters 1g
207def _get_attributes(dataset_types):
208 if dataset_types is None: 1gb
209 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1gb
210 else:
211 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
212 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
213 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
214 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
215 return spike_attributes, cluster_attributes
218def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None):
219 _logger.debug('loading spike sorting from disk')
220 channels = Bunch({})
221 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision)
222 if len(collections) == 0:
223 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
224 probes = list(set([c.split('/')[1] for c in collections]))
225 for probe in probes:
226 probe_collection = _get_spike_sorting_collection(collections, probe)
227 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels')
228 # if the spike sorter has not aligned data, try and get the alignment available
229 if 'brainLocationIds_ccf_2017' not in channels[probe].keys():
230 aligned_channel_collections = one.list_collections(
231 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision)
232 if len(aligned_channel_collections) == 0:
233 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}")
234 continue
235 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}")
236 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe)
237 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection)
238 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe])
239 # only have to reformat channels if we were able to load coordinates from disk
240 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions)
241 return channels
244def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None):
245 """
246 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
247 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
248 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
249 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys
250 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017'
251 OR
252 'x', 'y', 'z', 'acronym', 'axial_um'
253 those are the guide for the interpolation
254 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates'
255 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object
256 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017
257 if a brain region object is provided, outputts a dict with keys
258 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um'
259 :return: Bunch or dictionary of channels with brain coordinates keys
260 """
261 NEUROPIXEL_VERSION = 1 1i
262 h = trace_header(version=NEUROPIXEL_VERSION) 1i
263 if channels is None: 1i
264 channels = {'localCoordinates': np.c_[h['x'], h['y']]}
265 nch = channels['localCoordinates'].shape[0] 1i
266 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i
267 channels_aligned = _channels_bunch2alf(channels_aligned) 1i
268 if 'localCoordinates' in channels_aligned.keys(): 1i
269 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i
270 else: # this is a edge case for a few spike sorting sessions
271 assert channels_aligned['mlapdv'].shape[0] == 384
272 aligned_depths = h['y']
273 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i
274 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i
275 channels['mlapdv'] = np.zeros((nch, 3)) 1i
276 for i in np.arange(3): 1i
277 channels['mlapdv'][:, i] = np.interp( 1i
278 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv]
279 # the brain locations have to be interpolated by nearest neighbour
280 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i
281 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i
282 if brain_regions is not None: 1i
283 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i
284 else:
285 return channels 1i
288def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
289 brain_atlas=None, return_source=False):
290 if not hasattr(one, 'alyx'): 1b
291 return {}, None 1b
292 _logger.debug(f"trying to load from traj {probe}")
293 channels = Bunch()
294 brain_atlas = brain_atlas or AllenAtlas
295 # need to find the collection bruh
296 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0]
297 collection = _collection_filter_from_args(probe=probe)
298 collections = one.list_collections(eid, filename='channels*', collection=collection,
299 revision=revision)
300 probe_collection = _get_spike_sorting_collection(collections, probe)
301 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection)
302 depths = chn_coords[:, 1]
304 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
305 get('tracing_exists', False)
306 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
307 get('alignment_resolved', False)
308 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
309 get('alignment_count', 0)
311 if tracing:
312 xyz = np.array(insertion['json']['xyz_picks']) / 1e6
313 if resolved:
315 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. '
316 f'Channel and cluster locations obtained from ephys aligned histology '
317 f'track.')
318 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
319 provenance='Ephys aligned histology track')[0]
320 align_key = insertion['json']['extended_qc']['alignment_stored']
321 feature = traj['json'][align_key][0]
322 track = traj['json'][align_key][1]
323 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
324 feature_prev=feature,
325 brain_atlas=brain_atlas, speedy=True)
326 chans = ephysalign.get_channel_locations(feature, track)
327 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
328 source = 'resolved'
329 elif counts > 0 and aligned:
330 _logger.debug(f'Channel locations for {eid}/{probe} have not been '
331 f'resolved. However, alignment flag set to True so channel and cluster'
332 f' locations will be obtained from latest available ephys aligned '
333 f'histology track.')
334 # get the latest user aligned channels
335 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
336 provenance='Ephys aligned histology track')[0]
337 align_key = insertion['json']['extended_qc']['alignment_stored']
338 feature = traj['json'][align_key][0]
339 track = traj['json'][align_key][1]
340 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
341 feature_prev=feature,
342 brain_atlas=brain_atlas, speedy=True)
343 chans = ephysalign.get_channel_locations(feature, track)
345 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
346 source = 'aligned'
347 else:
348 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. '
349 f'Channel and cluster locations obtained from histology track.')
350 # get the channels from histology tracing
351 xyz = xyz[np.argsort(xyz[:, 2]), :]
352 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6)
353 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
354 source = 'traced'
355 channels[probe]['axial_um'] = chn_coords[:, 1]
356 channels[probe]['lateral_um'] = chn_coords[:, 0]
358 else:
359 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}')
360 source = ''
361 channels = None
363 if return_source:
364 return channels, source
365 else:
366 return channels
369def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None):
370 """
371 Load the brain locations of each channel for a given session/probe
373 Parameters
374 ----------
375 eid : [str, UUID, Path, dict]
376 Experiment session identifier; may be a UUID, URL, experiment reference string
377 details dict or Path
378 probe : [str, list of str]
379 The probe label(s), e.g. 'probe01'
380 one : one.api.OneAlyx
381 An instance of ONE (shouldn't be in 'local' mode)
382 aligned : bool
383 Whether to get the latest user aligned channel when not resolved or use histology track
384 brain_atlas : iblatlas.BrainAtlas
385 Brain atlas object (default: Allen atlas)
386 Returns
387 -------
388 dict of one.alf.io.AlfBunch
389 A dict with probe labels as keys, contains channel locations with keys ('acronym',
390 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
391 optional: string 'resolved', 'aligned', 'traced' or ''
392 """
393 one = one or ONE()
394 brain_atlas = brain_atlas or AllenAtlas()
395 if isinstance(eid, dict):
396 ses = eid
397 eid = ses['url'][-36:]
398 else:
399 eid = one.to_eid(eid)
400 collection = _collection_filter_from_args(probe=probe)
401 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection,
402 brain_regions=brain_atlas.regions)
403 incomplete_probes = [k for k in channels if 'x' not in channels[k]]
404 for iprobe in incomplete_probes:
405 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned,
406 brain_atlas=brain_atlas, return_source=True)
407 if channels_ is not None:
408 channels[iprobe] = channels_[iprobe]
409 return channels
412def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
413 brain_regions=None, nested=True, collection=None, return_collection=False):
414 """
415 From an eid, loads spikes and clusters for all probes
416 The following set of dataset types are loaded:
417 'clusters.channels',
418 'clusters.depths',
419 'clusters.metrics',
420 'spikes.clusters',
421 'spikes.times',
422 'probes.description'
423 :param eid: experiment UUID or pathlib.Path of the local session
424 :param one: an instance of OneAlyx
425 :param probe: name of probe to load in, if not given all probes for session will be loaded
426 :param dataset_types: additional spikes/clusters objects to add to the standard default list
427 :param spike_sorter: name of the spike sorting you want to load (None for default)
428 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00"
429 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
430 :param nested: if a single probe is required, do not output a dictionary with the probe name as key
431 :param return_collection: (False) if True, will return the collection used to load
432 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
433 """
434 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.'
435 'Use brainbox.io.one.SpikeSortingLoader instead')
436 if collection is None:
437 collection = _collection_filter_from_args(probe, spike_sorter)
438 _logger.debug(f"load spike sorting with collection filter {collection}")
439 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types,
440 brain_regions=brain_regions)
441 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True)
442 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
443 if nested is False and len(spikes.keys()) == 1:
444 k = list(spikes.keys())[0]
445 channels = channels[k]
446 clusters = clusters[k]
447 spikes = spikes[k]
448 if return_collection:
449 return spikes, clusters, channels, collection
450 else:
451 return spikes, clusters, channels
454def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
455 brain_regions=None, return_collection=False):
456 """
457 From an eid, loads spikes and clusters for all probes
458 The following set of dataset types are loaded:
459 'clusters.channels',
460 'clusters.depths',
461 'clusters.metrics',
462 'spikes.clusters',
463 'spikes.times',
464 'probes.description'
465 :param eid: experiment UUID or pathlib.Path of the local session
466 :param one: an instance of OneAlyx
467 :param probe: name of probe to load in, if not given all probes for session will be loaded
468 :param dataset_types: additional spikes/clusters objects to add to the standard default list
469 :param spike_sorter: name of the spike sorting you want to load (None for default)
470 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
471 :param return_collection:(bool - False) if True, returns the collection for loading the data
472 :return: spikes, clusters (dict of bunch, 1 bunch per probe)
473 """
474 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g
475 'Use brainbox.io.one.SpikeSortingLoader instead')
476 collection = _collection_filter_from_args(probe, spike_sorter) 1g
477 _logger.debug(f"load spike sorting with collection filter {collection}") 1g
478 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g
479 return_channels=False, dataset_types=dataset_types,
480 brain_regions=brain_regions)
481 if return_collection: 1g
482 return spikes, clusters, collection
483 else:
484 return spikes, clusters 1g
487def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None,
488 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False):
489 """
490 For a given eid, get spikes, clusters and channels information, and merges clusters
491 and channels information before returning all three variables.
493 Parameters
494 ----------
495 eid : [str, UUID, Path, dict]
496 Experiment session identifier; may be a UUID, URL, experiment reference string
497 details dict or Path
498 one : one.api.OneAlyx
499 An instance of ONE (shouldn't be in 'local' mode)
500 probe : [str, list of str]
501 The probe label(s), e.g. 'probe01'
502 aligned : bool
503 Whether to get the latest user aligned channel when not resolved or use histology track
504 dataset_types : list of str
505 Optional additional spikes/clusters objects to add to the standard default list
506 spike_sorter : str
507 Name of the spike sorting you want to load (None for default which is pykilosort if it's
508 available otherwise the default MATLAB kilosort)
509 brain_atlas : iblatlas.atlas.BrainAtlas
510 Brain atlas object (default: Allen atlas)
511 return_collection: bool
512 Returns an extra argument with the collection chosen
514 Returns
515 -------
516 spikes : dict of one.alf.io.AlfBunch
517 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
518 session and spike sorter, with keys ('clusters', 'times')
519 clusters : dict of one.alf.io.AlfBunch
520 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
521 ('channels', 'depths', 'metrics')
522 channels : dict of one.alf.io.AlfBunch
523 A dict with probe labels as keys, contains channel locations with keys ('acronym',
524 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
525 """
526 # --- Get spikes and clusters data
527 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
528 'Use brainbox.io.one.SpikeSortingLoader instead')
529 one = one or ONE()
530 brain_atlas = brain_atlas or AllenAtlas()
531 spikes, clusters, collection = load_spike_sorting(
532 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True)
533 # -- Get brain regions and assign to clusters
534 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned,
535 brain_atlas=brain_atlas)
536 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
537 if nested is False and len(spikes.keys()) == 1:
538 k = list(spikes.keys())[0]
539 channels = channels[k]
540 clusters = clusters[k]
541 spikes = spikes[k]
542 if return_collection:
543 return spikes, clusters, channels, collection
544 else:
545 return spikes, clusters, channels
548def load_ephys_session(eid, one=None):
549 """
550 From an eid, hits the Alyx database and downloads a standard default set of dataset types
551 From a local session Path (pathlib.Path), loads a standard default set of dataset types
552 to perform analysis:
553 'clusters.channels',
554 'clusters.depths',
555 'clusters.metrics',
556 'spikes.clusters',
557 'spikes.times',
558 'probes.description'
560 Parameters
561 ----------
562 eid : [str, UUID, Path, dict]
563 Experiment session identifier; may be a UUID, URL, experiment reference string
564 details dict or Path
565 one : oneibl.one.OneAlyx, optional
566 ONE object to use for loading. Will generate internal one if not used, by default None
568 Returns
569 -------
570 spikes : dict of one.alf.io.AlfBunch
571 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
572 session and spike sorter, with keys ('clusters', 'times')
573 clusters : dict of one.alf.io.AlfBunch
574 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
575 ('channels', 'depths', 'metrics')
576 trials : one.alf.io.AlfBunch of numpy.ndarray
577 The session trials data
578 """
579 assert one 1g
580 spikes, clusters = load_spike_sorting(eid, one=one) 1g
581 trials = one.load_object(eid, 'trials') 1g
582 return spikes, clusters, trials 1g
585def _remove_old_clusters(session_path, probe):
586 # gets clusters and spikes from a local session folder
587 probe_path = session_path.joinpath('alf', probe)
589 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead
590 cluster_file = probe_path.joinpath('clusters.metrics.csv')
592 if cluster_file.exists():
593 os.remove(cluster_file)
594 _logger.info('Deleting old clusters.metrics.csv file')
597def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
598 """
599 Takes (default and any extra) values in given keys from channels and assign them to clusters.
600 If channels does not contain any data, the new keys are added to clusters but left empty.
602 Parameters
603 ----------
604 dic_clus : dict of one.alf.io.AlfBunch
605 1 bunch per probe, containing cluster information
606 channels : dict of one.alf.io.AlfBunch
607 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates')
608 keys_to_add_extra : list of str
609 Any extra keys to load into channels bunches
611 Returns
612 -------
613 dict of one.alf.io.AlfBunch
614 clusters (1 bunch per probe) with new keys values.
615 """
616 probe_labels = list(channels.keys()) # Convert dict_keys into list
617 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um']
619 if keys_to_add_extra is None:
620 keys_to_add = keys_to_add_default
621 else:
622 # Append extra optional keys
623 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default))
625 for label in probe_labels:
626 clu_ch = dic_clus[label]['channels']
627 for key in keys_to_add:
628 try:
629 assert key in channels[label].keys() # Check key is in channels
630 ch_key = channels[label][key]
631 nch_key = len(ch_key) if ch_key is not None else 0
632 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index
633 dic_clus[label][key] = ch_key[clu_ch]
634 else:
635 _logger.warning(
636 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels'
637 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.')
638 dic_clus[label][key] = []
639 except AssertionError:
640 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge')
641 continue
643 return dic_clus
646def load_passive_rfmap(eid, one=None):
647 """
648 For a given eid load in the passive receptive field mapping protocol data
650 Parameters
651 ----------
652 eid : [str, UUID, Path, dict]
653 Experiment session identifier; may be a UUID, URL, experiment reference string
654 details dict or Path
655 one : oneibl.one.OneAlyx, optional
656 An instance of ONE (may be in 'local' - offline - mode)
658 Returns
659 -------
660 one.alf.io.AlfBunch
661 Passive receptive field mapping data
662 """
663 one = one or ONE()
665 # Load in the receptive field mapping data
666 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf')
667 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin',
668 collection='raw_passive_data'), dtype="uint8")
669 y_pix, x_pix = 15, 15
670 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0])
671 rf_map['frames'] = frames
673 return rf_map
676def load_wheel_reaction_times(eid, one=None):
677 """
678 Return the calculated reaction times for session. Reaction times are defined as the time
679 between the go cue (onset tone) and the onset of the first substantial wheel movement. A
680 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the
681 distance to threshold (~0.1 radians).
683 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if
684 there was no detected movement withing the period, or when the goCue_times or feedback_times
685 are nan.
687 Parameters
688 ----------
689 eid : [str, UUID, Path, dict]
690 Experiment session identifier; may be a UUID, URL, experiment reference string
691 details dict or Path
692 one : one.api.OneAlyx, optional
693 one object to use for loading. Will generate internal one if not used, by default None
695 Returns
696 ----------
697 array-like
698 reaction times
699 """
700 if one is None:
701 one = ONE()
703 trials = one.load_object(eid, 'trials')
704 # If already extracted, load and return
705 if trials and 'firstMovement_times' in trials:
706 return trials['firstMovement_times'] - trials['goCue_times']
707 # Otherwise load wheelMoves object and calculate
708 moves = one.load_object(eid, 'wheelMoves')
709 # Re-extract wheel moves if necessary
710 if not moves or 'peakAmplitude' not in moves:
711 wheel = one.load_object(eid, 'wheel')
712 moves = extract_wheel_moves(wheel['timestamps'], wheel['position'])
713 assert trials and moves, 'unable to load trials and wheelMoves data'
714 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials)
715 return firstMove_times - trials['goCue_times']
718def load_iti(trials):
719 """
720 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey
721 screen commencing at stimulus off and lasting until the quiescent period at the start of the
722 following trial. Note that the ITI for the first trial is the time between the first trial
723 and the next, therefore the last value is NaN.
725 Parameters
726 ----------
727 trials : one.alf.io.AlfBunch
728 An ALF trials object containing the keys {'intervals', 'stimOff_times'}.
730 Returns
731 -------
732 np.array
733 An array of inter-trial intervals, the last value being NaN.
734 """
735 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1o
736 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1o
737 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1o
740def load_channels_from_insertion(ins, depths=None, one=None, ba=None):
742 PROV_2_VAL = {
743 'Resolved': 90,
744 'Ephys aligned histology track': 70,
745 'Histology track': 50,
746 'Micro-manipulator': 30,
747 'Planned': 10}
749 one = one or ONE()
750 ba = ba or atlas.AllenAtlas()
751 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id'])
752 val = [PROV_2_VAL[tr['provenance']] for tr in traj]
753 idx = np.argmax(val)
754 traj = traj[idx]
755 if depths is None:
756 depths = trace_header(version=1)[:, 1]
757 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator':
758 ins = atlas.Insertion.from_dict(traj)
759 # Deepest coordinate first
760 xyz = np.c_[ins.tip, ins.entry].T
761 xyz_channels = histology.interpolate_along_track(xyz, (depths +
762 TIP_SIZE_UM) / 1e6)
763 else:
764 xyz = np.array(ins['json']['xyz_picks']) / 1e6
765 if traj['provenance'] == 'Histology track':
766 xyz = xyz[np.argsort(xyz[:, 2]), :]
767 xyz_channels = histology.interpolate_along_track(xyz, (depths +
768 TIP_SIZE_UM) / 1e6)
769 else:
770 align_key = ins['json']['extended_qc']['alignment_stored']
771 feature = traj['json'][align_key][0]
772 track = traj['json'][align_key][1]
773 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
774 feature_prev=feature,
775 brain_atlas=ba, speedy=True)
776 xyz_channels = ephysalign.get_channel_locations(feature, track)
777 return xyz_channels
780@dataclass
781class SpikeSortingLoader:
782 """
783 Object that will load spike sorting data for a given probe insertion.
784 This class can be instantiated in several manners
785 - With Alyx database probe id:
786 SpikeSortingLoader(pid=pid, one=one)
787 - With Alyx database eic and probe name:
788 SpikeSortingLoader(eid=eid, pname='probe00', one=one)
789 - From a local session and probe name:
790 SpikeSortingLoader(session_path=session_path, pname='probe00')
791 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded.
792 """
793 one: One = None
794 atlas: None = None
795 pid: str = None
796 eid: str = ''
797 pname: str = ''
798 session_path: Path = ''
799 # the following properties are the outcome of the post init function
800 collections: list = None
801 datasets: list = None # list of all datasets belonging to the session
802 # the following properties are the outcome of a reading function
803 files: dict = None
804 raw_data_files: list = None # list of raw ap and lf files corresponding to the recording
805 collection: str = ''
806 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced'
807 spike_sorter: str = 'pykilosort'
808 spike_sorting_path: Path = None
809 _sync: dict = None
811 def __post_init__(self):
812 # pid gets precedence
813 if self.pid is not None: 1cbedk
814 try: 1dk
815 self.eid, self.pname = self.one.pid2eid(self.pid) 1dk
816 except NotImplementedError:
817 if self.eid == '' or self.pname == '':
818 raise IOError("Cannot infer session id and probe name from pid. "
819 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.")
820 self.session_path = self.one.eid2path(self.eid) 1dk
821 # then eid / pname combination
822 elif self.session_path is None or self.session_path == '': 1cbe
823 self.session_path = self.one.eid2path(self.eid) 1cbe
824 # fully local providing a session path
825 else:
826 if self.one:
827 self.eid = self.one.to_eid(self.session_path)
828 else:
829 self.one = One(cache_dir=self.session_path.parents[2], mode='local')
830 df_sessions = cache._make_sessions_df(self.session_path)
831 self.one._cache['sessions'] = df_sessions.set_index('id')
832 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False)
833 self.eid = str(self.session_path.relative_to(self.session_path.parents[2]))
834 # populates default properties
835 self.collections = self.one.list_collections( 1cbedk
836 self.eid, filename='spikes*', collection=f"alf/{self.pname}*")
837 self.datasets = self.one.list_datasets(self.eid) 1cbedk
838 if self.atlas is None: 1cbedk
839 self.atlas = AllenAtlas() 1cbek
840 self.files = {} 1cbedk
841 self.raw_data_files = [] 1cbedk
843 def _load_object(self, *args, **kwargs):
844 """
845 This function is a wrapper around alfio.load_object that will remove the UUID in the
846 filename if the object is on SDSC.
847 """
848 remove_uuids = getattr(self.one, 'uuid_filenames', False) 1cbed
849 d = alfio.load_object(*args, **kwargs) 1cbed
850 if remove_uuids: 1cbed
851 # pops the UUID in the key names
852 keys = list(d.keys())
853 for k in keys:
854 d[k[:-37]] = d.pop(k)
855 return d 1cbed
857 @staticmethod
858 def _get_attributes(dataset_types):
859 """returns attributes to load for spikes and clusters objects"""
860 dataset_types = [] if dataset_types is None else dataset_types 1cbed
861 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1cbed
862 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1cbed
863 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1cbed
864 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1cbed
865 waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] 1cbed
866 waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) 1cbed
867 return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} 1cbed
869 def _get_spike_sorting_collection(self, spike_sorter=None):
870 """
871 Filters a list or array of collections to get the relevant spike sorting dataset
872 if there is a pykilosort, load it
873 """
874 for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']): 1cbed
875 if sorter is None: 1cbed
876 continue
877 if sorter == "": 1cbed
878 collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None) 1cbe
879 else:
880 collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None) 1bd
881 if collection is not None: 1cbed
882 return collection 1cbed
883 # if none is found amongst the defaults, prefers the shortest
884 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None)
885 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}")
886 return collection
888 def load_spike_sorting_object(self, obj, *args, **kwargs):
889 """
890 Loads an ALF object
891 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
892 :param spike_sorter: (defaults to 'pykilosort')
893 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
894 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
895 :param kwargs: additional arguments to be passed to one.api.One.load_object
896 :param missing: 'raise' (default) or 'ignore'
897 :return:
898 """
899 self.download_spike_sorting_object(obj, *args, **kwargs)
900 return self._load_object(self.files[obj])
902 def get_version(self, spike_sorter=None):
903 spike_sorter = (spike_sorter or self.spike_sorter) or 'iblsorter'
904 collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
905 dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy')
906 return dset[0]['version'] if len(dset) else 'unknown'
908 def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=None, collection=None,
909 attribute=None, missing='raise', **kwargs):
910 """
911 Downloads an ALF object
912 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
913 :param spike_sorter: (defaults to 'pykilosort')
914 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
915 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
916 :param kwargs: additional arguments to be passed to one.api.One.load_object
917 :param attribute: list of attributes to load for the object
918 :param missing: 'raise' (default) or 'ignore'
919 :return:
920 """
921 if spike_sorter is None: 1cbed
922 spike_sorter = self.spike_sorter if self.spike_sorter is not None else 'iblsorter' 1cbed
923 if len(self.collections) == 0: 1cbed
924 return {}, {}, {}
925 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1cbed
926 collection = collection or self.collection 1cbed
927 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1cbed
928 attributes = self._get_attributes(dataset_types) 1cbed
929 try: 1cbed
930 self.files[obj] = self.one.load_object( 1cbed
931 self.eid, obj=obj, attribute=attributes.get(obj, None),
932 collection=collection, download_only=True, **kwargs)
933 except ALFObjectNotFound as e: 1cbe
934 if missing == 'raise': 1cbe
935 raise e
937 def download_spike_sorting(self, objects=None, **kwargs):
938 """
939 Downloads spikes, clusters and channels
940 :param spike_sorter: (defaults to 'pykilosort')
941 :param dataset_types: list of extra dataset types
942 :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels']
943 :return:
944 """
945 objects = ['spikes', 'clusters', 'channels'] if objects is None else objects 1cbed
946 for obj in objects: 1cbed
947 self.download_spike_sorting_object(obj=obj, **kwargs) 1cbed
948 self.spike_sorting_path = self.files['clusters'][0].parent 1cbed
950 def download_raw_electrophysiology(self, band='ap'):
951 """
952 Downloads raw electrophysiology data files on local disk.
953 :param band: "ap" (default) or "lf" for LFP band
954 :return: list of raw data files full paths (ch, meta and cbin files)
955 """
956 raw_data_files = []
957 for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']:
958 try:
959 # FIXME: this will fail if multiple LFP segments are found
960 raw_data_files.append(self.one.load_dataset(
961 self.eid,
962 download_only=True,
963 collection=f'raw_ephys_data/{self.pname}',
964 dataset=suffix,
965 check_hash=False,
966 ))
967 except ALFObjectNotFound:
968 _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}")
969 self.raw_data_files = list(set(self.raw_data_files + raw_data_files))
970 return raw_data_files
972 def raw_electrophysiology(self, stream=True, band='ap', **kwargs):
973 """
974 Returns a reader for the raw electrophysiology data
975 By default it is a streamer object, but if stream is False, it will return a spikeglx.Reader after having
976 downloaded the raw data file if necessary
977 :param stream:
978 :param band:
979 :param kwargs:
980 :return:
981 """
982 if stream: 1k
983 return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) 1k
984 else:
985 raw_data_files = self.download_raw_electrophysiology(band=band)
986 cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None)
987 if cbin_file is not None:
988 return spikeglx.Reader(cbin_file)
990 def download_raw_waveforms(self, **kwargs):
991 """
992 Downloads raw waveforms extracted from sorting to local disk.
993 """
994 _logger.debug(f"loading waveforms from {self.collection}")
995 return self.one.load_object(
996 id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"],
997 collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs
998 )
1000 def raw_waveforms(self, **kwargs):
1001 wf_paths = self.download_raw_waveforms(**kwargs)
1002 return WaveformsLoader(wf_paths[0].parent)
1004 def load_channels(self, **kwargs):
1005 """
1006 Loads channels
1007 The channel locations can come from several sources, it will load the most advanced version of the histology available,
1008 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
1009 - alf: the final version of channel locations, same as resolved with the difference that data is on file
1010 - resolved: channel locations alignments have been agreed upon
1011 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
1012 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
1014 :param spike_sorter: (defaults to 'pykilosort')
1015 :param dataset_types: list of extra dataset types
1016 :return:
1017 """
1018 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
1019 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1cbed
1020 self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) 1cbed
1021 channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) 1cbed
1022 if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails 1cbed
1023 esites = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1d
1024 if alfio.check_dimensions(esites) != 0: 1d
1025 esites = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
1026 esites['rawInd'] = np.arange(esites[list(esites.keys())[0]].shape[0])
1027 if 'brainLocationIds_ccf_2017' not in channels: 1cbed
1028 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1b
1029 _channels, self.histology = _load_channel_locations_traj( 1b
1030 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True)
1031 if _channels: 1b
1032 channels = _channels[self.pname]
1033 else:
1034 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1cbed
1035 self.histology = 'alf' 1cbed
1036 return Bunch(channels) 1cbed
1038 def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, **kwargs):
1039 """
1040 Loads spikes, clusters and channels
1042 There could be several spike sorting collections, by default the loader will get the pykilosort collection
1044 The channel locations can come from several sources, it will load the most advanced version of the histology available,
1045 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
1046 - alf: the final version of channel locations, same as resolved with the difference that data is on file
1047 - resolved: channel locations alignments have been agreed upon
1048 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
1049 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
1051 :param spike_sorter: (defaults to 'pykilosort')
1052 :param revision: for example "2024-05-06", (defaults to None):
1053 :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
1054 :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
1055 :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
1056 :param kwargs: additional arguments to be passed to one.api.One.load_object
1057 :return:
1058 """
1059 if len(self.collections) == 0: 1cbed
1060 return {}, {}, {}
1061 self.files = {} 1cbed
1062 self.spike_sorter = spike_sorter 1cbed
1063 self.revision = revision 1cbed
1064 objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None 1cbed
1065 self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) 1cbed
1066 channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs) 1cbed
1067 clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) 1cbed
1068 if good_units: 1cbed
1069 spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards)
1070 else:
1071 spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) 1cbed
1072 if enforce_version: 1cbed
1073 self._assert_version_consistency()
1074 return spikes, clusters, channels 1cbed
1076 def _assert_version_consistency(self):
1077 """
1078 Makes sure the state of the spike sorting object matches the files downloaded
1079 :return: None
1080 """
1081 for k in ['spikes', 'clusters', 'channels', 'passingSpikes']:
1082 for fn in self.files.get(k, []):
1083 if self.spike_sorter:
1084 assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \
1085 f"You required strict version {self.spike_sorter}, {fn} does not match"
1086 if self.revision:
1087 assert full_path_parts(fn)[5] == self.revision, \
1088 f"You required strict revision {self.revision}, {fn} does not match"
1090 @staticmethod
1091 def compute_metrics(spikes, clusters=None):
1092 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size
1093 metrics = pd.DataFrame(quick_unit_metrics(
1094 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc)))
1095 return metrics
1097 @staticmethod
1098 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False):
1099 """
1100 Merge the metrics and the channel information into the clusters dictionary
1101 :param spikes:
1102 :param clusters:
1103 :param channels:
1104 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used
1105 for clusters or analysis applications (defaults to None).
1106 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false)
1107 :return: cluster dictionary containing metrics and histology
1108 """
1109 if spikes == {}: 1bd
1110 return
1111 nc = clusters['channels'].size 1bd
1112 # recompute metrics if they are not available
1113 metrics = None 1bd
1114 if 'metrics' in clusters: 1bd
1115 metrics = clusters.pop('metrics') 1bd
1116 if metrics.shape[0] != nc: 1bd
1117 metrics = None
1118 if metrics is None or compute_metrics is True: 1bd
1119 _logger.debug("recompute clusters metrics")
1120 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters)
1121 if isinstance(cache_dir, Path):
1122 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt'))
1123 for k in metrics.keys(): 1bd
1124 clusters[k] = metrics[k].to_numpy() 1bd
1125 for k in channels.keys(): 1bd
1126 clusters[k] = channels[k][clusters['channels']] 1bd
1127 if cache_dir is not None: 1bd
1128 _logger.debug(f'caching clusters metrics in {cache_dir}')
1129 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt'))
1130 return clusters 1bd
1132 @property
1133 def url(self):
1134 """Gets flatiron URL for the session"""
1135 webclient = getattr(self.one, '_web_client', None)
1136 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None
1138 def _get_probe_info(self):
1139 if self._sync is None: 1e
1140 timestamps = self.one.load_dataset( 1e
1141 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}')
1142 _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks 1e
1143 self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}')
1144 try: 1e
1145 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1e
1146 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}'))
1147 fs = spikeglx._get_fs_from_meta(ap_meta) 1e
1148 except ALFObjectNotFound:
1149 ap_meta = None
1150 fs = 30_000
1151 self._sync = { 1e
1152 'timestamps': timestamps,
1153 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'),
1154 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'),
1155 'ap_meta': ap_meta,
1156 'fs': fs,
1157 }
1159 def timesprobe2times(self, values, direction='forward'):
1160 self._get_probe_info()
1161 if direction == 'forward':
1162 return self._sync['forward'](values * self._sync['fs'])
1163 elif direction == 'reverse':
1164 return self._sync['reverse'](values) / self._sync['fs']
1166 def samples2times(self, values, direction='forward'):
1167 """
1168 Converts ephys sample values to session main clock seconds
1169 :param values: numpy array of times in seconds or samples to resync
1170 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse'
1171 (seconds main time to samples probe time)
1172 :return:
1173 """
1174 self._get_probe_info() 1e
1175 return self._sync[direction](values) 1e
1177 @property
1178 def pid2ref(self):
1179 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1cb
1181 def _default_plot_title(self, spikes):
1182 title = f"{self.pid2ref}, {self.pid} \n" \ 1c
1183 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters"
1184 return title 1c
1186 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None,
1187 drift=None, title=None, **kwargs):
1188 """
1189 :param spikes: spikes dictionary or Bunch
1190 :param channels: channels dictionary or Bunch.
1191 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png".
1192 Otherwise, plot.
1193 :param br: brain regions object (optional)
1194 :param label: label for saved image (optional, default="raster")
1195 :param time_series: timeseries dictionary for behavioral event times (optional)
1196 :param **kwargs: kwargs passed to `driftmap()` (optional)
1197 :return:
1198 """
1199 br = br or BrainRegions() 1c
1200 time_series = time_series or {} 1c
1201 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c
1202 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col')
1203 axs[0, 1].set_axis_off() 1c
1204 # axs[0, 0].set_xticks([])
1205 if kwargs is None: 1c
1206 # set default raster plot parameters
1207 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5}
1208 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c
1209 if title is None: 1c
1210 title = self._default_plot_title(spikes) 1c
1211 axs[0, 0].title.set_text(title) 1c
1212 for k, ts in time_series.items(): 1c
1213 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0])
1214 if 'atlas_id' in channels: 1c
1215 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c
1216 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology)
1217 axs[1, 0].set_ylim(0, 3800) 1c
1218 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c
1219 fig.tight_layout() 1c
1221 if drift is None: 1c
1222 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c
1223 if 'drift' in self.files: 1c
1224 drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
1225 if isinstance(drift, dict): 1c
1226 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5)
1227 axs[0, 0].set(ylim=[-15, 15])
1229 if save_dir is not None: 1c
1230 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1231 fig.savefig(png_file)
1232 plt.close(fig)
1233 gc.collect()
1234 else:
1235 return fig, axs 1c
1237 def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
1238 channels=None,
1239 br: BrainRegions = None,
1240 save_dir=None,
1241 label='raster',
1242 gain=-93,
1243 title=None):
1245 # compute the raw data offset and destripe, we take 400ms around t0
1246 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs))
1247 raw = sr[first_sample:last_sample, :-sr.nsync].T
1248 channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True
1249 destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels)
1250 # filter out the spikes according to good/bad clusters and to the time slice
1251 spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample]))
1252 ss = spikes['samples'][spike_sel]
1253 sc = clusters['channels'][spikes['clusters'][spike_sel]]
1254 sok = clusters['label'][spikes['clusters'][spike_sel]] == 1
1255 if title is None:
1256 title = self._default_plot_title(spikes)
1257 # display the raw data snippet with spikes overlaid
1258 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col')
1259 Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s')
1260 axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5)
1261 axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5)
1262 axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035])
1263 # adds the channel locations if available
1264 if (channels is not None) and ('atlas_id' in channels):
1265 br = br or BrainRegions()
1266 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
1267 brain_regions=br, display=True, ax=axs[1], title=self.histology)
1268 axs[1].get_yaxis().set_visible(False)
1269 fig.tight_layout()
1271 if save_dir is not None:
1272 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1273 fig.savefig(png_file)
1274 plt.close(fig)
1275 gc.collect()
1276 else:
1277 return fig, axs
1280@dataclass
1281class SessionLoader:
1282 """
1283 Object to load session data for a give session in the recommended way.
1285 Parameters
1286 ----------
1287 one: one.api.ONE instance
1288 Can be in remote or local mode (required)
1289 session_path: string or pathlib.Path
1290 The absolute path to the session (one of session_path or eid is required)
1291 eid: string
1292 database UUID of the session (one of session_path or eid is required)
1294 If both are provided, session_path takes precedence over eid.
1296 Examples
1297 --------
1298 1) Load all available session data for one session:
1299 >>> from one.api import ONE
1300 >>> from brainbox.io.one import SessionLoader
1301 >>> one = ONE()
1302 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/')
1303 # Object is initiated, but no data is loaded as you can see in the data_info attribute
1304 >>> sess_loader.data_info
1305 name is_loaded
1306 0 trials False
1307 1 wheel False
1308 2 pose False
1309 3 motion_energy False
1310 4 pupil False
1312 # Loading all available session data, the data_info attribute now shows which data has been loaded
1313 >>> sess_loader.load_session_data()
1314 >>> sess_loader.data_info
1315 name is_loaded
1316 0 trials True
1317 1 wheel True
1318 2 pose True
1319 3 motion_energy True
1320 4 pupil False
1322 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g.
1323 >>> type(sess_loader.trials)
1324 pandas.core.frame.DataFrame
1325 >>> sess_loader.trials.shape
1326 (626, 18)
1327 # Each data comes with its own timestamps in a column called 'times'
1328 >>> sess_loader.wheel['times']
1329 0 0.134286
1330 1 0.135286
1331 2 0.136286
1332 3 0.137286
1333 4 0.138286
1334 ...
1335 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera.
1336 # The dataframes of all cameras are collected in a dictionary
1337 >>> type(sess_loader.pose)
1338 dict
1339 >>> sess_loader.pose.keys()
1340 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera'])
1341 >>> sess_loader.pose['bodyCamera'].columns
1342 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object')
1343 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading
1344 functions:
1345 >>> sess_loader.load_wheel(sampling_rate=100)
1346 """
1347 one: One = None
1348 session_path: Path = ''
1349 eid: str = ''
1350 revision: str = ''
1351 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1352 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1353 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1354 pose: dict = field(default_factory=dict, repr=False)
1355 motion_energy: dict = field(default_factory=dict, repr=False)
1356 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1358 def __post_init__(self):
1359 """
1360 Function that runs automatically after initiation of the dataclass attributes.
1361 Checks for required inputs, sets session_path and eid, creates data_info table.
1362 """
1363 if self.one is None: 1af
1364 raise ValueError("An input to one is required. If not connection to a database is desired, it can be "
1365 "a fully local instance of One.")
1366 # If session path is given, takes precedence over eid
1367 if self.session_path is not None and self.session_path != '': 1af
1368 self.eid = self.one.to_eid(self.session_path) 1af
1369 self.session_path = Path(self.session_path) 1af
1370 # Providing no session path, try to infer from eid
1371 else:
1372 if self.eid is not None and self.eid != '':
1373 self.session_path = self.one.eid2path(self.eid)
1374 else:
1375 raise ValueError("If no session path is given, eid is required.")
1377 data_names = [ 1af
1378 'trials',
1379 'wheel',
1380 'pose',
1381 'motion_energy',
1382 'pupil'
1383 ]
1384 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1af
1386 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False):
1387 """
1388 Function to load available session data into the SessionLoader object. Input parameters allow to control which
1389 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input
1390 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored
1391 in SessionLoader.data_info
1393 Parameters
1394 ----------
1395 trials: boolean
1396 Whether to load all trials data into SessionLoader.trials, default is True
1397 wheel: boolean
1398 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True
1399 pose: boolean
1400 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose,
1401 default is True
1402 motion_energy: boolean
1403 Whether to load motion energy data (whisker pad for left/right camera, body for body camera)
1404 into SessionLoader.motion_energy, default is True
1405 pupil: boolean
1406 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil,
1407 default is True
1408 reload: boolean
1409 Whether to reload data that has already been loaded into this SessionLoader object, default is False
1410 """
1411 load_df = self.data_info.copy() 1f
1412 load_df['to_load'] = [ 1f
1413 trials,
1414 wheel,
1415 pose,
1416 motion_energy,
1417 pupil
1418 ]
1419 load_df['load_func'] = [ 1f
1420 self.load_trials,
1421 self.load_wheel,
1422 self.load_pose,
1423 self.load_motion_energy,
1424 self.load_pupil
1425 ]
1427 for idx, row in load_df.iterrows(): 1f
1428 if row['to_load'] is False: 1f
1429 _logger.debug(f"Not loading {row['name']} data, set to False.")
1430 elif row['is_loaded'] is True and reload is False: 1f
1431 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1f
1432 else:
1433 try: 1f
1434 _logger.info(f"Loading {row['name']} data") 1f
1435 row['load_func']() 1f
1436 self.data_info.loc[idx, 'is_loaded'] = True 1f
1437 except BaseException as e:
1438 _logger.warning(f"Could not load {row['name']} data.")
1439 _logger.debug(e)
1441 def _find_behaviour_collection(self, obj):
1442 """
1443 Function to find the trial or wheel collection
1445 Parameters
1446 ----------
1447 obj: str
1448 Alf object to load, either 'trials' or 'wheel'
1449 """
1450 dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy' 1fnj
1451 dsets = self.one.list_datasets(self.eid, dataset) 1fnj
1452 if len(dsets) == 0: 1fnj
1453 return 'alf' 1fn
1454 else:
1455 collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets] 1fj
1456 if len(set(collections)) == 1: 1fj
1457 return collections[0] 1fj
1458 else:
1459 _logger.error(f'Multiple collections found {collections}. Specify collection when loading, '
1460 f'e.g sl.load_{obj}(collection="{collections[0]}")')
1461 raise ALFMultipleCollectionsFound
1463 def load_trials(self, collection=None):
1464 """
1465 Function to load trials data into SessionLoader.trials
1467 Parameters
1468 ----------
1469 collection: str
1470 Alf collection of trials data
1471 """
1473 if not collection: 1fn
1474 collection = self._find_behaviour_collection('trials') 1fn
1475 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex
1476 self.one.wildcards = False 1fn
1477 self.trials = self.one.load_object( 1fn
1478 self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df()
1479 self.one.wildcards = True 1fn
1480 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1fn
1482 def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None):
1483 """
1484 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
1485 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
1486 a Butterworth low-pass filter is applied.
1488 Parameters
1489 ----------
1490 fs: int, float
1491 Sampling frequency for the wheel position, default is 1000 Hz
1492 corner_frequency: int, float
1493 Corner frequency of Butterworth low-pass filter, default is 20
1494 order: int, float
1495 Order of Butterworth low_pass filter, default is 8
1496 collection: str
1497 Alf collection of wheel data
1498 """
1499 if not collection: 1fj
1500 collection = self._find_behaviour_collection('wheel') 1fj
1501 wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection, revision=self.revision or None) 1fj
1502 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1fj
1503 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
1504 # resample the wheel position and compute velocity, acceleration
1505 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1fj
1506 self.wheel['position'], self.wheel['times'] = interpolate_position( 1fj
1507 wheel_raw['timestamps'], wheel_raw['position'], freq=fs)
1508 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1fj
1509 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order)
1510 self.wheel = self.wheel.apply(np.float32) 1fj
1511 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1fj
1513 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
1514 """
1515 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a
1516 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas
1517 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera.
1519 Parameters
1520 ----------
1521 likelihood_thr: float
1522 The position of each tracked body part come with a likelihood of that estimate for each time point.
1523 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set
1524 likelihood_thr=1. Default is 0.9
1525 views: list
1526 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1527 """
1528 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1529 self.pose = {} 1mhf
1530 for view in views: 1mhf
1531 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'], revision=self.revision or None) 1mhf
1532 # Double check if video timestamps are correct length or can be fixed
1533 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1mhf
1534 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1mhf
1535 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1mhf
1536 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1mhf
1538 def load_motion_energy(self, views=['left', 'right', 'body']):
1539 """
1540 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a
1541 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are
1542 pandas Dataframes with the timestamps and motion energy data.
1543 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad
1544 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the
1545 body (bodyMotionEnergy).
1547 Parameters
1548 ----------
1549 views: list
1550 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1551 """
1552 names = {'left': 'whiskerMotionEnergy', 1lf
1553 'right': 'whiskerMotionEnergy',
1554 'body': 'bodyMotionEnergy'}
1555 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1556 self.motion_energy = {} 1lf
1557 for view in views: 1lf
1558 me_raw = self.one.load_object( 1lf
1559 self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None)
1560 # Double check if video timestamps are correct length or can be fixed
1561 times_fixed, motion_energy = self._check_video_timestamps( 1lf
1562 view, me_raw['times'], me_raw['ROIMotionEnergy'])
1563 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1lf
1564 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1lf
1565 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1lf
1567 def load_licks(self):
1568 """
1569 Not yet implemented
1570 """
1571 pass
1573 def load_pupil(self, snr_thresh=5.):
1574 """
1575 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil.
1577 Parameters
1578 ----------
1579 snr_thresh: float
1580 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data
1581 will be considered unusable and will be discarded.
1582 """
1583 # Try to load from features
1584 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'], revision=self.revision or None) 1hf
1585 if 'features' in feat_raw.keys(): 1hf
1586 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features'])
1587 self.pupil = feats.copy()
1588 self.pupil.insert(0, 'times', times_fixed)
1590 # If unavailable compute on the fly
1591 else:
1592 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1hf
1593 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1hf
1594 and 'leftCamera' in self.pose.keys()):
1595 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt
1596 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1hf
1597 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1hf
1598 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1hf
1599 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1hf
1600 else:
1601 self.load_pose(views=['left'], likelihood_thr=0.9)
1602 dlc_thr = self.pose['leftCamera'].copy()
1604 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1hf
1605 try: 1hf
1606 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1hf
1607 except BaseException as e:
1608 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. "
1609 "Saving all NaNs for pupilDiameter_smooth.")
1610 _logger.debug(e)
1611 self.pupil['pupilDiameter_smooth'] = np.nan
1613 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1hf
1614 good_idxs = np.where( 1hf
1615 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0]
1616 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1hf
1617 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs])))
1618 if snr < snr_thresh: 1hf
1619 self.pupil = pd.DataFrame() 1h
1620 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h
1622 def _check_video_timestamps(self, view, video_timestamps, video_data):
1623 """
1624 Helper function to check for the length of the video frames vs video timestamps and fix in case
1625 timestamps are longer than video frames.
1626 """
1627 # If camera times are shorter than video data, or empty, no current fix
1628 if video_timestamps.shape[0] < video_data.shape[0]: 1lmhf
1629 if video_timestamps.shape[0] == 0:
1630 msg = f'Camera times empty for {view}Camera.'
1631 else:
1632 msg = f'Camera times are shorter than video data for {view}Camera.'
1633 _logger.warning(msg)
1634 raise ValueError(msg)
1635 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video.
1636 # This is because the first few frames are sometimes not recorded. We can remove the first few
1637 # timestamps in this case
1638 elif video_timestamps.shape[0] > video_data.shape[0]: 1lmhf
1639 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1lmhf
1640 return video_timestamps_fixed, video_data 1lmhf
1641 else:
1642 return video_timestamps, video_data
1645class EphysSessionLoader(SessionLoader):
1646 """
1647 Spike sorting enhanced version of SessionLoader
1648 Loads spike sorting data for all probes in the session, in the self.ephys dict
1649 >>> EphysSessionLoader(eid=eid, one=one)
1650 To select for a specific probe
1651 >>> EphysSessionLoader(eid=eid, one=one, pid=pid)
1652 """
1653 def __init__(self, *args, pname=None, pid=None, **kwargs):
1654 """
1655 Needs an active connection in order to get the list of insertions in the session
1656 :param args:
1657 :param kwargs:
1658 """
1659 super().__init__(*args, **kwargs)
1660 # if necessary, restrict the query
1661 qargs = {} if pname is None else {'name': pname}
1662 qargs = qargs or ({} if pid is None else {'id': pid})
1663 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs)
1664 self.ephys = {}
1665 for ins in insertions:
1666 self.ephys[ins['name']] = {}
1667 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one)
1669 def load_session_data(self, *args, **kwargs):
1670 super().load_session_data(*args, **kwargs)
1671 self.load_spike_sorting()
1673 def load_spike_sorting(self, pnames=None):
1674 pnames = pnames or list(self.ephys.keys())
1675 for pname in pnames:
1676 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting()
1677 self.ephys[pname]['spikes'] = spikes
1678 self.ephys[pname]['clusters'] = clusters
1679 self.ephys[pname]['channels'] = channels
1681 @property
1682 def probes(self):
1683 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}