Coverage for brainbox/singlecell.py: 59%

100 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1''' 

2Computes properties of single-cells, e.g. the autocorrelation and firing rate. 

3''' 

4 

5import numpy as np 

6from scipy.signal import convolve 

7from scipy.signal.windows import gaussian 

8from iblutil.util import Bunch 

9from brainbox.population.decode import xcorr 

10 

11 

12def acorr(spike_times, bin_size=None, window_size=None): 

13 """Compute the auto-correlogram of a neuron. 

14 

15 Parameters 

16 ---------- 

17 

18 :param spike_times: Spike times in seconds. 

19 :type spike_times: array-like 

20 :param bin_size: Size of the bin, in seconds. 

21 :type bin_size: float 

22 :param window_size: Size of the window, in seconds. 

23 :type window_size: float 

24 

25 Returns an `(winsize_samples,)` array with the auto-correlogram. 

26 

27 """ 

28 xc = xcorr(spike_times, np.zeros_like(spike_times, dtype=np.int32), 1cd

29 bin_size=bin_size, window_size=window_size) 

30 return xc[0, 0, :] 1cd

31 

32 

33def bin_spikes(times, align_times, pre_time=0.4, post_time=1, bin_size=0.01, weights=None): 

34 """ 

35 Event aligned raster for single cluster 

36 :param times: 

37 :param align_times: 

38 :param pre_time: 

39 :param post_time: 

40 :param bin_size: 

41 :param weights: 

42 :return: 

43 """ 

44 

45 n_bins_pre = int(np.ceil(pre_time / bin_size)) 

46 n_bins_post = int(np.ceil(post_time / bin_size)) 

47 n_bins = n_bins_pre + n_bins_post 

48 tscale = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size 

49 ts = np.repeat(align_times[:, np.newaxis], tscale.size, axis=1) + tscale 

50 epoch_idxs = np.searchsorted(times, np.c_[ts[:, 0], ts[:, -1]]) 

51 bins = np.zeros(shape=(align_times.shape[0], n_bins)) 

52 

53 for i, (ep, t) in enumerate(zip(epoch_idxs, ts)): 

54 xind = (np.floor((times[ep[0]:ep[1]] - t[0]) / bin_size)).astype(np.int64) 

55 w = weights[ep[0]:ep[1]] if weights is not None else None 

56 r = np.bincount(xind, minlength=tscale.shape[0], weights=w) 

57 bins[i, :] = r[:-1] 

58 

59 tscale = (tscale[:-1] + tscale[1:]) / 2 

60 

61 return bins, tscale 

62 

63 

64def bin_spikes2D(spike_times, spike_clusters, cluster_ids, align_times, pre_time=0.4, post_time=1, bin_size=0.01, weights=None): 

65 """ 

66 Event aligned raster for mutliple clusters 

67 :param spike_times: 

68 :param spike_clusters: 

69 :param cluster_ids: 

70 :param align_times: 

71 :param pre_time: 

72 :param post_time: 

73 :param bin_size: 

74 :param weights: 

75 :return: 

76 """ 

77 

78 n_bins_pre = int(np.ceil(pre_time / bin_size)) 

79 n_bins_post = int(np.ceil(post_time / bin_size)) 

80 n_bins = n_bins_pre + n_bins_post 

81 tscale = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size 

82 ts = np.repeat(align_times[:, np.newaxis], tscale.size, axis=1) + tscale 

83 epoch_idxs = np.searchsorted(spike_times, np.c_[ts[:, 0], ts[:, -1]]) 

84 bins = np.zeros(shape=(align_times.shape[0], cluster_ids.shape[0], n_bins)) 

85 

86 for i, (ep, t) in enumerate(zip(epoch_idxs, ts)): 

87 xind = (np.floor((spike_times[ep[0]:ep[1]] - t[0]) / bin_size)).astype(np.int64) 

88 w = weights[ep[0]:ep[1]] if weights is not None else None 

89 yscale, yind = np.unique(spike_clusters[ep[0]:ep[1]], return_inverse=True) 

90 nx, ny = [tscale.size, yscale.size] 

91 ind2d = np.ravel_multi_index(np.c_[yind, xind].transpose(), dims=(ny, nx)) 

