Coverage for brainbox/population/cca.py: 17%

203 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1import numpy as np 

2import matplotlib.pylab as plt 

3from scipy.signal.windows import gaussian 

4from scipy.signal import convolve 

5from sklearn.decomposition import PCA 

6 

7from iblutil.numerical import bincount2D 

8 

9 

10def _smooth(data, sd): 

11 

12 n_bins = data.shape[0] 

13 w = n_bins - 1 if n_bins % 2 == 0 else n_bins 

14 window = gaussian(w, std=sd) 

15 for j in range(data.shape[1]): 

16 data[:, j] = convolve(data[:, j], window, mode='same', method='auto') 

17 return data 

18 

19 

20def _pca(data, n_pcs): 

21 pca = PCA(n_components=n_pcs) 

22 pca.fit(data) 

23 data_pc = pca.transform(data) 

24 return data_pc 

25 

26 

27def preprocess(data, smoothing_sd=25, n_pcs=20): 

28 """ 

29 Preprocess neural data for cca analysis with smoothing and pca 

30 

31 :param data: array of shape (n_samples, n_features) 

32 :type data: array-like 

33 :param smoothing_sd: gaussian smoothing kernel standard deviation (ms) 

34 :type smoothing_sd: float 

35 :param n_pcs: number of pca dimensions to retain 

36 :type n_pcs: int 

37 :return: preprocessed neural data 

38 :rtype: array-like, shape (n_samples, pca_dims) 

39 """ 

40 if smoothing_sd > 0: 

41 data = _smooth(data, sd=smoothing_sd) 

42 if n_pcs > 0: 

43 data = _pca(data, n_pcs=n_pcs) 

44 return data 

45 

46 

47def split_trials(trial_ids, n_splits=5, rng_seed=0): 

48 """ 

49 Assign each trial to testing or training fold 

50 

51 :param trial_ids: 

52 :type trial_ids: array-like 

53 :param n_splits: one split used for testing; remaining splits used for training 

54 :type n_splits: int 

55 :param rng_seed: set random state for shuffling trials 

56 :type rng_seed: int 

57 :return: list of dicts of indices with keys `train` and `test` 

58 """ 

59 from sklearn.model_selection import KFold 

60 shuffle = True if rng_seed is not None else False 

61 kf = KFold(n_splits=n_splits, random_state=rng_seed, shuffle=shuffle) 

62 kf.get_n_splits(trial_ids) 

63 idxs = [None for _ in range(n_splits)] 

64 for i, t0 in enumerate(kf.split(trial_ids)): 

65 idxs[i] = {'train': t0[0], 'test': t0[1]} 

66 return idxs 

67 

68 

69def split_timepoints(trial_ids, idxs_trial): 

70 """ 

71 Assign each time point to testing or training fold 

72 

73 :param trial_ids: trial id for each timepoint 

74 :type trial_ids: array-like 

75 :param idxs_trial: list of dicts that define which trials are in `train` or `test` folds 

76 :type idxs_trial: list 

77 :return: list of dicts that define which time points are in `train` and `test` folds 

78 """ 

79 idxs_time = [None for _ in range(len(idxs_trial))] 

80 for i, idxs in enumerate(idxs_trial): 

81 idxs_time[i] = { 

82 dtype: np.where(np.isin(trial_ids, idxs[dtype]))[0] for dtype in idxs.keys()} 

83 return idxs_time 

84 

85 

86def fit_cca(data_0, data_1, n_cca_dims=10): 

87 """ 

88 Initialize and fit CCA sklearn object 

89 

90 :param data_0: shape (n_samples, n_features_0) 

91 :type data_0: array-like 

92 :param data_1: shape (n_samples, n_features_1) 

93 :type data_1: array-like 

94 :param n_cca_dims: number of CCA dimensions to fit 

95 :type n_cca_dims: int 

96 :return: sklearn cca object 

97 """ 

98 from sklearn.cross_decomposition import CCA 

99 cca = CCA(n_components=n_cca_dims, max_iter=1000) 

100 cca.fit(data_0, data_1) 

101 return cca 

102 

103 

104def get_cca_projection(cca, data_0, data_1): 

105 """ 

106 Project data into CCA dimensions 

107 

108 :param cca: 

109 :param data_0: 

110 :param data_1: 

111 :return: tuple; (data_0 projection, data_1 projection) 

112 """ 

113 x_scores, y_scores = cca.transform(data_0, data_1) 

114 return x_scores, y_scores 

115 

116 

117def get_correlations(cca, data_0, data_1): 

118 """ 

119 

120 :param cca: 

121 :param data_0: 

122 :param data_1: 

123 :return: 

124 """ 

