Coverage for brainbox/population/cca.py: 16%

203 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1import numpy as np 

2import matplotlib.pylab as plt 

3from iblutil.numerical import bincount2D 

4 

5 

6def _smooth(data, sd): 

7 from scipy.signal import gaussian 

8 from scipy.signal import convolve 

9 n_bins = data.shape[0] 

10 w = n_bins - 1 if n_bins % 2 == 0 else n_bins 

11 window = gaussian(w, std=sd) 

12 for j in range(data.shape[1]): 

13 data[:, j] = convolve(data[:, j], window, mode='same', method='auto') 

14 return data 

15 

16 

17def _pca(data, n_pcs): 

18 from sklearn.decomposition import PCA 

19 pca = PCA(n_components=n_pcs) 

20 pca.fit(data) 

21 data_pc = pca.transform(data) 

22 return data_pc 

23 

24 

25def preprocess(data, smoothing_sd=25, n_pcs=20): 

26 """ 

27 Preprocess neural data for cca analysis with smoothing and pca 

28 

29 :param data: array of shape (n_samples, n_features) 

30 :type data: array-like 

31 :param smoothing_sd: gaussian smoothing kernel standard deviation (ms) 

32 :type smoothing_sd: float 

33 :param n_pcs: number of pca dimensions to retain 

34 :type n_pcs: int 

35 :return: preprocessed neural data 

36 :rtype: array-like, shape (n_samples, pca_dims) 

37 """ 

38 if smoothing_sd > 0: 

39 data = _smooth(data, sd=smoothing_sd) 

40 if n_pcs > 0: 

41 data = _pca(data, n_pcs=n_pcs) 

42 return data 

43 

44 

45def split_trials(trial_ids, n_splits=5, rng_seed=0): 

46 """ 

47 Assign each trial to testing or training fold 

48 

49 :param trial_ids: 

50 :type trial_ids: array-like 

51 :param n_splits: one split used for testing; remaining splits used for training 

52 :type n_splits: int 

53 :param rng_seed: set random state for shuffling trials 

54 :type rng_seed: int 

55 :return: list of dicts of indices with keys `train` and `test` 

56 """ 

57 from sklearn.model_selection import KFold 

58 shuffle = True if rng_seed is not None else False 

59 kf = KFold(n_splits=n_splits, random_state=rng_seed, shuffle=shuffle) 

60 kf.get_n_splits(trial_ids) 

61 idxs = [None for _ in range(n_splits)] 

62 for i, t0 in enumerate(kf.split(trial_ids)): 

63 idxs[i] = {'train': t0[0], 'test': t0[1]} 

64 return idxs 

65 

66 

67def split_timepoints(trial_ids, idxs_trial): 

68 """ 

69 Assign each time point to testing or training fold 

70 

71 :param trial_ids: trial id for each timepoint 

72 :type trial_ids: array-like 

73 :param idxs_trial: list of dicts that define which trials are in `train` or `test` folds 

74 :type idxs_trial: list 

75 :return: list of dicts that define which time points are in `train` and `test` folds 

76 """ 

77 idxs_time = [None for _ in range(len(idxs_trial))] 

78 for i, idxs in enumerate(idxs_trial): 

79 idxs_time[i] = { 

80 dtype: np.where(np.isin(trial_ids, idxs[dtype]))[0] for dtype in idxs.keys()} 

81 return idxs_time 

82 

83 

84def fit_cca(data_0, data_1, n_cca_dims=10): 

85 """ 

86 Initialize and fit CCA sklearn object 

87 

88 :param data_0: shape (n_samples, n_features_0) 

89 :type data_0: array-like 

90 :param data_1: shape (n_samples, n_features_1) 

91 :type data_1: array-like 

92 :param n_cca_dims: number of CCA dimensions to fit 

93 :type n_cca_dims: int 

94 :return: sklearn cca object 

95 """ 

96 from sklearn.cross_decomposition import CCA 

97 cca = CCA(n_components=n_cca_dims, max_iter=1000) 

98 cca.fit(data_0, data_1) 

99 return cca 

100 

101 

102def get_cca_projection(cca, data_0, data_1): 

103 """ 

104 Project data into CCA dimensions 

105 

106 :param cca: 

107 :param data_0: 

108 :param data_1: 

109 :return: tuple; (data_0 projection, data_1 projection) 

110 """ 

111 x_scores, y_scores = cca.transform(data_0, data_1) 

112 return x_scores, y_scores 

113 

114 

115def get_correlations(cca, data_0, data_1): 

116 """ 

117 

118 :param cca: 

119 :param data_0: 

120 :param data_1: 

121 :return: 

122 """ 

