Coverage for brainbox/spike_features.py: 0%
25 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1'''
2Functions that compute spike features from spike waveforms.
3'''
5import numpy as np
6from brainbox.io.spikeglx import extract_waveforms
9def depth(ephys_file, spks_b, clstrs_b, chnls_b, tmplts_b, unit, n_ch=12, n_ch_probe=385, sr=30000,
10 dtype='int16', car=False):
11 '''
12 Gets `n_ch` channels around a unit's channel of max amplitude, extracts all unit spike
13 waveforms from binary datafile for these channels, and for each spike, computes the dot
14 products of waveform by unit template for those channels, and computes center-of-mass of these
15 dot products to get spike depth estimates.
17 Parameters
18 ----------
19 ephys_file : string
20 The file path to the binary ephys data.
21 spks_b : bunch
22 A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features,
23 etc.) for all spikes.
24 clstrs_b : bunch
25 A clusters bunch containing fields with cluster information (e.g. amp, ch of max amp, depth
26 of ch of max amp, etc.) for all clusters.
27 chnls_b : bunch
28 A channels bunch containing fields with channel information (e.g. coordinates, indices,
29 etc.) for all probe channels.
30 tmplts_b : bunch
31 A unit templates bunch containing fields with unit template information (e.g. template
32 waveforms, etc.) for all unit templates.
33 unit : numeric
34 The unit for which to return the spikes depths.
35 n_ch : int (optional)
36 The number of channels to sample around the channel of max amplitude to compute the depths.
37 sr : int (optional)
38 The sampling rate (hz) that the ephys data was acquired at.
39 n_ch_probe : int (optional)
40 The number of channels of the recording.
41 dtype: str (optional)
42 The datatype represented by the bytes in `ephys_file`.
43 car: bool (optional)
44 A flag to perform common-average-referencing before extracting waveforms.
46 Returns
47 -------
48 d : ndarray
49 The estimated spike depths for all spikes in `unit`.
51 See Also
52 --------
53 io.extract_waveforms
55 Examples
56 --------
57 1) Get the spike depths for unit 1.
58 >>> import numpy as np
59 >>> import brainbox as bb
60 >>> import alf.io as aio
61 >>> import ibllib.ephys.spikes as e_spks
62 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
63 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
64 # Get the necessary alf objects from an alf directory.
65 >>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
66 >>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters')
67 >>> chnls_b = aio.load_object(path_to_alf_out, 'channels')
68 >>> tmplts_b = aio.load_object(path_to_alf_out, 'templates')
69 # Compute spike depths.
70 >>> unit1_depths = bb.spike_features.depth(path_to_ephys_file, spks_b, clstrs_b, chnls_b,
71 tmplts_b, unit=1)
72 '''
74 # Set constants: #
75 n_c_ch = n_ch // 2 # number of close channels to take on either side of max channel
77 # Get unit waveforms: #
78 # Get unit timestamps.
79 unit_spk_indxs = np.where(spks_b['clusters'] == unit)[0]
80 ts = spks_b['times'][unit_spk_indxs]
81 # Get `n_close_ch` channels around channel of max amplitude.
82 max_ch = clstrs_b['channels'][unit]
83 if max_ch < n_c_ch: # take only channels greater than `max_ch`.
84 ch = np.arange(max_ch, max_ch + n_ch)
85 elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`.
86 ch = np.arange(max_ch - n_ch, max_ch)
87 else: # take `n_c_ch` around `max_ch`.
88 ch = np.arange(max_ch - n_c_ch, max_ch + n_c_ch)
89 # Get unit template across `ch` and extract waveforms from `ephys_file`.
90 tmplt_wfs = tmplts_b['waveforms']
91 unit_tmplt = tmplt_wfs[unit, :, ch].T
92 wf_t = tmplt_wfs.shape[1] / (sr / 1000) # duration (ms) of each waveform
93 wf = extract_waveforms(ephys_file=ephys_file, ts=ts, ch=ch, t=wf_t, sr=sr,
94 n_ch_probe=n_ch_probe, dtype='int16', car=car)
96 # Compute center-of-mass: #
97 ch_depths = chnls_b['localCoordinates'][[ch], [1]]
98 d = np.zeros_like(ts) # depths array
99 # Compute normalized dot product of (waveforms,unit_template) across `ch`,
100 # and get center-of-mass, `c_o_m`, of these dot products (one dot for each ch)
101 for spk in range(len(ts)):
102 dot_wf_template = np.sum(wf[spk, :, :] * unit_tmplt, axis=0)
103 dot_wf_template += np.abs(np.min(dot_wf_template))
104 dot_wf_template /= np.max(dot_wf_template)
105 c_o_m = (1 / np.sum(dot_wf_template)) * np.sum(dot_wf_template * ch_depths)
106 d[spk] = c_o_m
107 return d