125 x_scores, y_scores = get_cca_projection(cca, data_0, data_1) 

126 corrs_tmp = np.corrcoef(x_scores.T, y_scores.T) 

127 corrs = np.diagonal(corrs_tmp, offset=data_0.shape[1]) 

128 return corrs 

129 

130 

131def shuffle_analysis(data_0, data_1, n_shuffles=100, **cca_kwargs): 

132 """ 

133 Perform CCA on shuffled data 

134 

135 :param data_0: 

136 :param data_1: 

137 :param n_shuffles: 

138 :return: 

139 """ 

140 # TODO 

141 pass 

142 

143 

144def plot_correlations(corrs, errors=None, ax=None, **plot_kwargs): 

145 """ 

146 Correlation vs CCA dimension 

147 

148 :param corrs: correlation values for the CCA dimensions 

149 :type corrs: 1-D vector 

150 :param errors: error values 

151 :type shuffled: 1-D array of size len(corrs) 

152 :param ax: axis to plot on (default None) 

153 :type ax: matplotlib axis object 

154 :return: axis if specified, or plot if axis = None 

155 """ 

156 # evaluate if np.arrays are passed 

157 assert type(corrs) is np.ndarray, "'corrs' is not a numpy array." 1b

158 if errors is not None: 1b

159 assert type(errors) is np.ndarray, "'errors' is not a numpy array." 1b

160 # create axis if no axis is passed 

161 if ax is None: 1b

162 ax = plt.gca() 

163 # get the data for the x and y axis 

164 y_data = corrs 1b

165 x_data = range(1, (len(corrs) + 1)) 1b

166 # create the plot object 

167 ax.plot(x_data, y_data, **plot_kwargs) 1b

168 if errors is not None: 1b

169 ax.fill_between(x_data, y_data - errors, y_data + errors, **plot_kwargs, alpha=0.2) 1b

170 # change y and x labels and ticks 

171 ax.set_xticks(x_data) 1b

172 ax.set_ylabel("Correlation") 1b

173 ax.set_xlabel("CCA dimension") 1b

174 return ax 1b

175 

176 

177def plot_pairwise_correlations(means, stderrs=None, n_dims=None, region_strs=None, **kwargs): 

178 """ 

179 Plot CCA correlations for multiple pairs of regions 

180 

181 :param means: list of lists; means[i][j] contains the mean corrs between regions i, j 

182 :param stderrs: list of lists; stderrs[i][j] contains std errors of corrs between regions i, j 

183 :param n_dims: number of CCA dimensions to plot 

184 :param region_strs: list of strings identifying each region 

185 :param kwargs: keyword arguments for plot 

186 :return: matplotlib figure handle 

187 """ 

188 n_regions = len(means) 

189 

190 fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12)) 

191 for r in range(n_regions - 1): 

192 for c in range(n_regions - 1): 

193 axes[r, c].axis('off') 

194 

195 # get max correlation to standardize y axes 

196 max_val = 0 

197 for r in range(1, n_regions): 

198 for c in range(r): 

199 tmp = means[r][c] 

200 if tmp is not None: 

201 max_val = np.max([max_val, np.max(tmp)]) 

202 

203 for r in range(1, n_regions): 

204 for c in range(r): 

205 ax = axes[r - 1, c] 

206 ax.axis('on') 

207 ax = plot_correlations(means[r][c][:n_dims], stderrs[r][c][:n_dims], ax=ax, **kwargs) 

208 ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k') 

209 if region_strs is not None: 

210 ax.text( 

211 x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])), 

212 horizontalalignment='right', 

213 verticalalignment='top', 

214 transform=ax.transAxes) 

215 ax.set_ylim([-0.05, max_val + 0.05]) 

216 if not ax.is_first_col(): 

217 ax.set_ylabel('') 

218 ax.set_yticks([]) 

219 if not ax.is_last_row(): 

220 ax.set_xlabel('') 

221 ax.set_xticks([]) 

222 plt.tight_layout() 

223 plt.show() 

224 

225 return fig 

226 

227 

228def plot_pairwise_correlations_mult( 

229 means, stderrs, colvec, n_dims=None, region_strs=None, **kwargs): 

230 """ 

231 Plot CCA correlations for multiple pairs of regions, for multiple behavioural events 

232 

233 :param means: list of lists; means[k][i][j] contains the mean corrs between regions i, j for 

234 behavioral event k 

235 :param stderrs: list of lists; stderrs[k][i][j] contains std errors of corrs between 

236 regions i, j for behavioral event k 

237 :param colvec: color vector [must be a better way for this] 

238 :param n_dims: number of CCA dimensions to plot 

239 :param region_strs: list of strings identifying each region 

240 :param kwargs: keyword arguments for plot 

241 :return: matplotlib figure handle 

242 """ 