123 x_scores, y_scores = get_cca_projection(cca, data_0, data_1) 

124 corrs_tmp = np.corrcoef(x_scores.T, y_scores.T) 

125 corrs = np.diagonal(corrs_tmp, offset=data_0.shape[1]) 

126 return corrs 

127 

128 

129def shuffle_analysis(data_0, data_1, n_shuffles=100, **cca_kwargs): 

130 """ 

131 Perform CCA on shuffled data 

132 

133 :param data_0: 

134 :param data_1: 

135 :param n_shuffles: 

136 :return: 

137 """ 

138 # TODO 

139 pass 

140 

141 

142def plot_correlations(corrs, errors=None, ax=None, **plot_kwargs): 

143 """ 

144 Correlation vs CCA dimension 

145 

146 :param corrs: correlation values for the CCA dimensions 

147 :type corrs: 1-D vector 

148 :param errors: error values 

149 :type shuffled: 1-D array of size len(corrs) 

150 :param ax: axis to plot on (default None) 

151 :type ax: matplotlib axis object 

152 :return: axis if specified, or plot if axis = None 

153 """ 

154 # evaluate if np.arrays are passed 

155 assert type(corrs) is np.ndarray, "'corrs' is not a numpy array." 1b

156 if errors is not None: 1b

157 assert type(errors) is np.ndarray, "'errors' is not a numpy array." 1b

158 # create axis if no axis is passed 

159 if ax is None: 1b

160 ax = plt.gca() 

161 # get the data for the x and y axis 

162 y_data = corrs 1b

163 x_data = range(1, (len(corrs) + 1)) 1b

164 # create the plot object 

165 ax.plot(x_data, y_data, **plot_kwargs) 1b

166 if errors is not None: 1b

167 ax.fill_between(x_data, y_data - errors, y_data + errors, **plot_kwargs, alpha=0.2) 1b

168 # change y and x labels and ticks 

169 ax.set_xticks(x_data) 1b

170 ax.set_ylabel("Correlation") 1b

171 ax.set_xlabel("CCA dimension") 1b

172 return ax 1b

173 

174 

175def plot_pairwise_correlations(means, stderrs=None, n_dims=None, region_strs=None, **kwargs): 

176 """ 

177 Plot CCA correlations for multiple pairs of regions 

178 

179 :param means: list of lists; means[i][j] contains the mean corrs between regions i, j 

180 :param stderrs: list of lists; stderrs[i][j] contains std errors of corrs between regions i, j 

181 :param n_dims: number of CCA dimensions to plot 

182 :param region_strs: list of strings identifying each region 

183 :param kwargs: keyword arguments for plot 

184 :return: matplotlib figure handle 

185 """ 

186 n_regions = len(means) 

187 

188 fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12)) 

189 for r in range(n_regions - 1): 

190 for c in range(n_regions - 1): 

191 axes[r, c].axis('off') 

192 

193 # get max correlation to standardize y axes 

194 max_val = 0 

195 for r in range(1, n_regions): 

196 for c in range(r): 

197 tmp = means[r][c] 

198 if tmp is not None: 

199 max_val = np.max([max_val, np.max(tmp)]) 

200 

201 for r in range(1, n_regions): 

202 for c in range(r): 

203 ax = axes[r - 1, c] 

204 ax.axis('on') 

205 ax = plot_correlations(means[r][c][:n_dims], stderrs[r][c][:n_dims], ax=ax, **kwargs) 

206 ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k') 

207 if region_strs is not None: 

208 ax.text( 

209 x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])), 

210 horizontalalignment='right', 

211 verticalalignment='top', 

212 transform=ax.transAxes) 

213 ax.set_ylim([-0.05, max_val + 0.05]) 

214 if not ax.is_first_col(): 

215 ax.set_ylabel('') 

216 ax.set_yticks([]) 

217 if not ax.is_last_row(): 

218 ax.set_xlabel('') 

219 ax.set_xticks([]) 

220 plt.tight_layout() 

221 plt.show() 

222 

223 return fig 

224 

225 

226def plot_pairwise_correlations_mult( 

227 means, stderrs, colvec, n_dims=None, region_strs=None, **kwargs): 

228 """ 

229 Plot CCA correlations for multiple pairs of regions, for multiple behavioural events 

230 

231 :param means: list of lists; means[k][i][j] contains the mean corrs between regions i, j for 

232 behavioral event k 

233 :param stderrs: list of lists; stderrs[k][i][j] contains std errors of corrs between 

234 regions i, j for behavioral event k 

235 :param colvec: color vector [must be a better way for this] 

236 :param n_dims: number of CCA dimensions to plot 

237 :param region_strs: list of strings identifying each region 

238 :param kwargs: keyword arguments for plot 

239 :return: matplotlib figure handle 

240 """ 

