Coverage for brainbox/io/one.py: 56%
726 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
2from dataclasses import dataclass, field
3import gc
4import logging
5import re
6import os
7from pathlib import Path
9import numpy as np
10import pandas as pd
11from scipy.interpolate import interp1d
12import matplotlib.pyplot as plt
14from one.api import ONE, One
15from one.alf.files import get_alf_path, full_path_parts
16from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound
17from one.alf import cache
18import one.alf.io as alfio
19from neuropixel import TIP_SIZE_UM, trace_header
20import spikeglx
22import ibldsp.voltage
23from iblutil.util import Bunch
24from iblatlas.atlas import AllenAtlas, BrainRegions
25from iblatlas import atlas
26from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
27from ibllib.pipes import histology
28from ibllib.pipes.ephys_alignment import EphysAlignment
29from ibllib.plots import vertical_lines, Density
31import brainbox.plot
32from brainbox.io.spikeglx import Streamer
33from brainbox.ephys_plots import plot_brain_regions
34from brainbox.metrics.single_units import quick_unit_metrics
35from brainbox.behavior.wheel import interpolate_position, velocity_filtered
36from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter
38_logger = logging.getLogger('ibllib')
41SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
42CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']
43WAVEFORMS_ATTRIBUTES = ['templates']
46def load_lfp(eid, one=None, dataset_types=None, **kwargs):
47 """
48 TODO Verify works
49 From an eid, hits the Alyx database and downloads the standard set of datasets
50 needed for LFP
51 :param eid:
52 :param dataset_types: additional dataset types to add to the list
53 :param open: if True, spikeglx readers are opened
54 :return: spikeglx.Reader
55 """
56 if dataset_types is None:
57 dataset_types = []
58 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*']
59 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes]
60 session_path = one.eid2path(eid)
62 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False)
63 if ef.get('lf', None)]
64 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles]
67def _collection_filter_from_args(probe, spike_sorter=None):
68 collection = f'alf/{probe}/{spike_sorter}' 1g
69 collection = collection.replace('None', '*') 1g
70 collection = collection.replace('/*', '*') 1g
71 collection = collection[:-1] if collection.endswith('/') else collection 1g
72 return collection 1g
75def _get_spike_sorting_collection(collections, pname):
76 """
77 Filters a list or array of collections to get the relevant spike sorting dataset
78 if there is a pykilosort, load it
79 """
80 #
81 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1gb
82 # otherwise, prefers the shortest
83 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1gb
84 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1gb
85 return collection 1gb
88def _channels_alyx2bunch(chans):
89 channels = Bunch({
90 'atlas_id': np.array([ch['brain_region'] for ch in chans]),
91 'x': np.array([ch['x'] for ch in chans]) / 1e6,
92 'y': np.array([ch['y'] for ch in chans]) / 1e6,
93 'z': np.array([ch['z'] for ch in chans]) / 1e6,
94 'axial_um': np.array([ch['axial'] for ch in chans]),
95 'lateral_um': np.array([ch['lateral'] for ch in chans])
96 })
97 return channels
100def _channels_traj2bunch(xyz_chans, brain_atlas):
101 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans))
102 channels = {
103 'x': xyz_chans[:, 0],
104 'y': xyz_chans[:, 1],
105 'z': xyz_chans[:, 2],
106 'acronym': brain_regions['acronym'],
107 'atlas_id': brain_regions['id']
108 }
110 return channels
113def _channels_bunch2alf(channels):
114 channels_ = { 1i
115 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6,
116 'brainLocationIds_ccf_2017': channels['atlas_id'],
117 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]}
118 return channels_ 1i
121def _channels_alf2bunch(channels, brain_regions=None):
122 # reformat the dictionary according to the standard that comes out of Alyx
123 channels_ = { 1icbed
124 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6,
125 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6,
126 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6,
127 'acronym': None,
128 'atlas_id': channels['brainLocationIds_ccf_2017'],
129 'axial_um': channels['localCoordinates'][:, 1],
130 'lateral_um': channels['localCoordinates'][:, 0],
131 }
132 # here if we have some extra keys, they will carry over to the next dictionary
133 for k in channels: 1icbed
134 if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: 1icbed
135 channels_[k] = channels[k] 1icbed
136 if brain_regions: 1icbed
137 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icbed
138 return channels_ 1icbed
141def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None,
142 brain_regions=None):
143 """
144 Generic function to load spike sorting according data using ONE.
146 Will try to load one spike sorting for any probe present for the eid matching the collection
147 For each probe it will load a spike sorting:
148 - if there is one version: loads this one
149 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX)
151 Parameters
152 ----------
153 eid : [str, UUID, Path, dict]
154 Experiment session identifier; may be a UUID, URL, experiment reference string
155 details dict or Path
156 one : one.api.OneAlyx
157 An instance of ONE (may be in 'local' mode)
158 collection : str
159 collection filter word - accepts wildcards - can be a combination of spike sorter and
160 probe. See `ALF documentation`_ for details.
161 revision : str
162 A particular revision return (defaults to latest revision). See `ALF documentation`_ for
163 details.
164 return_channels : bool
165 Defaults to False otherwise loads channels from disk
167 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components
169 Returns
170 -------
171 spikes : dict of one.alf.io.AlfBunch
172 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
173 session and spike sorter, with keys ('clusters', 'times')
174 clusters : dict of one.alf.io.AlfBunch
175 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
176 ('channels', 'depths', 'metrics')
177 channels : dict of one.alf.io.AlfBunch
178 A dict with probe labels as keys, contains channel locations with keys ('acronym',
179 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs
180 non-lateralized.
181 """
182 one = one or ONE() 1gb
183 # enumerate probes and load according to the name
184 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1gb
185 if len(collections) == 0: 1gb
186 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
187 pnames = list(set(c.split('/')[1] for c in collections)) 1gb
188 spikes, clusters, channels = ({} for _ in range(3)) 1gb
190 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1gb
192 for pname in pnames: 1gb
193 probe_collection = _get_spike_sorting_collection(collections, pname) 1gb
194 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1gb
195 attribute=spike_attributes)
196 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1gb
197 attribute=cluster_attributes)
198 if return_channels: 1gb
199 channels = _load_channels_locations_from_disk( 1b
200 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions)
201 return spikes, clusters, channels 1b
202 else:
203 return spikes, clusters 1g
206def _get_attributes(dataset_types):
207 if dataset_types is None: 1gb
208 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1gb
209 else:
210 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
211 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
212 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
213 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
214 return spike_attributes, cluster_attributes
217def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None):
218 _logger.debug('loading spike sorting from disk') 1b
219 channels = Bunch({}) 1b
220 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) 1b
221 if len(collections) == 0: 1b
222 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
223 probes = list(set([c.split('/')[1] for c in collections])) 1b
224 for probe in probes: 1b
225 probe_collection = _get_spike_sorting_collection(collections, probe) 1b
226 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') 1b
227 # if the spike sorter has not aligned data, try and get the alignment available
228 if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): 1b
229 aligned_channel_collections = one.list_collections(
230 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision)
231 if len(aligned_channel_collections) == 0:
232 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}")
233 continue
234 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}")
235 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe)
236 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection)
237 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe])
238 # only have to reformat channels if we were able to load coordinates from disk
239 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) 1b
240 return channels 1b
243def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None):
244 """
245 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
246 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
247 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
248 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys
249 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017'
250 OR
251 'x', 'y', 'z', 'acronym', 'axial_um'
252 those are the guide for the interpolation
253 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates'
254 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object
255 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017
256 if a brain region object is provided, outputts a dict with keys
257 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um'
258 :return: Bunch or dictionary of channels with brain coordinates keys
259 """
260 NEUROPIXEL_VERSION = 1 1i
261 h = trace_header(version=NEUROPIXEL_VERSION) 1i
262 if channels is None: 1i
263 channels = {'localCoordinates': np.c_[h['x'], h['y']]}
264 nch = channels['localCoordinates'].shape[0] 1i
265 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i
266 channels_aligned = _channels_bunch2alf(channels_aligned) 1i
267 if 'localCoordinates' in channels_aligned.keys(): 1i
268 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i
269 else: # this is a edge case for a few spike sorting sessions
270 assert channels_aligned['mlapdv'].shape[0] == 384
271 aligned_depths = h['y']
272 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i
273 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i
274 channels['mlapdv'] = np.zeros((nch, 3)) 1i
275 for i in np.arange(3): 1i
276 channels['mlapdv'][:, i] = np.interp( 1i
277 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv]
278 # the brain locations have to be interpolated by nearest neighbour
279 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i
280 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i
281 if brain_regions is not None: 1i
282 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i
283 else:
284 return channels 1i
287def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
288 brain_atlas=None, return_source=False):
289 if not hasattr(one, 'alyx'): 1b
290 return {}, None 1b
291 _logger.debug(f"trying to load from traj {probe}")
292 channels = Bunch()
293 brain_atlas = brain_atlas or AllenAtlas
294 # need to find the collection bruh
295 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0]
296 collection = _collection_filter_from_args(probe=probe)
297 collections = one.list_collections(eid, filename='channels*', collection=collection,
298 revision=revision)
299 probe_collection = _get_spike_sorting_collection(collections, probe)
300 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection)
301 depths = chn_coords[:, 1]
303 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
304 get('tracing_exists', False)
305 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
306 get('alignment_resolved', False)
307 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
308 get('alignment_count', 0)
310 if tracing:
311 xyz = np.array(insertion['json']['xyz_picks']) / 1e6
312 if resolved:
314 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. '
315 f'Channel and cluster locations obtained from ephys aligned histology '
316 f'track.')
317 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
318 provenance='Ephys aligned histology track')[0]
319 align_key = insertion['json']['extended_qc']['alignment_stored']
320 feature = traj['json'][align_key][0]
321 track = traj['json'][align_key][1]
322 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
323 feature_prev=feature,
324 brain_atlas=brain_atlas, speedy=True)
325 chans = ephysalign.get_channel_locations(feature, track)
326 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
327 source = 'resolved'
328 elif counts > 0 and aligned:
329 _logger.debug(f'Channel locations for {eid}/{probe} have not been '
330 f'resolved. However, alignment flag set to True so channel and cluster'
331 f' locations will be obtained from latest available ephys aligned '
332 f'histology track.')
333 # get the latest user aligned channels
334 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
335 provenance='Ephys aligned histology track')[0]
336 align_key = insertion['json']['extended_qc']['alignment_stored']
337 feature = traj['json'][align_key][0]
338 track = traj['json'][align_key][1]
339 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
340 feature_prev=feature,
341 brain_atlas=brain_atlas, speedy=True)
342 chans = ephysalign.get_channel_locations(feature, track)
344 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
345 source = 'aligned'
346 else:
347 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. '
348 f'Channel and cluster locations obtained from histology track.')
349 # get the channels from histology tracing
350 xyz = xyz[np.argsort(xyz[:, 2]), :]
351 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6)
352 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
353 source = 'traced'
354 channels[probe]['axial_um'] = chn_coords[:, 1]
355 channels[probe]['lateral_um'] = chn_coords[:, 0]
357 else:
358 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}')
359 source = ''
360 channels = None
362 if return_source:
363 return channels, source
364 else:
365 return channels
368def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None):
369 """
370 Load the brain locations of each channel for a given session/probe
372 Parameters
373 ----------
374 eid : [str, UUID, Path, dict]
375 Experiment session identifier; may be a UUID, URL, experiment reference string
376 details dict or Path
377 probe : [str, list of str]
378 The probe label(s), e.g. 'probe01'
379 one : one.api.OneAlyx
380 An instance of ONE (shouldn't be in 'local' mode)
381 aligned : bool
382 Whether to get the latest user aligned channel when not resolved or use histology track
383 brain_atlas : iblatlas.BrainAtlas
384 Brain atlas object (default: Allen atlas)
385 Returns
386 -------
387 dict of one.alf.io.AlfBunch
388 A dict with probe labels as keys, contains channel locations with keys ('acronym',
389 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
390 optional: string 'resolved', 'aligned', 'traced' or ''
391 """
392 one = one or ONE()
393 brain_atlas = brain_atlas or AllenAtlas()
394 if isinstance(eid, dict):
395 ses = eid
396 eid = ses['url'][-36:]
397 else:
398 eid = one.to_eid(eid)
399 collection = _collection_filter_from_args(probe=probe)
400 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection,
401 brain_regions=brain_atlas.regions)
402 incomplete_probes = [k for k in channels if 'x' not in channels[k]]
403 for iprobe in incomplete_probes:
404 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned,
405 brain_atlas=brain_atlas, return_source=True)
406 if channels_ is not None:
407 channels[iprobe] = channels_[iprobe]
408 return channels
411def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
412 brain_regions=None, nested=True, collection=None, return_collection=False):
413 """
414 From an eid, loads spikes and clusters for all probes
415 The following set of dataset types are loaded:
416 'clusters.channels',
417 'clusters.depths',
418 'clusters.metrics',
419 'spikes.clusters',
420 'spikes.times',
421 'probes.description'
422 :param eid: experiment UUID or pathlib.Path of the local session
423 :param one: an instance of OneAlyx
424 :param probe: name of probe to load in, if not given all probes for session will be loaded
425 :param dataset_types: additional spikes/clusters objects to add to the standard default list
426 :param spike_sorter: name of the spike sorting you want to load (None for default)
427 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00"
428 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
429 :param nested: if a single probe is required, do not output a dictionary with the probe name as key
430 :param return_collection: (False) if True, will return the collection used to load
431 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
432 """
433 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.'
434 'Use brainbox.io.one.SpikeSortingLoader instead')
435 if collection is None:
436 collection = _collection_filter_from_args(probe, spike_sorter)
437 _logger.debug(f"load spike sorting with collection filter {collection}")
438 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types,
439 brain_regions=brain_regions)
440 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True)
441 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
442 if nested is False and len(spikes.keys()) == 1:
443 k = list(spikes.keys())[0]
444 channels = channels[k]
445 clusters = clusters[k]
446 spikes = spikes[k]
447 if return_collection:
448 return spikes, clusters, channels, collection
449 else:
450 return spikes, clusters, channels
453def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
454 brain_regions=None, return_collection=False):
455 """
456 From an eid, loads spikes and clusters for all probes
457 The following set of dataset types are loaded:
458 'clusters.channels',
459 'clusters.depths',
460 'clusters.metrics',
461 'spikes.clusters',
462 'spikes.times',
463 'probes.description'
464 :param eid: experiment UUID or pathlib.Path of the local session
465 :param one: an instance of OneAlyx
466 :param probe: name of probe to load in, if not given all probes for session will be loaded
467 :param dataset_types: additional spikes/clusters objects to add to the standard default list
468 :param spike_sorter: name of the spike sorting you want to load (None for default)
469 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
470 :param return_collection:(bool - False) if True, returns the collection for loading the data
471 :return: spikes, clusters (dict of bunch, 1 bunch per probe)
472 """
473 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g
474 'Use brainbox.io.one.SpikeSortingLoader instead')
475 collection = _collection_filter_from_args(probe, spike_sorter) 1g
476 _logger.debug(f"load spike sorting with collection filter {collection}") 1g
477 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g
478 return_channels=False, dataset_types=dataset_types,
479 brain_regions=brain_regions)
480 if return_collection: 1g
481 return spikes, clusters, collection
482 else:
483 return spikes, clusters 1g
486def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None,
487 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False):
488 """
489 For a given eid, get spikes, clusters and channels information, and merges clusters
490 and channels information before returning all three variables.
492 Parameters
493 ----------
494 eid : [str, UUID, Path, dict]
495 Experiment session identifier; may be a UUID, URL, experiment reference string
496 details dict or Path
497 one : one.api.OneAlyx
498 An instance of ONE (shouldn't be in 'local' mode)
499 probe : [str, list of str]
500 The probe label(s), e.g. 'probe01'
501 aligned : bool
502 Whether to get the latest user aligned channel when not resolved or use histology track
503 dataset_types : list of str
504 Optional additional spikes/clusters objects to add to the standard default list
505 spike_sorter : str
506 Name of the spike sorting you want to load (None for default which is pykilosort if it's
507 available otherwise the default MATLAB kilosort)
508 brain_atlas : iblatlas.atlas.BrainAtlas
509 Brain atlas object (default: Allen atlas)
510 return_collection: bool
511 Returns an extra argument with the collection chosen
513 Returns
514 -------
515 spikes : dict of one.alf.io.AlfBunch
516 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
517 session and spike sorter, with keys ('clusters', 'times')
518 clusters : dict of one.alf.io.AlfBunch
519 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
520 ('channels', 'depths', 'metrics')
521 channels : dict of one.alf.io.AlfBunch
522 A dict with probe labels as keys, contains channel locations with keys ('acronym',
523 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
524 """
525 # --- Get spikes and clusters data
526 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
527 'Use brainbox.io.one.SpikeSortingLoader instead')
528 one = one or ONE()
529 brain_atlas = brain_atlas or AllenAtlas()
530 spikes, clusters, collection = load_spike_sorting(
531 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True)
532 # -- Get brain regions and assign to clusters
533 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned,
534 brain_atlas=brain_atlas)
535 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
536 if nested is False and len(spikes.keys()) == 1:
537 k = list(spikes.keys())[0]
538 channels = channels[k]
539 clusters = clusters[k]
540 spikes = spikes[k]
541 if return_collection:
542 return spikes, clusters, channels, collection
543 else:
544 return spikes, clusters, channels
547def load_ephys_session(eid, one=None):
548 """
549 From an eid, hits the Alyx database and downloads a standard default set of dataset types
550 From a local session Path (pathlib.Path), loads a standard default set of dataset types
551 to perform analysis:
552 'clusters.channels',
553 'clusters.depths',
554 'clusters.metrics',
555 'spikes.clusters',
556 'spikes.times',
557 'probes.description'
559 Parameters
560 ----------
561 eid : [str, UUID, Path, dict]
562 Experiment session identifier; may be a UUID, URL, experiment reference string
563 details dict or Path
564 one : oneibl.one.OneAlyx, optional
565 ONE object to use for loading. Will generate internal one if not used, by default None
567 Returns
568 -------
569 spikes : dict of one.alf.io.AlfBunch
570 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
571 session and spike sorter, with keys ('clusters', 'times')
572 clusters : dict of one.alf.io.AlfBunch
573 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
574 ('channels', 'depths', 'metrics')
575 trials : one.alf.io.AlfBunch of numpy.ndarray
576 The session trials data
577 """
578 assert one 1g
579 spikes, clusters = load_spike_sorting(eid, one=one) 1g
580 trials = one.load_object(eid, 'trials') 1g
581 return spikes, clusters, trials 1g
584def _remove_old_clusters(session_path, probe):
585 # gets clusters and spikes from a local session folder
586 probe_path = session_path.joinpath('alf', probe)
588 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead
589 cluster_file = probe_path.joinpath('clusters.metrics.csv')
591 if cluster_file.exists():
592 os.remove(cluster_file)
593 _logger.info('Deleting old clusters.metrics.csv file')
596def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
597 """
598 Takes (default and any extra) values in given keys from channels and assign them to clusters.
599 If channels does not contain any data, the new keys are added to clusters but left empty.
601 Parameters
602 ----------
603 dic_clus : dict of one.alf.io.AlfBunch
604 1 bunch per probe, containing cluster information
605 channels : dict of one.alf.io.AlfBunch
606 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates')
607 keys_to_add_extra : list of str
608 Any extra keys to load into channels bunches
610 Returns
611 -------
612 dict of one.alf.io.AlfBunch
613 clusters (1 bunch per probe) with new keys values.
614 """
615 probe_labels = list(channels.keys()) # Convert dict_keys into list
616 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um']
618 if keys_to_add_extra is None:
619 keys_to_add = keys_to_add_default
620 else:
621 # Append extra optional keys
622 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default))
624 for label in probe_labels:
625 clu_ch = dic_clus[label]['channels']
626 for key in keys_to_add:
627 try:
628 assert key in channels[label].keys() # Check key is in channels
629 ch_key = channels[label][key]
630 nch_key = len(ch_key) if ch_key is not None else 0
631 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index
632 dic_clus[label][key] = ch_key[clu_ch]
633 else:
634 _logger.warning(
635 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels'
636 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.')
637 dic_clus[label][key] = []
638 except AssertionError:
639 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge')
640 continue
642 return dic_clus
645def load_passive_rfmap(eid, one=None):
646 """
647 For a given eid load in the passive receptive field mapping protocol data
649 Parameters
650 ----------
651 eid : [str, UUID, Path, dict]
652 Experiment session identifier; may be a UUID, URL, experiment reference string
653 details dict or Path
654 one : oneibl.one.OneAlyx, optional
655 An instance of ONE (may be in 'local' - offline - mode)
657 Returns
658 -------
659 one.alf.io.AlfBunch
660 Passive receptive field mapping data
661 """
662 one = one or ONE()
664 # Load in the receptive field mapping data
665 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf')
666 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin',
667 collection='raw_passive_data'), dtype="uint8")
668 y_pix, x_pix = 15, 15
669 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0])
670 rf_map['frames'] = frames
672 return rf_map
675def load_wheel_reaction_times(eid, one=None):
676 """
677 Return the calculated reaction times for session. Reaction times are defined as the time
678 between the go cue (onset tone) and the onset of the first substantial wheel movement. A
679 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the
680 distance to threshold (~0.1 radians).
682 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if
683 there was no detected movement withing the period, or when the goCue_times or feedback_times
684 are nan.
686 Parameters
687 ----------
688 eid : [str, UUID, Path, dict]
689 Experiment session identifier; may be a UUID, URL, experiment reference string
690 details dict or Path
691 one : one.api.OneAlyx, optional
692 one object to use for loading. Will generate internal one if not used, by default None
694 Returns
695 ----------
696 array-like
697 reaction times
698 """
699 if one is None:
700 one = ONE()
702 trials = one.load_object(eid, 'trials')
703 # If already extracted, load and return
704 if trials and 'firstMovement_times' in trials:
705 return trials['firstMovement_times'] - trials['goCue_times']
706 # Otherwise load wheelMoves object and calculate
707 moves = one.load_object(eid, 'wheelMoves')
708 # Re-extract wheel moves if necessary
709 if not moves or 'peakAmplitude' not in moves:
710 wheel = one.load_object(eid, 'wheel')
711 moves = extract_wheel_moves(wheel['timestamps'], wheel['position'])
712 assert trials and moves, 'unable to load trials and wheelMoves data'
713 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials)
714 return firstMove_times - trials['goCue_times']
717def load_iti(trials):
718 """
719 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey
720 screen commencing at stimulus off and lasting until the quiescent period at the start of the
721 following trial. Note that the ITI for the first trial is the time between the first trial
722 and the next, therefore the last value is NaN.
724 Parameters
725 ----------
726 trials : one.alf.io.AlfBunch
727 An ALF trials object containing the keys {'intervals', 'stimOff_times'}.
729 Returns
730 -------
731 np.array
732 An array of inter-trial intervals, the last value being NaN.
733 """
734 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1o
735 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1o
736 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1o
739def load_channels_from_insertion(ins, depths=None, one=None, ba=None):
741 PROV_2_VAL = {
742 'Resolved': 90,
743 'Ephys aligned histology track': 70,
744 'Histology track': 50,
745 'Micro-manipulator': 30,
746 'Planned': 10}
748 one = one or ONE()
749 ba = ba or atlas.AllenAtlas()
750 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id'])
751 val = [PROV_2_VAL[tr['provenance']] for tr in traj]
752 idx = np.argmax(val)
753 traj = traj[idx]
754 if depths is None:
755 depths = trace_header(version=1)[:, 1]
756 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator':
757 ins = atlas.Insertion.from_dict(traj)
758 # Deepest coordinate first
759 xyz = np.c_[ins.tip, ins.entry].T
760 xyz_channels = histology.interpolate_along_track(xyz, (depths +
761 TIP_SIZE_UM) / 1e6)
762 else:
763 xyz = np.array(ins['json']['xyz_picks']) / 1e6
764 if traj['provenance'] == 'Histology track':
765 xyz = xyz[np.argsort(xyz[:, 2]), :]
766 xyz_channels = histology.interpolate_along_track(xyz, (depths +
767 TIP_SIZE_UM) / 1e6)
768 else:
769 align_key = ins['json']['extended_qc']['alignment_stored']
770 feature = traj['json'][align_key][0]
771 track = traj['json'][align_key][1]
772 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
773 feature_prev=feature,
774 brain_atlas=ba, speedy=True)
775 xyz_channels = ephysalign.get_channel_locations(feature, track)
776 return xyz_channels
779@dataclass
780class SpikeSortingLoader:
781 """
782 Object that will load spike sorting data for a given probe insertion.
783 This class can be instantiated in several manners
784 - With Alyx database probe id:
785 SpikeSortingLoader(pid=pid, one=one)
786 - With Alyx database eic and probe name:
787 SpikeSortingLoader(eid=eid, pname='probe00', one=one)
788 - From a local session and probe name:
789 SpikeSortingLoader(session_path=session_path, pname='probe00')
790 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded.
791 """
792 one: One = None
793 atlas: None = None
794 pid: str = None
795 eid: str = ''
796 pname: str = ''
797 session_path: Path = ''
798 # the following properties are the outcome of the post init function
799 collections: list = None
800 datasets: list = None # list of all datasets belonging to the session
801 # the following properties are the outcome of a reading function
802 files: dict = None
803 raw_data_files: list = None # list of raw ap and lf files corresponding to the recording
804 collection: str = ''
805 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced'
806 spike_sorter: str = 'pykilosort'
807 spike_sorting_path: Path = None
808 _sync: dict = None
810 def __post_init__(self):
811 # pid gets precedence
812 if self.pid is not None: 1cbedk
813 try: 1dk
814 self.eid, self.pname = self.one.pid2eid(self.pid) 1dk
815 except NotImplementedError:
816 if self.eid == '' or self.pname == '':
817 raise IOError("Cannot infer session id and probe name from pid. "
818 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.")
819 self.session_path = self.one.eid2path(self.eid) 1dk
820 # then eid / pname combination
821 elif self.session_path is None or self.session_path == '': 1cbe
822 self.session_path = self.one.eid2path(self.eid) 1cbe
823 # fully local providing a session path
824 else:
825 if self.one:
826 self.eid = self.one.to_eid(self.session_path)
827 else:
828 self.one = One(cache_dir=self.session_path.parents[2], mode='local')
829 df_sessions = cache._make_sessions_df(self.session_path)
830 self.one._cache['sessions'] = df_sessions.set_index('id')
831 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False)
832 self.eid = str(self.session_path.relative_to(self.session_path.parents[2]))
833 # populates default properties
834 self.collections = self.one.list_collections( 1cbedk
835 self.eid, filename='spikes*', collection=f"alf/{self.pname}*")
836 self.datasets = self.one.list_datasets(self.eid) 1cbedk
837 if self.atlas is None: 1cbedk
838 self.atlas = AllenAtlas() 1cbek
839 self.files = {} 1cbedk
840 self.raw_data_files = [] 1cbedk
842 def _load_object(self, *args, **kwargs):
843 """
844 This function is a wrapper around alfio.load_object that will remove the UUID in the
845 filename if the object is on SDSC.
846 """
847 remove_uuids = getattr(self.one, 'uuid_filenames', False) 1cbed
848 d = alfio.load_object(*args, **kwargs) 1cbed
849 if remove_uuids: 1cbed
850 # pops the UUID in the key names
851 keys = list(d.keys())
852 for k in keys:
853 d[k[:-37]] = d.pop(k)
854 return d 1cbed
856 @staticmethod
857 def _get_attributes(dataset_types):
858 """returns attributes to load for spikes and clusters objects"""
859 dataset_types = [] if dataset_types is None else dataset_types 1cbed
860 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1cbed
861 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1cbed
862 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1cbed
863 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1cbed
864 waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] 1cbed
865 waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) 1cbed
866 return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} 1cbed
868 def _get_spike_sorting_collection(self, spike_sorter='pykilosort'):
869 """
870 Filters a list or array of collections to get the relevant spike sorting dataset
871 if there is a pykilosort, load it
872 """
873 collection = next(filter(lambda c: c == f'alf/{self.pname}/{spike_sorter}', self.collections), None) 1cbed
874 # otherwise, prefers the shortest
875 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) 1cbed
876 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") 1cbed
877 return collection 1cbed
879 def load_spike_sorting_object(self, obj, *args, **kwargs):
880 """
881 Loads an ALF object
882 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
883 :param spike_sorter: (defaults to 'pykilosort')
884 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
885 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
886 :param kwargs: additional arguments to be passed to one.api.One.load_object
887 :param missing: 'raise' (default) or 'ignore'
888 :return:
889 """
890 self.download_spike_sorting_object(obj, *args, **kwargs)
891 return self._load_object(self.files[obj])
893 def get_version(self, spike_sorter='pykilosort'):
894 collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
895 dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy')
896 return dset[0]['version'] if len(dset) else 'unknown'
898 def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None,
899 attribute=None, missing='raise', **kwargs):
900 """
901 Downloads an ALF object
902 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
903 :param spike_sorter: (defaults to 'pykilosort')
904 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
905 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
906 :param kwargs: additional arguments to be passed to one.api.One.load_object
907 :param attribute: list of attributes to load for the object
908 :param missing: 'raise' (default) or 'ignore'
909 :return:
910 """
911 if len(self.collections) == 0: 1cbed
912 return {}, {}, {}
913 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1cbed
914 collection = collection or self.collection 1cbed
915 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1cbed
916 attributes = self._get_attributes(dataset_types) 1cbed
917 try: 1cbed
918 self.files[obj] = self.one.load_object( 1cbed
919 self.eid, obj=obj, attribute=attributes.get(obj, None),
920 collection=collection, download_only=True, **kwargs)
921 except ALFObjectNotFound as e: 1cbe
922 if missing == 'raise': 1cbe
923 raise e
925 def download_spike_sorting(self, objects=None, **kwargs):
926 """
927 Downloads spikes, clusters and channels
928 :param spike_sorter: (defaults to 'pykilosort')
929 :param dataset_types: list of extra dataset types
930 :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels']
931 :return:
932 """
933 objects = ['spikes', 'clusters', 'channels'] if objects is None else objects 1cbed
934 for obj in objects: 1cbed
935 self.download_spike_sorting_object(obj=obj, **kwargs) 1cbed
936 self.spike_sorting_path = self.files['clusters'][0].parent 1cbed
938 def download_raw_electrophysiology(self, band='ap'):
939 """
940 Downloads raw electrophysiology data files on local disk.
941 :param band: "ap" (default) or "lf" for LFP band
942 :return: list of raw data files full paths (ch, meta and cbin files)
943 """
944 raw_data_files = [] 1b
945 for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']: 1b
946 try: 1b
947 # FIXME: this will fail if multiple LFP segments are found
948 raw_data_files.append(self.one.load_dataset( 1b
949 self.eid,
950 download_only=True,
951 collection=f'raw_ephys_data/{self.pname}',
952 dataset=suffix,
953 check_hash=False,
954 ))
955 except ALFObjectNotFound: 1b
956 _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}") 1b
957 self.raw_data_files = list(set(self.raw_data_files + raw_data_files)) 1b
958 return raw_data_files 1b
960 def raw_electrophysiology(self, stream=True, band='ap', **kwargs):
961 """
962 Returns a reader for the raw electrophysiology data
963 By default it is a streamer object, but if stream is False, it will return a spikeglx.Reader after having
964 downloaded the raw data file if necessary
965 :param stream:
966 :param band:
967 :param kwargs:
968 :return:
969 """
970 if stream: 1k
971 return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) 1k
972 else:
973 raw_data_files = self.download_raw_electrophysiology(band=band)
974 cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None)
975 if cbin_file is not None:
976 return spikeglx.Reader(cbin_file)
978 def load_channels(self, **kwargs):
979 """
980 Loads channels
981 The channel locations can come from several sources, it will load the most advanced version of the histology available,
982 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
983 - alf: the final version of channel locations, same as resolved with the difference that data is on file
984 - resolved: channel locations alignments have been agreed upon
985 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
986 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
988 :param spike_sorter: (defaults to 'pykilosort')
989 :param dataset_types: list of extra dataset types
990 :return:
991 """
992 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
993 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1cbed
994 self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) 1cbed
995 channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) 1cbed
996 if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails 1cbed
997 channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1d
998 if 'brainLocationIds_ccf_2017' not in channels: 1cbed
999 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1b
1000 _channels, self.histology = _load_channel_locations_traj( 1b
1001 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True)
1002 if _channels: 1b
1003 channels = _channels[self.pname]
1004 else:
1005 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1cbed
1006 self.histology = 'alf' 1cbed
1007 return Bunch(channels) 1cbed
1009 def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs):
1010 """
1011 Loads spikes, clusters and channels
1013 There could be several spike sorting collections, by default the loader will get the pykilosort collection
1015 The channel locations can come from several sources, it will load the most advanced version of the histology available,
1016 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
1017 - alf: the final version of channel locations, same as resolved with the difference that data is on file
1018 - resolved: channel locations alignments have been agreed upon
1019 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
1020 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
1022 :param spike_sorter: (defaults to 'pykilosort')
1023 :param revision: for example "2024-05-06", (defaults to None):
1024 :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
1025 :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
1026 :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
1027 :param kwargs: additional arguments to be passed to one.api.One.load_object
1028 :return:
1029 """
1030 if len(self.collections) == 0: 1cbed
1031 return {}, {}, {}
1032 self.files = {} 1cbed
1033 self.spike_sorter = spike_sorter 1cbed
1034 self.revision = revision 1cbed
1035 objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None 1cbed
1036 self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) 1cbed
1037 channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs) 1cbed
1038 clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) 1cbed
1039 if good_units: 1cbed
1040 spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards)
1041 else:
1042 spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) 1cbed
1043 if enforce_version: 1cbed
1044 self._assert_version_consistency() 1cbed
1045 return spikes, clusters, channels 1cbed
1047 def _assert_version_consistency(self):
1048 """
1049 Makes sure the state of the spike sorting object matches the files downloaded
1050 :return: None
1051 """
1052 for k in ['spikes', 'clusters', 'channels', 'passingSpikes']: 1cbed
1053 for fn in self.files.get(k, []): 1cbed
1054 if self.spike_sorter: 1cbed
1055 assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \ 1bd
1056 f"You required strict version {self.spike_sorter}, {fn} does not match"
1057 if self.revision: 1cbed
1058 assert fn.relative_to(self.session_path).parts[3] == f"#{self.revision}#", \
1059 f"You required strict revision {self.revision}, {fn} does not match"
1061 @staticmethod
1062 def compute_metrics(spikes, clusters=None):
1063 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size
1064 metrics = pd.DataFrame(quick_unit_metrics(
1065 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc)))
1066 return metrics
1068 @staticmethod
1069 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False):
1070 """
1071 Merge the metrics and the channel information into the clusters dictionary
1072 :param spikes:
1073 :param clusters:
1074 :param channels:
1075 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used
1076 for clusters or analysis applications (defaults to None).
1077 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false)
1078 :return: cluster dictionary containing metrics and histology
1079 """
1080 if spikes == {}: 1bd
1081 return
1082 nc = clusters['channels'].size 1bd
1083 # recompute metrics if they are not available
1084 metrics = None 1bd
1085 if 'metrics' in clusters: 1bd
1086 metrics = clusters.pop('metrics') 1bd
1087 if metrics.shape[0] != nc: 1bd
1088 metrics = None
1089 if metrics is None or compute_metrics is True: 1bd
1090 _logger.debug("recompute clusters metrics")
1091 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters)
1092 if isinstance(cache_dir, Path):
1093 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt'))
1094 for k in metrics.keys(): 1bd
1095 clusters[k] = metrics[k].to_numpy() 1bd
1096 for k in channels.keys(): 1bd
1097 clusters[k] = channels[k][clusters['channels']] 1bd
1098 if cache_dir is not None: 1bd
1099 _logger.debug(f'caching clusters metrics in {cache_dir}')
1100 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt'))
1101 return clusters 1bd
1103 @property
1104 def url(self):
1105 """Gets flatiron URL for the session"""
1106 webclient = getattr(self.one, '_web_client', None)
1107 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None
1109 def _get_probe_info(self):
1110 if self._sync is None: 1e
1111 timestamps = self.one.load_dataset( 1e
1112 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}')
1113 _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks 1e
1114 self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}')
1115 try: 1e
1116 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1e
1117 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}'))
1118 fs = spikeglx._get_fs_from_meta(ap_meta) 1e
1119 except ALFObjectNotFound:
1120 ap_meta = None
1121 fs = 30_000
1122 self._sync = { 1e
1123 'timestamps': timestamps,
1124 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'),
1125 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'),
1126 'ap_meta': ap_meta,
1127 'fs': fs,
1128 }
1130 def timesprobe2times(self, values, direction='forward'):
1131 self._get_probe_info()
1132 if direction == 'forward':
1133 return self._sync['forward'](values * self._sync['fs'])
1134 elif direction == 'reverse':
1135 return self._sync['reverse'](values) / self._sync['fs']
1137 def samples2times(self, values, direction='forward'):
1138 """
1139 Converts ephys sample values to session main clock seconds
1140 :param values: numpy array of times in seconds or samples to resync
1141 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse'
1142 (seconds main time to samples probe time)
1143 :return:
1144 """
1145 self._get_probe_info() 1e
1146 return self._sync[direction](values) 1e
1148 @property
1149 def pid2ref(self):
1150 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1cb
1152 def _default_plot_title(self, spikes):
1153 title = f"{self.pid2ref}, {self.pid} \n" \ 1c
1154 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters"
1155 return title 1c
1157 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None,
1158 drift=None, title=None, **kwargs):
1159 """
1160 :param spikes: spikes dictionary or Bunch
1161 :param channels: channels dictionary or Bunch.
1162 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png".
1163 Otherwise, plot.
1164 :param br: brain regions object (optional)
1165 :param label: label for saved image (optional, default="raster")
1166 :param time_series: timeseries dictionary for behavioral event times (optional)
1167 :param **kwargs: kwargs passed to `driftmap()` (optional)
1168 :return:
1169 """
1170 br = br or BrainRegions() 1c
1171 time_series = time_series or {} 1c
1172 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c
1173 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col')
1174 axs[0, 1].set_axis_off() 1c
1175 # axs[0, 0].set_xticks([])
1176 if kwargs is None: 1c
1177 # set default raster plot parameters
1178 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5}
1179 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c
1180 if title is None: 1c
1181 title = self._default_plot_title(spikes) 1c
1182 axs[0, 0].title.set_text(title) 1c
1183 for k, ts in time_series.items(): 1c
1184 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0])
1185 if 'atlas_id' in channels: 1c
1186 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c
1187 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology)
1188 axs[1, 0].set_ylim(0, 3800) 1c
1189 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c
1190 fig.tight_layout() 1c
1192 if drift is None: 1c
1193 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c
1194 if 'drift' in self.files: 1c
1195 drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
1196 if isinstance(drift, dict): 1c
1197 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5)
1198 axs[0, 0].set(ylim=[-15, 15])
1200 if save_dir is not None: 1c
1201 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1202 fig.savefig(png_file)
1203 plt.close(fig)
1204 gc.collect()
1205 else:
1206 return fig, axs 1c
1208 def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
1209 channels=None,
1210 br: BrainRegions = None,
1211 save_dir=None,
1212 label='raster',
1213 gain=-93,
1214 title=None):
1216 # compute the raw data offset and destripe, we take 400ms around t0
1217 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs))
1218 raw = sr[first_sample:last_sample, :-sr.nsync].T
1219 channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True
1220 destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels)
1221 # filter out the spikes according to good/bad clusters and to the time slice
1222 spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample]))
1223 ss = spikes['samples'][spike_sel]
1224 sc = clusters['channels'][spikes['clusters'][spike_sel]]
1225 sok = clusters['label'][spikes['clusters'][spike_sel]] == 1
1226 if title is None:
1227 title = self._default_plot_title(spikes)
1228 # display the raw data snippet with spikes overlaid
1229 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col')
1230 Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s')
1231 axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5)
1232 axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5)
1233 axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035])
1234 # adds the channel locations if available
1235 if (channels is not None) and ('atlas_id' in channels):
1236 br = br or BrainRegions()
1237 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
1238 brain_regions=br, display=True, ax=axs[1], title=self.histology)
1239 axs[1].get_yaxis().set_visible(False)
1240 fig.tight_layout()
1242 if save_dir is not None:
1243 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1244 fig.savefig(png_file)
1245 plt.close(fig)
1246 gc.collect()
1247 else:
1248 return fig, axs
1251@dataclass
1252class SessionLoader:
1253 """
1254 Object to load session data for a give session in the recommended way.
1256 Parameters
1257 ----------
1258 one: one.api.ONE instance
1259 Can be in remote or local mode (required)
1260 session_path: string or pathlib.Path
1261 The absolute path to the session (one of session_path or eid is required)
1262 eid: string
1263 database UUID of the session (one of session_path or eid is required)
1265 If both are provided, session_path takes precedence over eid.
1267 Examples
1268 --------
1269 1) Load all available session data for one session:
1270 >>> from one.api import ONE
1271 >>> from brainbox.io.one import SessionLoader
1272 >>> one = ONE()
1273 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/')
1274 # Object is initiated, but no data is loaded as you can see in the data_info attribute
1275 >>> sess_loader.data_info
1276 name is_loaded
1277 0 trials False
1278 1 wheel False
1279 2 pose False
1280 3 motion_energy False
1281 4 pupil False
1283 # Loading all available session data, the data_info attribute now shows which data has been loaded
1284 >>> sess_loader.load_session_data()
1285 >>> sess_loader.data_info
1286 name is_loaded
1287 0 trials True
1288 1 wheel True
1289 2 pose True
1290 3 motion_energy True
1291 4 pupil False
1293 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g.
1294 >>> type(sess_loader.trials)
1295 pandas.core.frame.DataFrame
1296 >>> sess_loader.trials.shape
1297 (626, 18)
1298 # Each data comes with its own timestamps in a column called 'times'
1299 >>> sess_loader.wheel['times']
1300 0 0.134286
1301 1 0.135286
1302 2 0.136286
1303 3 0.137286
1304 4 0.138286
1305 ...
1306 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera.
1307 # The dataframes of all cameras are collected in a dictionary
1308 >>> type(sess_loader.pose)
1309 dict
1310 >>> sess_loader.pose.keys()
1311 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera'])
1312 >>> sess_loader.pose['bodyCamera'].columns
1313 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object')
1314 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading
1315 functions:
1316 >>> sess_loader.load_wheel(sampling_rate=100)
1317 """
1318 one: One = None
1319 session_path: Path = ''
1320 eid: str = ''
1321 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1322 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1323 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1324 pose: dict = field(default_factory=dict, repr=False)
1325 motion_energy: dict = field(default_factory=dict, repr=False)
1326 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1328 def __post_init__(self):
1329 """
1330 Function that runs automatically after initiation of the dataclass attributes.
1331 Checks for required inputs, sets session_path and eid, creates data_info table.
1332 """
1333 if self.one is None: 1af
1334 raise ValueError("An input to one is required. If not connection to a database is desired, it can be "
1335 "a fully local instance of One.")
1336 # If session path is given, takes precedence over eid
1337 if self.session_path is not None and self.session_path != '': 1af
1338 self.eid = self.one.to_eid(self.session_path) 1af
1339 self.session_path = Path(self.session_path) 1af
1340 # Providing no session path, try to infer from eid
1341 else:
1342 if self.eid is not None and self.eid != '':
1343 self.session_path = self.one.eid2path(self.eid)
1344 else:
1345 raise ValueError("If no session path is given, eid is required.")
1347 data_names = [ 1af
1348 'trials',
1349 'wheel',
1350 'pose',
1351 'motion_energy',
1352 'pupil'
1353 ]
1354 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1af
1356 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False):
1357 """
1358 Function to load available session data into the SessionLoader object. Input parameters allow to control which
1359 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input
1360 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored
1361 in SessionLoader.data_info
1363 Parameters
1364 ----------
1365 trials: boolean
1366 Whether to load all trials data into SessionLoader.trials, default is True
1367 wheel: boolean
1368 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True
1369 pose: boolean
1370 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose,
1371 default is True
1372 motion_energy: boolean
1373 Whether to load motion energy data (whisker pad for left/right camera, body for body camera)
1374 into SessionLoader.motion_energy, default is True
1375 pupil: boolean
1376 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil,
1377 default is True
1378 reload: boolean
1379 Whether to reload data that has already been loaded into this SessionLoader object, default is False
1380 """
1381 load_df = self.data_info.copy() 1f
1382 load_df['to_load'] = [ 1f
1383 trials,
1384 wheel,
1385 pose,
1386 motion_energy,
1387 pupil
1388 ]
1389 load_df['load_func'] = [ 1f
1390 self.load_trials,
1391 self.load_wheel,
1392 self.load_pose,
1393 self.load_motion_energy,
1394 self.load_pupil
1395 ]
1397 for idx, row in load_df.iterrows(): 1f
1398 if row['to_load'] is False: 1f
1399 _logger.debug(f"Not loading {row['name']} data, set to False.")
1400 elif row['is_loaded'] is True and reload is False: 1f
1401 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1f
1402 else:
1403 try: 1f
1404 _logger.info(f"Loading {row['name']} data") 1f
1405 row['load_func']() 1f
1406 self.data_info.loc[idx, 'is_loaded'] = True 1f
1407 except BaseException as e:
1408 _logger.warning(f"Could not load {row['name']} data.")
1409 _logger.debug(e)
1411 def _find_behaviour_collection(self, obj):
1412 """
1413 Function to find the trial or wheel collection
1415 Parameters
1416 ----------
1417 obj: str
1418 Alf object to load, either 'trials' or 'wheel'
1419 """
1420 dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy' 1fnj
1421 dsets = self.one.list_datasets(self.eid, dataset) 1fnj
1422 if len(dsets) == 0: 1fnj
1423 return 'alf' 1fn
1424 else:
1425 collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets] 1fj
1426 if len(set(collections)) == 1: 1fj
1427 return collections[0] 1fj
1428 else:
1429 _logger.error(f'Multiple collections found {collections}. Specify collection when loading, '
1430 f'e.g sl.load_{obj}(collection="{collections[0]}")')
1431 raise ALFMultipleCollectionsFound
1433 def load_trials(self, collection=None):
1434 """
1435 Function to load trials data into SessionLoader.trials
1437 Parameters
1438 ----------
1439 collection: str
1440 Alf collection of trials data
1441 """
1443 if not collection: 1fn
1444 collection = self._find_behaviour_collection('trials') 1fn
1445 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex
1446 self.one.wildcards = False 1fn
1447 self.trials = self.one.load_object( 1fn
1448 self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*').to_df()
1449 self.one.wildcards = True 1fn
1450 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1fn
1452 def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None):
1453 """
1454 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
1455 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
1456 a Butterworth low-pass filter is applied.
1458 Parameters
1459 ----------
1460 fs: int, float
1461 Sampling frequency for the wheel position, default is 1000 Hz
1462 corner_frequency: int, float
1463 Corner frequency of Butterworth low-pass filter, default is 20
1464 order: int, float
1465 Order of Butterworth low_pass filter, default is 8
1466 collection: str
1467 Alf collection of wheel data
1468 """
1469 if not collection: 1fj
1470 collection = self._find_behaviour_collection('wheel') 1fj
1471 wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection) 1fj
1472 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1fj
1473 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
1474 # resample the wheel position and compute velocity, acceleration
1475 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1fj
1476 self.wheel['position'], self.wheel['times'] = interpolate_position( 1fj
1477 wheel_raw['timestamps'], wheel_raw['position'], freq=fs)
1478 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1fj
1479 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order)
1480 self.wheel = self.wheel.apply(np.float32) 1fj
1481 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1fj
1483 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
1484 """
1485 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a
1486 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas
1487 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera.
1489 Parameters
1490 ----------
1491 likelihood_thr: float
1492 The position of each tracked body part come with a likelihood of that estimate for each time point.
1493 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set
1494 likelihood_thr=1. Default is 0.9
1495 views: list
1496 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1497 """
1498 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1499 self.pose = {} 1mhf
1500 for view in views: 1mhf
1501 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times']) 1mhf
1502 # Double check if video timestamps are correct length or can be fixed
1503 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1mhf
1504 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1mhf
1505 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1mhf
1506 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1mhf
1508 def load_motion_energy(self, views=['left', 'right', 'body']):
1509 """
1510 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a
1511 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are
1512 pandas Dataframes with the timestamps and motion energy data.
1513 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad
1514 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the
1515 body (bodyMotionEnergy).
1517 Parameters
1518 ----------
1519 views: list
1520 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1521 """
1522 names = {'left': 'whiskerMotionEnergy', 1lf
1523 'right': 'whiskerMotionEnergy',
1524 'body': 'bodyMotionEnergy'}
1525 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1526 self.motion_energy = {} 1lf
1527 for view in views: 1lf
1528 me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times']) 1lf
1529 # Double check if video timestamps are correct length or can be fixed
1530 times_fixed, motion_energy = self._check_video_timestamps( 1lf
1531 view, me_raw['times'], me_raw['ROIMotionEnergy'])
1532 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1lf
1533 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1lf
1534 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1lf
1536 def load_licks(self):
1537 """
1538 Not yet implemented
1539 """
1540 pass
1542 def load_pupil(self, snr_thresh=5.):
1543 """
1544 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil.
1546 Parameters
1547 ----------
1548 snr_thresh: float
1549 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data
1550 will be considered unusable and will be discarded.
1551 """
1552 # Try to load from features
1553 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features']) 1hf
1554 if 'features' in feat_raw.keys(): 1hf
1555 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features'])
1556 self.pupil = feats.copy()
1557 self.pupil.insert(0, 'times', times_fixed)
1559 # If unavailable compute on the fly
1560 else:
1561 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1hf
1562 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1hf
1563 and 'leftCamera' in self.pose.keys()):
1564 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt
1565 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1hf
1566 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1hf
1567 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1hf
1568 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1hf
1569 else:
1570 self.load_pose(views=['left'], likelihood_thr=0.9)
1571 dlc_thr = self.pose['leftCamera'].copy()
1573 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1hf
1574 try: 1hf
1575 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1hf
1576 except BaseException as e:
1577 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. "
1578 "Saving all NaNs for pupilDiameter_smooth.")
1579 _logger.debug(e)
1580 self.pupil['pupilDiameter_smooth'] = np.nan
1582 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1hf
1583 good_idxs = np.where( 1hf
1584 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0]
1585 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1hf
1586 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs])))
1587 if snr < snr_thresh: 1hf
1588 self.pupil = pd.DataFrame() 1h
1589 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h
1591 def _check_video_timestamps(self, view, video_timestamps, video_data):
1592 """
1593 Helper function to check for the length of the video frames vs video timestamps and fix in case
1594 timestamps are longer than video frames.
1595 """
1596 # If camera times are shorter than video data, or empty, no current fix
1597 if video_timestamps.shape[0] < video_data.shape[0]: 1lmhf
1598 if video_timestamps.shape[0] == 0:
1599 msg = f'Camera times empty for {view}Camera.'
1600 else:
1601 msg = f'Camera times are shorter than video data for {view}Camera.'
1602 _logger.warning(msg)
1603 raise ValueError(msg)
1604 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video.
1605 # This is because the first few frames are sometimes not recorded. We can remove the first few
1606 # timestamps in this case
1607 elif video_timestamps.shape[0] > video_data.shape[0]: 1lmhf
1608 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1lmhf
1609 return video_timestamps_fixed, video_data 1lmhf
1610 else:
1611 return video_timestamps, video_data
1614class EphysSessionLoader(SessionLoader):
1615 """
1616 Spike sorting enhanced version of SessionLoader
1617 Loads spike sorting data for all probes in the session, in the self.ephys dict
1618 >>> EphysSessionLoader(eid=eid, one=one)
1619 To select for a specific probe
1620 >>> EphysSessionLoader(eid=eid, one=one, pid=pid)
1621 """
1622 def __init__(self, *args, pname=None, pid=None, **kwargs):
1623 """
1624 Needs an active connection in order to get the list of insertions in the session
1625 :param args:
1626 :param kwargs:
1627 """
1628 super().__init__(*args, **kwargs)
1629 # if necessary, restrict the query
1630 qargs = {} if pname is None else {'name': pname}
1631 qargs = qargs or ({} if pid is None else {'id': pid})
1632 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs)
1633 self.ephys = {}
1634 for ins in insertions:
1635 self.ephys[ins['name']] = {}
1636 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one)
1638 def load_session_data(self, *args, **kwargs):
1639 super().load_session_data(*args, **kwargs)
1640 self.load_spike_sorting()
1642 def load_spike_sorting(self, pnames=None):
1643 pnames = pnames or list(self.ephys.keys())
1644 for pname in pnames:
1645 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting()
1646 self.ephys[pname]['spikes'] = spikes
1647 self.ephys[pname]['clusters'] = clusters
1648 self.ephys[pname]['channels'] = channels
1650 @property
1651 def probes(self):
1652 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}