243 n_regions = len(means[0]) 

244 

245 fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12)) 

246 for r in range(n_regions - 1): 

247 for c in range(n_regions - 1): 

248 axes[r, c].axis('off') 

249 

250 # get max correlation to standardize y axes 

251 max_val = 0 

252 for b in range(len(means)): 

253 for r in range(1, n_regions): 

254 for c in range(r): 

255 tmp = means[b][r][c] 

256 if tmp is not None: 

257 max_val = np.max([max_val, np.max(tmp)]) 

258 

259 for r in range(1, n_regions): 

260 for c in range(r): 

261 ax = axes[r - 1, c] 

262 ax.axis('on') 

263 for b in range(len(means)): 

264 plot_correlations(means[b][r][c][:n_dims], stderrs[b][r][c][:n_dims], 

265 ax=ax, color=colvec[b], **kwargs) 

266 ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k') 

267 if region_strs is not None: 

268 ax.text( 

269 x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])), 

270 horizontalalignment='right', 

271 verticalalignment='top', 

272 transform=ax.transAxes) 

273 ax.set_ylim([-0.05, max_val + 0.05]) 

274 if not ax.is_first_col(): 

275 ax.set_ylabel('') 

276 ax.set_yticks([]) 

277 if not ax.is_last_row(): 

278 ax.set_xlabel('') 

279 ax.set_xticks([]) 

280 plt.tight_layout() 

281 plt.show() 

282 

283 return fig 

284 

285 

286def bin_spikes_trials(spikes, trials, bin_size=0.01): 

287 """ 

288 Binarizes the spike times into a raster and assigns a trial number to each bin 

289 

290 :param spikes: spikes object 

291 :type spikes: Bunch 

292 :param trials: trials object 

293 :type trials: Bunch 

294 :param bin_size: size, in s, of the bins 

295 :type bin_size: float 

296 :return: a matrix (bins, SpikeCounts), and a vector of bins size with trial ID, 

297 and a vector bins size with the time that the bins start 

298 """ 

299 binned_spikes, bin_times, _ = bincount2D(spikes['times'], spikes['clusters'], bin_size) 

300 trial_start_times = trials['intervals'][:, 0] 

301 binned_trialIDs = np.digitize(bin_times, trial_start_times) 

302 # correct, as index 0 is whatever happens before the first trial 

303 binned_trialIDs_corrected = binned_trialIDs - 1 

304 

305 return binned_spikes.T, binned_trialIDs_corrected, bin_times 

306 

307 

308def split_by_area(binned_spikes, cl_brainAcronyms, active_clusters, brain_areas): 

309 """ 

310 This function converts a matrix of binned spikes into a list of matrices, with the clusters 

311 grouped by brain areas 

312 

313 :param binned_spikes: binned spike data of shape (n_bins, n_lusters) 

314 :type binned_spikes: numpy.ndarray 

315 :param cl_brainAcronyms: brain region for each cluster 

316 :type cl_brainAcronyms: pandas.core.frame.DataFrame 

317 :param brain_areas: list of brain areas to select 

318 :type brain_areas: numpy.ndarray 

319 :param active_clusters: list of clusterIDs 

320 :type active_clusters: numpy.ndarray 

321 :return: list of numpy.ndarrays of size brain_areas 

322 """ 

323 # TODO: check that this is doing what it is suppossed to!!! 

324 

325 # TODO: check that input is as expected 

326 # 

327 # initialize list 

328 listof_bs = [] 

329 for b_area in brain_areas: 

330 # get the ids of clusters in the area 

331 cl_in_area = cl_brainAcronyms.loc[cl_brainAcronyms['brainAcronyms'] == b_area].index 

332 # get the indexes of the clusters that are in that area 

333 cl_idx_in_area = np.isin(active_clusters, cl_in_area) 

334 bs_in_area = binned_spikes[:, cl_idx_in_area] 

335 listof_bs.append(bs_in_area) 

336 return listof_bs 

337 

338 

339def get_event_bin_indexes(event_times, bin_times, window): 

340 """ 

341 Get the indexes of the bins corresponding to a specific behavioral event within a window 

342 

343 :param event_times: time series of an event 

344 :type event_times: numpy.array 

345 :param bin_times: time series pf starting point of bins 

346 :type bin_times: numpy.array 

347 :param window: list of size 2 specifying the window in seconds [-time before, time after] 

348 :type window: numpy.array 

349 :return: array of indexes 

350 """ 

351 # TODO: check that this is doing what it is supposed to (coded during codecamp in a rush) 

352 # find bin size 

353 bin_size = bin_times[1] - bin_times[0] 