92 r = np.bincount(ind2d, minlength=nx * ny, weights=w).reshape(ny, nx) 

93 

94 bs_idxs = np.isin(cluster_ids, yscale) 

95 bins[i, bs_idxs, :] = r[:, :-1] 

96 

97 tscale = (tscale[:-1] + tscale[1:]) / 2 

98 

99 return bins, tscale 

100 

101 

102def calculate_peths( 

103 spike_times, spike_clusters, cluster_ids, align_times, pre_time=0.2, 

104 post_time=0.5, bin_size=0.025, smoothing=0.025, return_fr=True): 

105 """ 

106 Calcluate peri-event time histograms; return means and standard deviations 

107 for each time point across specified clusters 

108 

109 :param spike_times: spike times (in seconds) 

110 :type spike_times: array-like 

111 :param spike_clusters: cluster ids corresponding to each event in `spikes` 

112 :type spike_clusters: array-like 

113 :param cluster_ids: subset of cluster ids for calculating peths 

114 :type cluster_ids: array-like 

115 :param align_times: times (in seconds) to align peths to 

116 :type align_times: array-like 

117 :param pre_time: time (in seconds) to precede align times in peth 

118 :type pre_time: float 

119 :param post_time: time (in seconds) to follow align times in peth 

120 :type post_time: float 

121 :param bin_size: width of time windows (in seconds) to bin spikes 

122 :type bin_size: float 

123 :param smoothing: standard deviation (in seconds) of Gaussian kernel for 

124 smoothing peths; use `smoothing=0` to skip smoothing 

125 :type smoothing: float 

126 :param return_fr: `True` to return (estimated) firing rate, `False` to return spike counts 

127 :type return_fr: bool 

128 :return: peths, binned_spikes 

129 :rtype: peths: Bunch({'mean': peth_means, 'std': peth_stds, 'tscale': ts, 'cscale': ids}) 

130 :rtype: binned_spikes: np.array (n_align_times, n_clusters, n_bins) 

131 """ 

132 

133 # initialize containers 

134 n_offset = 5 * int(np.ceil(smoothing / bin_size)) # get rid of boundary effects for smoothing 1a

135 n_bins_pre = int(np.ceil(pre_time / bin_size)) + n_offset 1a

136 n_bins_post = int(np.ceil(post_time / bin_size)) + n_offset 1a

137 n_bins = n_bins_pre + n_bins_post 1a

138 binned_spikes = np.zeros(shape=(len(align_times), len(cluster_ids), n_bins)) 1a

139 

140 # build gaussian kernel if requested 

141 if smoothing > 0: 1a

142 w = n_bins - 1 if n_bins % 2 == 0 else n_bins 1a

143 window = gaussian(w, std=smoothing / bin_size) 1a

144 # half (causal) gaussian filter 

145 # window[int(np.ceil(w/2)):] = 0 

146 window /= np.sum(window) 1a

147 binned_spikes_conv = np.copy(binned_spikes) 1a

148 

149 ids = np.unique(cluster_ids) 1a

150 

151 # filter spikes outside of the loop 

152 idxs = np.bitwise_and(spike_times >= np.min(align_times) - (n_bins_pre + 1) * bin_size, 1a

153 spike_times <= np.max(align_times) + (n_bins_post + 1) * bin_size) 

154 idxs = np.bitwise_and(idxs, np.isin(spike_clusters, cluster_ids)) 1a

155 spike_times = spike_times[idxs] 1a

156 spike_clusters = spike_clusters[idxs] 1a

157 

158 # compute floating tscale 

159 tscale = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size 1a

160 # bin spikes 

161 for i, t_0 in enumerate(align_times): 1a

162 # define bin edges 

163 ts = tscale + t_0 1a

164 # filter spikes 

165 idxs = np.bitwise_and(spike_times >= ts[0], spike_times <= ts[-1]) 1a

166 i_spikes = spike_times[idxs] 1a

167 i_clusters = spike_clusters[idxs] 1a

168 

169 # bin spikes similar to bincount2D: x = spike times, y = spike clusters 

170 xscale = ts 1a

