Coverage for brainbox/population/cca.py: 17%
203 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1import numpy as np
2import matplotlib.pylab as plt
3from scipy.signal.windows import gaussian
4from scipy.signal import convolve
5from sklearn.decomposition import PCA
7from iblutil.numerical import bincount2D
10def _smooth(data, sd):
12 n_bins = data.shape[0]
13 w = n_bins - 1 if n_bins % 2 == 0 else n_bins
14 window = gaussian(w, std=sd)
15 for j in range(data.shape[1]):
16 data[:, j] = convolve(data[:, j], window, mode='same', method='auto')
17 return data
20def _pca(data, n_pcs):
21 pca = PCA(n_components=n_pcs)
22 pca.fit(data)
23 data_pc = pca.transform(data)
24 return data_pc
27def preprocess(data, smoothing_sd=25, n_pcs=20):
28 """
29 Preprocess neural data for cca analysis with smoothing and pca
31 :param data: array of shape (n_samples, n_features)
32 :type data: array-like
33 :param smoothing_sd: gaussian smoothing kernel standard deviation (ms)
34 :type smoothing_sd: float
35 :param n_pcs: number of pca dimensions to retain
36 :type n_pcs: int
37 :return: preprocessed neural data
38 :rtype: array-like, shape (n_samples, pca_dims)
39 """
40 if smoothing_sd > 0:
41 data = _smooth(data, sd=smoothing_sd)
42 if n_pcs > 0:
43 data = _pca(data, n_pcs=n_pcs)
44 return data
47def split_trials(trial_ids, n_splits=5, rng_seed=0):
48 """
49 Assign each trial to testing or training fold
51 :param trial_ids:
52 :type trial_ids: array-like
53 :param n_splits: one split used for testing; remaining splits used for training
54 :type n_splits: int
55 :param rng_seed: set random state for shuffling trials
56 :type rng_seed: int
57 :return: list of dicts of indices with keys `train` and `test`
58 """
59 from sklearn.model_selection import KFold
60 shuffle = True if rng_seed is not None else False
61 kf = KFold(n_splits=n_splits, random_state=rng_seed, shuffle=shuffle)
62 kf.get_n_splits(trial_ids)
63 idxs = [None for _ in range(n_splits)]
64 for i, t0 in enumerate(kf.split(trial_ids)):
65 idxs[i] = {'train': t0[0], 'test': t0[1]}
66 return idxs
69def split_timepoints(trial_ids, idxs_trial):
70 """
71 Assign each time point to testing or training fold
73 :param trial_ids: trial id for each timepoint
74 :type trial_ids: array-like
75 :param idxs_trial: list of dicts that define which trials are in `train` or `test` folds
76 :type idxs_trial: list
77 :return: list of dicts that define which time points are in `train` and `test` folds
78 """
79 idxs_time = [None for _ in range(len(idxs_trial))]
80 for i, idxs in enumerate(idxs_trial):
81 idxs_time[i] = {
82 dtype: np.where(np.isin(trial_ids, idxs[dtype]))[0] for dtype in idxs.keys()}
83 return idxs_time
86def fit_cca(data_0, data_1, n_cca_dims=10):
87 """
88 Initialize and fit CCA sklearn object
90 :param data_0: shape (n_samples, n_features_0)
91 :type data_0: array-like
92 :param data_1: shape (n_samples, n_features_1)
93 :type data_1: array-like
94 :param n_cca_dims: number of CCA dimensions to fit
95 :type n_cca_dims: int
96 :return: sklearn cca object
97 """
98 from sklearn.cross_decomposition import CCA
99 cca = CCA(n_components=n_cca_dims, max_iter=1000)
100 cca.fit(data_0, data_1)
101 return cca
104def get_cca_projection(cca, data_0, data_1):
105 """
106 Project data into CCA dimensions
108 :param cca:
109 :param data_0:
110 :param data_1:
111 :return: tuple; (data_0 projection, data_1 projection)
112 """
113 x_scores, y_scores = cca.transform(data_0, data_1)
114 return x_scores, y_scores
117def get_correlations(cca, data_0, data_1):
118 """
120 :param cca:
121 :param data_0:
122 :param data_1:
123 :return:
124 """
125 x_scores, y_scores = get_cca_projection(cca, data_0, data_1)
126 corrs_tmp = np.corrcoef(x_scores.T, y_scores.T)
127 corrs = np.diagonal(corrs_tmp, offset=data_0.shape[1])
128 return corrs
131def shuffle_analysis(data_0, data_1, n_shuffles=100, **cca_kwargs):
132 """
133 Perform CCA on shuffled data
135 :param data_0:
136 :param data_1:
137 :param n_shuffles:
138 :return:
139 """
140 # TODO
141 pass
144def plot_correlations(corrs, errors=None, ax=None, **plot_kwargs):
145 """
146 Correlation vs CCA dimension
148 :param corrs: correlation values for the CCA dimensions
149 :type corrs: 1-D vector
150 :param errors: error values
151 :type shuffled: 1-D array of size len(corrs)
152 :param ax: axis to plot on (default None)
153 :type ax: matplotlib axis object
154 :return: axis if specified, or plot if axis = None
155 """
156 # evaluate if np.arrays are passed
157 assert type(corrs) is np.ndarray, "'corrs' is not a numpy array." 1b
158 if errors is not None: 1b
159 assert type(errors) is np.ndarray, "'errors' is not a numpy array." 1b
160 # create axis if no axis is passed
161 if ax is None: 1b
162 ax = plt.gca()
163 # get the data for the x and y axis
164 y_data = corrs 1b
165 x_data = range(1, (len(corrs) + 1)) 1b
166 # create the plot object
167 ax.plot(x_data, y_data, **plot_kwargs) 1b
168 if errors is not None: 1b
169 ax.fill_between(x_data, y_data - errors, y_data + errors, **plot_kwargs, alpha=0.2) 1b
170 # change y and x labels and ticks
171 ax.set_xticks(x_data) 1b
172 ax.set_ylabel("Correlation") 1b
173 ax.set_xlabel("CCA dimension") 1b
174 return ax 1b
177def plot_pairwise_correlations(means, stderrs=None, n_dims=None, region_strs=None, **kwargs):
178 """
179 Plot CCA correlations for multiple pairs of regions
181 :param means: list of lists; means[i][j] contains the mean corrs between regions i, j
182 :param stderrs: list of lists; stderrs[i][j] contains std errors of corrs between regions i, j
183 :param n_dims: number of CCA dimensions to plot
184 :param region_strs: list of strings identifying each region
185 :param kwargs: keyword arguments for plot
186 :return: matplotlib figure handle
187 """
188 n_regions = len(means)
190 fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12))
191 for r in range(n_regions - 1):
192 for c in range(n_regions - 1):
193 axes[r, c].axis('off')
195 # get max correlation to standardize y axes
196 max_val = 0
197 for r in range(1, n_regions):
198 for c in range(r):
199 tmp = means[r][c]
200 if tmp is not None:
201 max_val = np.max([max_val, np.max(tmp)])
203 for r in range(1, n_regions):
204 for c in range(r):
205 ax = axes[r - 1, c]
206 ax.axis('on')
207 ax = plot_correlations(means[r][c][:n_dims], stderrs[r][c][:n_dims], ax=ax, **kwargs)
208 ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k')
209 if region_strs is not None:
210 ax.text(
211 x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])),
212 horizontalalignment='right',
213 verticalalignment='top',
214 transform=ax.transAxes)
215 ax.set_ylim([-0.05, max_val + 0.05])
216 if not ax.is_first_col():
217 ax.set_ylabel('')
218 ax.set_yticks([])
219 if not ax.is_last_row():
220 ax.set_xlabel('')
221 ax.set_xticks([])
222 plt.tight_layout()
223 plt.show()
225 return fig
228def plot_pairwise_correlations_mult(
229 means, stderrs, colvec, n_dims=None, region_strs=None, **kwargs):
230 """
231 Plot CCA correlations for multiple pairs of regions, for multiple behavioural events
233 :param means: list of lists; means[k][i][j] contains the mean corrs between regions i, j for
234 behavioral event k
235 :param stderrs: list of lists; stderrs[k][i][j] contains std errors of corrs between
236 regions i, j for behavioral event k
237 :param colvec: color vector [must be a better way for this]
238 :param n_dims: number of CCA dimensions to plot
239 :param region_strs: list of strings identifying each region
240 :param kwargs: keyword arguments for plot
241 :return: matplotlib figure handle
242 """
243 n_regions = len(means[0])
245 fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12))
246 for r in range(n_regions - 1):
247 for c in range(n_regions - 1):
248 axes[r, c].axis('off')
250 # get max correlation to standardize y axes
251 max_val = 0
252 for b in range(len(means)):
253 for r in range(1, n_regions):
254 for c in range(r):
255 tmp = means[b][r][c]
256 if tmp is not None:
257 max_val = np.max([max_val, np.max(tmp)])
259 for r in range(1, n_regions):
260 for c in range(r):
261 ax = axes[r - 1, c]
262 ax.axis('on')
263 for b in range(len(means)):
264 plot_correlations(means[b][r][c][:n_dims], stderrs[b][r][c][:n_dims],
265 ax=ax, color=colvec[b], **kwargs)
266 ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k')
267 if region_strs is not None:
268 ax.text(
269 x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])),
270 horizontalalignment='right',
271 verticalalignment='top',
272 transform=ax.transAxes)
273 ax.set_ylim([-0.05, max_val + 0.05])
274 if not ax.is_first_col():
275 ax.set_ylabel('')
276 ax.set_yticks([])
277 if not ax.is_last_row():
278 ax.set_xlabel('')
279 ax.set_xticks([])
280 plt.tight_layout()
281 plt.show()
283 return fig
286def bin_spikes_trials(spikes, trials, bin_size=0.01):
287 """
288 Binarizes the spike times into a raster and assigns a trial number to each bin
290 :param spikes: spikes object
291 :type spikes: Bunch
292 :param trials: trials object
293 :type trials: Bunch
294 :param bin_size: size, in s, of the bins
295 :type bin_size: float
296 :return: a matrix (bins, SpikeCounts), and a vector of bins size with trial ID,
297 and a vector bins size with the time that the bins start
298 """
299 binned_spikes, bin_times, _ = bincount2D(spikes['times'], spikes['clusters'], bin_size)
300 trial_start_times = trials['intervals'][:, 0]
301 binned_trialIDs = np.digitize(bin_times, trial_start_times)
302 # correct, as index 0 is whatever happens before the first trial
303 binned_trialIDs_corrected = binned_trialIDs - 1
305 return binned_spikes.T, binned_trialIDs_corrected, bin_times
308def split_by_area(binned_spikes, cl_brainAcronyms, active_clusters, brain_areas):
309 """
310 This function converts a matrix of binned spikes into a list of matrices, with the clusters
311 grouped by brain areas
313 :param binned_spikes: binned spike data of shape (n_bins, n_lusters)
314 :type binned_spikes: numpy.ndarray
315 :param cl_brainAcronyms: brain region for each cluster
316 :type cl_brainAcronyms: pandas.core.frame.DataFrame
317 :param brain_areas: list of brain areas to select
318 :type brain_areas: numpy.ndarray
319 :param active_clusters: list of clusterIDs
320 :type active_clusters: numpy.ndarray
321 :return: list of numpy.ndarrays of size brain_areas
322 """
323 # TODO: check that this is doing what it is suppossed to!!!
325 # TODO: check that input is as expected
326 #
327 # initialize list
328 listof_bs = []
329 for b_area in brain_areas:
330 # get the ids of clusters in the area
331 cl_in_area = cl_brainAcronyms.loc[cl_brainAcronyms['brainAcronyms'] == b_area].index
332 # get the indexes of the clusters that are in that area
333 cl_idx_in_area = np.isin(active_clusters, cl_in_area)
334 bs_in_area = binned_spikes[:, cl_idx_in_area]
335 listof_bs.append(bs_in_area)
336 return listof_bs
339def get_event_bin_indexes(event_times, bin_times, window):
340 """
341 Get the indexes of the bins corresponding to a specific behavioral event within a window
343 :param event_times: time series of an event
344 :type event_times: numpy.array
345 :param bin_times: time series pf starting point of bins
346 :type bin_times: numpy.array
347 :param window: list of size 2 specifying the window in seconds [-time before, time after]
348 :type window: numpy.array
349 :return: array of indexes
350 """
351 # TODO: check that this is doing what it is supposed to (coded during codecamp in a rush)
352 # find bin size
353 bin_size = bin_times[1] - bin_times[0]
354 # find window size in bin units
355 bin_window = int(np.ceil((window[1] - window[0]) / bin_size))
356 # correct event_times to the start of the window
357 event_times_corrected = event_times - window[0]
359 # get the indexes of the bins that are containing each event and add the window after
360 idx_array = np.empty(shape=0)
361 for etc in event_times_corrected:
362 start_idx = (np.abs(bin_times - etc)).argmin()
363 # add the window
364 arr_to_append = np.array(range(start_idx, start_idx + bin_window))
365 idx_array = np.concatenate((idx_array, arr_to_append), axis=None)
367 # remove the non-existing bins if any
369 return idx_array.astype(int)
372if __name__ == '__main__':
374 from pathlib import Path
375 from oneibl.one import ONE
376 import alf.io as ioalf
378 BIN_SIZE = 0.025 # seconds
379 SMOOTH_SIZE = 0.025 # seconds; standard deviation of gaussian kernel
380 PCA_DIMS = 20
381 CCA_DIMS = PCA_DIMS
382 N_SPLITS = 5
383 RNG_SEED = 0
385 # get the data from flatiron
386 subject = 'KS005'
387 date = '2019-08-30'
388 number = 1
390 one = ONE()
391 eid = one.search(subject=subject, date=date, number=number)
392 D = one.load(eid[0], download_only=True)
393 session_path = Path(D.local_path[0]).parent
395 spikes = ioalf.load_object(session_path, 'spikes')
396 clusters = ioalf.load_object(session_path, 'clusters')
397 # channels = ioalf.load_object(session_path, 'channels')
398 trials = ioalf.load_object(session_path, 'trials')
400 # bin spikes and get trial IDs associated with them
401 binned_spikes, binned_trialIDs, _ = bin_spikes_trials(spikes, trials, bin_size=0.01)
403 # define areas
404 brain_areas = np.unique(clusters.brainAcronyms)
405 brain_areas = brain_areas[1:4] # [take subset for testing]
407 # split data by brain area
408 # (bin_spikes_trials does not return info for innactive clusters)
409 active_clusters = np.unique(spikes['clusters'])
410 split_binned_spikes = split_by_area(
411 binned_spikes, clusters.brainAcronyms, active_clusters, brain_areas)
413 # preprocess data
414 for i, pop in enumerate(split_binned_spikes):
415 split_binned_spikes[i] = preprocess(pop, n_pcs=PCA_DIMS, smoothing_sd=SMOOTH_SIZE)
417 # split trials
418 idxs_trial = split_trials(np.unique(binned_trialIDs), n_splits=N_SPLITS, rng_seed=RNG_SEED)
419 # get train/test indices into spike arrays
420 idxs_time = split_timepoints(binned_trialIDs, idxs_trial)
422 # Create empty "matrix" to store cca objects
423 n_regions = len(brain_areas)
424 cca_mat = [[None for _ in range(n_regions)] for _ in range(n_regions)]
425 means_list = [[None for _ in range(n_regions)] for _ in range(n_regions)]
426 serrs_list = [[None for _ in range(n_regions)] for _ in range(n_regions)]
427 # For each pair of populations:
428 for i in range(len(brain_areas)):
429 pop_0 = split_binned_spikes[i]
430 for j in range(len(brain_areas)):
431 if j < i:
432 # print progress
433 print('Fitting CCA on regions {} / {}'.format(i, j))
434 pop_1 = split_binned_spikes[j]
435 ccas = [None for _ in range(N_SPLITS)]
436 corrs = [None for _ in range(N_SPLITS)]
437 # for each xv fold
438 for k, idxs in enumerate(idxs_time):
439 ccas[k] = fit_cca(
440 pop_0[idxs['train'], :], pop_1[idxs['train'], :], n_cca_dims=CCA_DIMS)
441 corrs[k] = get_correlations(
442 ccas[k], pop_0[idxs['test'], :], pop_1[idxs['test'], :])
443 cca_mat[i][j] = ccas[k]
444 vals = np.stack(corrs, axis=1)
445 means_list[i][j] = np.mean(vals, axis=1)
446 serrs_list[i][j] = np.std(vals, axis=1) / np.sqrt(N_SPLITS)
448 # plot matrix of correlations
449 fig = plot_pairwise_correlations(means_list, serrs_list, n_dims=10, region_strs=brain_areas)