241 n_regions = len(means[0]) 

242 

243 fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12)) 

244 for r in range(n_regions - 1): 

245 for c in range(n_regions - 1): 

246 axes[r, c].axis('off') 

247 

248 # get max correlation to standardize y axes 

249 max_val = 0 

250 for b in range(len(means)): 

251 for r in range(1, n_regions): 

252 for c in range(r): 

253 tmp = means[b][r][c] 

254 if tmp is not None: 

255 max_val = np.max([max_val, np.max(tmp)]) 

256 

257 for r in range(1, n_regions): 

258 for c in range(r): 

259 ax = axes[r - 1, c] 

260 ax.axis('on') 

261 for b in range(len(means)): 

262 plot_correlations(means[b][r][c][:n_dims], stderrs[b][r][c][:n_dims], 

263 ax=ax, color=colvec[b], **kwargs) 

264 ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k') 

265 if region_strs is not None: 

266 ax.text( 

267 x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])), 

268 horizontalalignment='right', 

269 verticalalignment='top', 

270 transform=ax.transAxes) 

271 ax.set_ylim([-0.05, max_val + 0.05]) 

272 if not ax.is_first_col(): 

273 ax.set_ylabel('') 

274 ax.set_yticks([]) 

275 if not ax.is_last_row(): 

276 ax.set_xlabel('') 

277 ax.set_xticks([]) 

278 plt.tight_layout() 

279 plt.show() 

280 

281 return fig 

282 

283 

284def bin_spikes_trials(spikes, trials, bin_size=0.01): 

285 """ 

286 Binarizes the spike times into a raster and assigns a trial number to each bin 

287 

288 :param spikes: spikes object 

289 :type spikes: Bunch 

290 :param trials: trials object 

291 :type trials: Bunch 

292 :param bin_size: size, in s, of the bins 

293 :type bin_size: float 

294 :return: a matrix (bins, SpikeCounts), and a vector of bins size with trial ID, 

295 and a vector bins size with the time that the bins start 

296 """ 

297 binned_spikes, bin_times, _ = bincount2D(spikes['times'], spikes['clusters'], bin_size) 

298 trial_start_times = trials['intervals'][:, 0] 

299 binned_trialIDs = np.digitize(bin_times, trial_start_times) 

300 # correct, as index 0 is whatever happens before the first trial 

301 binned_trialIDs_corrected = binned_trialIDs - 1 

302 

303 return binned_spikes.T, binned_trialIDs_corrected, bin_times 

304 

305 

306def split_by_area(binned_spikes, cl_brainAcronyms, active_clusters, brain_areas): 

307 """ 

308 This function converts a matrix of binned spikes into a list of matrices, with the clusters 

309 grouped by brain areas 

310 

311 :param binned_spikes: binned spike data of shape (n_bins, n_lusters) 

312 :type binned_spikes: numpy.ndarray 

313 :param cl_brainAcronyms: brain region for each cluster 

314 :type cl_brainAcronyms: pandas.core.frame.DataFrame 

315 :param brain_areas: list of brain areas to select 

316 :type brain_areas: numpy.ndarray 

317 :param active_clusters: list of clusterIDs 

318 :type active_clusters: numpy.ndarray 

319 :return: list of numpy.ndarrays of size brain_areas 

320 """ 

321 # TODO: check that this is doing what it is suppossed to!!! 

322 

323 # TODO: check that input is as expected 

324 # 

325 # initialize list 

326 listof_bs = [] 

327 for b_area in brain_areas: 

328 # get the ids of clusters in the area 

329 cl_in_area = cl_brainAcronyms.loc[cl_brainAcronyms['brainAcronyms'] == b_area].index 

330 # get the indexes of the clusters that are in that area 

331 cl_idx_in_area = np.isin(active_clusters, cl_in_area) 

332 bs_in_area = binned_spikes[:, cl_idx_in_area] 

333 listof_bs.append(bs_in_area) 

334 return listof_bs 

335 

336 

337def get_event_bin_indexes(event_times, bin_times, window): 

338 """ 

339 Get the indexes of the bins corresponding to a specific behavioral event within a window 

340 

341 :param event_times: time series of an event 

342 :type event_times: numpy.array 

343 :param bin_times: time series pf starting point of bins 

344 :type bin_times: numpy.array 

345 :param window: list of size 2 specifying the window in seconds [-time before, time after] 

346 :type window: numpy.array 

347 :return: array of indexes 

348 """ 

349 # TODO: check that this is doing what it is supposed to (coded during codecamp in a rush) 

350 # find bin size 

351 bin_size = bin_times[1] - bin_times[0] 