171 xind = (np.floor((i_spikes - np.min(ts)) / bin_size)).astype(np.int64) 1a

172 yscale, yind = np.unique(i_clusters, return_inverse=True) 1a

173 nx, ny = [xscale.size, yscale.size] 1a

174 ind2d = np.ravel_multi_index(np.c_[yind, xind].transpose(), dims=(ny, nx)) 1a

175 r = np.bincount(ind2d, minlength=nx * ny, weights=None).reshape(ny, nx) 1a

176 

177 # store (ts represent bin edges, so there are one fewer bins) 

178 bs_idxs = np.isin(ids, yscale) 1a

179 binned_spikes[i, bs_idxs, :] = r[:, :-1] 1a

180 

181 # smooth 

182 if smoothing > 0: 1a

183 idxs = np.where(bs_idxs)[0] 1a

184 for j in range(r.shape[0]): 1a

185 binned_spikes_conv[i, idxs[j], :] = convolve( 1a

186 r[j, :], window, mode='same', method='auto')[:-1] 

187 

188 # average 

189 if smoothing > 0: 1a

190 binned_spikes_ = np.copy(binned_spikes_conv) 1a

191 else: 

192 binned_spikes_ = np.copy(binned_spikes) 

193 if return_fr: 1a

194 binned_spikes_ /= bin_size 1a

195 

196 peth_means = np.mean(binned_spikes_, axis=0) 1a

197 peth_stds = np.std(binned_spikes_, axis=0) 1a

198 

199 if smoothing > 0: 1a

200 peth_means = peth_means[:, n_offset:-n_offset] 1a

201 peth_stds = peth_stds[:, n_offset:-n_offset] 1a

202 binned_spikes = binned_spikes[:, :, n_offset:-n_offset] 1a

203 tscale = tscale[n_offset:-n_offset] 1a

204 

205 # package output 

206 tscale = (tscale[:-1] + tscale[1:]) / 2 1a

207 peths = Bunch({'means': peth_means, 'stds': peth_stds, 'tscale': tscale, 'cscale': ids}) 1a

208 return peths, binned_spikes 1a

209 

210 

211def firing_rate(ts, hist_win=0.01, fr_win=0.5): 

212 ''' 

213 Computes the instantaneous firing rate of a unit over time by computing a histogram of spike 

214 counts over a specified window of time, and summing this histogram over a sliding window of 

215 specified time over a specified period of total time. 

216 

217 Parameters 

218 ---------- 

219 ts : ndarray 

220 The spike timestamps from which to compute the firing rate.. 

221 hist_win : float 

222 The time window (in s) to use for computing spike counts. 

223 fr_win : float 

224 The time window (in s) to use as a moving slider to compute the instantaneous firing rate. 

225 

226 Returns 

227 ------- 

228 fr : ndarray 

229 The instantaneous firing rate over time (in hz). 

230 

231 See Also 

232 -------- 

233 metrics.firing_rate_cv 

234 metrics.firing_rate_fano_factor 

235 plot.firing_rate 

236 

237 Examples 

238 -------- 

239 1) Compute the firing rate for unit 1 from the time of its first to last spike. 

240 >>> import brainbox as bb 

241 >>> import alf.io as aio 

242 >>> import ibllib.ephys.spikes as e_spks 

243 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

244 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

245 # Load a spikes bunch and get the timestamps for unit 1, and calculate the instantaneous 

246 # firing rate. 

247 >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') 

248 >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0] 

249 >>> ts = spks_b['times'][unit_idxs] 

250 >>> fr = bb.singlecell.firing_rate(ts) 

251 ''' 

252 

253 # Compute histogram of spike counts. 

254 t_tot = ts[-1] - ts[0] 

255 n_bins_hist = int(t_tot / hist_win) 

256 counts = np.histogram(ts, n_bins_hist)[0] 

257 # Compute moving average of spike counts to get instantaneous firing rate in s. 

258 n_bins_fr = int(t_tot / fr_win) 

259 step_sz = int(len(counts) / n_bins_fr) 

260 fr = np.convolve(counts, np.ones(step_sz)) / fr_win 

261 fr = fr[step_sz - 1:- step_sz] 

262 return fr