Coverage for brainbox/population/cca.py: 16%
203 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1import numpy as np
2import matplotlib.pylab as plt
3from iblutil.numerical import bincount2D
6def _smooth(data, sd):
7 from scipy.signal import gaussian
8 from scipy.signal import convolve
9 n_bins = data.shape[0]
10 w = n_bins - 1 if n_bins % 2 == 0 else n_bins
11 window = gaussian(w, std=sd)
12 for j in range(data.shape[1]):
13 data[:, j] = convolve(data[:, j], window, mode='same', method='auto')
14 return data
17def _pca(data, n_pcs):
18 from sklearn.decomposition import PCA
19 pca = PCA(n_components=n_pcs)
20 pca.fit(data)
21 data_pc = pca.transform(data)
22 return data_pc
25def preprocess(data, smoothing_sd=25, n_pcs=20):
26 """
27 Preprocess neural data for cca analysis with smoothing and pca
29 :param data: array of shape (n_samples, n_features)
30 :type data: array-like
31 :param smoothing_sd: gaussian smoothing kernel standard deviation (ms)
32 :type smoothing_sd: float
33 :param n_pcs: number of pca dimensions to retain
34 :type n_pcs: int
35 :return: preprocessed neural data
36 :rtype: array-like, shape (n_samples, pca_dims)
37 """
38 if smoothing_sd > 0:
39 data = _smooth(data, sd=smoothing_sd)
40 if n_pcs > 0:
41 data = _pca(data, n_pcs=n_pcs)
42 return data
45def split_trials(trial_ids, n_splits=5, rng_seed=0):
46 """
47 Assign each trial to testing or training fold
49 :param trial_ids:
50 :type trial_ids: array-like
51 :param n_splits: one split used for testing; remaining splits used for training
52 :type n_splits: int
53 :param rng_seed: set random state for shuffling trials
54 :type rng_seed: int
55 :return: list of dicts of indices with keys `train` and `test`
56 """
57 from sklearn.model_selection import KFold
58 shuffle = True if rng_seed is not None else False
59 kf = KFold(n_splits=n_splits, random_state=rng_seed, shuffle=shuffle)
60 kf.get_n_splits(trial_ids)
61 idxs = [None for _ in range(n_splits)]
62 for i, t0 in enumerate(kf.split(trial_ids)):
63 idxs[i] = {'train': t0[0], 'test': t0[1]}
64 return idxs
67def split_timepoints(trial_ids, idxs_trial):
68 """
69 Assign each time point to testing or training fold
71 :param trial_ids: trial id for each timepoint
72 :type trial_ids: array-like
73 :param idxs_trial: list of dicts that define which trials are in `train` or `test` folds
74 :type idxs_trial: list
75 :return: list of dicts that define which time points are in `train` and `test` folds
76 """
77 idxs_time = [None for _ in range(len(idxs_trial))]
78 for i, idxs in enumerate(idxs_trial):
79 idxs_time[i] = {
80 dtype: np.where(np.isin(trial_ids, idxs[dtype]))[0] for dtype in idxs.keys()}
81 return idxs_time
84def fit_cca(data_0, data_1, n_cca_dims=10):
85 """
86 Initialize and fit CCA sklearn object
88 :param data_0: shape (n_samples, n_features_0)
89 :type data_0: array-like
90 :param data_1: shape (n_samples, n_features_1)
91 :type data_1: array-like
92 :param n_cca_dims: number of CCA dimensions to fit
93 :type n_cca_dims: int
94 :return: sklearn cca object
95 """
96 from sklearn.cross_decomposition import CCA
97 cca = CCA(n_components=n_cca_dims, max_iter=1000)
98 cca.fit(data_0, data_1)
99 return cca
102def get_cca_projection(cca, data_0, data_1):
103 """
104 Project data into CCA dimensions
106 :param cca:
107 :param data_0:
108 :param data_1:
109 :return: tuple; (data_0 projection, data_1 projection)
110 """
111 x_scores, y_scores = cca.transform(data_0, data_1)
112 return x_scores, y_scores
115def get_correlations(cca, data_0, data_1):
116 """
118 :param cca:
119 :param data_0:
120 :param data_1:
121 :return:
122 """
123 x_scores, y_scores = get_cca_projection(cca, data_0, data_1)
124 corrs_tmp = np.corrcoef(x_scores.T, y_scores.T)
125 corrs = np.diagonal(corrs_tmp, offset=data_0.shape[1])
126 return corrs
129def shuffle_analysis(data_0, data_1, n_shuffles=100, **cca_kwargs):
130 """
131 Perform CCA on shuffled data
133 :param data_0:
134 :param data_1:
135 :param n_shuffles:
136 :return:
137 """
138 # TODO
139 pass
142def plot_correlations(corrs, errors=None, ax=None, **plot_kwargs):
143 """
144 Correlation vs CCA dimension
146 :param corrs: correlation values for the CCA dimensions
147 :type corrs: 1-D vector
148 :param errors: error values
149 :type shuffled: 1-D array of size len(corrs)
150 :param ax: axis to plot on (default None)
151 :type ax: matplotlib axis object
152 :return: axis if specified, or plot if axis = None
153 """
154 # evaluate if np.arrays are passed
155 assert type(corrs) is np.ndarray, "'corrs' is not a numpy array." 1b
156 if errors is not None: 1b
157 assert type(errors) is np.ndarray, "'errors' is not a numpy array." 1b
158 # create axis if no axis is passed
159 if ax is None: 1b
160 ax = plt.gca()
161 # get the data for the x and y axis
162 y_data = corrs 1b
163 x_data = range(1, (len(corrs) + 1)) 1b
164 # create the plot object
165 ax.plot(x_data, y_data, **plot_kwargs) 1b
166 if errors is not None: 1b
167 ax.fill_between(x_data, y_data - errors, y_data + errors, **plot_kwargs, alpha=0.2) 1b
168 # change y and x labels and ticks
169 ax.set_xticks(x_data) 1b
170 ax.set_ylabel("Correlation") 1b
171 ax.set_xlabel("CCA dimension") 1b
172 return ax 1b
175def plot_pairwise_correlations(means, stderrs=None, n_dims=None, region_strs=None, **kwargs):
176 """
177 Plot CCA correlations for multiple pairs of regions
179 :param means: list of lists; means[i][j] contains the mean corrs between regions i, j
180 :param stderrs: list of lists; stderrs[i][j] contains std errors of corrs between regions i, j
181 :param n_dims: number of CCA dimensions to plot
182 :param region_strs: list of strings identifying each region
183 :param kwargs: keyword arguments for plot
184 :return: matplotlib figure handle
185 """
186 n_regions = len(means)
188 fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12))
189 for r in range(n_regions - 1):
190 for c in range(n_regions - 1):
191 axes[r, c].axis('off')
193 # get max correlation to standardize y axes
194 max_val = 0
195 for r in range(1, n_regions):
196 for c in range(r):
197 tmp = means[r][c]
198 if tmp is not None:
199 max_val = np.max([max_val, np.max(tmp)])
201 for r in range(1, n_regions):
202 for c in range(r):
203 ax = axes[r - 1, c]
204 ax.axis('on')
205 ax = plot_correlations(means[r][c][:n_dims], stderrs[r][c][:n_dims], ax=ax, **kwargs)
206 ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k')
207 if region_strs is not None:
208 ax.text(
209 x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])),
210 horizontalalignment='right',
211 verticalalignment='top',
212 transform=ax.transAxes)
213 ax.set_ylim([-0.05, max_val + 0.05])
214 if not ax.is_first_col():
215 ax.set_ylabel('')
216 ax.set_yticks([])
217 if not ax.is_last_row():
218 ax.set_xlabel('')
219 ax.set_xticks([])
220 plt.tight_layout()
221 plt.show()
223 return fig
226def plot_pairwise_correlations_mult(
227 means, stderrs, colvec, n_dims=None, region_strs=None, **kwargs):
228 """
229 Plot CCA correlations for multiple pairs of regions, for multiple behavioural events
231 :param means: list of lists; means[k][i][j] contains the mean corrs between regions i, j for
232 behavioral event k
233 :param stderrs: list of lists; stderrs[k][i][j] contains std errors of corrs between
234 regions i, j for behavioral event k
235 :param colvec: color vector [must be a better way for this]
236 :param n_dims: number of CCA dimensions to plot
237 :param region_strs: list of strings identifying each region
238 :param kwargs: keyword arguments for plot
239 :return: matplotlib figure handle
240 """
241 n_regions = len(means[0])
243 fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12))
244 for r in range(n_regions - 1):
245 for c in range(n_regions - 1):
246 axes[r, c].axis('off')
248 # get max correlation to standardize y axes
249 max_val = 0
250 for b in range(len(means)):
251 for r in range(1, n_regions):
252 for c in range(r):
253 tmp = means[b][r][c]
254 if tmp is not None:
255 max_val = np.max([max_val, np.max(tmp)])
257 for r in range(1, n_regions):
258 for c in range(r):
259 ax = axes[r - 1, c]
260 ax.axis('on')
261 for b in range(len(means)):
262 plot_correlations(means[b][r][c][:n_dims], stderrs[b][r][c][:n_dims],
263 ax=ax, color=colvec[b], **kwargs)
264 ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k')
265 if region_strs is not None:
266 ax.text(
267 x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])),
268 horizontalalignment='right',
269 verticalalignment='top',
270 transform=ax.transAxes)
271 ax.set_ylim([-0.05, max_val + 0.05])
272 if not ax.is_first_col():
273 ax.set_ylabel('')
274 ax.set_yticks([])
275 if not ax.is_last_row():
276 ax.set_xlabel('')
277 ax.set_xticks([])
278 plt.tight_layout()
279 plt.show()
281 return fig
284def bin_spikes_trials(spikes, trials, bin_size=0.01):
285 """
286 Binarizes the spike times into a raster and assigns a trial number to each bin
288 :param spikes: spikes object
289 :type spikes: Bunch
290 :param trials: trials object
291 :type trials: Bunch
292 :param bin_size: size, in s, of the bins
293 :type bin_size: float
294 :return: a matrix (bins, SpikeCounts), and a vector of bins size with trial ID,
295 and a vector bins size with the time that the bins start
296 """
297 binned_spikes, bin_times, _ = bincount2D(spikes['times'], spikes['clusters'], bin_size)
298 trial_start_times = trials['intervals'][:, 0]
299 binned_trialIDs = np.digitize(bin_times, trial_start_times)
300 # correct, as index 0 is whatever happens before the first trial
301 binned_trialIDs_corrected = binned_trialIDs - 1
303 return binned_spikes.T, binned_trialIDs_corrected, bin_times
306def split_by_area(binned_spikes, cl_brainAcronyms, active_clusters, brain_areas):
307 """
308 This function converts a matrix of binned spikes into a list of matrices, with the clusters
309 grouped by brain areas
311 :param binned_spikes: binned spike data of shape (n_bins, n_lusters)
312 :type binned_spikes: numpy.ndarray
313 :param cl_brainAcronyms: brain region for each cluster
314 :type cl_brainAcronyms: pandas.core.frame.DataFrame
315 :param brain_areas: list of brain areas to select
316 :type brain_areas: numpy.ndarray
317 :param active_clusters: list of clusterIDs
318 :type active_clusters: numpy.ndarray
319 :return: list of numpy.ndarrays of size brain_areas
320 """
321 # TODO: check that this is doing what it is suppossed to!!!
323 # TODO: check that input is as expected
324 #
325 # initialize list
326 listof_bs = []
327 for b_area in brain_areas:
328 # get the ids of clusters in the area
329 cl_in_area = cl_brainAcronyms.loc[cl_brainAcronyms['brainAcronyms'] == b_area].index
330 # get the indexes of the clusters that are in that area
331 cl_idx_in_area = np.isin(active_clusters, cl_in_area)
332 bs_in_area = binned_spikes[:, cl_idx_in_area]
333 listof_bs.append(bs_in_area)
334 return listof_bs
337def get_event_bin_indexes(event_times, bin_times, window):
338 """
339 Get the indexes of the bins corresponding to a specific behavioral event within a window
341 :param event_times: time series of an event
342 :type event_times: numpy.array
343 :param bin_times: time series pf starting point of bins
344 :type bin_times: numpy.array
345 :param window: list of size 2 specifying the window in seconds [-time before, time after]
346 :type window: numpy.array
347 :return: array of indexes
348 """
349 # TODO: check that this is doing what it is supposed to (coded during codecamp in a rush)
350 # find bin size
351 bin_size = bin_times[1] - bin_times[0]
352 # find window size in bin units
353 bin_window = int(np.ceil((window[1] - window[0]) / bin_size))
354 # correct event_times to the start of the window
355 event_times_corrected = event_times - window[0]
357 # get the indexes of the bins that are containing each event and add the window after
358 idx_array = np.empty(shape=0)
359 for etc in event_times_corrected:
360 start_idx = (np.abs(bin_times - etc)).argmin()
361 # add the window
362 arr_to_append = np.array(range(start_idx, start_idx + bin_window))
363 idx_array = np.concatenate((idx_array, arr_to_append), axis=None)
365 # remove the non-existing bins if any
367 return idx_array.astype(int)
370if __name__ == '__main__':
372 from pathlib import Path
373 from oneibl.one import ONE
374 import alf.io as ioalf
376 BIN_SIZE = 0.025 # seconds
377 SMOOTH_SIZE = 0.025 # seconds; standard deviation of gaussian kernel
378 PCA_DIMS = 20
379 CCA_DIMS = PCA_DIMS
380 N_SPLITS = 5
381 RNG_SEED = 0
383 # get the data from flatiron
384 subject = 'KS005'
385 date = '2019-08-30'
386 number = 1
388 one = ONE()
389 eid = one.search(subject=subject, date=date, number=number)
390 D = one.load(eid[0], download_only=True)
391 session_path = Path(D.local_path[0]).parent
393 spikes = ioalf.load_object(session_path, 'spikes')
394 clusters = ioalf.load_object(session_path, 'clusters')
395 # channels = ioalf.load_object(session_path, 'channels')
396 trials = ioalf.load_object(session_path, 'trials')
398 # bin spikes and get trial IDs associated with them
399 binned_spikes, binned_trialIDs, _ = bin_spikes_trials(spikes, trials, bin_size=0.01)
401 # define areas
402 brain_areas = np.unique(clusters.brainAcronyms)
403 brain_areas = brain_areas[1:4] # [take subset for testing]
405 # split data by brain area
406 # (bin_spikes_trials does not return info for innactive clusters)
407 active_clusters = np.unique(spikes['clusters'])
408 split_binned_spikes = split_by_area(
409 binned_spikes, clusters.brainAcronyms, active_clusters, brain_areas)
411 # preprocess data
412 for i, pop in enumerate(split_binned_spikes):
413 split_binned_spikes[i] = preprocess(pop, n_pcs=PCA_DIMS, smoothing_sd=SMOOTH_SIZE)
415 # split trials
416 idxs_trial = split_trials(np.unique(binned_trialIDs), n_splits=N_SPLITS, rng_seed=RNG_SEED)
417 # get train/test indices into spike arrays
418 idxs_time = split_timepoints(binned_trialIDs, idxs_trial)
420 # Create empty "matrix" to store cca objects
421 n_regions = len(brain_areas)
422 cca_mat = [[None for _ in range(n_regions)] for _ in range(n_regions)]
423 means_list = [[None for _ in range(n_regions)] for _ in range(n_regions)]
424 serrs_list = [[None for _ in range(n_regions)] for _ in range(n_regions)]
425 # For each pair of populations:
426 for i in range(len(brain_areas)):
427 pop_0 = split_binned_spikes[i]
428 for j in range(len(brain_areas)):
429 if j < i:
430 # print progress
431 print('Fitting CCA on regions {} / {}'.format(i, j))
432 pop_1 = split_binned_spikes[j]
433 ccas = [None for _ in range(N_SPLITS)]
434 corrs = [None for _ in range(N_SPLITS)]
435 # for each xv fold
436 for k, idxs in enumerate(idxs_time):
437 ccas[k] = fit_cca(
438 pop_0[idxs['train'], :], pop_1[idxs['train'], :], n_cca_dims=CCA_DIMS)
439 corrs[k] = get_correlations(
440 ccas[k], pop_0[idxs['test'], :], pop_1[idxs['test'], :])
441 cca_mat[i][j] = ccas[k]
442 vals = np.stack(corrs, axis=1)
443 means_list[i][j] = np.mean(vals, axis=1)
444 serrs_list[i][j] = np.std(vals, axis=1) / np.sqrt(N_SPLITS)
446 # plot matrix of correlations
447 fig = plot_pairwise_correlations(means_list, serrs_list, n_dims=10, region_strs=brain_areas)