Coverage for brainbox/population/decode.py: 79%
168 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""
2Population functions.
4Code from https://github.com/cortex-lab/phylib/blob/master/phylib/stats/ccg.py by C. Rossant.
5Code for decoding by G. Meijer
6Code from sigtest_pseudosessions and sigtest_linshift by B. Benson
7"""
9import numpy as np
10import scipy as sp
11import scipy.stats
12import types
13from itertools import groupby
14from sklearn.linear_model import LinearRegression, Lasso, Ridge
15from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
16from sklearn.model_selection import KFold, LeaveOneOut, LeaveOneGroupOut
17from sklearn.metrics import accuracy_score
20def get_spike_counts_in_bins(spike_times, spike_clusters, intervals):
21 """
22 Return the number of spikes in a sequence of time intervals, for each neuron.
24 Parameters
25 ----------
26 spike_times : 1D array
27 spike times (in seconds)
28 spike_clusters : 1D array
29 cluster ids corresponding to each event in `spikes`
30 intervals : 2D array of shape (n_events, 2)
31 the start and end times of the events
33 Returns
34 ---------
35 counts : 2D array of shape (n_neurons, n_events)
36 the spike counts of all neurons ffrom scipy.stats import sem, tor all events
37 value (i, j) is the number of spikes of neuron `neurons[i]` in interval #j
38 cluster_ids : 1D array
39 list of cluster ids
40 """
42 # Check input
43 assert intervals.ndim == 2 1gifjklm
44 assert intervals.shape[1] == 2 1gifjklm
45 assert np.all(np.diff(spike_times) >= 0), "Spike times need to be sorted" 1gifjklm
47 intervals_idx = np.searchsorted(spike_times, intervals) 1gifjklm
49 # For each neuron and each interval, the number of spikes in the interval.
50 cluster_ids = np.unique(spike_clusters) 1gifjklm
51 n_neurons = len(cluster_ids) 1gifjklm
52 n_intervals = intervals.shape[0] 1gifjklm
53 counts = np.zeros((n_neurons, n_intervals), dtype=np.uint32) 1gifjklm
54 for j in range(n_intervals): 1gifjklm
55 i0, i1 = intervals_idx[j, :] 1gifjklm
56 # Count the number of spikes in the window, for each neuron.
57 x = np.bincount(spike_clusters[i0:i1], minlength=cluster_ids.max() + 1) 1gifjklm
58 counts[:, j] = x[cluster_ids] 1gifjklm
59 return counts, cluster_ids 1gifjklm
62def _index_of(arr, lookup):
63 """Replace scalars in an array by their indices in a lookup table.
65 Implicitly assume that:
67 * All elements of arr and lookup are non-negative integers.
68 * All elements or arr belong to lookup.
70 This is not checked for performance reasons.
72 """
73 # Equivalent of np.digitize(arr, lookup) - 1, but much faster.
74 # TODO: assertions to disable in production for performance reasons.
75 # TODO: np.searchsorted(lookup, arr) is faster on small arrays with large
76 # values
77 lookup = np.asarray(lookup, dtype=np.int32) 1abcde
78 m = (lookup.max() if len(lookup) else 0) + 1 1abcde
79 tmp = np.zeros(m + 1, dtype=int) 1abcde
80 # Ensure that -1 values are kept.
81 tmp[-1] = -1 1abcde
82 if len(lookup): 1abcde
83 tmp[lookup] = np.arange(len(lookup)) 1abcde
84 return tmp[arr] 1abcde
87def _increment(arr, indices):
88 """Increment some indices in a 1D vector of non-negative integers.
89 Repeated indices are taken into account."""
90 bbins = np.bincount(indices) 1abcde
91 arr[:len(bbins)] += bbins 1abcde
92 return arr 1abcde
95def _diff_shifted(arr, steps=1):
96 return arr[steps:] - arr[:len(arr) - steps] 1abcde
99def _create_correlograms_array(n_clusters, winsize_bins):
100 return np.zeros((n_clusters, n_clusters, winsize_bins // 2 + 1), dtype=np.int32) 1abcde
103def _symmetrize_correlograms(correlograms):
104 """Return the symmetrized version of the CCG arrays."""
106 n_clusters, _, n_bins = correlograms.shape 1abcde
107 assert n_clusters == _ 1abcde
109 # We symmetrize c[i, j, 0].
110 # This is necessary because the algorithm in correlograms()
111 # is sensitive to the order of identical spikes.
112 correlograms[..., 0] = np.maximum( 1abcde
113 correlograms[..., 0], correlograms[..., 0].T)
115 sym = correlograms[..., 1:][..., ::-1] 1abcde
116 sym = np.transpose(sym, (1, 0, 2)) 1abcde
118 return np.dstack((sym, correlograms)) 1abcde
121def xcorr(spike_times, spike_clusters, bin_size=None, window_size=None):
122 """Compute all pairwise cross-correlograms among the clusters appearing in `spike_clusters`.
124 Parameters
125 ----------
127 :param spike_times: Spike times in seconds.
128 :type spike_times: array-like
129 :param spike_clusters: Spike-cluster mapping.
130 :type spike_clusters: array-like
131 :param bin_size: Size of the bin, in seconds.
132 :type bin_size: float
133 :param window_size: Size of the window, in seconds.
134 :type window_size: float
136 Returns an `(n_clusters, n_clusters, winsize_samples)` array with all pairwise
137 cross-correlograms.
139 """
140 assert np.all(np.diff(spike_times) >= 0), "The spike times must be increasing." 1abcde
141 assert spike_times.ndim == 1 1abcde
142 assert spike_times.shape == spike_clusters.shape 1abcde
144 # Find `binsize`.
145 bin_size = np.clip(bin_size, 1e-5, 1e5) # in seconds 1abcde
147 # Find `winsize_bins`.
148 window_size = np.clip(window_size, 1e-5, 1e5) # in seconds 1abcde
149 winsize_bins = 2 * int(.5 * window_size / bin_size) + 1 1abcde
151 # Take the cluster order into account.
152 clusters = np.unique(spike_clusters) 1abcde
153 n_clusters = len(clusters) 1abcde
155 # Like spike_clusters, but with 0..n_clusters-1 indices.
156 spike_clusters_i = _index_of(spike_clusters, clusters) 1abcde
158 # Shift between the two copies of the spike trains.
159 shift = 1 1abcde
161 # At a given shift, the mask precises which spikes have matching spikes
162 # within the correlogram time window.
163 mask = np.ones_like(spike_times, dtype=bool) 1abcde
165 correlograms = _create_correlograms_array(n_clusters, winsize_bins) 1abcde
167 # The loop continues as long as there is at least one spike with
168 # a matching spike.
169 while mask[:-shift].any(): 1abcde
170 # Interval between spike i and spike i+shift.
171 spike_diff = _diff_shifted(spike_times, shift) 1abcde
173 # Binarize the delays between spike i and spike i+shift.
174 spike_diff_b = np.round(spike_diff / bin_size).astype(np.int64) 1abcde
176 # Spikes with no matching spikes are masked.
177 mask[:-shift][spike_diff_b > (winsize_bins / 2)] = False 1abcde
179 # Cache the masked spike delays.
180 m = mask[:-shift].copy() 1abcde
181 d = spike_diff_b[m] 1abcde
183 # Find the indices in the raveled correlograms array that need
184 # to be incremented, taking into account the spike clusters.
185 indices = np.ravel_multi_index( 1abcde
186 (spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d), correlograms.shape)
188 # Increment the matching spikes in the correlograms array.
189 _increment(correlograms.ravel(), indices) 1abcde
191 shift += 1 1abcde
193 return _symmetrize_correlograms(correlograms) 1abcde
196def classify(population_activity, trial_labels, classifier, cross_validation=None,
197 return_training=False):
198 """
199 Classify trial identity (e.g. stim left/right) from neural population activity.
201 Parameters
202 ----------
203 population_activity : 2D array (trials x neurons)
204 population activity of all neurons in the population for each trial.
205 trial_labels : 1D or 2D array
206 identities of the trials, can be any number of groups, accepts integers and strings
207 classifier : scikit-learn object
208 which decoder to use, for example Gaussian with Multinomial likelihood:
209 from sklearn.naive_bayes import MultinomialNB
210 classifier = MultinomialNB()
211 cross_validation : None or scikit-learn object
212 which cross-validation method to use, for example 5-fold:
213 from sklearn.model_selection import KFold
214 cross_validation = KFold(n_splits=5)
215 return_training : bool
216 if set to True the classifier will also return the performance on the training set
218 Returns
219 -------
220 accuracy : float
221 accuracy of the classifier
222 pred : 1D array
223 predictions of the classifier
224 prob : 1D array
225 probablity of classification
226 training_accuracy : float
227 accuracy of the classifier on the training set (only if return_training is True)
228 """
230 # Check input
231 if (cross_validation is None) and (return_training is True): 1g
232 raise RuntimeError('cannot return training accuracy without cross-validation')
233 if population_activity.shape[0] != trial_labels.shape[0]: 1g
234 raise ValueError('trial_labels is not the same length as the first dimension of '
235 'population_activity')
237 if cross_validation is None: 1g
238 # Fit the model on all the data
239 classifier.fit(population_activity, trial_labels)
240 pred = classifier.predict(population_activity)
241 prob = classifier.predict_proba(population_activity)
242 prob = prob[:, 1]
243 else:
244 pred = np.empty(trial_labels.shape[0]) 1g
245 prob = np.empty(trial_labels.shape[0]) 1g
246 if return_training: 1g
247 pred_training = np.empty(trial_labels.shape[0]) 1g
249 for train_index, test_index in cross_validation.split(population_activity): 1g
250 # Fit the model to the training data
251 classifier.fit(population_activity[train_index], trial_labels[train_index]) 1g
253 # Predict the held-out test data
254 pred[test_index] = classifier.predict(population_activity[test_index]) 1g
255 proba = classifier.predict_proba(population_activity[test_index]) 1g
256 prob[test_index] = proba[:, 1] 1g
258 # Predict the training data
259 if return_training: 1g
260 pred_training[train_index] = classifier.predict(population_activity[train_index]) 1g
262 # Calculate accuracy
263 accuracy = accuracy_score(trial_labels, pred) 1g
264 if return_training: 1g
265 training_accuracy = accuracy_score(trial_labels, pred_training) 1g
266 return accuracy, pred, prob, training_accuracy 1g
267 else:
268 return accuracy, pred, prob
271def regress(population_activity, trial_targets, regularization=None,
272 cross_validation=None, return_training=False):
273 """
274 Perform linear regression to predict a continuous variable from neural data
276 Parameters
277 ----------
278 population_activity : 2D array (trials x neurons)
279 population activity of all neurons in the population for each trial.
280 trial_targets : 1D or 2D array
281 the decoding target per trial as a continuous variable
282 regularization : None or string
283 None = no regularization using ordinary least squares linear regression
284 'L1' = L1 regularization using Lasso
285 'L2' = L2 regularization using Ridge regression
286 cross_validation : None or scikit-learn object
287 which cross-validation method to use, for example 5-fold:
288 from sklearn.model_selection import KFold
289 cross_validation = KFold(n_splits=5)
290 return_training : bool
291 if set to True the classifier will also return the performance on the training set
293 Returns
294 -------
295 pred : 1D array
296 array with predictions
297 pred_training : 1D array
298 array with predictions for the training set (only if return_training is True)
299 """
301 # Check input
302 if (cross_validation is None) and (return_training is True): 1f
303 raise RuntimeError('cannot return training accuracy without cross-validation')
304 if population_activity.shape[0] != trial_targets.shape[0]: 1f
305 raise ValueError('trial_targets is not the same length as the first dimension of '
306 'population_activity')
308 # Initialize regression
309 if regularization is None: 1f
310 reg = LinearRegression() 1f
311 elif regularization == 'L1': 1f
312 reg = Lasso() 1f
313 elif regularization == 'L2': 1f
314 reg = Ridge() 1f
316 if cross_validation is None: 1f
317 # Fit the model on all the data
318 reg.fit(population_activity, trial_targets) 1f
319 pred = reg.predict(population_activity) 1f
320 else:
321 pred = np.empty(trial_targets.shape[0]) 1f
322 if return_training: 1f
323 pred_training = np.empty(trial_targets.shape[0]) 1f
324 for train_index, test_index in cross_validation.split(population_activity): 1f
325 # Fit the model to the training data
326 reg.fit(population_activity[train_index], trial_targets[train_index]) 1f
328 # Predict the held-out test data
329 pred[test_index] = reg.predict(population_activity[test_index]) 1f
331 # Predict the training data
332 if return_training: 1f
333 pred_training[train_index] = reg.predict(population_activity[train_index]) 1f
334 if return_training: 1f
335 return pred, pred_training 1f
336 else:
337 return pred 1f
340def lda_project(spike_times, spike_clusters, event_times, event_groups, pre_time=0, post_time=0.5,
341 cross_validation='kfold', num_splits=5, prob_left=None, custom_validation=None):
342 """
343 Use linear discriminant analysis to project population vectors to the line that best separates
344 the two groups. When cross-validation is used, the LDA projection is fitted on the training
345 data after which the test data is projected to this projection.
347 spike_times : 1D array
348 spike times (in seconds)
349 spike_clusters : 1D array
350 cluster ids corresponding to each event in `spikes`
351 event_times : 1D array
352 times (in seconds) of the events from the two groups
353 event_groups : 1D array
354 group identities of the events, can be any number of groups, accepts integers and strings
355 cross_validation : string
356 which cross-validation method to use, options are:
357 'none' No cross-validation
358 'kfold' K-fold cross-validation
359 'leave-one-out' Leave out the trial that is being decoded
360 'block' Leave out the block the to-be-decoded trial is in
361 'custom' Any custom cross-validation provided by the user
362 num_splits : integer
363 ** only for 'kfold' cross-validation **
364 Number of splits to use for k-fold cross validation, a value of 5 means that the decoder
365 will be trained on 4/5th of the data and used to predict the remaining 1/5th. This process
366 is repeated five times so that all data has been used as both training and test set.
367 prob_left : 1D array
368 ** only for 'block' cross-validation **
369 the probability of the stimulus appearing on the left for each trial in event_times
370 custom_validation : generator
371 ** only for 'custom' cross-validation **
372 a generator object with the splits to be used for cross validation using this format:
373 (
374 (split1_train_idxs, split1_test_idxs),
375 (split2_train_idxs, split2_test_idxs),
376 (split3_train_idxs, split3_test_idxs),
377 ...)
378 n_neurons : int
379 Group size of number of neurons to be sub-selected
381 Returns
382 -------
383 lda_projection : 1D array
384 the position along the LDA projection axis for the population vector of each trial
386 """
388 # Check input
389 assert cross_validation in ['none', 'kfold', 'leave-one-out', 'block', 'custom']
390 assert event_times.shape[0] == event_groups.shape[0]
391 if cross_validation == 'block':
392 assert event_times.shape[0] == prob_left.shape[0]
393 if cross_validation == 'custom':
394 assert isinstance(custom_validation, types.GeneratorType)
396 # Get matrix of all neuronal responses
397 times = np.column_stack(((event_times - pre_time), (event_times + post_time)))
398 pop_vector, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times)
399 pop_vector = pop_vector.T
401 # Initialize
402 lda = LinearDiscriminantAnalysis()
403 lda_projection = np.zeros(event_groups.shape)
405 if cross_validation == 'none':
406 # Find the best LDA projection on all data and transform those data
407 lda_projection = lda.fit_transform(pop_vector, event_groups)
409 else:
410 # Perform cross-validation
411 if cross_validation == 'leave-one-out':
412 cv = LeaveOneOut().split(pop_vector)
413 elif cross_validation == 'kfold':
414 cv = KFold(n_splits=num_splits).split(pop_vector)
415 elif cross_validation == 'block':
416 block_lengths = [sum(1 for i in g) for k, g in groupby(prob_left)]
417 blocks = np.repeat(np.arange(len(block_lengths)), block_lengths)
418 cv = LeaveOneGroupOut().split(pop_vector, groups=blocks)
419 elif cross_validation == 'custom':
420 cv = custom_validation
422 # Loop over the splits into train and test
423 for train_index, test_index in cv:
425 # Find LDA projection on the training data
426 lda.fit(pop_vector[train_index], [event_groups[j] for j in train_index])
428 # Project the held-out test data to projection
429 lda_projection[test_index] = lda.transform(pop_vector[test_index]).T[0]
431 return lda_projection
434def sigtest_pseudosessions(X, y, fStatMeas, genPseudo, npseuds=200):
435 """
436 Estimates significance level of any statistical measure following Harris, Arxiv, 2021
437 (https://www.biorxiv.org/content/10.1101/2020.11.29.402719v2).
438 fStatMeas computes a scalar statistical measure (e.g. R^2) between the data, X, and the
439 decoded variable, y. pseudosessions are generated npseuds times to create a null
440 distribution of statistical measures. Significance level is reported relative to this
441 null distribution.
443 X : 2-d array
444 Data of size (elements, timetrials)
445 y : 1-d array
446 predicted variable of size (timetrials)
447 fStatMeas : function
448 takes arguments (X, y) and returns a statistical measure relating how well X decodes y
449 genPseudo : function
450 takes no arguments () and returns a pseudosession (same shape as y) drawn from the
451 experimentally known null-distribution of y
452 npseuds : int
453 the number of pseudosessions used to estimate the significance level
455 Returns
456 -------
457 alpha : p-value e.g. at a significance level of b, if alpha <= b then reject the null
458 hypothesis.
459 statms_real : the value of the statistical measure evaluated on X and y
460 statms_pseuds : array of statistical measures evaluated on pseudosessions
461 """
462 statms_real = fStatMeas(X, y) 1o
463 statms_pseuds = np.zeros(npseuds) 1o
464 for i in range(npseuds): 1o
465 statms_pseuds[i] = fStatMeas(X, genPseudo()) 1o
467 alpha = 1 - (0.01 * sp.stats.percentileofscore(statms_pseuds, statms_real, kind='weak')) 1o
469 return alpha, statms_real, statms_pseuds 1o
472def sigtest_linshift(X, y, fStatMeas, D=300):
473 """
474 Uses a provably conservative Linear Shift technique (Harris, Kenneth Arxiv 2021,
475 https://arxiv.org/ftp/arxiv/papers/2012/2012.06862.pdf) to estimate
476 significance level of a statistical measure. fStatMeas computes a
477 scalar statistical measure (e.g. R^2) from the data matrix, X, and the variable, y.
478 A central window of X and y of size, D, is linearly shifted to generate a null distribution
479 of statistical measures. Significance level is reported relative to this null distribution.
481 X : 2-d array
482 Data of size (elements, timetrials)
483 y : 1-d array
484 predicted variable of size (timetrials)
485 fStatMeas : function
486 takes arguments (X, y) and returns a scalar statistical measure of how well X decodes y
487 D : int
488 the window length along the center of y used to compute the statistical measure.
489 must have room to shift both right and left: len(y) >= D+2
491 Returns
492 -------
493 alpha : conservative p-value e.g. at a significance level of b, if alpha <= b then reject the
494 null hypothesis.
495 statms_real : the value of the statistical measure evaluated on X and y
496 statms_pseuds : a 1-d array of statistical measures evaluated on shifted versions of y
497 """
498 assert len(y) >= D + 2 1n
500 T = len(y) 1n
501 N = int((T - D) / 2) 1n
503 shifts = np.arange(-N, N + 1) 1n
505 # compute all statms
506 statms_real = fStatMeas(X[:, N:T - N], y[N:T - N]) 1n
507 statms_pseuds = np.zeros(len(shifts)) 1n
508 for si in range(len(shifts)): 1n
509 s = shifts[si] 1n
510 statms_pseuds[si] = fStatMeas(np.copy(X[:, N:T - N]), np.copy(y[s + N:s + T - N])) 1n
512 M = np.sum(statms_pseuds >= statms_real) 1n
513 alpha = M / (N + 1) 1n
515 return alpha, statms_real, statms_pseuds 1n