Coverage for brainbox/io/one.py: 57%
772 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-02 18:55 +0100
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-02 18:55 +0100
1"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
2from dataclasses import dataclass, field
3import gc
4import logging
5import re
6import os
7from pathlib import Path
8from collections import defaultdict
10import numpy as np
11import pandas as pd
12from scipy.interpolate import interp1d
13import matplotlib.pyplot as plt
15from one.api import ONE, One
16from one.alf.path import get_alf_path, full_path_parts, filename_parts
17from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound
18from one.alf import cache
19import one.alf.io as alfio
20from neuropixel import TIP_SIZE_UM, trace_header
21import spikeglx
23import ibldsp.voltage
24from ibldsp.waveform_extraction import WaveformsLoader
25from iblutil.util import Bunch
26from iblatlas.atlas import AllenAtlas, BrainRegions
27from iblatlas import atlas
28from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
29from ibllib.pipes import histology
30from ibllib.pipes.ephys_alignment import EphysAlignment
31from ibllib.plots import vertical_lines, Density
33import brainbox.plot
34from brainbox.io.spikeglx import Streamer
35from brainbox.ephys_plots import plot_brain_regions
36from brainbox.metrics.single_units import quick_unit_metrics
37from brainbox.behavior.wheel import interpolate_position, velocity_filtered
38from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter
40_logger = logging.getLogger('ibllib')
43SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
44CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']
45WAVEFORMS_ATTRIBUTES = ['templates']
48def load_lfp(eid, one=None, dataset_types=None, **kwargs):
49 """
50 TODO Verify works
51 From an eid, hits the Alyx database and downloads the standard set of datasets
52 needed for LFP
53 :param eid:
54 :param dataset_types: additional dataset types to add to the list
55 :param open: if True, spikeglx readers are opened
56 :return: spikeglx.Reader
57 """
58 if dataset_types is None:
59 dataset_types = []
60 dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*']
61 [one.load_dataset(eid, dset, download_only=True) for dset in dtypes]
62 session_path = one.eid2path(eid)
64 efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False)
65 if ef.get('lf', None)]
66 return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles]
69def _collection_filter_from_args(probe, spike_sorter=None):
70 collection = f'alf/{probe}/{spike_sorter}' 1g
71 collection = collection.replace('None', '*') 1g
72 collection = collection.replace('/*', '*') 1g
73 collection = collection[:-1] if collection.endswith('/') else collection 1g
74 return collection 1g
77def _get_spike_sorting_collection(collections, pname):
78 """
79 Filters a list or array of collections to get the relevant spike sorting dataset
80 if there is a pykilosort, load it
81 """
82 #
83 collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) 1ga
84 # otherwise, prefers the shortest
85 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) 1ga
86 _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") 1ga
87 return collection 1ga
90def _channels_alyx2bunch(chans):
91 channels = Bunch({
92 'atlas_id': np.array([ch['brain_region'] for ch in chans]),
93 'x': np.array([ch['x'] for ch in chans]) / 1e6,
94 'y': np.array([ch['y'] for ch in chans]) / 1e6,
95 'z': np.array([ch['z'] for ch in chans]) / 1e6,
96 'axial_um': np.array([ch['axial'] for ch in chans]),
97 'lateral_um': np.array([ch['lateral'] for ch in chans])
98 })
99 return channels
102def _channels_traj2bunch(xyz_chans, brain_atlas):
103 brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans))
104 channels = {
105 'x': xyz_chans[:, 0],
106 'y': xyz_chans[:, 1],
107 'z': xyz_chans[:, 2],
108 'acronym': brain_regions['acronym'],
109 'atlas_id': brain_regions['id']
110 }
112 return channels
115def _channels_bunch2alf(channels):
116 channels_ = { 1i
117 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6,
118 'brainLocationIds_ccf_2017': channels['atlas_id'],
119 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]}
120 return channels_ 1i
123def _channels_alf2bunch(channels, brain_regions=None):
124 # reformat the dictionary according to the standard that comes out of Alyx
125 channels_ = { 1icaed
126 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6,
127 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6,
128 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6,
129 'acronym': None,
130 'atlas_id': channels['brainLocationIds_ccf_2017'],
131 'axial_um': channels['localCoordinates'][:, 1],
132 'lateral_um': channels['localCoordinates'][:, 0],
133 }
134 # here if we have some extra keys, they will carry over to the next dictionary
135 for k in channels: 1icaed
136 if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: 1icaed
137 channels_[k] = channels[k] 1icaed
138 if brain_regions: 1icaed
139 channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] 1icaed
140 return channels_ 1icaed
143def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None,
144 brain_regions=None):
145 """
146 Generic function to load spike sorting according data using ONE.
148 Will try to load one spike sorting for any probe present for the eid matching the collection
149 For each probe it will load a spike sorting:
150 - if there is one version: loads this one
151 - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX)
153 Parameters
154 ----------
155 eid : [str, UUID, Path, dict]
156 Experiment session identifier; may be a UUID, URL, experiment reference string
157 details dict or Path
158 one : one.api.OneAlyx
159 An instance of ONE (may be in 'local' mode)
160 collection : str
161 collection filter word - accepts wildcards - can be a combination of spike sorter and
162 probe. See `ALF documentation`_ for details.
163 revision : str
164 A particular revision return (defaults to latest revision). See `ALF documentation`_ for
165 details.
166 return_channels : bool
167 Defaults to False otherwise loads channels from disk
169 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components
171 Returns
172 -------
173 spikes : dict of one.alf.io.AlfBunch
174 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
175 session and spike sorter, with keys ('clusters', 'times')
176 clusters : dict of one.alf.io.AlfBunch
177 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
178 ('channels', 'depths', 'metrics')
179 channels : dict of one.alf.io.AlfBunch
180 A dict with probe labels as keys, contains channel locations with keys ('acronym',
181 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs
182 non-lateralized.
183 """
184 one = one or ONE() 1ga
185 # enumerate probes and load according to the name
186 collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) 1ga
187 if len(collections) == 0: 1ga
188 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
189 pnames = list(set(c.split('/')[1] for c in collections)) 1ga
190 spikes, clusters, channels = ({} for _ in range(3)) 1ga
192 spike_attributes, cluster_attributes = _get_attributes(dataset_types) 1ga
194 for pname in pnames: 1ga
195 probe_collection = _get_spike_sorting_collection(collections, pname) 1ga
196 spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', 1ga
197 attribute=spike_attributes, namespace='')
198 clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', 1ga
199 attribute=cluster_attributes, namespace='')
200 if return_channels: 1ga
201 channels = _load_channels_locations_from_disk( 1a
202 eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions)
203 return spikes, clusters, channels 1a
204 else:
205 return spikes, clusters 1g
208def _get_attributes(dataset_types):
209 if dataset_types is None: 1ga
210 return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES 1ga
211 else:
212 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
213 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
214 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
215 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
216 return spike_attributes, cluster_attributes
219def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None):
220 _logger.debug('loading spike sorting from disk') 1a
221 channels = Bunch({}) 1a
222 collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) 1a
223 if len(collections) == 0: 1a
224 _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}")
225 probes = list(set([c.split('/')[1] for c in collections])) 1a
226 for probe in probes: 1a
227 probe_collection = _get_spike_sorting_collection(collections, probe) 1a
228 channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') 1a
229 # if the spike sorter has not aligned data, try and get the alignment available
230 if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): 1a
231 aligned_channel_collections = one.list_collections(
232 eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision)
233 if len(aligned_channel_collections) == 0:
234 _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}")
235 continue
236 _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}")
237 ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe)
238 channels_aligned = one.load_object(eid, 'channels', collection=ac_collection)
239 channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe])
240 # only have to reformat channels if we were able to load coordinates from disk
241 channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) 1a
242 return channels 1a
245def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None):
246 """
247 oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
248 if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
249 so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
250 :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys
251 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017'
252 OR
253 'x', 'y', 'z', 'acronym', 'axial_um'
254 those are the guide for the interpolation
255 :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates'
256 :param brain_regions: None (default) or iblatlas.regions.BrainRegions object
257 if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017
258 if a brain region object is provided, outputts a dict with keys
259 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um'
260 :return: Bunch or dictionary of channels with brain coordinates keys
261 """
262 NEUROPIXEL_VERSION = 1 1i
263 h = trace_header(version=NEUROPIXEL_VERSION) 1i
264 if channels is None: 1i
265 channels = {'localCoordinates': np.c_[h['x'], h['y']]}
266 nch = channels['localCoordinates'].shape[0] 1i
267 if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): 1i
268 channels_aligned = _channels_bunch2alf(channels_aligned) 1i
269 if 'localCoordinates' in channels_aligned.keys(): 1i
270 aligned_depths = channels_aligned['localCoordinates'][:, 1] 1i
271 else: # this is a edge case for a few spike sorting sessions
272 assert channels_aligned['mlapdv'].shape[0] == 384
273 aligned_depths = h['y']
274 depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) 1i
275 depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) 1i
276 channels['mlapdv'] = np.zeros((nch, 3)) 1i
277 for i in np.arange(3): 1i
278 channels['mlapdv'][:, i] = np.interp( 1i
279 depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv]
280 # the brain locations have to be interpolated by nearest neighbour
281 fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') 1i
282 channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) 1i
283 if brain_regions is not None: 1i
284 return _channels_alf2bunch(channels, brain_regions=brain_regions) 1i
285 else:
286 return channels 1i
289def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
290 brain_atlas=None, return_source=False):
291 if not hasattr(one, 'alyx'): 1a
292 return {}, None 1a
293 _logger.debug(f"trying to load from traj {probe}")
294 channels = Bunch()
295 brain_atlas = brain_atlas or AllenAtlas
296 # need to find the collection bruh
297 insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0]
298 collection = _collection_filter_from_args(probe=probe)
299 collections = one.list_collections(eid, filename='channels*', collection=collection,
300 revision=revision)
301 probe_collection = _get_spike_sorting_collection(collections, probe)
302 chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection)
303 depths = chn_coords[:, 1]
305 tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
306 get('tracing_exists', False)
307 resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
308 get('alignment_resolved', False)
309 counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
310 get('alignment_count', 0)
312 if tracing:
313 xyz = np.array(insertion['json']['xyz_picks']) / 1e6
314 if resolved:
316 _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. '
317 f'Channel and cluster locations obtained from ephys aligned histology '
318 f'track.')
319 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
320 provenance='Ephys aligned histology track')[0]
321 align_key = insertion['json']['extended_qc']['alignment_stored']
322 feature = traj['json'][align_key][0]
323 track = traj['json'][align_key][1]
324 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
325 feature_prev=feature,
326 brain_atlas=brain_atlas, speedy=True)
327 chans = ephysalign.get_channel_locations(feature, track)
328 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
329 source = 'resolved'
330 elif counts > 0 and aligned:
331 _logger.debug(f'Channel locations for {eid}/{probe} have not been '
332 f'resolved. However, alignment flag set to True so channel and cluster'
333 f' locations will be obtained from latest available ephys aligned '
334 f'histology track.')
335 # get the latest user aligned channels
336 traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
337 provenance='Ephys aligned histology track')[0]
338 align_key = insertion['json']['extended_qc']['alignment_stored']
339 feature = traj['json'][align_key][0]
340 track = traj['json'][align_key][1]
341 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
342 feature_prev=feature,
343 brain_atlas=brain_atlas, speedy=True)
344 chans = ephysalign.get_channel_locations(feature, track)
346 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
347 source = 'aligned'
348 else:
349 _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. '
350 f'Channel and cluster locations obtained from histology track.')
351 # get the channels from histology tracing
352 xyz = xyz[np.argsort(xyz[:, 2]), :]
353 chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6)
354 channels[probe] = _channels_traj2bunch(chans, brain_atlas)
355 source = 'traced'
356 channels[probe]['axial_um'] = chn_coords[:, 1]
357 channels[probe]['lateral_um'] = chn_coords[:, 0]
359 else:
360 _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}')
361 source = ''
362 channels = None
364 if return_source:
365 return channels, source
366 else:
367 return channels
370def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None):
371 """
372 Load the brain locations of each channel for a given session/probe
374 Parameters
375 ----------
376 eid : [str, UUID, Path, dict]
377 Experiment session identifier; may be a UUID, URL, experiment reference string
378 details dict or Path
379 probe : [str, list of str]
380 The probe label(s), e.g. 'probe01'
381 one : one.api.OneAlyx
382 An instance of ONE (shouldn't be in 'local' mode)
383 aligned : bool
384 Whether to get the latest user aligned channel when not resolved or use histology track
385 brain_atlas : iblatlas.BrainAtlas
386 Brain atlas object (default: Allen atlas)
387 Returns
388 -------
389 dict of one.alf.io.AlfBunch
390 A dict with probe labels as keys, contains channel locations with keys ('acronym',
391 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
392 optional: string 'resolved', 'aligned', 'traced' or ''
393 """
394 one = one or ONE()
395 brain_atlas = brain_atlas or AllenAtlas()
396 if isinstance(eid, dict):
397 ses = eid
398 eid = ses['url'][-36:]
399 else:
400 eid = one.to_eid(eid)
401 collection = _collection_filter_from_args(probe=probe)
402 channels = _load_channels_locations_from_disk(eid, one=one, collection=collection,
403 brain_regions=brain_atlas.regions)
404 incomplete_probes = [k for k in channels if 'x' not in channels[k]]
405 for iprobe in incomplete_probes:
406 channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned,
407 brain_atlas=brain_atlas, return_source=True)
408 if channels_ is not None:
409 channels[iprobe] = channels_[iprobe]
410 return channels
413def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
414 brain_regions=None, nested=True, collection=None, return_collection=False):
415 """
416 From an eid, loads spikes and clusters for all probes
417 The following set of dataset types are loaded:
418 'clusters.channels',
419 'clusters.depths',
420 'clusters.metrics',
421 'spikes.clusters',
422 'spikes.times',
423 'probes.description'
424 :param eid: experiment UUID or pathlib.Path of the local session
425 :param one: an instance of OneAlyx
426 :param probe: name of probe to load in, if not given all probes for session will be loaded
427 :param dataset_types: additional spikes/clusters objects to add to the standard default list
428 :param spike_sorter: name of the spike sorting you want to load (None for default)
429 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00"
430 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
431 :param nested: if a single probe is required, do not output a dictionary with the probe name as key
432 :param return_collection: (False) if True, will return the collection used to load
433 :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
434 """
435 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.'
436 'Use brainbox.io.one.SpikeSortingLoader instead')
437 if collection is None:
438 collection = _collection_filter_from_args(probe, spike_sorter)
439 _logger.debug(f"load spike sorting with collection filter {collection}")
440 kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types,
441 brain_regions=brain_regions)
442 spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True)
443 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
444 if nested is False and len(spikes.keys()) == 1:
445 k = list(spikes.keys())[0]
446 channels = channels[k]
447 clusters = clusters[k]
448 spikes = spikes[k]
449 if return_collection:
450 return spikes, clusters, channels, collection
451 else:
452 return spikes, clusters, channels
455def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
456 brain_regions=None, return_collection=False):
457 """
458 From an eid, loads spikes and clusters for all probes
459 The following set of dataset types are loaded:
460 'clusters.channels',
461 'clusters.depths',
462 'clusters.metrics',
463 'spikes.clusters',
464 'spikes.times',
465 'probes.description'
466 :param eid: experiment UUID or pathlib.Path of the local session
467 :param one: an instance of OneAlyx
468 :param probe: name of probe to load in, if not given all probes for session will be loaded
469 :param dataset_types: additional spikes/clusters objects to add to the standard default list
470 :param spike_sorter: name of the spike sorting you want to load (None for default)
471 :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
472 :param return_collection:(bool - False) if True, returns the collection for loading the data
473 :return: spikes, clusters (dict of bunch, 1 bunch per probe)
474 """
475 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 1g
476 'Use brainbox.io.one.SpikeSortingLoader instead')
477 collection = _collection_filter_from_args(probe, spike_sorter) 1g
478 _logger.debug(f"load spike sorting with collection filter {collection}") 1g
479 spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, 1g
480 return_channels=False, dataset_types=dataset_types,
481 brain_regions=brain_regions)
482 if return_collection: 1g
483 return spikes, clusters, collection
484 else:
485 return spikes, clusters 1g
488def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None,
489 spike_sorter=None, brain_atlas=None, nested=True, return_collection=False):
490 """
491 For a given eid, get spikes, clusters and channels information, and merges clusters
492 and channels information before returning all three variables.
494 Parameters
495 ----------
496 eid : [str, UUID, Path, dict]
497 Experiment session identifier; may be a UUID, URL, experiment reference string
498 details dict or Path
499 one : one.api.OneAlyx
500 An instance of ONE (shouldn't be in 'local' mode)
501 probe : [str, list of str]
502 The probe label(s), e.g. 'probe01'
503 aligned : bool
504 Whether to get the latest user aligned channel when not resolved or use histology track
505 dataset_types : list of str
506 Optional additional spikes/clusters objects to add to the standard default list
507 spike_sorter : str
508 Name of the spike sorting you want to load (None for default which is pykilosort if it's
509 available otherwise the default MATLAB kilosort)
510 brain_atlas : iblatlas.atlas.BrainAtlas
511 Brain atlas object (default: Allen atlas)
512 return_collection: bool
513 Returns an extra argument with the collection chosen
515 Returns
516 -------
517 spikes : dict of one.alf.io.AlfBunch
518 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
519 session and spike sorter, with keys ('clusters', 'times')
520 clusters : dict of one.alf.io.AlfBunch
521 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
522 ('channels', 'depths', 'metrics')
523 channels : dict of one.alf.io.AlfBunch
524 A dict with probe labels as keys, contains channel locations with keys ('acronym',
525 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
526 """
527 # --- Get spikes and clusters data
528 _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
529 'Use brainbox.io.one.SpikeSortingLoader instead')
530 one = one or ONE()
531 brain_atlas = brain_atlas or AllenAtlas()
532 spikes, clusters, collection = load_spike_sorting(
533 eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True)
534 # -- Get brain regions and assign to clusters
535 channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned,
536 brain_atlas=brain_atlas)
537 clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
538 if nested is False and len(spikes.keys()) == 1:
539 k = list(spikes.keys())[0]
540 channels = channels[k]
541 clusters = clusters[k]
542 spikes = spikes[k]
543 if return_collection:
544 return spikes, clusters, channels, collection
545 else:
546 return spikes, clusters, channels
549def load_ephys_session(eid, one=None):
550 """
551 From an eid, hits the Alyx database and downloads a standard default set of dataset types
552 From a local session Path (pathlib.Path), loads a standard default set of dataset types
553 to perform analysis:
554 'clusters.channels',
555 'clusters.depths',
556 'clusters.metrics',
557 'spikes.clusters',
558 'spikes.times',
559 'probes.description'
561 Parameters
562 ----------
563 eid : [str, UUID, Path, dict]
564 Experiment session identifier; may be a UUID, URL, experiment reference string
565 details dict or Path
566 one : oneibl.one.OneAlyx, optional
567 ONE object to use for loading. Will generate internal one if not used, by default None
569 Returns
570 -------
571 spikes : dict of one.alf.io.AlfBunch
572 A dict with probe labels as keys, contains bunch(es) of spike data for the provided
573 session and spike sorter, with keys ('clusters', 'times')
574 clusters : dict of one.alf.io.AlfBunch
575 A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
576 ('channels', 'depths', 'metrics')
577 trials : one.alf.io.AlfBunch of numpy.ndarray
578 The session trials data
579 """
580 assert one 1g
581 spikes, clusters = load_spike_sorting(eid, one=one) 1g
582 trials = one.load_object(eid, 'trials') 1g
583 return spikes, clusters, trials 1g
586def _remove_old_clusters(session_path, probe):
587 # gets clusters and spikes from a local session folder
588 probe_path = session_path.joinpath('alf', probe)
590 # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead
591 cluster_file = probe_path.joinpath('clusters.metrics.csv')
593 if cluster_file.exists():
594 os.remove(cluster_file)
595 _logger.info('Deleting old clusters.metrics.csv file')
598def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
599 """
600 Takes (default and any extra) values in given keys from channels and assign them to clusters.
601 If channels does not contain any data, the new keys are added to clusters but left empty.
603 Parameters
604 ----------
605 dic_clus : dict of one.alf.io.AlfBunch
606 1 bunch per probe, containing cluster information
607 channels : dict of one.alf.io.AlfBunch
608 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates')
609 keys_to_add_extra : list of str
610 Any extra keys to load into channels bunches
612 Returns
613 -------
614 dict of one.alf.io.AlfBunch
615 clusters (1 bunch per probe) with new keys values.
616 """
617 probe_labels = list(channels.keys()) # Convert dict_keys into list
618 keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um']
620 if keys_to_add_extra is None:
621 keys_to_add = keys_to_add_default
622 else:
623 # Append extra optional keys
624 keys_to_add = list(set(keys_to_add_extra + keys_to_add_default))
626 for label in probe_labels:
627 clu_ch = dic_clus[label]['channels']
628 for key in keys_to_add:
629 try:
630 assert key in channels[label].keys() # Check key is in channels
631 ch_key = channels[label][key]
632 nch_key = len(ch_key) if ch_key is not None else 0
633 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index
634 dic_clus[label][key] = ch_key[clu_ch]
635 else:
636 _logger.warning(
637 f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels'
638 f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.')
639 dic_clus[label][key] = []
640 except AssertionError:
641 _logger.warning(f'Either clusters or channels does not have key {key}, could not merge')
642 continue
644 return dic_clus
647def load_passive_rfmap(eid, one=None):
648 """
649 For a given eid load in the passive receptive field mapping protocol data
651 Parameters
652 ----------
653 eid : [str, UUID, Path, dict]
654 Experiment session identifier; may be a UUID, URL, experiment reference string
655 details dict or Path
656 one : oneibl.one.OneAlyx, optional
657 An instance of ONE (may be in 'local' - offline - mode)
659 Returns
660 -------
661 one.alf.io.AlfBunch
662 Passive receptive field mapping data
663 """
664 one = one or ONE()
666 # Load in the receptive field mapping data
667 rf_map = one.load_object(eid, obj='passiveRFM', collection='alf')
668 frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin',
669 collection='raw_passive_data'), dtype="uint8")
670 y_pix, x_pix = 15, 15
671 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0])
672 rf_map['frames'] = frames
674 return rf_map
677def load_wheel_reaction_times(eid, one=None):
678 """
679 Return the calculated reaction times for session. Reaction times are defined as the time
680 between the go cue (onset tone) and the onset of the first substantial wheel movement. A
681 movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the
682 distance to threshold (~0.1 radians).
684 Negative times mean the onset of the movement occurred before the go cue. Nans may occur if
685 there was no detected movement withing the period, or when the goCue_times or feedback_times
686 are nan.
688 Parameters
689 ----------
690 eid : [str, UUID, Path, dict]
691 Experiment session identifier; may be a UUID, URL, experiment reference string
692 details dict or Path
693 one : one.api.OneAlyx, optional
694 one object to use for loading. Will generate internal one if not used, by default None
696 Returns
697 ----------
698 array-like
699 reaction times
700 """
701 if one is None:
702 one = ONE()
704 trials = one.load_object(eid, 'trials')
705 # If already extracted, load and return
706 if trials and 'firstMovement_times' in trials:
707 return trials['firstMovement_times'] - trials['goCue_times']
708 # Otherwise load wheelMoves object and calculate
709 moves = one.load_object(eid, 'wheelMoves')
710 # Re-extract wheel moves if necessary
711 if not moves or 'peakAmplitude' not in moves:
712 wheel = one.load_object(eid, 'wheel')
713 moves = extract_wheel_moves(wheel['timestamps'], wheel['position'])
714 assert trials and moves, 'unable to load trials and wheelMoves data'
715 firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials)
716 return firstMove_times - trials['goCue_times']
719def load_iti(trials):
720 """
721 The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey
722 screen commencing at stimulus off and lasting until the quiescent period at the start of the
723 following trial. Note that the ITI for the first trial is the time between the first trial
724 and the next, therefore the last value is NaN.
726 Parameters
727 ----------
728 trials : one.alf.io.AlfBunch
729 An ALF trials object containing the keys {'intervals', 'stimOff_times'}.
731 Returns
732 -------
733 np.array
734 An array of inter-trial intervals, the last value being NaN.
735 """
736 if not {'intervals', 'stimOff_times'} <= set(trials.keys()): 1o
737 raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') 1o
738 return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan] 1o
741def load_channels_from_insertion(ins, depths=None, one=None, ba=None):
743 PROV_2_VAL = {
744 'Resolved': 90,
745 'Ephys aligned histology track': 70,
746 'Histology track': 50,
747 'Micro-manipulator': 30,
748 'Planned': 10}
750 one = one or ONE()
751 ba = ba or atlas.AllenAtlas()
752 traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id'])
753 val = [PROV_2_VAL[tr['provenance']] for tr in traj]
754 idx = np.argmax(val)
755 traj = traj[idx]
756 if depths is None:
757 depths = trace_header(version=1)[:, 1]
758 if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator':
759 ins = atlas.Insertion.from_dict(traj)
760 # Deepest coordinate first
761 xyz = np.c_[ins.tip, ins.entry].T
762 xyz_channels = histology.interpolate_along_track(xyz, (depths +
763 TIP_SIZE_UM) / 1e6)
764 else:
765 xyz = np.array(ins['json']['xyz_picks']) / 1e6
766 if traj['provenance'] == 'Histology track':
767 xyz = xyz[np.argsort(xyz[:, 2]), :]
768 xyz_channels = histology.interpolate_along_track(xyz, (depths +
769 TIP_SIZE_UM) / 1e6)
770 else:
771 align_key = ins['json']['extended_qc']['alignment_stored']
772 feature = traj['json'][align_key][0]
773 track = traj['json'][align_key][1]
774 ephysalign = EphysAlignment(xyz, depths, track_prev=track,
775 feature_prev=feature,
776 brain_atlas=ba, speedy=True)
777 xyz_channels = ephysalign.get_channel_locations(feature, track)
778 return xyz_channels
781@dataclass
782class SpikeSortingLoader:
783 """
784 Object that will load spike sorting data for a given probe insertion.
785 This class can be instantiated in several manners
786 - With Alyx database probe id:
787 SpikeSortingLoader(pid=pid, one=one)
788 - With Alyx database eic and probe name:
789 SpikeSortingLoader(eid=eid, pname='probe00', one=one)
790 - From a local session and probe name:
791 SpikeSortingLoader(session_path=session_path, pname='probe00')
792 NB: When no ONE instance is passed, any datasets that are loaded will not be recorded.
793 """
794 one: One = None
795 atlas: None = None
796 pid: str = None
797 eid: str = ''
798 pname: str = ''
799 session_path: Path = ''
800 # the following properties are the outcome of the post init function
801 collections: list = None
802 datasets: list = None # list of all datasets belonging to the session
803 # the following properties are the outcome of a reading function
804 files: dict = None
805 raw_data_files: list = None # list of raw ap and lf files corresponding to the recording
806 collection: str = ''
807 histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced'
808 spike_sorter: str = 'pykilosort'
809 spike_sorting_path: Path = None
810 _sync: dict = None
811 revision: str = None
813 def __post_init__(self):
814 # pid gets precedence
815 if self.pid is not None: 1caedk
816 try: 1dk
817 self.eid, self.pname = self.one.pid2eid(self.pid) 1dk
818 except NotImplementedError:
819 if self.eid == '' or self.pname == '':
820 raise IOError("Cannot infer session id and probe name from pid. "
821 "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.")
822 self.session_path = self.one.eid2path(self.eid) 1dk
823 # then eid / pname combination
824 elif self.session_path is None or self.session_path == '': 1cae
825 self.session_path = self.one.eid2path(self.eid) 1cae
826 # fully local providing a session path
827 else:
828 if self.one:
829 self.eid = self.one.to_eid(self.session_path)
830 else:
831 self.one = One(cache_dir=self.session_path.parents[2], mode='local')
832 df_sessions = cache._make_sessions_df(self.session_path)
833 self.one._cache['sessions'] = df_sessions.set_index('id')
834 self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False)
835 self.eid = str(self.session_path.relative_to(self.session_path.parents[2]))
836 # populates default properties
837 self.collections = self.one.list_collections( 1caedk
838 self.eid, filename='spikes*', collection=f"alf/{self.pname}*")
839 self.datasets = self.one.list_datasets(self.eid) 1caedk
840 if self.atlas is None: 1caedk
841 self.atlas = AllenAtlas() 1caek
842 self.files = {} 1caedk
843 self.raw_data_files = [] 1caedk
845 def _load_object(self, *args, **kwargs):
846 """
847 This function is a wrapper around alfio.load_object that will remove the UUID in the
848 filename if the object is on SDSC.
849 """
850 remove_uuids = getattr(self.one, 'uuid_filenames', False) 1caed
851 d = alfio.load_object(*args, **kwargs) 1caed
852 if remove_uuids: 1caed
853 # pops the UUID in the key names
854 keys = list(d.keys())
855 for k in keys:
856 d[k[:-37]] = d.pop(k)
857 return d 1caed
859 @staticmethod
860 def _get_attributes(dataset_types):
861 """returns attributes to load for spikes and clusters objects"""
862 dataset_types = [] if dataset_types is None else dataset_types 1caed
863 spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] 1caed
864 spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) 1caed
865 cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] 1caed
866 cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) 1caed
867 waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] 1caed
868 waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) 1caed
869 return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} 1caed
871 def _get_spike_sorting_collection(self, spike_sorter=None):
872 """
873 Filters a list or array of collections to get the relevant spike sorting dataset
874 if there is a pykilosort, load it
875 """
876 for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']): 1caed
877 if sorter is None: 1caed
878 continue
879 if sorter == "": 1caed
880 collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None) 1cae
881 else:
882 collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None) 1ad
883 if collection is not None: 1caed
884 return collection 1caed
885 # if none is found amongst the defaults, prefers the shortest
886 collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None)
887 _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}")
888 return collection
890 def load_spike_sorting_object(self, obj, *args, revision=None, **kwargs):
891 """
892 Loads an ALF object
893 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
894 :param spike_sorter: (defaults to 'pykilosort')
895 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
896 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
897 :param kwargs: additional arguments to be passed to one.api.One.load_object
898 :param missing: 'raise' (default) or 'ignore'
899 :param revision: the dataset revision to load
900 :return:
901 """
902 revision = revision if revision is not None else self.revision
903 self.download_spike_sorting_object(obj, *args, revision=revision, **kwargs)
904 return self._load_object(self.files[obj])
906 def get_version(self, spike_sorter=None):
907 spike_sorter = (spike_sorter or self.spike_sorter) or 'iblsorter'
908 collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
909 dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy')
910 return dset[0]['version'] if len(dset) else 'unknown'
912 def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=None, collection=None,
913 attribute=None, missing='raise', revision=None, **kwargs):
914 """
915 Downloads an ALF object
916 :param obj: object name, str between 'spikes', 'clusters' or 'channels'
917 :param spike_sorter: (defaults to 'pykilosort')
918 :param dataset_types: list of extra dataset types, for example ['spikes.samples']
919 :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
920 :param kwargs: additional arguments to be passed to one.api.One.load_object
921 :param attribute: list of attributes to load for the object
922 :param missing: 'raise' (default) or 'ignore'
923 :param revision: the dataset revision to load
924 :return:
925 """
926 revision = revision if revision is not None else self.revision 1caed
927 if spike_sorter is None: 1caed
928 spike_sorter = self.spike_sorter if self.spike_sorter is not None else 'iblsorter' 1caed
929 if len(self.collections) == 0: 1caed
930 return {}, {}, {}
931 self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) 1caed
932 collection = collection or self.collection 1caed
933 _logger.debug(f"loading spike sorting object {obj} from {collection}") 1caed
934 attributes = self._get_attributes(dataset_types) 1caed
935 try: 1caed
936 self.files[obj] = self.one.load_object( 1caed
937 self.eid, obj=obj, attribute=attributes.get(obj, None),
938 collection=collection, download_only=True, revision=revision, **kwargs)
939 except ALFObjectNotFound as e: 1cae
940 if missing == 'raise': 1cae
941 raise e
943 def download_spike_sorting(self, objects=None, **kwargs):
944 """
945 Downloads spikes, clusters and channels
946 :param spike_sorter: (defaults to 'pykilosort')
947 :param dataset_types: list of extra dataset types
948 :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels']
949 :return:
950 """
951 objects = ['spikes', 'clusters', 'channels'] if objects is None else objects 1caed
952 for obj in objects: 1caed
953 self.download_spike_sorting_object(obj=obj, **kwargs) 1caed
954 self.spike_sorting_path = self.files['clusters'][0].parent 1caed
956 def download_raw_electrophysiology(self, band='ap'):
957 """
958 Downloads raw electrophysiology data files on local disk.
959 :param band: "ap" (default) or "lf" for LFP band
960 :return: list of raw data files full paths (ch, meta and cbin files)
961 """
962 raw_data_files = [] 1a
963 for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']: 1a
964 try: 1a
965 # FIXME: this will fail if multiple LFP segments are found
966 raw_data_files.append(self.one.load_dataset( 1a
967 self.eid,
968 download_only=True,
969 collection=f'raw_ephys_data/{self.pname}',
970 dataset=suffix,
971 check_hash=False,
972 ))
973 except ALFObjectNotFound: 1a
974 _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}") 1a
975 self.raw_data_files = list(set(self.raw_data_files + raw_data_files)) 1a
976 return raw_data_files 1a
978 def raw_electrophysiology(self, stream=True, band='ap', **kwargs):
979 """
980 Returns a reader for the raw electrophysiology data
981 By default it is a streamer object, but if stream is False, it will return a spikeglx.Reader after having
982 downloaded the raw data file if necessary
983 :param stream:
984 :param band:
985 :param kwargs:
986 :return:
987 """
988 if stream: 1k
989 return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) 1k
990 else:
991 raw_data_files = self.download_raw_electrophysiology(band=band)
992 cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None)
993 if cbin_file is not None:
994 return spikeglx.Reader(cbin_file)
996 def download_raw_waveforms(self, **kwargs):
997 """
998 Downloads raw waveforms extracted from sorting to local disk.
999 """
1000 _logger.debug(f"loading waveforms from {self.collection}")
1001 return self.one.load_object(
1002 id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"],
1003 collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs
1004 )
1006 def raw_waveforms(self, **kwargs):
1007 wf_paths = self.download_raw_waveforms(**kwargs)
1008 return WaveformsLoader(wf_paths[0].parent)
1010 def load_channels(self, **kwargs):
1011 """
1012 Loads channels
1013 The channel locations can come from several sources, it will load the most advanced version of the histology available,
1014 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
1015 - alf: the final version of channel locations, same as resolved with the difference that data is on file
1016 - resolved: channel locations alignments have been agreed upon
1017 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
1018 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
1020 :param spike_sorter: (defaults to 'pykilosort')
1021 :param dataset_types: list of extra dataset types
1022 :return:
1023 """
1024 # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
1025 self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') 1caed
1026 self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) 1caed
1027 channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) 1caed
1028 if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails 1caed
1029 channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) 1d
1030 if alfio.check_dimensions(channels) != 0: 1d
1031 channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
1032 channels['rawInd'] = np.arange(channels[list(channels.keys())[0]].shape[0])
1033 if 'brainLocationIds_ccf_2017' not in channels: 1caed
1034 _logger.debug(f"loading channels from alyx for {self.files['channels']}") 1a
1035 _channels, self.histology = _load_channel_locations_traj( 1a
1036 self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True)
1037 if _channels: 1a
1038 channels = _channels[self.pname]
1039 else:
1040 channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) 1caed
1041 self.histology = 'alf' 1caed
1042 return Bunch(channels) 1caed
1044 @staticmethod
1045 def filter_files_by_namespace(all_files, namespace):
1047 # Create dict for each file with available namespaces, no namespce is stored under the key None
1048 namespace_files = defaultdict(dict) 1caed
1049 available_namespaces = [] 1caed
1050 for file in all_files: 1caed
1051 fparts = filename_parts(file.name, as_dict=True) 1caed
1052 fname = f"{fparts['object']}.{fparts['attribute']}" 1caed
1053 nspace = fparts['namespace'] 1caed
1054 available_namespaces.append(nspace) 1caed
1055 namespace_files[fname][nspace] = file 1caed
1057 if namespace not in set(available_namespaces): 1caed
1058 _logger.info(f'Could not find manual curation results for {namespace}, returning default' 1a
1059 f' non manually curated spikesorting data')
1061 # Return the files with the chosen namespace.
1062 files = [f.get(namespace, f.get(None, None)) for f in namespace_files.values()] 1caed
1063 # remove any None files
1064 files = [f for f in files if f] 1caed
1065 return files 1caed
1067 def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False,
1068 namespace=None, **kwargs):
1069 """
1070 Loads spikes, clusters and channels
1072 There could be several spike sorting collections, by default the loader will get the pykilosort collection
1074 The channel locations can come from several sources, it will load the most advanced version of the histology available,
1075 regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging):
1076 - alf: the final version of channel locations, same as resolved with the difference that data is on file
1077 - resolved: channel locations alignments have been agreed upon
1078 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
1079 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
1081 :param spike_sorter: (defaults to 'pykilosort')
1082 :param revision: for example "2024-05-06", (defaults to None):
1083 :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
1084 :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
1085 :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
1086 :param namespace: None, if given will load the manually curated spikesorting with the given namespace,
1087 e.g to load '_av_.clusters.depths use namespace='av'
1088 :param kwargs: additional arguments to be passed to one.api.One.load_object
1089 :return:
1090 """
1091 if len(self.collections) == 0: 1caed
1092 return {}, {}, {}
1093 self.files = {} 1caed
1094 self.spike_sorter = spike_sorter 1caed
1095 self.revision = revision 1caed
1097 if good_units and namespace is not None: 1caed
1098 _logger.info('Good units table does not exist for manually curated spike sorting. Pass in namespace with' 1a
1099 'good_units=False and filter the spikes post hoc by the good clusters.')
1100 return [None] * 3 1a
1101 objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None 1caed
1102 self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) 1caed
1103 channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs) 1caed
1104 self.files['clusters'] = self.filter_files_by_namespace(self.files['clusters'], namespace) 1caed
1105 clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) 1caed
1107 if good_units: 1caed
1108 spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards)
1109 else:
1110 self.files['spikes'] = self.filter_files_by_namespace(self.files['spikes'], namespace) 1caed
1111 spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) 1caed
1112 if enforce_version: 1caed
1113 self._assert_version_consistency()
1114 return spikes, clusters, channels 1caed
1116 def _assert_version_consistency(self):
1117 """
1118 Makes sure the state of the spike sorting object matches the files downloaded
1119 :return: None
1120 """
1121 for k in ['spikes', 'clusters', 'channels', 'passingSpikes']:
1122 for fn in self.files.get(k, []):
1123 if self.spike_sorter:
1124 assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \
1125 f"You required strict version {self.spike_sorter}, {fn} does not match"
1126 if self.revision:
1127 assert full_path_parts(fn)[5] == self.revision, \
1128 f"You required strict revision {self.revision}, {fn} does not match"
1130 @staticmethod
1131 def compute_metrics(spikes, clusters=None):
1132 nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size
1133 metrics = pd.DataFrame(quick_unit_metrics(
1134 spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc)))
1135 return metrics
1137 @staticmethod
1138 def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False):
1139 """
1140 Merge the metrics and the channel information into the clusters dictionary
1141 :param spikes:
1142 :param clusters:
1143 :param channels:
1144 :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used
1145 for clusters or analysis applications (defaults to None).
1146 :param compute_metrics: if True, will explicitly recompute metrics (defaults to false)
1147 :return: cluster dictionary containing metrics and histology
1148 """
1149 if spikes == {}: 1ad
1150 return
1151 nc = clusters['channels'].size 1ad
1152 # recompute metrics if they are not available
1153 metrics = None 1ad
1154 if 'metrics' in clusters: 1ad
1155 metrics = clusters.pop('metrics') 1ad
1156 if metrics.shape[0] != nc: 1ad
1157 metrics = None
1158 if metrics is None or compute_metrics is True: 1ad
1159 _logger.debug("recompute clusters metrics")
1160 metrics = SpikeSortingLoader.compute_metrics(spikes, clusters)
1161 if isinstance(cache_dir, Path):
1162 metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt'))
1163 for k in metrics.keys(): 1ad
1164 clusters[k] = metrics[k].to_numpy() 1ad
1165 for k in channels.keys(): 1ad
1166 clusters[k] = channels[k][clusters['channels']] 1ad
1167 if cache_dir is not None: 1ad
1168 _logger.debug(f'caching clusters metrics in {cache_dir}')
1169 pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt'))
1170 return clusters 1ad
1172 @property
1173 def url(self):
1174 """Gets flatiron URL for the session"""
1175 webclient = getattr(self.one, '_web_client', None)
1176 return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None
1178 def _get_probe_info(self, revision=None):
1179 revision = revision if revision is not None else self.revision 1e
1180 if self._sync is None: 1e
1181 timestamps = self.one.load_dataset( 1e
1182 self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision)
1183 _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks 1e
1184 self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision)
1185 try: 1e
1186 ap_meta = spikeglx.read_meta_data(self.one.load_dataset( 1e
1187 self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}'))
1188 fs = spikeglx._get_fs_from_meta(ap_meta) 1e
1189 except ALFObjectNotFound:
1190 ap_meta = None
1191 fs = 30_000
1192 self._sync = { 1e
1193 'timestamps': timestamps,
1194 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'),
1195 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'),
1196 'ap_meta': ap_meta,
1197 'fs': fs,
1198 }
1200 def timesprobe2times(self, values, direction='forward'):
1201 self._get_probe_info()
1202 if direction == 'forward':
1203 return self._sync['forward'](values * self._sync['fs'])
1204 elif direction == 'reverse':
1205 return self._sync['reverse'](values) / self._sync['fs']
1207 def samples2times(self, values, direction='forward'):
1208 """
1209 Converts ephys sample values to session main clock seconds
1210 :param values: numpy array of times in seconds or samples to resync
1211 :param direction: 'forward' (samples probe time to seconds main time) or 'reverse'
1212 (seconds main time to samples probe time)
1213 :return:
1214 """
1215 self._get_probe_info() 1e
1216 return self._sync[direction](values) 1e
1218 @property
1219 def pid2ref(self):
1220 return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" 1ca
1222 def _default_plot_title(self, spikes):
1223 title = f"{self.pid2ref}, {self.pid} \n" \ 1c
1224 f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters"
1225 return title 1c
1227 def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None,
1228 drift=None, title=None, **kwargs):
1229 """
1230 :param spikes: spikes dictionary or Bunch
1231 :param channels: channels dictionary or Bunch.
1232 :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png".
1233 Otherwise, plot.
1234 :param br: brain regions object (optional)
1235 :param label: label for saved image (optional, default="raster")
1236 :param time_series: timeseries dictionary for behavioral event times (optional)
1237 :param **kwargs: kwargs passed to `driftmap()` (optional)
1238 :return:
1239 """
1240 br = br or BrainRegions() 1c
1241 time_series = time_series or {} 1c
1242 fig, axs = plt.subplots(2, 2, gridspec_kw={ 1c
1243 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col')
1244 axs[0, 1].set_axis_off() 1c
1245 # axs[0, 0].set_xticks([])
1246 if kwargs is None: 1c
1247 # set default raster plot parameters
1248 kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5}
1249 brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) 1c
1250 if title is None: 1c
1251 title = self._default_plot_title(spikes) 1c
1252 axs[0, 0].title.set_text(title) 1c
1253 for k, ts in time_series.items(): 1c
1254 vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0])
1255 if 'atlas_id' in channels: 1c
1256 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 1c
1257 brain_regions=br, display=True, ax=axs[1, 1], title=self.histology)
1258 axs[1, 0].set_ylim(0, 3800) 1c
1259 axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) 1c
1260 fig.tight_layout() 1c
1262 if drift is None: 1c
1263 self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') 1c
1264 if 'drift' in self.files: 1c
1265 drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
1266 if isinstance(drift, dict): 1c
1267 axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5)
1268 axs[0, 0].set(ylim=[-15, 15])
1270 if save_dir is not None: 1c
1271 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1272 fig.savefig(png_file)
1273 plt.close(fig)
1274 gc.collect()
1275 else:
1276 return fig, axs 1c
1278 def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
1279 channels=None,
1280 br: BrainRegions = None,
1281 save_dir=None,
1282 label='raster',
1283 gain=-93,
1284 title=None):
1286 # compute the raw data offset and destripe, we take 400ms around t0
1287 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs))
1288 raw = sr[first_sample:last_sample, :-sr.nsync].T
1289 channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True
1290 destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels)
1291 # filter out the spikes according to good/bad clusters and to the time slice
1292 spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample]))
1293 ss = spikes['samples'][spike_sel]
1294 sc = clusters['channels'][spikes['clusters'][spike_sel]]
1295 sok = clusters['label'][spikes['clusters'][spike_sel]] == 1
1296 if title is None:
1297 title = self._default_plot_title(spikes)
1298 # display the raw data snippet with spikes overlaid
1299 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col')
1300 Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s')
1301 axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5)
1302 axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5)
1303 axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035])
1304 # adds the channel locations if available
1305 if (channels is not None) and ('atlas_id' in channels):
1306 br = br or BrainRegions()
1307 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
1308 brain_regions=br, display=True, ax=axs[1], title=self.histology)
1309 axs[1].get_yaxis().set_visible(False)
1310 fig.tight_layout()
1312 if save_dir is not None:
1313 png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1314 fig.savefig(png_file)
1315 plt.close(fig)
1316 gc.collect()
1317 else:
1318 return fig, axs
1321@dataclass
1322class SessionLoader:
1323 """
1324 Object to load session data for a give session in the recommended way.
1326 Parameters
1327 ----------
1328 one: one.api.ONE instance
1329 Can be in remote or local mode (required)
1330 session_path: string or pathlib.Path
1331 The absolute path to the session (one of session_path or eid is required)
1332 eid: string
1333 database UUID of the session (one of session_path or eid is required)
1335 If both are provided, session_path takes precedence over eid.
1337 Examples
1338 --------
1339 1) Load all available session data for one session:
1340 >>> from one.api import ONE
1341 >>> from brainbox.io.one import SessionLoader
1342 >>> one = ONE()
1343 >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/')
1344 # Object is initiated, but no data is loaded as you can see in the data_info attribute
1345 >>> sess_loader.data_info
1346 name is_loaded
1347 0 trials False
1348 1 wheel False
1349 2 pose False
1350 3 motion_energy False
1351 4 pupil False
1353 # Loading all available session data, the data_info attribute now shows which data has been loaded
1354 >>> sess_loader.load_session_data()
1355 >>> sess_loader.data_info
1356 name is_loaded
1357 0 trials True
1358 1 wheel True
1359 2 pose True
1360 3 motion_energy True
1361 4 pupil False
1363 # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g.
1364 >>> type(sess_loader.trials)
1365 pandas.core.frame.DataFrame
1366 >>> sess_loader.trials.shape
1367 (626, 18)
1368 # Each data comes with its own timestamps in a column called 'times'
1369 >>> sess_loader.wheel['times']
1370 0 0.134286
1371 1 0.135286
1372 2 0.136286
1373 3 0.137286
1374 4 0.138286
1375 ...
1376 # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera.
1377 # The dataframes of all cameras are collected in a dictionary
1378 >>> type(sess_loader.pose)
1379 dict
1380 >>> sess_loader.pose.keys()
1381 dict_keys(['leftCamera', 'rightCamera', 'bodyCamera'])
1382 >>> sess_loader.pose['bodyCamera'].columns
1383 Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object')
1384 # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading
1385 functions:
1386 >>> sess_loader.load_wheel(sampling_rate=100)
1387 """
1388 one: One = None
1389 session_path: Path = ''
1390 eid: str = ''
1391 revision: str = ''
1392 data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1393 trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1394 wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1395 pose: dict = field(default_factory=dict, repr=False)
1396 motion_energy: dict = field(default_factory=dict, repr=False)
1397 pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1399 def __post_init__(self):
1400 """
1401 Function that runs automatically after initiation of the dataclass attributes.
1402 Checks for required inputs, sets session_path and eid, creates data_info table.
1403 """
1404 if self.one is None: 1bf
1405 raise ValueError("An input to one is required. If not connection to a database is desired, it can be "
1406 "a fully local instance of One.")
1407 # If session path is given, takes precedence over eid
1408 if self.session_path is not None and self.session_path != '': 1bf
1409 self.eid = self.one.to_eid(self.session_path) 1bf
1410 self.session_path = Path(self.session_path) 1bf
1411 # Providing no session path, try to infer from eid
1412 else:
1413 if self.eid is not None and self.eid != '':
1414 self.session_path = self.one.eid2path(self.eid)
1415 else:
1416 raise ValueError("If no session path is given, eid is required.")
1418 data_names = [ 1bf
1419 'trials',
1420 'wheel',
1421 'pose',
1422 'motion_energy',
1423 'pupil'
1424 ]
1425 self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) 1bf
1427 def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False):
1428 """
1429 Function to load available session data into the SessionLoader object. Input parameters allow to control which
1430 data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input
1431 parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored
1432 in SessionLoader.data_info
1434 Parameters
1435 ----------
1436 trials: boolean
1437 Whether to load all trials data into SessionLoader.trials, default is True
1438 wheel: boolean
1439 Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True
1440 pose: boolean
1441 Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose,
1442 default is True
1443 motion_energy: boolean
1444 Whether to load motion energy data (whisker pad for left/right camera, body for body camera)
1445 into SessionLoader.motion_energy, default is True
1446 pupil: boolean
1447 Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil,
1448 default is True
1449 reload: boolean
1450 Whether to reload data that has already been loaded into this SessionLoader object, default is False
1451 """
1452 load_df = self.data_info.copy() 1f
1453 load_df['to_load'] = [ 1f
1454 trials,
1455 wheel,
1456 pose,
1457 motion_energy,
1458 pupil
1459 ]
1460 load_df['load_func'] = [ 1f
1461 self.load_trials,
1462 self.load_wheel,
1463 self.load_pose,
1464 self.load_motion_energy,
1465 self.load_pupil
1466 ]
1468 for idx, row in load_df.iterrows(): 1f
1469 if row['to_load'] is False: 1f
1470 _logger.debug(f"Not loading {row['name']} data, set to False.")
1471 elif row['is_loaded'] is True and reload is False: 1f
1472 _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") 1f
1473 else:
1474 try: 1f
1475 _logger.info(f"Loading {row['name']} data") 1f
1476 row['load_func']() 1f
1477 self.data_info.loc[idx, 'is_loaded'] = True 1f
1478 except BaseException as e:
1479 _logger.warning(f"Could not load {row['name']} data.")
1480 _logger.debug(e)
1482 def _find_behaviour_collection(self, obj):
1483 """
1484 Function to find the trial or wheel collection
1486 Parameters
1487 ----------
1488 obj: str
1489 Alf object to load, either 'trials' or 'wheel'
1490 """
1491 dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy' 1fnj
1492 dsets = self.one.list_datasets(self.eid, dataset) 1fnj
1493 if len(dsets) == 0: 1fnj
1494 return 'alf' 1fn
1495 else:
1496 collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets] 1fj
1497 if len(set(collections)) == 1: 1fj
1498 return collections[0] 1fj
1499 else:
1500 _logger.error(f'Multiple collections found {collections}. Specify collection when loading, '
1501 f'e.g sl.load_{obj}(collection="{collections[0]}")')
1502 raise ALFMultipleCollectionsFound
1504 def load_trials(self, collection=None):
1505 """
1506 Function to load trials data into SessionLoader.trials
1508 Parameters
1509 ----------
1510 collection: str
1511 Alf collection of trials data
1512 """
1514 if not collection: 1fn
1515 collection = self._find_behaviour_collection('trials') 1fn
1516 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex
1517 self.one.wildcards = False 1fn
1518 self.trials = self.one.load_object( 1fn
1519 self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df()
1520 self.one.wildcards = True 1fn
1521 self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True 1fn
1523 def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None):
1524 """
1525 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
1526 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
1527 a Butterworth low-pass filter is applied.
1529 Parameters
1530 ----------
1531 fs: int, float
1532 Sampling frequency for the wheel position, default is 1000 Hz
1533 corner_frequency: int, float
1534 Corner frequency of Butterworth low-pass filter, default is 20
1535 order: int, float
1536 Order of Butterworth low_pass filter, default is 8
1537 collection: str
1538 Alf collection of wheel data
1539 """
1540 if not collection: 1fj
1541 collection = self._find_behaviour_collection('wheel') 1fj
1542 wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection, revision=self.revision or None) 1fj
1543 if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: 1fj
1544 raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
1545 # resample the wheel position and compute velocity, acceleration
1546 self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) 1fj
1547 self.wheel['position'], self.wheel['times'] = interpolate_position( 1fj
1548 wheel_raw['timestamps'], wheel_raw['position'], freq=fs)
1549 self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( 1fj
1550 self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order)
1551 self.wheel = self.wheel.apply(np.float32) 1fj
1552 self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True 1fj
1554 def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
1555 """
1556 Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a
1557 dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas
1558 Dataframes with the timestamps and pose data, one row for each body part tracked for that camera.
1560 Parameters
1561 ----------
1562 likelihood_thr: float
1563 The position of each tracked body part come with a likelihood of that estimate for each time point.
1564 Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set
1565 likelihood_thr=1. Default is 0.9
1566 views: list
1567 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1568 """
1569 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1570 self.pose = {} 1mhf
1571 for view in views: 1mhf
1572 pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'], revision=self.revision or None) 1mhf
1573 # Double check if video timestamps are correct length or can be fixed
1574 times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) 1mhf
1575 self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) 1mhf
1576 self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) 1mhf
1577 self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True 1mhf
1579 def load_motion_energy(self, views=['left', 'right', 'body']):
1580 """
1581 Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a
1582 dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are
1583 pandas Dataframes with the timestamps and motion energy data.
1584 The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad
1585 (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the
1586 body (bodyMotionEnergy).
1588 Parameters
1589 ----------
1590 views: list
1591 List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1592 """
1593 names = {'left': 'whiskerMotionEnergy', 1lf
1594 'right': 'whiskerMotionEnergy',
1595 'body': 'bodyMotionEnergy'}
1596 # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1597 self.motion_energy = {} 1lf
1598 for view in views: 1lf
1599 me_raw = self.one.load_object( 1lf
1600 self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None)
1601 # Double check if video timestamps are correct length or can be fixed
1602 times_fixed, motion_energy = self._check_video_timestamps( 1lf
1603 view, me_raw['times'], me_raw['ROIMotionEnergy'])
1604 self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) 1lf
1605 self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) 1lf
1606 self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True 1lf
1608 def load_licks(self):
1609 """
1610 Not yet implemented
1611 """
1612 pass
1614 def load_pupil(self, snr_thresh=5.):
1615 """
1616 Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil.
1618 Parameters
1619 ----------
1620 snr_thresh: float
1621 An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data
1622 will be considered unusable and will be discarded.
1623 """
1624 # Try to load from features
1625 feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'], revision=self.revision or None) 1hf
1626 if 'features' in feat_raw.keys(): 1hf
1627 times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features'])
1628 self.pupil = feats.copy()
1629 self.pupil.insert(0, 'times', times_fixed)
1631 # If unavailable compute on the fly
1632 else:
1633 _logger.info('Pupil diameter not available, trying to compute on the fly.') 1hf
1634 if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] 1hf
1635 and 'leftCamera' in self.pose.keys()):
1636 # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt
1637 copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data 1hf
1638 self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 1hf
1639 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable 1hf
1640 self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place 1hf
1641 else:
1642 self.load_pose(views=['left'], likelihood_thr=0.9)
1643 dlc_thr = self.pose['leftCamera'].copy()
1645 self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) 1hf
1646 try: 1hf
1647 self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') 1hf
1648 except BaseException as e:
1649 _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. "
1650 "Saving all NaNs for pupilDiameter_smooth.")
1651 _logger.debug(e)
1652 self.pupil['pupilDiameter_smooth'] = np.nan
1654 if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): 1hf
1655 good_idxs = np.where( 1hf
1656 ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0]
1657 snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / 1hf
1658 (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs])))
1659 if snr < snr_thresh: 1hf
1660 self.pupil = pd.DataFrame() 1h
1661 raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') 1h
1663 def _check_video_timestamps(self, view, video_timestamps, video_data):
1664 """
1665 Helper function to check for the length of the video frames vs video timestamps and fix in case
1666 timestamps are longer than video frames.
1667 """
1668 # If camera times are shorter than video data, or empty, no current fix
1669 if video_timestamps.shape[0] < video_data.shape[0]: 1lmhf
1670 if video_timestamps.shape[0] == 0:
1671 msg = f'Camera times empty for {view}Camera.'
1672 else:
1673 msg = f'Camera times are shorter than video data for {view}Camera.'
1674 _logger.warning(msg)
1675 raise ValueError(msg)
1676 # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video.
1677 # This is because the first few frames are sometimes not recorded. We can remove the first few
1678 # timestamps in this case
1679 elif video_timestamps.shape[0] > video_data.shape[0]: 1lmhf
1680 video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] 1lmhf
1681 return video_timestamps_fixed, video_data 1lmhf
1682 else:
1683 return video_timestamps, video_data
1686class EphysSessionLoader(SessionLoader):
1687 """
1688 Spike sorting enhanced version of SessionLoader
1689 Loads spike sorting data for all probes in the session, in the self.ephys dict
1690 >>> EphysSessionLoader(eid=eid, one=one)
1691 To select for a specific probe
1692 >>> EphysSessionLoader(eid=eid, one=one, pid=pid)
1693 """
1694 def __init__(self, *args, pname=None, pid=None, **kwargs):
1695 """
1696 Needs an active connection in order to get the list of insertions in the session
1697 :param args:
1698 :param kwargs:
1699 """
1700 super().__init__(*args, **kwargs)
1701 # if necessary, restrict the query
1702 qargs = {} if pname is None else {'name': pname}
1703 qargs = qargs or ({} if pid is None else {'id': pid})
1704 insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs)
1705 self.ephys = {}
1706 for ins in insertions:
1707 self.ephys[ins['name']] = {}
1708 self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one)
1710 def load_session_data(self, *args, **kwargs):
1711 super().load_session_data(*args, **kwargs)
1712 self.load_spike_sorting()
1714 def load_spike_sorting(self, pnames=None):
1715 pnames = pnames or list(self.ephys.keys())
1716 for pname in pnames:
1717 spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting()
1718 self.ephys[pname]['spikes'] = spikes
1719 self.ephys[pname]['clusters'] = clusters
1720 self.ephys[pname]['channels'] = channels
1722 @property
1723 def probes(self):
1724 return {k: self.ephys[k]['ssl'].pid for k in self.ephys}