Coverage for brainbox/io/one.py: 57%
768 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 14:26 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 14:26 +0100
1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
2from dataclasses import dataclass, field
3import gc
4import logging
5import re
6import os
7from pathlib import Path
8from collections import defaultdict
10import numpy as np
11import pandas as pd
12from scipy.interpolate import interp1d
13import matplotlib.pyplot as plt
15from one.api import ONE, One
16from one.alf.path import get_alf_path, full_path_parts, filename_parts
17from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound
18from one.alf import cache
19import one.alf.io as alfio
20from neuropixel import TIP_SIZE_UM, trace_header
21import spikeglx
23import ibldsp.voltage
24from ibldsp.waveform_extraction import WaveformsLoader
25from iblutil.util import Bunch
26from iblatlas.atlas import AllenAtlas, BrainRegions
27from iblatlas import atlas
28from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
29from ibllib.pipes import histology
30from ibllib.pipes.ephys_alignment import EphysAlignment
31from ibllib.plots import vertical_lines, Density
33import brainbox.plot
34from brainbox.io.spikeglx import Streamer
35from brainbox.ephys_plots import plot_brain_regions
36from brainbox.metrics.single_units import quick_unit_metrics
37from brainbox.behavior.wheel import interpolate_position, velocity_filtered
38from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter
40_logger = logging.getLogger('ibllib')
43SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
44CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']
45WAVEFORMS_ATTRIBUTES = ['templates']
48def load_lfp(eid, one=None, dataset_types=None, **kwargs):
49 """
50 TODO Verify works
51 From an eid, hits the Alyx database and downloads the standard set of datasets
52 needed for LFP
53 :param eid:
54 :param dataset_types: additional dataset types to add to the list
55 :param open: if True, spikeglx readers are opened
56 :return: spikeglx.Reader
57 """
58 if dataset_types is None:
59 dataset_types = []
60 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*']
61 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes]
62 session_path = one.eid2path(eid)
64 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False)
65 if ef.get('lf', None)]
66 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles]
69def _collection_filter_from_args(probe, spike_sorter=None):
70 collection = f'alf/{probe}/{spike_sorter}' 1g
71 collection = collection.replace('None', '*') 1g
72 collection = collection.replace('/*', '*') 1g
73 collection = collection[:-1] if collection.endswith('/') else collection 1g
74 return collection 1g
77def _get_spike_sorting_collection(collections, pname):
78 """
79 Filters a list or array of collections to get the relevant spike sorting dataset
80 if there is a pykilosort, load it
81 """
82 #
83 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1ga
84 # otherwise, prefers the shortest
85 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1ga
86 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1ga
87 return collection 1ga
90def _channels_alyx2bunch(chans):
91 channels = Bunch({
92 'atlas_id': np.array([ch['brain_region'] for ch in chans]),
93 'x': np.array([ch['x'] for ch in chans]) / 1e6,
94 'y': np.array([ch['y'] for ch in chans]) / 1e6,
95 'z': np.array([ch['z'] for ch in chans]) / 1e6,
96 'axial_um': np.array([ch['axial'] for ch in chans]),
97 'lateral_um': np.array([ch['lateral'] for ch in chans])
98 })
99 return channels
102def _channels_traj2bunch(xyz_chans, brain_atlas):
103 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans))
104 channels = {
105 'x': xyz_chans[:, 0],
106 'y': xyz_chans[:, 1],
107 'z': xyz_chans[:, 2],
108 'acronym': brain_regions['acronym'],
109 'atlas_id': brain_regions['id']
110 }
112 return channels
115def _channels_bunch2alf(channels):
116 channels_ = { 1i
117 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6,
118 'brainLocationIds_ccf_2017': channels['atlas_id'],
119 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]}
120 return channels_ 1i
123def _channels_alf2bunch(channels, brain_regions=None):
124 # reformat the dictionary according to the standard that comes out of Alyx
125 channels_ = { 1icaed
126 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6,
127 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6,
128 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6,
129 'acronym': None,
130 'atlas_id': channels['brainLocationIds_ccf_2017'],
131 'axial_um': channels['localCoordinates'][:, 1],
132 'lateral_um': channels['localCoordinates'][:, 0],
133 }
134 # here if we have some extra keys, they will carry over to the next dictionary
135 for k in channels: 1icaed
136 if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: 1icaed
137 channels_[k] = channels[k] 1icaed
138 if brain_regions: 1icaed
139 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icaed
140 return channels_ 1icaed
143def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None,
144 brain_regions=None):
145 """
146 Generic function to load spike sorting according data using ONE.
148 Will try to load one spike sorting for any probe present for the eid matching the collection
149 For each probe it will load a spike sorting:
150 - if there is one version: loads this one
151 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX)
153 Parameters
154 ----------
155 eid : [str, UUID, Path, dict]
156 Experiment session identifier; may be a UUID, URL, experiment reference string
157 details dict or Path
158 one : one.api.OneAlyx
159 An instance of ONE (may be in 'local' mode)
160 collection : str
161 collection filter word - accepts wildcards - can be a combination of spike sorter and
162 probe. See `ALF documentation`_ for details.
163 revision : str
164 A particular revision return (defaults to latest revision). See `ALF documentation`_ for
165 details.
166 return_channels : bool
167 Defaults to False otherwise loads channels from disk
169 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components
171 Returns
172 -------
173 spikes : dict of one.alf.io.AlfBunch
174 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
175 session and spike sorter, with keys ('clusters', 'times')
176 clusters : dict of one.alf.io.AlfBunch
177 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
178 ('channels', 'depths', 'metrics')
179 channels : dict of one.alf.io.AlfBunch
180 A dict with probe labels as keys, contains channel locations with keys ('acronym',
181 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs
182 non-lateralized.
183 """
184 one = one or ONE() 1ga
185 # enumerate probes and load according to the name
186 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1ga
187 if len(collections) == 0: 1ga
188 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
189 pnames = list(set(c.split('/')[1] for c in collections)) 1ga
190 spikes, clusters, channels = ({} for _ in range(3)) 1ga
192 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1ga
194 for pname in pnames: 1ga
195 probe_collection = _get_spike_sorting_collection(collections, pname) 1ga
196 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1ga
197 attribute=spike_attributes, namespace='')
198 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1ga
199 attribute=cluster_attributes, namespace='')
200 if return_channels: 1ga
201 channels = _load_channels_locations_from_disk( 1a
202 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions)
203 return spikes, clusters, channels 1a
204 else:
205 return spikes, clusters 1g
208def _get_attributes(dataset_types):
209 if dataset_types is None: 1ga
210 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1ga
211 else:
212 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
213 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
214 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
215 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
216 return spike_attributes, cluster_attributes
219def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None):
220 _logger.debug('loading spike sorting from disk') 1a
221 channels = Bunch({}) 1a
222 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) 1a
223 if len(collections) == 0: 1a
224 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
225 probes = list(set([c.split('/')[1] for c in collections])) 1a
226 for probe in probes: 1a
227 probe_collection = _get_spike_sorting_collection(collections, probe) 1a
228 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') 1a
229 # if the spike sorter has not aligned data, try and get the alignment available
230 if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): 1a
231 aligned_channel_collections = one.list_collections(
232 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision)
233 if len(aligned_channel_collections) == 0:
234 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}")
235 continue
236 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}")
237 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe)
238 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection)
239 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe])
240 # only have to reformat channels if we were able to load coordinates from disk
241 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) 1a
242 return channels 1a
245def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None):
246 """
247 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
248 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
249 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
250 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys
251 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017'
252 OR
253 'x', 'y', 'z', 'acronym', 'axial_um'
254 those are the guide for the interpolation
255 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates'
256 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object
257 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017
258 if a brain region object is provided, outputts a dict with keys
259 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um'
260 :return: Bunch or dictionary of channels with brain coordinates keys
261 """
262 NEUROPIXEL_VERSION = 1 1i
263 h = trace_header(version=NEUROPIXEL_VERSION) 1i
264 if channels is None: 1i
265 channels = {'localCoordinates': np.c_[h['x'], h['y']]}
266 nch = channels['localCoordinates'].shape[0] 1i
267 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i
268 channels_aligned = _channels_bunch2alf(channels_aligned) 1i
269 if 'localCoordinates' in channels_aligned.keys(): 1i
270 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i
271 else: # this is a edge case for a few spike sorting sessions
272 assert channels_aligned['mlapdv'].shape[0] == 384
273 aligned_depths = h['y']
274 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i
275 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i
276 channels['mlapdv'] = np.zeros((nch, 3)) 1i
277 for i in np.arange(3): 1i
278 channels['mlapdv'][:, i] = np.interp( 1i
279 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv]
280 # the brain locations have to be interpolated by nearest neighbour
281 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i
282 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i
283 if brain_regions is not None: 1i
284 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i
285 else:
286 return channels 1i
289def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
290 brain_atlas=None, return_source=False):
291 if not hasattr(one, 'alyx'): 1a
292 return {}, None 1a
293 _logger.debug(f"trying to load from traj {probe}")
294 channels = Bunch()
295 brain_atlas = brain_atlas or AllenAtlas
296 # need to find the collection bruh
297 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0]
298 collection = _collection_filter_from_args(probe=probe)
299 collections = one.list_collections(eid, filename='channels*', collection=collection,
300 revision=revision)
301 probe_collection = _get_spike_sorting_collection(collections, probe)
302 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection)
303 depths = chn_coords[:, 1]
305 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
306 get('tracing_exists', False)
307 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
308 get('alignment_resolved', False)
309 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
310 get('alignment_count', 0)
312 if tracing:
313 xyz = np.array(insertion['json']['xyz_picks']) / 1e6
314 if resolved:
316 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. '
317 f'Channel and cluster locations obtained from ephys aligned histology '
318 f'track.')
319 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
320 provenance='Ephys aligned histology track')[0]
321 align_key = insertion['json']['extended_qc']['alignment_stored']
322 feature = traj['json'][align_key][0]
323 track = traj['json'][align_key][1]
324 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
325 feature_prev=feature,
326 brain_atlas=brain_atlas, speedy=True)
327 chans = ephysalign.get_channel_locations(feature, track)
328 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
329 source = 'resolved'
330 elif counts > 0 and aligned:
331 _logger.debug(f'Channel locations for {eid}/{probe} have not been '
332 f'resolved. However, alignment flag set to True so channel and cluster'
333 f' locations will be obtained from latest available ephys aligned '
334 f'histology track.')
335 # get the latest user aligned channels
336 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
337 provenance='Ephys aligned histology track')[0]
338 align_key = insertion['json']['extended_qc']['alignment_stored']
339 feature = traj['json'][align_key][0]
340 track = traj['json'][align_key][1]
341 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
342 feature_prev=feature,
343 brain_atlas=brain_atlas, speedy=True)
344 chans = ephysalign.get_channel_locations(feature, track)
346 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
347 source = 'aligned'
348 else:
349 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. '
350 f'Channel and cluster locations obtained from histology track.')
351 # get the channels from histology tracing
352 xyz = xyz[np.argsort(xyz[:, 2]), :]
353 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6)
354 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
355 source = 'traced'
356 channels[probe]['axial_um'] = chn_coords[:, 1]
357 channels[probe]['lateral_um'] = chn_coords[:, 0]
359 else:
360 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}')
361 source = ''
362 channels = None
364 if return_source:
365 return channels, source
366 else:
367 return channels
370def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None):
371 """
372 Load the brain locations of each channel for a given session/probe
374 Parameters
375 ----------
376 eid : [str, UUID, Path, dict]
377 Experiment session identifier; may be a UUID, URL, experiment reference string
378 details dict or Path
379 probe : [str, list of str]
380 The probe label(s), e.g. 'probe01'
381 one : one.api.OneAlyx
382 An instance of ONE (shouldn't be in 'local' mode)
383 aligned : bool
384 Whether to get the latest user aligned channel when not resolved or use histology track
385 brain_atlas : iblatlas.BrainAtlas
386 Brain atlas object (default: Allen atlas)
387 Returns
388 -------
389 dict of one.alf.io.AlfBunch
390 A dict with probe labels as keys, contains channel locations with keys ('acronym',
391 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
392 optional: string 'resolved', 'aligned', 'traced' or ''
393 """
394 one = one or ONE()
395 brain_atlas = brain_atlas or AllenAtlas()
396 if isinstance(eid, dict):
397 ses = eid
398 eid = ses['url'][-36:]
399 else:
400 eid = one.to_eid(eid)
401 collection = _collection_filter_from_args(probe=probe)
402 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection,
403 brain_regions=brain_atlas.regions)
404 incomplete_probes = [k for k in channels if 'x' not in channels[k]]
405 for iprobe in incomplete_probes:
406 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned,
407 brain_atlas=brain_atlas, return_source=True)
408 if channels_ is not None:
409 channels[iprobe] = channels_[iprobe]
410 return channels
413def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
414 brain_regions=None, nested=True, collection=None, return_collection=False):
415 """
416 From an eid, loads spikes and clusters for all probes
417 The following set of dataset types are loaded:
418 'clusters.channels',
419 'clusters.depths',
420 'clusters.metrics',
421 'spikes.clusters',
422 'spikes.times',
423 'probes.description'
424 :param eid: experiment UUID or pathlib.Path of the local session
425 :param one: an instance of OneAlyx
426 :param probe: name of probe to load in, if not given all probes for session will be loaded
427 :param dataset_types: additional spikes/clusters objects to add to the standard default list
428 :param spike_sorter: name of the spike sorting you want to load (None for default)
429 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00"
430 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
431 :param nested: if a single probe is required, do not output a dictionary with the probe name as key
432 :param return_collection: (False) if True, will return the collection used to load
433 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
434 """
435 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.'
436 'Use brainbox.io.one.SpikeSortingLoader instead')
437 if collection is None:
438 collection = _collection_filter_from_args(probe, spike_sorter)
439 _logger.debug(f"load spike sorting with collection filter {collection}")
440 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types,
441 brain_regions=brain_regions)
442 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True)
443 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
444 if nested is False and len(spikes.keys()) == 1:
445 k = list(spikes.keys())[0]
446 channels = channels[k]
447 clusters = clusters[k]
448 spikes = spikes[k]
449 if return_collection:
450 return spikes, clusters, channels, collection
451 else:
452 return spikes, clusters, channels
455def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
456 brain_regions=None, return_collection=False):
457 """
458 From an eid, loads spikes and clusters for all probes
459 The following set of dataset types are loaded:
460 'clusters.channels',
461 'clusters.depths',
462 'clusters.metrics',
463 'spikes.clusters',
464 'spikes.times',
465 'probes.description'
466 :param eid: experiment UUID or pathlib.Path of the local session
467 :param one: an instance of OneAlyx
468 :param probe: name of probe to load in, if not given all probes for session will be loaded
469 :param dataset_types: additional spikes/clusters objects to add to the standard default list
470 :param spike_sorter: name of the spike sorting you want to load (None for default)
471 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
472 :param return_collection:(bool - False) if True, returns the collection for loading the data
473 :return: spikes, clusters (dict of bunch, 1 bunch per probe)
474 """
475 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g
476 'Use brainbox.io.one.SpikeSortingLoader instead')
477 collection = _collection_filter_from_args(probe, spike_sorter) 1g
478 _logger.debug(f"load spike sorting with collection filter {collection}") 1g
479 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g
480 return_channels=False, dataset_types=dataset_types,
481 brain_regions=brain_regions)
482 if return_collection: 1g
483 return spikes, clusters, collection
484 else:
485 return spikes, clusters 1g
488def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None,
489 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False):
490 """
491 For a given eid, get spikes, clusters and channels information, and merges clusters
492 and channels information before returning all three variables.
494 Parameters
495 ----------
496 eid : [str, UUID, Path, dict]
497 Experiment session identifier; may be a UUID, URL, experiment reference string
498 details dict or Path
499 one : one.api.OneAlyx
500 An instance of ONE (shouldn't be in 'local' mode)
501 probe : [str, list of str]
502 The probe label(s), e.g. 'probe01'
503 aligned : bool
504 Whether to get the latest user aligned channel when not resolved or use histology track
505 dataset_types : list of str
506 Optional additional spikes/clusters objects to add to the standard default list
507 spike_sorter : str
508 Name of the spike sorting you want to load (None for default which is pykilosort if it's
509 available otherwise the default MATLAB kilosort)
510 brain_atlas : iblatlas.atlas.BrainAtlas
511 Brain atlas object (default: Allen atlas)
512 return_collection: bool
513 Returns an extra argument with the collection chosen
515 Returns
516 -------
517 spikes : dict of one.alf.io.AlfBunch
518 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
519 session and spike sorter, with keys ('clusters', 'times')
520 clusters : dict of one.alf.io.AlfBunch
521 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
522 ('channels', 'depths', 'metrics')
523 channels : dict of one.alf.io.AlfBunch
524 A dict with probe labels as keys, contains channel locations with keys ('acronym',
525 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
526 """
527 # --- Get spikes and clusters data
528 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
529 'Use brainbox.io.one.SpikeSortingLoader instead')
530 one = one or ONE()
531 brain_atlas = brain_atlas or AllenAtlas()
532 spikes, clusters, collection = load_spike_sorting(
533 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True)
534 # -- Get brain regions and assign to clusters
535 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned,
536 brain_atlas=brain_atlas)
537 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
538 if nested is False and len(spikes.keys()) == 1:
539 k = list(spikes.keys())[0]
540 channels = channels[k]
541 clusters = clusters[k]
542 spikes = spikes[k]
543 if return_collection:
544 return spikes, clusters, channels, collection
545 else:
546 return spikes, clusters, channels
549def load_ephys_session(eid, one=None):
550 """
551 From an eid, hits the Alyx database and downloads a standard default set of dataset types
552 From a local session Path (pathlib.Path), loads a standard default set of dataset types
553 to perform analysis:
554 'clusters.channels',
555 'clusters.depths',
556 'clusters.metrics',
557 'spikes.clusters',
558 'spikes.times',
559 'probes.description'
561 Parameters
562 ----------
563 eid : [str, UUID, Path, dict]
564 Experiment session identifier; may be a UUID, URL, experiment reference string
565 details dict or Path
566 one : oneibl.one.OneAlyx, optional
567 ONE object to use for loading. Will generate internal one if not used, by default None
569 Returns
570 -------
571 spikes : dict of one.alf.io.AlfBunch
572 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
573 session and spike sorter, with keys ('clusters', 'times')
574 clusters : dict of one.alf.io.AlfBunch
575 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
576 ('channels', 'depths', 'metrics')
577 trials : one.alf.io.AlfBunch of numpy.ndarray
578 The session trials data
579 """
580 assert one 1g
581 spikes, clusters = load_spike_sorting(eid, one=one) 1g
582 trials = one.load_object(eid, 'trials') 1g
583 return spikes, clusters, trials 1g
586def _remove_old_clusters(session_path, probe):
587 # gets clusters and spikes from a local session folder
588 probe_path = session_path.joinpath('alf', probe)
590 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead
591 cluster_file = probe_path.joinpath('clusters.metrics.csv')
593 if cluster_file.exists():
594 os.remove(cluster_file)
595 _logger.info('Deleting old clusters.metrics.csv file')
598def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
599 """
600 Takes (default and any extra) values in given keys from channels and assign them to clusters.
601 If channels does not contain any data, the new keys are added to clusters but left empty.
603 Parameters
604 ----------
605 dic_clus : dict of one.alf.io.AlfBunch
606 1 bunch per probe, containing cluster information
607 channels : dict of one.alf.io.AlfBunch
608 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates')
609 keys_to_add_extra : list of str
610 Any extra keys to load into channels bunches
612 Returns
613 -------
614 dict of one.alf.io.AlfBunch
615 clusters (1 bunch per probe) with new keys values.
616 """
617 probe_labels = list(channels.keys()) # Convert dict_keys into list
618 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um']
620 if keys_to_add_extra is None:
621 keys_to_add = keys_to_add_default
622 else:
623 # Append extra optional keys
624 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default))
626 for label in probe_labels:
627 clu_ch = dic_clus[label]['channels']
628 for key in keys_to_add:
629 try:
630 assert key in channels[label].keys() # Check key is in channels
631 ch_key = channels[label][key]
632 nch_key = len(ch_key) if ch_key is not None else 0
633 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index
634 dic_clus[label][key] = ch_key[clu_ch]
635 else:
636 _logger.warning(
637 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels'
638 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.')
639 dic_clus[label][key] = []
640 except AssertionError:
641 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge')
642 continue
644 return dic_clus
647def load_passive_rfmap(eid, one=None):
648 """
649 For a given eid load in the passive receptive field mapping protocol data
651 Parameters
652 ----------
653 eid : [str, UUID, Path, dict]
654 Experiment session identifier; may be a UUID, URL, experiment reference string
655 details dict or Path
656 one : oneibl.one.OneAlyx, optional
657 An instance of ONE (may be in 'local' - offline - mode)
659 Returns
660 -------
661 one.alf.io.AlfBunch
662 Passive receptive field mapping data
663 """
664 one = one or ONE()
666 # Load in the receptive field mapping data
667 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf')
668 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin',
669 collection='raw_passive_data'), dtype="uint8")
670 y_pix, x_pix = 15, 15
671 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0])
672 rf_map['frames'] = frames
674 return rf_map
677def load_wheel_reaction_times(eid, one=None):
678 """
679 Return the calculated reaction times for session. Reaction times are defined as the time
680 between the go cue (onset tone) and the onset of the first substantial wheel movement. A
681 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the
682 distance to threshold (~0.1 radians).
684 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if
685 there was no detected movement withing the period, or when the goCue_times or feedback_times
686 are nan.
688 Parameters
689 ----------
690 eid : [str, UUID, Path, dict]
691 Experiment session identifier; may be a UUID, URL, experiment reference string
692 details dict or Path
693 one : one.api.OneAlyx, optional
694 one object to use for loading. Will generate internal one if not used, by default None
696 Returns
697 ----------
698 array-like
699 reaction times
700 """
701 if one is None:
702 one = ONE()
704 trials = one.load_object(eid, 'trials')
705 # If already extracted, load and return
706 if trials and 'firstMovement_times' in trials:
707 return trials['firstMovement_times'] - trials['goCue_times']
708 # Otherwise load wheelMoves object and calculate
709 moves = one.load_object(eid, 'wheelMoves')
710 # Re-extract wheel moves if necessary
711 if not moves or 'peakAmplitude' not in moves:
712 wheel = one.load_object(eid, 'wheel')
713 moves = extract_wheel_moves(wheel['timestamps'], wheel['position'])
714 assert trials and moves, 'unable to load trials and wheelMoves data'
715 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials)
716 return firstMove_times - trials['goCue_times']
719def load_iti(trials):
720 """
721 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey
722 screen commencing at stimulus off and lasting until the quiescent period at the start of the
723 following trial. Note that the ITI for the first trial is the time between the first trial
724 and the next, therefore the last value is NaN.
726 Parameters
727 ----------
728 trials : one.alf.io.AlfBunch
729 An ALF trials object containing the keys {'intervals', 'stimOff_times'}.
731 Returns
732 -------
733 np.array
734 An array of inter-trial intervals, the last value being NaN.
735 """
736 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1o
737 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1o
738 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1o
741def load_channels_from_insertion(ins, depths=None, one=None, ba=None):
743 PROV_2_VAL = {
744 'Resolved': 90,
745 'Ephys aligned histology track': 70,
746 'Histology track': 50,
747 'Micro-manipulator': 30,
748 'Planned': 10}
750 one = one or ONE()
751 ba = ba or atlas.AllenAtlas()
752 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id'])
753 val = [PROV_2_VAL[tr['provenance']] for tr in traj]
754 idx = np.argmax(val)
755 traj = traj[idx]
756 if depths is None:
757 depths = trace_header(version=1)[:, 1]
758 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator':
759 ins = atlas.Insertion.from_dict(traj)
760 # Deepest coordinate first
761 xyz = np.c_[ins.tip, ins.entry].T
762 xyz_channels = histology.interpolate_along_track(xyz, (depths +
763 TIP_SIZE_UM) / 1e6)
764 else:
765 xyz = np.array(ins['json']['xyz_picks']) / 1e6
766 if traj['provenance'] == 'Histology track':
767 xyz = xyz[np.argsort(xyz[:, 2]), :]
768 xyz_channels = histology.interpolate_along_track(xyz, (depths +
769 TIP_SIZE_UM) / 1e6)
770 else:
771 align_key = ins['json']['extended_qc']['alignment_stored']
772 feature = traj['json'][align_key][0]
773 track = traj['json'][align_key][1]
774 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
775 feature_prev=feature,
776 brain_atlas=ba, speedy=True)
777 xyz_channels = ephysalign.get_channel_locations(feature, track)
778 return xyz_channels
781@dataclass
782class SpikeSortingLoader:
783 """
784 Object that will load spike sorting data for a given probe insertion.
785 This class can be instantiated in several manners
786 - With Alyx database probe id:
787 SpikeSortingLoader(pid=pid, one=one)
788 - With Alyx database eic and probe name:
789 SpikeSortingLoader(eid=eid, pname='probe00', one=one)
790 - From a local session and probe name:
791 SpikeSortingLoader(session_path=session_path, pname='probe00')
792 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded.
793 """
794 one: One = None
795 atlas: None = None
796 pid: str = None
797 eid: str = ''
798 pname: str = ''
799 session_path: Path = ''
800 # the following properties are the outcome of the post init function
801 collections: list = None
802 datasets: list = None # list of all datasets belonging to the session
803 # the following properties are the outcome of a reading function
804 files: dict = None
805 raw_data_files: list = None # list of raw ap and lf files corresponding to the recording
806 collection: str = ''
807 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced'
808 spike_sorter: str = 'pykilosort'
809 spike_sorting_path: Path = None
810 _sync: dict = None
812 def __post_init__(self):
813 # pid gets precedence
814 if self.pid is not None: 1caedk
815 try: 1dk
816 self.eid, self.pname = self.one.pid2eid(self.pid) 1dk
817 except NotImplementedError:
818 if self.eid == '' or self.pname == '':
819 raise IOError("Cannot infer session id and probe name from pid. "
820 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.")
821 self.session_path = self.one.eid2path(self.eid) 1dk
822 # then eid / pname combination
823 elif self.session_path is None or self.session_path == '': 1cae
824 self.session_path = self.one.eid2path(self.eid) 1cae
825 # fully local providing a session path
826 else:
827 if self.one:
828 self.eid = self.one.to_eid(self.session_path)
829 else:
830 self.one = One(cache_dir=self.session_path.parents[2], mode='local')
831 df_sessions = cache._make_sessions_df(self.session_path)
832 self.one._cache['sessions'] = df_sessions.set_index('id')
833 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False)
834 self.eid = str(self.session_path.relative_to(self.session_path.parents[2]))
835 # populates default properties
836 self.collections = self.one.list_collections( 1caedk
837 self.eid, filename='spikes*', collection=f"alf/{self.pname}*")
838 self.datasets = self.one.list_datasets(self.eid) 1caedk
839 if self.atlas is None: 1caedk
840 self.atlas = AllenAtlas() 1caek
841 self.files = {} 1caedk
842 self.raw_data_files = [] 1caedk
844 def _load_object(self, *args, **kwargs):
845 """
846 This function is a wrapper around alfio.load_object that will remove the UUID in the
847 filename if the object is on SDSC.
848 """
849 remove_uuids = getattr(self.one, 'uuid_filenames', False) 1caed
850 d = alfio.load_object(*args, **kwargs) 1caed
851 if remove_uuids: 1caed
852 # pops the UUID in the key names
853 keys = list(d.keys())
854 for k in keys:
855 d[k[:-37]] = d.pop(k)
856 return d 1caed
858 @staticmethod
859 def _get_attributes(dataset_types):
860 """returns attributes to load for spikes and clusters objects"""
861 dataset_types = [] if dataset_types is None else dataset_types 1caed
862 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1caed
863 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1caed
864 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1caed
865 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1caed
866 waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] 1caed
867 waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) 1caed
868 return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} 1caed
870 def _get_spike_sorting_collection(self, spike_sorter=None):
871 """
872 Filters a list or array of collections to get the relevant spike sorting dataset
873 if there is a pykilosort, load it
874 """
875 for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']): 1caed
876 if sorter is None: 1caed
877 continue
878 if sorter == "": 1caed
879 collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None) 1cae
880 else:
881 collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None) 1ad
882 if collection is not None: 1caed
883 return collection 1caed
884 # if none is found amongst the defaults, prefers the shortest
885 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None)
886 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}")
887 return collection
889 def load_spike_sorting_object(self, obj, *args, **kwargs):
890 """
891 Loads an ALF object
892 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
893 :param spike_sorter: (defaults to 'pykilosort')
894 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
895 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
896 :param kwargs: additional arguments to be passed to one.api.One.load_object
897 :param missing: 'raise' (default) or 'ignore'
898 :return:
899 """
900 self.download_spike_sorting_object(obj, *args, **kwargs)
901 return self._load_object(self.files[obj])
903 def get_version(self, spike_sorter=None):
904 spike_sorter = (spike_sorter or self.spike_sorter) or 'iblsorter'
905 collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
906 dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy')
907 return dset[0]['version'] if len(dset) else 'unknown'
909 def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=None, collection=None,
910 attribute=None, missing='raise', **kwargs):
911 """
912 Downloads an ALF object
913 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
914 :param spike_sorter: (defaults to 'pykilosort')
915 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
916 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
917 :param kwargs: additional arguments to be passed to one.api.One.load_object
918 :param attribute: list of attributes to load for the object
919 :param missing: 'raise' (default) or 'ignore'
920 :return:
921 """
922 if spike_sorter is None: 1caed
923 spike_sorter = self.spike_sorter if self.spike_sorter is not None else 'iblsorter' 1caed
924 if len(self.collections) == 0: 1caed
925 return {}, {}, {}
926 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1caed
927 collection = collection or self.collection 1caed
928 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1caed
929 attributes = self._get_attributes(dataset_types) 1caed
930 try: 1caed
931 self.files[obj] = self.one.load_object( 1caed
932 self.eid, obj=obj, attribute=attributes.get(obj, None),
933 collection=collection, download_only=True, **kwargs)
934 except ALFObjectNotFound as e: 1cae
935 if missing == 'raise': 1cae
936 raise e
938 def download_spike_sorting(self, objects=None, **kwargs):
939 """
940 Downloads spikes, clusters and channels
941 :param spike_sorter: (defaults to 'pykilosort')
942 :param dataset_types: list of extra dataset types
943 :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels']
944 :return:
945 """
946 objects = ['spikes', 'clusters', 'channels'] if objects is None else objects 1caed
947 for obj in objects: 1caed
948 self.download_spike_sorting_object(obj=obj, **kwargs) 1caed
949 self.spike_sorting_path = self.files['clusters'][0].parent 1caed
951 def download_raw_electrophysiology(self, band='ap'):
952 """
953 Downloads raw electrophysiology data files on local disk.
954 :param band: "ap" (default) or "lf" for LFP band
955 :return: list of raw data files full paths (ch, meta and cbin files)
956 """
957 raw_data_files = [] 1a
958 for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']: 1a
959 try: 1a
960 # FIXME: this will fail if multiple LFP segments are found
961 raw_data_files.append(self.one.load_dataset( 1a
962 self.eid,
963 download_only=True,
964 collection=f'raw_ephys_data/{self.pname}',
965 dataset=suffix,
966 check_hash=False,
967 ))
968 except ALFObjectNotFound: 1a
969 _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}") 1a
970 self.raw_data_files = list(set(self.raw_data_files + raw_data_files)) 1a
971 return raw_data_files 1a
973 def raw_electrophysiology(self, stream=True, band='ap', **kwargs):
974 """
975 Returns a reader for the raw electrophysiology data
976 By default it is a streamer object, but if stream is False, it will return a spikeglx.Reader after having
977 downloaded the raw data file if necessary
978 :param stream:
979 :param band:
980 :param kwargs:
981 :return:
982 """
983 if stream: 1k
984 return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) 1k
985 else:
986 raw_data_files = self.download_raw_electrophysiology(band=band)
987 cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None)
988 if cbin_file is not None:
989 return spikeglx.Reader(cbin_file)
991 def download_raw_waveforms(self, **kwargs):
992 """
993 Downloads raw waveforms extracted from sorting to local disk.
994 """
995 _logger.debug(f"loading waveforms from {self.collection}")
996 return self.one.load_object(
997 id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"],
998 collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs
999 )
1001 def raw_waveforms(self, **kwargs):
1002 wf_paths = self.download_raw_waveforms(**kwargs)
1003 return WaveformsLoader(wf_paths[0].parent)
1005 def load_channels(self, **kwargs):
1006 """
1007 Loads channels
1008 The channel locations can come from several sources, it will load the most advanced version of the histology available,
1009 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
1010 - alf: the final version of channel locations, same as resolved with the difference that data is on file
1011 - resolved: channel locations alignments have been agreed upon
1012 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
1013 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
1015 :param spike_sorter: (defaults to 'pykilosort')
1016 :param dataset_types: list of extra dataset types
1017 :return:
1018 """
1019 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
1020 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1caed
1021 self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) 1caed
1022 channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) 1caed
1023 if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails 1caed
1024 channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1d
1025 if alfio.check_dimensions(channels) != 0: 1d
1026 channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
1027 channels['rawInd'] = np.arange(channels[list(channels.keys())[0]].shape[0])
1028 if 'brainLocationIds_ccf_2017' not in channels: 1caed
1029 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1a
1030 _channels, self.histology = _load_channel_locations_traj( 1a
1031 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True)
1032 if _channels: 1a
1033 channels = _channels[self.pname]
1034 else:
1035 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1caed
1036 self.histology = 'alf' 1caed
1037 return Bunch(channels) 1caed
1039 @staticmethod
1040 def filter_files_by_namespace(all_files, namespace):
1042 # Create dict for each file with available namespaces, no namespce is stored under the key None
1043 namespace_files = defaultdict(dict) 1caed
1044 available_namespaces = [] 1caed
1045 for file in all_files: 1caed
1046 fparts = filename_parts(file.name, as_dict=True) 1caed
1047 fname = f"{fparts['object']}.{fparts['attribute']}" 1caed
1048 nspace = fparts['namespace'] 1caed
1049 available_namespaces.append(nspace) 1caed
1050 namespace_files[fname][nspace] = file 1caed
1052 if namespace not in set(available_namespaces): 1caed
1053 _logger.info(f'Could not find manual curation results for {namespace}, returning default' 1a
1054 f' non manually curated spikesorting data')
1056 # Return the files with the chosen namespace.
1057 files = [f.get(namespace, f.get(None, None)) for f in namespace_files.values()] 1caed
1058 # remove any None files
1059 files = [f for f in files if f] 1caed
1060 return files 1caed
1062 def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False,
1063 namespace=None, **kwargs):
1064 """
1065 Loads spikes, clusters and channels
1067 There could be several spike sorting collections, by default the loader will get the pykilosort collection
1069 The channel locations can come from several sources, it will load the most advanced version of the histology available,
1070 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
1071 - alf: the final version of channel locations, same as resolved with the difference that data is on file
1072 - resolved: channel locations alignments have been agreed upon
1073 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
1074 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
1076 :param spike_sorter: (defaults to 'pykilosort')
1077 :param revision: for example "2024-05-06", (defaults to None):
1078 :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
1079 :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
1080 :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
1081 :param namespace: None, if given will load the manually curated spikesorting with the given namespace,
1082 e.g to load '_av_.clusters.depths use namespace='av'
1083 :param kwargs: additional arguments to be passed to one.api.One.load_object
1084 :return:
1085 """
1086 if len(self.collections) == 0: 1caed
1087 return {}, {}, {}
1088 self.files = {} 1caed
1089 self.spike_sorter = spike_sorter 1caed
1090 self.revision = revision 1caed
1092 if good_units and namespace is not None: 1caed
1093 _logger.info('Good units table does not exist for manually curated spike sorting. Pass in namespace with' 1a
1094 'good_units=False and filter the spikes post hoc by the good clusters.')
1095 return [None] * 3 1a
1096 objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None 1caed
1097 self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) 1caed
1098 channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs) 1caed
1099 self.files['clusters'] = self.filter_files_by_namespace(self.files['clusters'], namespace) 1caed
1100 clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) 1caed
1102 if good_units: 1caed
1103 spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards)
1104 else:
1105 self.files['spikes'] = self.filter_files_by_namespace(self.files['spikes'], namespace) 1caed
1106 spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) 1caed
1107 if enforce_version: 1caed
1108 self._assert_version_consistency()
1109 return spikes, clusters, channels 1caed
1111 def _assert_version_consistency(self):
1112 """
1113 Makes sure the state of the spike sorting object matches the files downloaded
1114 :return: None
1115 """
1116 for k in ['spikes', 'clusters', 'channels', 'passingSpikes']:
1117 for fn in self.files.get(k, []):
1118 if self.spike_sorter:
1119 assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \
1120 f"You required strict version {self.spike_sorter}, {fn} does not match"
1121 if self.revision:
1122 assert full_path_parts(fn)[5] == self.revision, \
1123 f"You required strict revision {self.revision}, {fn} does not match"
1125 @staticmethod
1126 def compute_metrics(spikes, clusters=None):
1127 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size
1128 metrics = pd.DataFrame(quick_unit_metrics(
1129 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc)))
1130 return metrics
1132 @staticmethod
1133 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False):
1134 """
1135 Merge the metrics and the channel information into the clusters dictionary
1136 :param spikes:
1137 :param clusters:
1138 :param channels:
1139 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used
1140 for clusters or analysis applications (defaults to None).
1141 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false)
1142 :return: cluster dictionary containing metrics and histology
1143 """
1144 if spikes == {}: 1ad
1145 return
1146 nc = clusters['channels'].size 1ad
1147 # recompute metrics if they are not available
1148 metrics = None 1ad
1149 if 'metrics' in clusters: 1ad
1150 metrics = clusters.pop('metrics') 1ad
1151 if metrics.shape[0] != nc: 1ad
1152 metrics = None
1153 if metrics is None or compute_metrics is True: 1ad
1154 _logger.debug("recompute clusters metrics")
1155 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters)
1156 if isinstance(cache_dir, Path):
1157 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt'))
1158 for k in metrics.keys(): 1ad
1159 clusters[k] = metrics[k].to_numpy() 1ad
1160 for k in channels.keys(): 1ad
1161 clusters[k] = channels[k][clusters['channels']] 1ad
1162 if cache_dir is not None: 1ad
1163 _logger.debug(f'caching clusters metrics in {cache_dir}')
1164 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt'))
1165 return clusters 1ad
1167 @property
1168 def url(self):
1169 """Gets flatiron URL for the session"""
1170 webclient = getattr(self.one, '_web_client', None)
1171 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None
1173 def _get_probe_info(self):
1174 if self._sync is None: 1e
1175 timestamps = self.one.load_dataset( 1e
1176 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}')
1177 _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks 1e
1178 self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}')
1179 try: 1e
1180 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1e
1181 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}'))
1182 fs = spikeglx._get_fs_from_meta(ap_meta) 1e
1183 except ALFObjectNotFound:
1184 ap_meta = None
1185 fs = 30_000
1186 self._sync = { 1e
1187 'timestamps': timestamps,
1188 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'),
1189 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'),
1190 'ap_meta': ap_meta,
1191 'fs': fs,
1192 }
1194 def timesprobe2times(self, values, direction='forward'):
1195 self._get_probe_info()
1196 if direction == 'forward':
1197 return self._sync['forward'](values * self._sync['fs'])
1198 elif direction == 'reverse':
1199 return self._sync['reverse'](values) / self._sync['fs']
1201 def samples2times(self, values, direction='forward'):
1202 """
1203 Converts ephys sample values to session main clock seconds
1204 :param values: numpy array of times in seconds or samples to resync
1205 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse'
1206 (seconds main time to samples probe time)
1207 :return:
1208 """
1209 self._get_probe_info() 1e
1210 return self._sync[direction](values) 1e
1212 @property
1213 def pid2ref(self):
1214 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1ca
1216 def _default_plot_title(self, spikes):
1217 title = f"{self.pid2ref}, {self.pid} \n" \ 1c
1218 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters"
1219 return title 1c
1221 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None,
1222 drift=None, title=None, **kwargs):
1223 """
1224 :param spikes: spikes dictionary or Bunch
1225 :param channels: channels dictionary or Bunch.
1226 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png".
1227 Otherwise, plot.
1228 :param br: brain regions object (optional)
1229 :param label: label for saved image (optional, default="raster")
1230 :param time_series: timeseries dictionary for behavioral event times (optional)
1231 :param **kwargs: kwargs passed to `driftmap()` (optional)
1232 :return:
1233 """
1234 br = br or BrainRegions() 1c
1235 time_series = time_series or {} 1c
1236 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c
1237 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col')
1238 axs[0, 1].set_axis_off() 1c
1239 # axs[0, 0].set_xticks([])
1240 if kwargs is None: 1c
1241 # set default raster plot parameters
1242 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5}
1243 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c
1244 if title is None: 1c
1245 title = self._default_plot_title(spikes) 1c
1246 axs[0, 0].title.set_text(title) 1c
1247 for k, ts in time_series.items(): 1c
1248 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0])
1249 if 'atlas_id' in channels: 1c
1250 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c
1251 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology)
1252 axs[1, 0].set_ylim(0, 3800) 1c
1253 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c
1254 fig.tight_layout() 1c
1256 if drift is None: 1c
1257 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c
1258 if 'drift' in self.files: 1c
1259 drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
1260 if isinstance(drift, dict): 1c
1261 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5)
1262 axs[0, 0].set(ylim=[-15, 15])
1264 if save_dir is not None: 1c
1265 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1266 fig.savefig(png_file)
1267 plt.close(fig)
1268 gc.collect()
1269 else:
1270 return fig, axs 1c
1272 def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
1273 channels=None,
1274 br: BrainRegions = None,
1275 save_dir=None,
1276 label='raster',
1277 gain=-93,
1278 title=None):
1280 # compute the raw data offset and destripe, we take 400ms around t0
1281 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs))
1282 raw = sr[first_sample:last_sample, :-sr.nsync].T
1283 channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True
1284 destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels)
1285 # filter out the spikes according to good/bad clusters and to the time slice
1286 spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample]))
1287 ss = spikes['samples'][spike_sel]
1288 sc = clusters['channels'][spikes['clusters'][spike_sel]]
1289 sok = clusters['label'][spikes['clusters'][spike_sel]] == 1
1290 if title is None:
1291 title = self._default_plot_title(spikes)
1292 # display the raw data snippet with spikes overlaid
1293 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col')
1294 Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s')
1295 axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5)
1296 axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5)
1297 axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035])
1298 # adds the channel locations if available
1299 if (channels is not None) and ('atlas_id' in channels):
1300 br = br or BrainRegions()
1301 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
1302 brain_regions=br, display=True, ax=axs[1], title=self.histology)
1303 axs[1].get_yaxis().set_visible(False)
1304 fig.tight_layout()
1306 if save_dir is not None:
1307 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1308 fig.savefig(png_file)
1309 plt.close(fig)
1310 gc.collect()
1311 else:
1312 return fig, axs
1315@dataclass
1316class SessionLoader:
1317 """
1318 Object to load session data for a give session in the recommended way.
1320 Parameters
1321 ----------
1322 one: one.api.ONE instance
1323 Can be in remote or local mode (required)
1324 session_path: string or pathlib.Path
1325 The absolute path to the session (one of session_path or eid is required)
1326 eid: string
1327 database UUID of the session (one of session_path or eid is required)
1329 If both are provided, session_path takes precedence over eid.
1331 Examples
1332 --------
1333 1) Load all available session data for one session:
1334 >>> from one.api import ONE
1335 >>> from brainbox.io.one import SessionLoader
1336 >>> one = ONE()
1337 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/')
1338 # Object is initiated, but no data is loaded as you can see in the data_info attribute
1339 >>> sess_loader.data_info
1340 name is_loaded
1341 0 trials False
1342 1 wheel False
1343 2 pose False
1344 3 motion_energy False
1345 4 pupil False
1347 # Loading all available session data, the data_info attribute now shows which data has been loaded
1348 >>> sess_loader.load_session_data()
1349 >>> sess_loader.data_info
1350 name is_loaded
1351 0 trials True
1352 1 wheel True
1353 2 pose True
1354 3 motion_energy True
1355 4 pupil False
1357 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g.
1358 >>> type(sess_loader.trials)
1359 pandas.core.frame.DataFrame
1360 >>> sess_loader.trials.shape
1361 (626, 18)
1362 # Each data comes with its own timestamps in a column called 'times'
1363 >>> sess_loader.wheel['times']
1364 0 0.134286
1365 1 0.135286
1366 2 0.136286
1367 3 0.137286
1368 4 0.138286
1369 ...
1370 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera.
1371 # The dataframes of all cameras are collected in a dictionary
1372 >>> type(sess_loader.pose)
1373 dict
1374 >>> sess_loader.pose.keys()
1375 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera'])
1376 >>> sess_loader.pose['bodyCamera'].columns
1377 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object')
1378 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading
1379 functions:
1380 >>> sess_loader.load_wheel(sampling_rate=100)
1381 """
1382 one: One = None
1383 session_path: Path = ''
1384 eid: str = ''
1385 revision: str = ''
1386 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1387 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1388 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1389 pose: dict = field(default_factory=dict, repr=False)
1390 motion_energy: dict = field(default_factory=dict, repr=False)
1391 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1393 def __post_init__(self):
1394 """
1395 Function that runs automatically after initiation of the dataclass attributes.
1396 Checks for required inputs, sets session_path and eid, creates data_info table.
1397 """
1398 if self.one is None: 1bf
1399 raise ValueError("An input to one is required. If not connection to a database is desired, it can be "
1400 "a fully local instance of One.")
1401 # If session path is given, takes precedence over eid
1402 if self.session_path is not None and self.session_path != '': 1bf
1403 self.eid = self.one.to_eid(self.session_path) 1bf
1404 self.session_path = Path(self.session_path) 1bf
1405 # Providing no session path, try to infer from eid
1406 else:
1407 if self.eid is not None and self.eid != '':
1408 self.session_path = self.one.eid2path(self.eid)
1409 else:
1410 raise ValueError("If no session path is given, eid is required.")
1412 data_names = [ 1bf
1413 'trials',
1414 'wheel',
1415 'pose',
1416 'motion_energy',
1417 'pupil'
1418 ]
1419 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1bf
1421 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False):
1422 """
1423 Function to load available session data into the SessionLoader object. Input parameters allow to control which
1424 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input
1425 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored
1426 in SessionLoader.data_info
1428 Parameters
1429 ----------
1430 trials: boolean
1431 Whether to load all trials data into SessionLoader.trials, default is True
1432 wheel: boolean
1433 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True
1434 pose: boolean
1435 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose,
1436 default is True
1437 motion_energy: boolean
1438 Whether to load motion energy data (whisker pad for left/right camera, body for body camera)
1439 into SessionLoader.motion_energy, default is True
1440 pupil: boolean
1441 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil,
1442 default is True
1443 reload: boolean
1444 Whether to reload data that has already been loaded into this SessionLoader object, default is False
1445 """
1446 load_df = self.data_info.copy() 1f
1447 load_df['to_load'] = [ 1f
1448 trials,
1449 wheel,
1450 pose,
1451 motion_energy,
1452 pupil
1453 ]
1454 load_df['load_func'] = [ 1f
1455 self.load_trials,
1456 self.load_wheel,
1457 self.load_pose,
1458 self.load_motion_energy,
1459 self.load_pupil
1460 ]
1462 for idx, row in load_df.iterrows(): 1f
1463 if row['to_load'] is False: 1f
1464 _logger.debug(f"Not loading {row['name']} data, set to False.")
1465 elif row['is_loaded'] is True and reload is False: 1f
1466 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1f
1467 else:
1468 try: 1f
1469 _logger.info(f"Loading {row['name']} data") 1f
1470 row['load_func']() 1f
1471 self.data_info.loc[idx, 'is_loaded'] = True 1f
1472 except BaseException as e:
1473 _logger.warning(f"Could not load {row['name']} data.")
1474 _logger.debug(e)
1476 def _find_behaviour_collection(self, obj):
1477 """
1478 Function to find the trial or wheel collection
1480 Parameters
1481 ----------
1482 obj: str
1483 Alf object to load, either 'trials' or 'wheel'
1484 """
1485 dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy' 1fnj
1486 dsets = self.one.list_datasets(self.eid, dataset) 1fnj
1487 if len(dsets) == 0: 1fnj
1488 return 'alf' 1fn
1489 else:
1490 collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets] 1fj
1491 if len(set(collections)) == 1: 1fj
1492 return collections[0] 1fj
1493 else:
1494 _logger.error(f'Multiple collections found {collections}. Specify collection when loading, '
1495 f'e.g sl.load_{obj}(collection="{collections[0]}")')
1496 raise ALFMultipleCollectionsFound
1498 def load_trials(self, collection=None):
1499 """
1500 Function to load trials data into SessionLoader.trials
1502 Parameters
1503 ----------
1504 collection: str
1505 Alf collection of trials data
1506 """
1508 if not collection: 1fn
1509 collection = self._find_behaviour_collection('trials') 1fn
1510 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex
1511 self.one.wildcards = False 1fn
1512 self.trials = self.one.load_object( 1fn
1513 self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df()
1514 self.one.wildcards = True 1fn
1515 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1fn
1517 def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None):
1518 """
1519 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
1520 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
1521 a Butterworth low-pass filter is applied.
1523 Parameters
1524 ----------
1525 fs: int, float
1526 Sampling frequency for the wheel position, default is 1000 Hz
1527 corner_frequency: int, float
1528 Corner frequency of Butterworth low-pass filter, default is 20
1529 order: int, float
1530 Order of Butterworth low_pass filter, default is 8
1531 collection: str
1532 Alf collection of wheel data
1533 """
1534 if not collection: 1fj
1535 collection = self._find_behaviour_collection('wheel') 1fj
1536 wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection, revision=self.revision or None) 1fj
1537 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1fj
1538 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
1539 # resample the wheel position and compute velocity, acceleration
1540 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1fj
1541 self.wheel['position'], self.wheel['times'] = interpolate_position( 1fj
1542 wheel_raw['timestamps'], wheel_raw['position'], freq=fs)
1543 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1fj
1544 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order)
1545 self.wheel = self.wheel.apply(np.float32) 1fj
1546 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1fj
1548 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
1549 """
1550 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a
1551 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas
1552 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera.
1554 Parameters
1555 ----------
1556 likelihood_thr: float
1557 The position of each tracked body part come with a likelihood of that estimate for each time point.
1558 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set
1559 likelihood_thr=1. Default is 0.9
1560 views: list
1561 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1562 """
1563 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1564 self.pose = {} 1mhf
1565 for view in views: 1mhf
1566 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'], revision=self.revision or None) 1mhf
1567 # Double check if video timestamps are correct length or can be fixed
1568 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1mhf
1569 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1mhf
1570 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1mhf
1571 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1mhf
1573 def load_motion_energy(self, views=['left', 'right', 'body']):
1574 """
1575 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a
1576 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are
1577 pandas Dataframes with the timestamps and motion energy data.
1578 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad
1579 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the
1580 body (bodyMotionEnergy).
1582 Parameters
1583 ----------
1584 views: list
1585 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1586 """
1587 names = {'left': 'whiskerMotionEnergy', 1lf
1588 'right': 'whiskerMotionEnergy',
1589 'body': 'bodyMotionEnergy'}
1590 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1591 self.motion_energy = {} 1lf
1592 for view in views: 1lf
1593 me_raw = self.one.load_object( 1lf
1594 self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None)
1595 # Double check if video timestamps are correct length or can be fixed
1596 times_fixed, motion_energy = self._check_video_timestamps( 1lf
1597 view, me_raw['times'], me_raw['ROIMotionEnergy'])
1598 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1lf
1599 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1lf
1600 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1lf
1602 def load_licks(self):
1603 """
1604 Not yet implemented
1605 """
1606 pass
1608 def load_pupil(self, snr_thresh=5.):
1609 """
1610 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil.
1612 Parameters
1613 ----------
1614 snr_thresh: float
1615 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data
1616 will be considered unusable and will be discarded.
1617 """
1618 # Try to load from features
1619 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'], revision=self.revision or None) 1hf
1620 if 'features' in feat_raw.keys(): 1hf
1621 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features'])
1622 self.pupil = feats.copy()
1623 self.pupil.insert(0, 'times', times_fixed)
1625 # If unavailable compute on the fly
1626 else:
1627 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1hf
1628 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1hf
1629 and 'leftCamera' in self.pose.keys()):
1630 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt
1631 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1hf
1632 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1hf
1633 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1hf
1634 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1hf
1635 else:
1636 self.load_pose(views=['left'], likelihood_thr=0.9)
1637 dlc_thr = self.pose['leftCamera'].copy()
1639 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1hf
1640 try: 1hf
1641 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1hf
1642 except BaseException as e:
1643 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. "
1644 "Saving all NaNs for pupilDiameter_smooth.")
1645 _logger.debug(e)
1646 self.pupil['pupilDiameter_smooth'] = np.nan
1648 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1hf
1649 good_idxs = np.where( 1hf
1650 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0]
1651 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1hf
1652 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs])))
1653 if snr < snr_thresh: 1hf
1654 self.pupil = pd.DataFrame() 1h
1655 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h
1657 def _check_video_timestamps(self, view, video_timestamps, video_data):
1658 """
1659 Helper function to check for the length of the video frames vs video timestamps and fix in case
1660 timestamps are longer than video frames.
1661 """
1662 # If camera times are shorter than video data, or empty, no current fix
1663 if video_timestamps.shape[0] < video_data.shape[0]: 1lmhf
1664 if video_timestamps.shape[0] == 0:
1665 msg = f'Camera times empty for {view}Camera.'
1666 else:
1667 msg = f'Camera times are shorter than video data for {view}Camera.'
1668 _logger.warning(msg)
1669 raise ValueError(msg)
1670 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video.
1671 # This is because the first few frames are sometimes not recorded. We can remove the first few
1672 # timestamps in this case
1673 elif video_timestamps.shape[0] > video_data.shape[0]: 1lmhf
1674 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1lmhf
1675 return video_timestamps_fixed, video_data 1lmhf
1676 else:
1677 return video_timestamps, video_data
1680class EphysSessionLoader(SessionLoader):
1681 """
1682 Spike sorting enhanced version of SessionLoader
1683 Loads spike sorting data for all probes in the session, in the self.ephys dict
1684 >>> EphysSessionLoader(eid=eid, one=one)
1685 To select for a specific probe
1686 >>> EphysSessionLoader(eid=eid, one=one, pid=pid)
1687 """
1688 def __init__(self, *args, pname=None, pid=None, **kwargs):
1689 """
1690 Needs an active connection in order to get the list of insertions in the session
1691 :param args:
1692 :param kwargs:
1693 """
1694 super().__init__(*args, **kwargs)
1695 # if necessary, restrict the query
1696 qargs = {} if pname is None else {'name': pname}
1697 qargs = qargs or ({} if pid is None else {'id': pid})
1698 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs)
1699 self.ephys = {}
1700 for ins in insertions:
1701 self.ephys[ins['name']] = {}
1702 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one)
1704 def load_session_data(self, *args, **kwargs):
1705 super().load_session_data(*args, **kwargs)
1706 self.load_spike_sorting()
1708 def load_spike_sorting(self, pnames=None):
1709 pnames = pnames or list(self.ephys.keys())
1710 for pname in pnames:
1711 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting()
1712 self.ephys[pname]['spikes'] = spikes
1713 self.ephys[pname]['clusters'] = clusters
1714 self.ephys[pname]['channels'] = channels
1716 @property
1717 def probes(self):
1718 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}