352 # find window size in bin units 

353 bin_window = int(np.ceil((window[1] - window[0]) / bin_size)) 

354 # correct event_times to the start of the window 

355 event_times_corrected = event_times - window[0] 

356 

357 # get the indexes of the bins that are containing each event and add the window after 

358 idx_array = np.empty(shape=0) 

359 for etc in event_times_corrected: 

360 start_idx = (np.abs(bin_times - etc)).argmin() 

361 # add the window 

362 arr_to_append = np.array(range(start_idx, start_idx + bin_window)) 

363 idx_array = np.concatenate((idx_array, arr_to_append), axis=None) 

364 

365 # remove the non-existing bins if any 

366 

367 return idx_array.astype(int) 

368 

369 

370if __name__ == '__main__': 

371 

372 from pathlib import Path 

373 from oneibl.one import ONE 

374 import alf.io as ioalf 

375 

376 BIN_SIZE = 0.025 # seconds 

377 SMOOTH_SIZE = 0.025 # seconds; standard deviation of gaussian kernel 

378 PCA_DIMS = 20 

379 CCA_DIMS = PCA_DIMS 

380 N_SPLITS = 5 

381 RNG_SEED = 0 

382 

383 # get the data from flatiron 

384 subject = 'KS005' 

385 date = '2019-08-30' 

386 number = 1 

387 

388 one = ONE() 

389 eid = one.search(subject=subject, date=date, number=number) 

390 D = one.load(eid[0], download_only=True) 

391 session_path = Path(D.local_path[0]).parent 

392 

393 spikes = ioalf.load_object(session_path, 'spikes') 

394 clusters = ioalf.load_object(session_path, 'clusters') 

395 # channels = ioalf.load_object(session_path, 'channels') 

396 trials = ioalf.load_object(session_path, 'trials') 

397 

398 # bin spikes and get trial IDs associated with them 

399 binned_spikes, binned_trialIDs, _ = bin_spikes_trials(spikes, trials, bin_size=0.01) 

400 

401 # define areas 

402 brain_areas = np.unique(clusters.brainAcronyms) 

403 brain_areas = brain_areas[1:4] # [take subset for testing] 

404 

405 # split data by brain area 

406 # (bin_spikes_trials does not return info for innactive clusters) 

407 active_clusters = np.unique(spikes['clusters']) 

408 split_binned_spikes = split_by_area( 

409 binned_spikes, clusters.brainAcronyms, active_clusters, brain_areas) 

410 

411 # preprocess data 

412 for i, pop in enumerate(split_binned_spikes): 

413 split_binned_spikes[i] = preprocess(pop, n_pcs=PCA_DIMS, smoothing_sd=SMOOTH_SIZE) 

414 

415 # split trials 

416 idxs_trial = split_trials(np.unique(binned_trialIDs), n_splits=N_SPLITS, rng_seed=RNG_SEED) 

417 # get train/test indices into spike arrays 

418 idxs_time = split_timepoints(binned_trialIDs, idxs_trial) 

419 

420 # Create empty "matrix" to store cca objects 

421 n_regions = len(brain_areas) 

422 cca_mat = [[None for _ in range(n_regions)] for _ in range(n_regions)] 

423 means_list = [[None for _ in range(n_regions)] for _ in range(n_regions)] 

424 serrs_list = [[None for _ in range(n_regions)] for _ in range(n_regions)] 

425 # For each pair of populations: 

426 for i in range(len(brain_areas)): 

427 pop_0 = split_binned_spikes[i] 

428 for j in range(len(brain_areas)): 

429 if j < i: 

430 # print progress 

431 print('Fitting CCA on regions {} / {}'.format(i, j)) 

432 pop_1 = split_binned_spikes[j] 

433 ccas = [None for _ in range(N_SPLITS)] 

434 corrs = [None for _ in range(N_SPLITS)] 

435 # for each xv fold 

436 for k, idxs in enumerate(idxs_time): 

437 ccas[k] = fit_cca( 

438 pop_0[idxs['train'], :], pop_1[idxs['train'], :], n_cca_dims=CCA_DIMS) 

439 corrs[k] = get_correlations( 

440 ccas[k], pop_0[idxs['test'], :], pop_1[idxs['test'], :]) 

441 cca_mat[i][j] = ccas[k] 

442 vals = np.stack(corrs, axis=1) 

443 means_list[i][j] = np.mean(vals, axis=1) 

444 serrs_list[i][j] = np.std(vals, axis=1) / np.sqrt(N_SPLITS) 

445 

446 # plot matrix of correlations 

447 fig = plot_pairwise_correlations(means_list, serrs_list, n_dims=10, region_strs=brain_areas)