Coverage for brainbox/singlecell.py: 59%
99 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1'''
2Computes properties of single-cells, e.g. the autocorrelation and firing rate.
3'''
5import numpy as np
6from scipy.signal import convolve, gaussian
7from iblutil.util import Bunch
8from brainbox.population.decode import xcorr
11def acorr(spike_times, bin_size=None, window_size=None):
12 """Compute the auto-correlogram of a neuron.
14 Parameters
15 ----------
17 :param spike_times: Spike times in seconds.
18 :type spike_times: array-like
19 :param bin_size: Size of the bin, in seconds.
20 :type bin_size: float
21 :param window_size: Size of the window, in seconds.
22 :type window_size: float
24 Returns an `(winsize_samples,)` array with the auto-correlogram.
26 """
27 xc = xcorr(spike_times, np.zeros_like(spike_times, dtype=np.int32), 1cd
28 bin_size=bin_size, window_size=window_size)
29 return xc[0, 0, :] 1cd
32def bin_spikes(times, align_times, pre_time=0.4, post_time=1, bin_size=0.01, weights=None):
33 """
34 Event aligned raster for single cluster
35 :param times:
36 :param align_times:
37 :param pre_time:
38 :param post_time:
39 :param bin_size:
40 :param weights:
41 :return:
42 """
44 n_bins_pre = int(np.ceil(pre_time / bin_size))
45 n_bins_post = int(np.ceil(post_time / bin_size))
46 n_bins = n_bins_pre + n_bins_post
47 tscale = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size
48 ts = np.repeat(align_times[:, np.newaxis], tscale.size, axis=1) + tscale
49 epoch_idxs = np.searchsorted(times, np.c_[ts[:, 0], ts[:, -1]])
50 bins = np.zeros(shape=(align_times.shape[0], n_bins))
52 for i, (ep, t) in enumerate(zip(epoch_idxs, ts)):
53 xind = (np.floor((times[ep[0]:ep[1]] - t[0]) / bin_size)).astype(np.int64)
54 w = weights[ep[0]:ep[1]] if weights is not None else None
55 r = np.bincount(xind, minlength=tscale.shape[0], weights=w)
56 bins[i, :] = r[:-1]
58 tscale = (tscale[:-1] + tscale[1:]) / 2
60 return bins, tscale
63def bin_spikes2D(spike_times, spike_clusters, cluster_ids, align_times, pre_time=0.4, post_time=1, bin_size=0.01, weights=None):
64 """
65 Event aligned raster for mutliple clusters
66 :param spike_times:
67 :param spike_clusters:
68 :param cluster_ids:
69 :param align_times:
70 :param pre_time:
71 :param post_time:
72 :param bin_size:
73 :param weights:
74 :return:
75 """
77 n_bins_pre = int(np.ceil(pre_time / bin_size))
78 n_bins_post = int(np.ceil(post_time / bin_size))
79 n_bins = n_bins_pre + n_bins_post
80 tscale = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size
81 ts = np.repeat(align_times[:, np.newaxis], tscale.size, axis=1) + tscale
82 epoch_idxs = np.searchsorted(spike_times, np.c_[ts[:, 0], ts[:, -1]])
83 bins = np.zeros(shape=(align_times.shape[0], cluster_ids.shape[0], n_bins))
85 for i, (ep, t) in enumerate(zip(epoch_idxs, ts)):
86 xind = (np.floor((spike_times[ep[0]:ep[1]] - t[0]) / bin_size)).astype(np.int64)
87 w = weights[ep[0]:ep[1]] if weights is not None else None
88 yscale, yind = np.unique(spike_clusters[ep[0]:ep[1]], return_inverse=True)
89 nx, ny = [tscale.size, yscale.size]
90 ind2d = np.ravel_multi_index(np.c_[yind, xind].transpose(), dims=(ny, nx))
91 r = np.bincount(ind2d, minlength=nx * ny, weights=w).reshape(ny, nx)
93 bs_idxs = np.isin(cluster_ids, yscale)
94 bins[i, bs_idxs, :] = r[:, :-1]
96 tscale = (tscale[:-1] + tscale[1:]) / 2
98 return bins, tscale
101def calculate_peths(
102 spike_times, spike_clusters, cluster_ids, align_times, pre_time=0.2,
103 post_time=0.5, bin_size=0.025, smoothing=0.025, return_fr=True):
104 """
105 Calcluate peri-event time histograms; return means and standard deviations
106 for each time point across specified clusters
108 :param spike_times: spike times (in seconds)
109 :type spike_times: array-like
110 :param spike_clusters: cluster ids corresponding to each event in `spikes`
111 :type spike_clusters: array-like
112 :param cluster_ids: subset of cluster ids for calculating peths
113 :type cluster_ids: array-like
114 :param align_times: times (in seconds) to align peths to
115 :type align_times: array-like
116 :param pre_time: time (in seconds) to precede align times in peth
117 :type pre_time: float
118 :param post_time: time (in seconds) to follow align times in peth
119 :type post_time: float
120 :param bin_size: width of time windows (in seconds) to bin spikes
121 :type bin_size: float
122 :param smoothing: standard deviation (in seconds) of Gaussian kernel for
123 smoothing peths; use `smoothing=0` to skip smoothing
124 :type smoothing: float
125 :param return_fr: `True` to return (estimated) firing rate, `False` to return spike counts
126 :type return_fr: bool
127 :return: peths, binned_spikes
128 :rtype: peths: Bunch({'mean': peth_means, 'std': peth_stds, 'tscale': ts, 'cscale': ids})
129 :rtype: binned_spikes: np.array (n_align_times, n_clusters, n_bins)
130 """
132 # initialize containers
133 n_offset = 5 * int(np.ceil(smoothing / bin_size)) # get rid of boundary effects for smoothing 1a
134 n_bins_pre = int(np.ceil(pre_time / bin_size)) + n_offset 1a
135 n_bins_post = int(np.ceil(post_time / bin_size)) + n_offset 1a
136 n_bins = n_bins_pre + n_bins_post 1a
137 binned_spikes = np.zeros(shape=(len(align_times), len(cluster_ids), n_bins)) 1a
139 # build gaussian kernel if requested
140 if smoothing > 0: 1a
141 w = n_bins - 1 if n_bins % 2 == 0 else n_bins 1a
142 window = gaussian(w, std=smoothing / bin_size) 1a
143 # half (causal) gaussian filter
144 # window[int(np.ceil(w/2)):] = 0
145 window /= np.sum(window) 1a
146 binned_spikes_conv = np.copy(binned_spikes) 1a
148 ids = np.unique(cluster_ids) 1a
150 # filter spikes outside of the loop
151 idxs = np.bitwise_and(spike_times >= np.min(align_times) - (n_bins_pre + 1) * bin_size, 1a
152 spike_times <= np.max(align_times) + (n_bins_post + 1) * bin_size)
153 idxs = np.bitwise_and(idxs, np.isin(spike_clusters, cluster_ids)) 1a
154 spike_times = spike_times[idxs] 1a
155 spike_clusters = spike_clusters[idxs] 1a
157 # compute floating tscale
158 tscale = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size 1a
159 # bin spikes
160 for i, t_0 in enumerate(align_times): 1a
161 # define bin edges
162 ts = tscale + t_0 1a
163 # filter spikes
164 idxs = np.bitwise_and(spike_times >= ts[0], spike_times <= ts[-1]) 1a
165 i_spikes = spike_times[idxs] 1a
166 i_clusters = spike_clusters[idxs] 1a
168 # bin spikes similar to bincount2D: x = spike times, y = spike clusters
169 xscale = ts 1a
170 xind = (np.floor((i_spikes - np.min(ts)) / bin_size)).astype(np.int64) 1a
171 yscale, yind = np.unique(i_clusters, return_inverse=True) 1a
172 nx, ny = [xscale.size, yscale.size] 1a
173 ind2d = np.ravel_multi_index(np.c_[yind, xind].transpose(), dims=(ny, nx)) 1a
174 r = np.bincount(ind2d, minlength=nx * ny, weights=None).reshape(ny, nx) 1a
176 # store (ts represent bin edges, so there are one fewer bins)
177 bs_idxs = np.isin(ids, yscale) 1a
178 binned_spikes[i, bs_idxs, :] = r[:, :-1] 1a
180 # smooth
181 if smoothing > 0: 1a
182 idxs = np.where(bs_idxs)[0] 1a
183 for j in range(r.shape[0]): 1a
184 binned_spikes_conv[i, idxs[j], :] = convolve( 1a
185 r[j, :], window, mode='same', method='auto')[:-1]
187 # average
188 if smoothing > 0: 1a
189 binned_spikes_ = np.copy(binned_spikes_conv) 1a
190 else:
191 binned_spikes_ = np.copy(binned_spikes)
192 if return_fr: 1a
193 binned_spikes_ /= bin_size 1a
195 peth_means = np.mean(binned_spikes_, axis=0) 1a
196 peth_stds = np.std(binned_spikes_, axis=0) 1a
198 if smoothing > 0: 1a
199 peth_means = peth_means[:, n_offset:-n_offset] 1a
200 peth_stds = peth_stds[:, n_offset:-n_offset] 1a
201 binned_spikes = binned_spikes[:, :, n_offset:-n_offset] 1a
202 tscale = tscale[n_offset:-n_offset] 1a
204 # package output
205 tscale = (tscale[:-1] + tscale[1:]) / 2 1a
206 peths = Bunch({'means': peth_means, 'stds': peth_stds, 'tscale': tscale, 'cscale': ids}) 1a
207 return peths, binned_spikes 1a
210def firing_rate(ts, hist_win=0.01, fr_win=0.5):
211 '''
212 Computes the instantaneous firing rate of a unit over time by computing a histogram of spike
213 counts over a specified window of time, and summing this histogram over a sliding window of
214 specified time over a specified period of total time.
216 Parameters
217 ----------
218 ts : ndarray
219 The spike timestamps from which to compute the firing rate..
220 hist_win : float
221 The time window (in s) to use for computing spike counts.
222 fr_win : float
223 The time window (in s) to use as a moving slider to compute the instantaneous firing rate.
225 Returns
226 -------
227 fr : ndarray
228 The instantaneous firing rate over time (in hz).
230 See Also
231 --------
232 metrics.firing_rate_cv
233 metrics.firing_rate_fano_factor
234 plot.firing_rate
236 Examples
237 --------
238 1) Compute the firing rate for unit 1 from the time of its first to last spike.
239 >>> import brainbox as bb
240 >>> import alf.io as aio
241 >>> import ibllib.ephys.spikes as e_spks
242 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
243 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
244 # Load a spikes bunch and get the timestamps for unit 1, and calculate the instantaneous
245 # firing rate.
246 >>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
247 >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0]
248 >>> ts = spks_b['times'][unit_idxs]
249 >>> fr = bb.singlecell.firing_rate(ts)
250 '''
252 # Compute histogram of spike counts.
253 t_tot = ts[-1] - ts[0]
254 n_bins_hist = int(t_tot / hist_win)
255 counts = np.histogram(ts, n_bins_hist)[0]
256 # Compute moving average of spike counts to get instantaneous firing rate in s.
257 n_bins_fr = int(t_tot / fr_win)
258 step_sz = int(len(counts) / n_bins_fr)
259 fr = np.convolve(counts, np.ones(step_sz)) / fr_win
260 fr = fr[step_sz - 1:- step_sz]
261 return fr