Coverage for brainbox/population/decode.py: 79%

168 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1""" 

2Population functions. 

3 

4Code from https://github.com/cortex-lab/phylib/blob/master/phylib/stats/ccg.py by C. Rossant. 

5Code for decoding by G. Meijer 

6Code from sigtest_pseudosessions and sigtest_linshift by B. Benson 

7""" 

8 

9import numpy as np 

10import scipy as sp 

11import scipy.stats 

12import types 

13from itertools import groupby 

14from sklearn.linear_model import LinearRegression, Lasso, Ridge 

15from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 

16from sklearn.model_selection import KFold, LeaveOneOut, LeaveOneGroupOut 

17from sklearn.metrics import accuracy_score 

18 

19 

20def get_spike_counts_in_bins(spike_times, spike_clusters, intervals): 

21 """ 

22 Return the number of spikes in a sequence of time intervals, for each neuron. 

23 

24 Parameters 

25 ---------- 

26 spike_times : 1D array 

27 spike times (in seconds) 

28 spike_clusters : 1D array 

29 cluster ids corresponding to each event in `spikes` 

30 intervals : 2D array of shape (n_events, 2) 

31 the start and end times of the events 

32 

33 Returns 

34 --------- 

35 counts : 2D array of shape (n_neurons, n_events) 

36 the spike counts of all neurons ffrom scipy.stats import sem, tor all events 

37 value (i, j) is the number of spikes of neuron `neurons[i]` in interval #j 

38 cluster_ids : 1D array 

39 list of cluster ids 

40 """ 

41 

42 # Check input 

43 assert intervals.ndim == 2 1gifjklm

44 assert intervals.shape[1] == 2 1gifjklm

45 assert np.all(np.diff(spike_times) >= 0), "Spike times need to be sorted" 1gifjklm

46 

47 intervals_idx = np.searchsorted(spike_times, intervals) 1gifjklm

48 

49 # For each neuron and each interval, the number of spikes in the interval. 

50 cluster_ids = np.unique(spike_clusters) 1gifjklm

51 n_neurons = len(cluster_ids) 1gifjklm

52 n_intervals = intervals.shape[0] 1gifjklm

53 counts = np.zeros((n_neurons, n_intervals), dtype=np.uint32) 1gifjklm

54 for j in range(n_intervals): 1gifjklm

55 i0, i1 = intervals_idx[j, :] 1gifjklm

56 # Count the number of spikes in the window, for each neuron. 

57 x = np.bincount(spike_clusters[i0:i1], minlength=cluster_ids.max() + 1) 1gifjklm

58 counts[:, j] = x[cluster_ids] 1gifjklm

59 return counts, cluster_ids 1gifjklm

60 

61 

62def _index_of(arr, lookup): 

63 """Replace scalars in an array by their indices in a lookup table. 

64 

65 Implicitly assume that: 

66 

67 * All elements of arr and lookup are non-negative integers. 

68 * All elements or arr belong to lookup. 

69 

70 This is not checked for performance reasons. 

71 

72 """ 

73 # Equivalent of np.digitize(arr, lookup) - 1, but much faster. 

74 # TODO: assertions to disable in production for performance reasons. 

75 # TODO: np.searchsorted(lookup, arr) is faster on small arrays with large 

76 # values 

77 lookup = np.asarray(lookup, dtype=np.int32) 1abcde

78 m = (lookup.max() if len(lookup) else 0) + 1 1abcde

79 tmp = np.zeros(m + 1, dtype=int) 1abcde

80 # Ensure that -1 values are kept. 

81 tmp[-1] = -1 1abcde

82 if len(lookup): 1abcde

83 tmp[lookup] = np.arange(len(lookup)) 1abcde

84 return tmp[arr] 1abcde

85 

86 

87def _increment(arr, indices): 

88 """Increment some indices in a 1D vector of non-negative integers. 

89 Repeated indices are taken into account.""" 

90 bbins = np.bincount(indices) 1abcde

91 arr[:len(bbins)] += bbins 1abcde

92 return arr 1abcde

93 

94 

95def _diff_shifted(arr, steps=1): 

96 return arr[steps:] - arr[:len(arr) - steps] 1abcde

97 

98 

99def _create_correlograms_array(n_clusters, winsize_bins): 