354 # find window size in bin units 

355 bin_window = int(np.ceil((window[1] - window[0]) / bin_size)) 

356 # correct event_times to the start of the window 

357 event_times_corrected = event_times - window[0] 

358 

359 # get the indexes of the bins that are containing each event and add the window after 

360 idx_array = np.empty(shape=0) 

361 for etc in event_times_corrected: 

362 start_idx = (np.abs(bin_times - etc)).argmin() 

363 # add the window 

364 arr_to_append = np.array(range(start_idx, start_idx + bin_window)) 

365 idx_array = np.concatenate((idx_array, arr_to_append), axis=None) 

366 

367 # remove the non-existing bins if any 

368 

369 return idx_array.astype(int) 

370 

371 

372if __name__ == '__main__': 

373 

374 from pathlib import Path 

375 from oneibl.one import ONE 

376 import alf.io as ioalf 

377 

378 BIN_SIZE = 0.025 # seconds 

379 SMOOTH_SIZE = 0.025 # seconds; standard deviation of gaussian kernel 

380 PCA_DIMS = 20 

381 CCA_DIMS = PCA_DIMS 

382 N_SPLITS = 5 

383 RNG_SEED = 0 

384 

385 # get the data from flatiron 

386 subject = 'KS005' 

387 date = '2019-08-30' 

388 number = 1 

389 

390 one = ONE() 

391 eid = one.search(subject=subject, date=date, number=number) 

392 D = one.load(eid[0], download_only=True) 

393 session_path = Path(D.local_path[0]).parent 

394 

395 spikes = ioalf.load_object(session_path, 'spikes') 

396 clusters = ioalf.load_object(session_path, 'clusters') 

397 # channels = ioalf.load_object(session_path, 'channels') 

398 trials = ioalf.load_object(session_path, 'trials') 

399 

400 # bin spikes and get trial IDs associated with them 

401 binned_spikes, binned_trialIDs, _ = bin_spikes_trials(spikes, trials, bin_size=0.01) 

402 

403 # define areas 

404 brain_areas = np.unique(clusters.brainAcronyms) 

405 brain_areas = brain_areas[1:4] # [take subset for testing] 

406 

407 # split data by brain area 

408 # (bin_spikes_trials does not return info for innactive clusters) 

409 active_clusters = np.unique(spikes['clusters']) 

410 split_binned_spikes = split_by_area( 

411 binned_spikes, clusters.brainAcronyms, active_clusters, brain_areas) 

412 

413 # preprocess data 

414 for i, pop in enumerate(split_binned_spikes): 

415 split_binned_spikes[i] = preprocess(pop, n_pcs=PCA_DIMS, smoothing_sd=SMOOTH_SIZE) 

416 

417 # split trials 

418 idxs_trial = split_trials(np.unique(binned_trialIDs), n_splits=N_SPLITS, rng_seed=RNG_SEED) 

419 # get train/test indices into spike arrays 

420 idxs_time = split_timepoints(binned_trialIDs, idxs_trial) 

421 

422 # Create empty "matrix" to store cca objects 

423 n_regions = len(brain_areas) 

424 cca_mat = [[None for _ in range(n_regions)] for _ in range(n_regions)] 

425 means_list = [[None for _ in range(n_regions)] for _ in range(n_regions)] 

426 serrs_list = [[None for _ in range(n_regions)] for _ in range(n_regions)] 

427 # For each pair of populations: 

428 for i in range(len(brain_areas)): 

429 pop_0 = split_binned_spikes[i] 

430 for j in range(len(brain_areas)): 

431 if j < i: 

432 # print progress 

433 print('Fitting CCA on regions {} / {}'.format(i, j)) 

434 pop_1 = split_binned_spikes[j] 

435 ccas = [None for _ in range(N_SPLITS)] 

436 corrs = [None for _ in range(N_SPLITS)] 

437 # for each xv fold 

438 for k, idxs in enumerate(idxs_time): 

439 ccas[k] = fit_cca( 

440 pop_0[idxs['train'], :], pop_1[idxs['train'], :], n_cca_dims=CCA_DIMS) 

441 corrs[k] = get_correlations( 

442 ccas[k], pop_0[idxs['test'], :], pop_1[idxs['test'], :]) 

443 cca_mat[i][j] = ccas[k] 

444 vals = np.stack(corrs, axis=1) 

445 means_list[i][j] = np.mean(vals, axis=1) 

446 serrs_list[i][j] = np.std(vals, axis=1) / np.sqrt(N_SPLITS) 

447 

448 # plot matrix of correlations 

449 fig = plot_pairwise_correlations(means_list, serrs_list, n_dims=10, region_strs=brain_areas)