100 return np.zeros((n_clusters, n_clusters, winsize_bins // 2 + 1), dtype=np.int32) 1abcde

101 

102 

103def _symmetrize_correlograms(correlograms): 

104 """Return the symmetrized version of the CCG arrays.""" 

105 

106 n_clusters, _, n_bins = correlograms.shape 1abcde

107 assert n_clusters == _ 1abcde

108 

109 # We symmetrize c[i, j, 0]. 

110 # This is necessary because the algorithm in correlograms() 

111 # is sensitive to the order of identical spikes. 

112 correlograms[..., 0] = np.maximum( 1abcde

113 correlograms[..., 0], correlograms[..., 0].T) 

114 

115 sym = correlograms[..., 1:][..., ::-1] 1abcde

116 sym = np.transpose(sym, (1, 0, 2)) 1abcde

117 

118 return np.dstack((sym, correlograms)) 1abcde

119 

120 

121def xcorr(spike_times, spike_clusters, bin_size=None, window_size=None): 

122 """Compute all pairwise cross-correlograms among the clusters appearing in `spike_clusters`. 

123 

124 Parameters 

125 ---------- 

126 

127 :param spike_times: Spike times in seconds. 

128 :type spike_times: array-like 

129 :param spike_clusters: Spike-cluster mapping. 

130 :type spike_clusters: array-like 

131 :param bin_size: Size of the bin, in seconds. 

132 :type bin_size: float 

133 :param window_size: Size of the window, in seconds. 

134 :type window_size: float 

135 

136 Returns an `(n_clusters, n_clusters, winsize_samples)` array with all pairwise 

137 cross-correlograms. 

138 

139 """ 

140 assert np.all(np.diff(spike_times) >= 0), "The spike times must be increasing." 1abcde

141 assert spike_times.ndim == 1 1abcde

142 assert spike_times.shape == spike_clusters.shape 1abcde

143 

144 # Find `binsize`. 

145 bin_size = np.clip(bin_size, 1e-5, 1e5) # in seconds 1abcde

146 

147 # Find `winsize_bins`. 

148 window_size = np.clip(window_size, 1e-5, 1e5) # in seconds 1abcde

149 winsize_bins = 2 * int(.5 * window_size / bin_size) + 1 1abcde

150 

151 # Take the cluster order into account. 

152 clusters = np.unique(spike_clusters) 1abcde

153 n_clusters = len(clusters) 1abcde

154 

155 # Like spike_clusters, but with 0..n_clusters-1 indices. 

156 spike_clusters_i = _index_of(spike_clusters, clusters) 1abcde

157 

158 # Shift between the two copies of the spike trains. 

159 shift = 1 1abcde

160 

161 # At a given shift, the mask precises which spikes have matching spikes 

162 # within the correlogram time window. 

163 mask = np.ones_like(spike_times, dtype=bool) 1abcde

164 

165 correlograms = _create_correlograms_array(n_clusters, winsize_bins) 1abcde

166 

167 # The loop continues as long as there is at least one spike with 

168 # a matching spike. 

169 while mask[:-shift].any(): 1abcde

170 # Interval between spike i and spike i+shift. 

171 spike_diff = _diff_shifted(spike_times, shift) 1abcde

172 

173 # Binarize the delays between spike i and spike i+shift. 

174 spike_diff_b = np.round(spike_diff / bin_size).astype(np.int64) 1abcde

175 

176 # Spikes with no matching spikes are masked. 

177 mask[:-shift][spike_diff_b > (winsize_bins / 2)] = False 1abcde

178 

179 # Cache the masked spike delays. 

180 m = mask[:-shift].copy() 1abcde

181 d = spike_diff_b[m] 1abcde

182 

183 # Find the indices in the raveled correlograms array that need 

184 # to be incremented, taking into account the spike clusters. 

185 indices = np.ravel_multi_index( 1abcde

186 (spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d), correlograms.shape) 

187 

188 # Increment the matching spikes in the correlograms array. 

189 _increment(correlograms.ravel(), indices) 1abcde

190 

191 shift += 1 1abcde

192 

193 return _symmetrize_correlograms(correlograms) 1abcde

194 

195 

196def classify(population_activity, trial_labels, classifier, cross_validation=None, 

197 return_training=False): 

198 """ 

199 Classify trial identity (e.g. stim left/right) from neural population activity. 

200 

201 Parameters 

202 ---------- 

203 population_activity : 2D array (trials x neurons) 

204 population activity of all neurons in the population for each trial. 

205 trial_labels : 1D or 2D array 

206 identities of the trials, can be any number of groups, accepts integers and strings 

207 classifier : scikit-learn object 

208 which decoder to use, for example Gaussian with Multinomial likelihood: 

209 from sklearn.naive_bayes import MultinomialNB 

210 classifier = MultinomialNB() 

211 cross_validation : None or scikit-learn object 

212 which cross-validation method to use, for example 5-fold: 

213 from sklearn.model_selection import KFold 

214 cross_validation = KFold(n_splits=5) 

215 return_training : bool 

216 if set to True the classifier will also return the performance on the training set 

217 

218 Returns 

219 ------- 

220 accuracy : float 

221 accuracy of the classifier 

222 pred : 1D array 

223 predictions of the classifier 

224 prob : 1D array 

225 probablity of classification 

226 training_accuracy : float 

227 accuracy of the classifier on the training set (only if return_training is True) 

228 """ 

229 

230 # Check input 

231 if (cross_validation is None) and (return_training is True): 1g

232 raise RuntimeError('cannot return training accuracy without cross-validation') 

233 if population_activity.shape[0] != trial_labels.shape[0]: 1g

234 raise ValueError('trial_labels is not the same length as the first dimension of ' 

235 'population_activity') 

236 

237 if cross_validation is None: 1g

238 # Fit the model on all the data 

239 classifier.fit(population_activity, trial_labels) 

240 pred = classifier.predict(population_activity) 

241 prob = classifier.predict_proba(population_activity) 

242 prob = prob[:, 1] 

243 else: 

244 pred = np.empty(trial_labels.shape[0]) 1g

245 prob = np.empty(trial_labels.shape[0]) 1g

246 if return_training: 1g

247 pred_training = np.empty(trial_labels.shape[0]) 1g

248 

249 for train_index, test_index in cross_validation.split(population_activity): 1g

250 # Fit the model to the training data 

251 classifier.fit(population_activity[train_index], trial_labels[train_index]) 1g

252 

253 # Predict the held-out test data 

254 pred[test_index] = classifier.predict(population_activity[test_index]) 1g

255 proba = classifier.predict_proba(population_activity[test_index]) 1g

256 prob[test_index] = proba[:, 1] 1g

257 

258 # Predict the training data 

259 if return_training: 1g

260 pred_training[train_index] = classifier.predict(population_activity[train_index]) 1g

261 

262 # Calculate accuracy 

263 accuracy = accuracy_score(trial_labels, pred) 1g

264 if return_training: 1g

265 training_accuracy = accuracy_score(trial_labels, pred_training) 1g

266 return accuracy, pred, prob, training_accuracy 1g

267 else: 

268 return accuracy, pred, prob 

269 

270 

271def regress(population_activity, trial_targets, regularization=None, 

272 cross_validation=None, return_training=False): 

273 """ 

274 Perform linear regression to predict a continuous variable from neural data 

275 

276 Parameters 

277 ---------- 

278 population_activity : 2D array (trials x neurons) 

279 population activity of all neurons in the population for each trial. 

280 trial_targets : 1D or 2D array 

281 the decoding target per trial as a continuous variable 

282 regularization : None or string 

283 None = no regularization using ordinary least squares linear regression 

284 'L1' = L1 regularization using Lasso 

285 'L2' = L2 regularization using Ridge regression 

286 cross_validation : None or scikit-learn object 

287 which cross-validation method to use, for example 5-fold: 

288 from sklearn.model_selection import KFold 

289 cross_validation = KFold(n_splits=5) 

290 return_training : bool 

291 if set to True the classifier will also return the performance on the training set 

292 

293 Returns 

294 ------- 

295 pred : 1D array 

296 array with predictions 

297 pred_training : 1D array 

298 array with predictions for the training set (only if return_training is True) 

299 """ 

300 

301 # Check input 

302 if (cross_validation is None) and (return_training is True): 1f

303 raise RuntimeError('cannot return training accuracy without cross-validation') 

304 if population_activity.shape[0] != trial_targets.shape[0]: 1f

305 raise ValueError('trial_targets is not the same length as the first dimension of ' 

306 'population_activity') 

307 

308 # Initialize regression 

309 if regularization is None: 1f

310 reg = LinearRegression() 1f

311 elif regularization == 'L1': 1f

312 reg = Lasso() 1f

313 elif regularization == 'L2': 1f

314 reg = Ridge() 1f

315 

316 if cross_validation is None: 1f

317 # Fit the model on all the data 

318 reg.fit(population_activity, trial_targets) 1f

319 pred = reg.predict(population_activity) 1f

320 else: 

321 pred = np.empty(trial_targets.shape[0]) 1f

322 if return_training: 1f

323 pred_training = np.empty(trial_targets.shape[0]) 1f

324 for train_index, test_index in cross_validation.split(population_activity): 1f

325 # Fit the model to the training data 

326 reg.fit(population_activity[train_index], trial_targets[train_index]) 1f

327 

328 # Predict the held-out test data 

329 pred[test_index] = reg.predict(population_activity[test_index]) 1f

330 

331 # Predict the training data 

332 if return_training: 1f

333 pred_training[train_index] = reg.predict(population_activity[train_index]) 1f

334 if return_training: 1f

335 return pred, pred_training 1f

336 else: 

337 return pred 1f

338 

339 

340def lda_project(spike_times, spike_clusters, event_times, event_groups, pre_time=0, post_time=0.5, 

341 cross_validation='kfold', num_splits=5, prob_left=None, custom_validation=None): 

342 """ 

343 Use linear discriminant analysis to project population vectors to the line that best separates 

344 the two groups. When cross-validation is used, the LDA projection is fitted on the training 

345 data after which the test data is projected to this projection. 

346 

347 spike_times : 1D array 

348 spike times (in seconds) 

349 spike_clusters : 1D array 

350 cluster ids corresponding to each event in `spikes` 

351 event_times : 1D array 

352 times (in seconds) of the events from the two groups 

353 event_groups : 1D array 

354 group identities of the events, can be any number of groups, accepts integers and strings 

355 cross_validation : string 

356 which cross-validation method to use, options are: 

357 'none' No cross-validation 

358 'kfold' K-fold cross-validation 

359 'leave-one-out' Leave out the trial that is being decoded 

360 'block' Leave out the block the to-be-decoded trial is in 

361 'custom' Any custom cross-validation provided by the user 

362 num_splits : integer 

363 ** only for 'kfold' cross-validation ** 

364 Number of splits to use for k-fold cross validation, a value of 5 means that the decoder 

365 will be trained on 4/5th of the data and used to predict the remaining 1/5th. This process 

366 is repeated five times so that all data has been used as both training and test set. 

367 prob_left : 1D array 

368 ** only for 'block' cross-validation ** 

369 the probability of the stimulus appearing on the left for each trial in event_times 

370 custom_validation : generator 

371 ** only for 'custom' cross-validation ** 

372 a generator object with the splits to be used for cross validation using this format: 

373 ( 

374 (split1_train_idxs, split1_test_idxs), 

375 (split2_train_idxs, split2_test_idxs), 

376 (split3_train_idxs, split3_test_idxs), 

377 ...) 

378 n_neurons : int 

379 Group size of number of neurons to be sub-selected 

380 

381 Returns 

382 ------- 

383 lda_projection : 1D array 

384 the position along the LDA projection axis for the population vector of each trial 

385 

386 """ 

387 

388 # Check input 

389 assert cross_validation in ['none', 'kfold', 'leave-one-out', 'block', 'custom'] 

390 assert event_times.shape[0] == event_groups.shape[0] 

391 if cross_validation == 'block': 

392 assert event_times.shape[0] == prob_left.shape[0] 

393 if cross_validation == 'custom': 

394 assert isinstance(custom_validation, types.GeneratorType) 

395 

396 # Get matrix of all neuronal responses 

397 times = np.column_stack(((event_times - pre_time), (event_times + post_time))) 

398 pop_vector, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) 

399 pop_vector = pop_vector.T 

400 

401 # Initialize 

402 lda = LinearDiscriminantAnalysis() 

403 lda_projection = np.zeros(event_groups.shape) 

404 

405 if cross_validation == 'none': 

406 # Find the best LDA projection on all data and transform those data 

407 lda_projection = lda.fit_transform(pop_vector, event_groups) 

408 

409 else: 

410 # Perform cross-validation 

411 if cross_validation == 'leave-one-out': 

412 cv = LeaveOneOut().split(pop_vector) 

413 elif cross_validation == 'kfold': 

414 cv = KFold(n_splits=num_splits).split(pop_vector) 

415 elif cross_validation == 'block': 

416 block_lengths = [sum(1 for i in g) for k, g in groupby(prob_left)] 

417 blocks = np.repeat(np.arange(len(block_lengths)), block_lengths) 

418 cv = LeaveOneGroupOut().split(pop_vector, groups=blocks) 

419 elif cross_validation == 'custom': 

420 cv = custom_validation 

421 

422 # Loop over the splits into train and test 

423 for train_index, test_index in cv: 

424 

425 # Find LDA projection on the training data 

426 lda.fit(pop_vector[train_index], [event_groups[j] for j in train_index]) 

427 

428 # Project the held-out test data to projection 

429 lda_projection[test_index] = lda.transform(pop_vector[test_index]).T[0] 

430 

431 return lda_projection 

432 

433 

434def sigtest_pseudosessions(X, y, fStatMeas, genPseudo, npseuds=200): 

435 """ 

436 Estimates significance level of any statistical measure following Harris, Arxiv, 2021 

437 (https://www.biorxiv.org/content/10.1101/2020.11.29.402719v2). 

438 fStatMeas computes a scalar statistical measure (e.g. R^2) between the data, X, and the 

439 decoded variable, y. pseudosessions are generated npseuds times to create a null 

440 distribution of statistical measures. Significance level is reported relative to this 

441 null distribution. 

442 

443 X : 2-d array 

444 Data of size (elements, timetrials) 

445 y : 1-d array 

446 predicted variable of size (timetrials) 

447 fStatMeas : function 

448 takes arguments (X, y) and returns a statistical measure relating how well X decodes y 

449 genPseudo : function 

450 takes no arguments () and returns a pseudosession (same shape as y) drawn from the 

451 experimentally known null-distribution of y 

452 npseuds : int 

453 the number of pseudosessions used to estimate the significance level 

454 

455 Returns 

456 ------- 

457 alpha : p-value e.g. at a significance level of b, if alpha <= b then reject the null 

458 hypothesis. 

459 statms_real : the value of the statistical measure evaluated on X and y 

460 statms_pseuds : array of statistical measures evaluated on pseudosessions 

461 """ 

462 statms_real = fStatMeas(X, y) 1o

463 statms_pseuds = np.zeros(npseuds) 1o

464 for i in range(npseuds): 1o

465 statms_pseuds[i] = fStatMeas(X, genPseudo()) 1o

466 

467 alpha = 1 - (0.01 * sp.stats.percentileofscore(statms_pseuds, statms_real, kind='weak')) 1o

468 

469 return alpha, statms_real, statms_pseuds 1o

470 

471 

472def sigtest_linshift(X, y, fStatMeas, D=300): 

473 """ 

474 Uses a provably conservative Linear Shift technique (Harris, Kenneth Arxiv 2021, 

475 https://arxiv.org/ftp/arxiv/papers/2012/2012.06862.pdf) to estimate 

476 significance level of a statistical measure. fStatMeas computes a 

477 scalar statistical measure (e.g. R^2) from the data matrix, X, and the variable, y. 

478 A central window of X and y of size, D, is linearly shifted to generate a null distribution 

479 of statistical measures. Significance level is reported relative to this null distribution. 

480 

481 X : 2-d array 

482 Data of size (elements, timetrials) 

483 y : 1-d array 

484 predicted variable of size (timetrials) 

485 fStatMeas : function 

486 takes arguments (X, y) and returns a scalar statistical measure of how well X decodes y 

487 D : int 

488 the window length along the center of y used to compute the statistical measure. 

489 must have room to shift both right and left: len(y) >= D+2 

490 

491 Returns 

492 ------- 

493 alpha : conservative p-value e.g. at a significance level of b, if alpha <= b then reject the 

494 null hypothesis. 

495 statms_real : the value of the statistical measure evaluated on X and y 

496 statms_pseuds : a 1-d array of statistical measures evaluated on shifted versions of y 

497 """ 

498 assert len(y) >= D + 2 1n

499 

500 T = len(y) 1n

501 N = int((T - D) / 2) 1n

502 

503 shifts = np.arange(-N, N + 1) 1n

504 

505 # compute all statms 

506 statms_real = fStatMeas(X[:, N:T - N], y[N:T - N]) 1n

507 statms_pseuds = np.zeros(len(shifts)) 1n

508 for si in range(len(shifts)): 1n

509 s = shifts[si] 1n

510 statms_pseuds[si] = fStatMeas(np.copy(X[:, N:T - N]), np.copy(y[s + N:s + T - N])) 1n

511 

512 M = np.sum(statms_pseuds >= statms_real) 1n

513 alpha = M / (N + 1) 1n

514 

515 return alpha, statms_real, statms_pseuds 1n