Coverage for brainbox/task/closed_loop.py: 84%

179 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1''' 

2Computes task related output 

3''' 

4 

5import numpy as np 

6from scipy.stats import ranksums, wilcoxon, ttest_ind, ttest_rel 

7from ._statsmodels import multipletests 

8from sklearn.metrics import roc_auc_score 

9import pandas as pd 

10from brainbox.population.decode import get_spike_counts_in_bins 

11 

12 

13def responsive_units(spike_times, spike_clusters, event_times, pre_time=[0.5, 0], 

14 post_time=[0, 0.5], alpha=0.05, fdr_corr=False, use_fr=False): 

15 """ 

16 Determine responsive neurons by doing a Wilcoxon Signed-Rank test between a baseline period 

17 before a certain task event (e.g. stimulus onset) and a period after the task event. 

18 

19 Parameters 

20 ---------- 

21 spike_times : 1D array 

22 spike times (in seconds) 

23 spike_clusters : 1D array 

24 cluster ids corresponding to each event in `spikes` 

25 event_times : 1D array 

26 times (in seconds) of the events from the two groups 

27 pre_time : two-element array 

28 time (in seconds) preceding the event to get the baseline (e.g. [0.5, 0.2] would be a 

29 window starting 0.5 seconds before the event and ending at 0.2 seconds before the event) 

30 post_time : two-element array 

31 time (in seconds) to follow the event times 

32 alpha : float 

33 alpha to use for statistical significance 

34 fdr_corr : boolean 

35 whether to use an FDR correction (Benjamin-Hochmann) to correct for multiple testing 

36 use_fr : bool 

37 whether to use the firing rate instead of total spike count 

38 

39 Returns 

40 ------- 

41 significant_units : ndarray 

42 an array with the indices of clusters that are significatly modulated 

43 stats : 1D array 

44 the statistic of the test that was performed 

45 p_values : ndarray 

46 the p-values of all the clusters 

47 cluster_ids : ndarray 

48 cluster ids of the p-values 

49 """ 

50 

51 # Get spike counts for baseline and event timewindow 

52 baseline_times = np.column_stack(((event_times - pre_time[0]), (event_times - pre_time[1]))) 1e

53 baseline_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, 1e

54 baseline_times) 

55 times = np.column_stack(((event_times + post_time[0]), (event_times + post_time[1]))) 1e

56 spike_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) 1e

57 

58 if use_fr: 1e

59 baseline_counts = baseline_counts / (pre_time[0] - pre_time[1]) 

60 spike_counts = spike_counts / (post_time[1] - post_time[0]) 

61 

62 # Do statistics 

63 sig_units, stats, p_values = compute_comparison_statistics(baseline_counts, spike_counts, test='signrank', alpha=alpha) 1e

64 significant_units = cluster_ids[sig_units] 1e

65 

66 return significant_units, stats, p_values, cluster_ids 1e

67 

68 

69def differentiate_units(spike_times, spike_clusters, event_times, event_groups, 

70 pre_time=0, post_time=0.5, test='ranksums', alpha=0.05, fdr_corr=False): 

71 """ 

72 Determine units which significantly differentiate between two task events 

73 (e.g. stimulus left/right) by performing a statistical test between the spike rates 

74 elicited by the two events. Default is a Wilcoxon Rank Sum test. 

75 

76 Parameters 

77 ---------- 

78 spike_times : 1D array 

79 spike times (in seconds) 

80 spike_clusters : 1D array 

81 cluster ids corresponding to each event in `spikes` 

82 event_times : 1D array 

83 times (in seconds) of the events from the two groups 

84 event_groups : 1D array 

85 group identities of the events as either 0 or 1 

86 pre_time : float 

87 time (in seconds) to precede the event times to get the baseline 

88 post_time : float 

89 time (in seconds) to follow the event times 

90 test : string 

91 which statistical test to use, options are: 

92 'ranksums' Wilcoxon Rank Sums test 

93 'signrank' Wilcoxon Signed Rank test (for paired observations) 

94 'ttest' independent samples t-test 

95 'paired_ttest' paired t-test 

96 alpha : float 

97 alpha to use for statistical significance 

98 fdr_corr : boolean 

99 whether to use an FDR correction (Benjamin-Hochmann) to correct for multiple testing 

100 

101 Returns 

102 ------- 

103 significant_units : 1D array 

104 an array with the indices of clusters that are significatly modulated 

105 stats : 1D array 

106 the statistic of the test that was performed 

107 p_values : 1D array 

108 the p-values of all the clusters 

109 cluster_ids : ndarray 

110 cluster ids of the p-values 

111 """ 

112 

113 # Check input 

114 assert test in ['ranksums', 'signrank', 'ttest', 'paired_ttest'] 1d

115 if (test == 'signrank') or (test == 'paired_ttest'): 1d

116 assert np.sum(event_groups == 0) == np.sum(event_groups == 1), \ 

117 'For paired tests the number of events in both groups needs to be the same' 

118 

119 # Get spike counts for the two events 

120 times_1 = np.column_stack(((event_times[event_groups == 0] - pre_time), 1d

121 (event_times[event_groups == 0] + post_time))) 

122 counts_1, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times_1) 1d

123 times_2 = np.column_stack(((event_times[event_groups == 1] - pre_time), 1d

124 (event_times[event_groups == 1] + post_time))) 

125 counts_2, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times_2) 1d

126 

127 # Do statistics 

128 sig_units, stats, p_values = compute_comparison_statistics(counts_1, counts_2, test=test, alpha=alpha) 1d

129 significant_units = cluster_ids[sig_units] 1d

130 

131 return significant_units, stats, p_values, cluster_ids 1d

132 

133 

134def compute_comparison_statistics(value1, value2, test='ranksums', alpha=0.05, fdr_corr=False): 

135 """ 

136 Compute statistical test between two arrays 

137 

138 Parameters 

139 ---------- 

140 value1 : 1D array 

141 first array of values to compare 

142 value2 : 1D array 

143 second array of values to compare 

144 test : string 

145 which statistical test to use, options are: 

146 'ranksums' Wilcoxon Rank Sums test 

147 'signrank' Wilcoxon Signed Rank test (for paired observations) 

148 'ttest' independent samples t-test 

149 'paired_ttest' paired t-test 

150 alpha : float 

151 alpha to use for statistical significance 

152 fdr_corr : boolean 

153 whether to use an FDR correction (Benjamin-Hochmann) to correct for multiple testing 

154 

155 Returns 

156 ------- 

157 significant_units : 1D array 

158 an array with the indices of values that are significatly modulated 

159 stats : 1D array 

160 the statistic of the test that was performed 

161 p_values : 1D array 

162 the p-values of all the values 

163 """ 

164 

165 p_values = np.empty(len(value1)) 1de

166 stats = np.empty(len(value1)) 1de

167 for i in range(len(value1)): 1de

168 if test == 'signrank': 1de

169 if np.sum(value1[i, :] - value2[i, :]) == 0: 1e

170 p_values[i] = 1 1e

171 stats[i] = 0 1e

172 else: 

173 stats[i], p_values[i] = wilcoxon(value1[i, :], value2[i, :]) 1e

174 else: 

175 if (np.sum(value1[i, :]) == 0) and (np.sum(value2[i, :]) == 0): 1d

176 p_values[i] = 1 1d

177 stats[i] = 0 1d

178 else: 

179 if test == 'ranksums': 1d

180 stats[i], p_values[i] = ranksums(value1[i, :], value2[i, :]) 1d

181 elif test == 'ttest': 

182 stats[i], p_values[i] = ttest_ind(value1[i, :], value2[i, :]) 

183 elif test == 'paired_ttest': 

184 stats[i], p_values[i] = ttest_rel(value1[i, :], value2[i, :]) 

185 

186 # Perform Benjamin-Hochmann FDR correction for multiple testing 

187 if fdr_corr: 1de

188 sig_units, p_values, _, _ = multipletests(p_values, alpha, method='fdr_bh') 

189 else: 

190 sig_units = p_values < alpha 1de

191 

192 return sig_units, stats, p_values 1de

193 

194 

195def roc_single_event(spike_times, spike_clusters, event_times, 

196 pre_time=[0.5, 0], post_time=[0, 0.5]): 

197 """ 

198 Determine how well neurons respond to a certain task event by calculating the area under the 

199 ROC curve between a baseline period before the event and a period after the event. 

200 Values of > 0.5 indicate the neuron respons positively to the event and < 0.5 indicate 

201 a negative response. 

202 

203 Parameters 

204 ---------- 

205 spike_times : 1D array 

206 spike times (in seconds) 

207 spike_clusters : 1D array 

208 cluster ids corresponding to each event in `spikes` 

209 event_times : 1D array 

210 times (in seconds) of the events from the two groups 

211 pre_time : two-element array 

212 time (in seconds) preceding the event to get the baseline (e.g. [0.5, 0.2] would be a 

213 window starting 0.5 seconds before the event and ending at 0.2 seconds before the event) 

214 post_time : two-element array 

215 time (in seconds) to follow the event times 

216 

217 Returns 

218 ------- 

219 auc_roc : 1D array 

220 the area under the ROC curve 

221 cluster_ids : 1D array 

222 cluster ids of the p-values 

223 """ 

224 

225 # Get spike counts for baseline and event timewindow 

226 baseline_times = np.column_stack(((event_times - pre_time[0]), (event_times - pre_time[1]))) 1h

227 baseline_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, 1h

228 baseline_times) 

229 times = np.column_stack(((event_times + post_time[0]), (event_times + post_time[1]))) 1h

230 spike_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) 1h

231 

232 # Calculate area under the ROC curve per neuron 

233 auc_roc = np.empty(spike_counts.shape[0]) 1h

234 for i in range(spike_counts.shape[0]): 1h

235 auc_roc[i] = roc_auc_score(np.concatenate((np.zeros(baseline_counts.shape[1]), 1h

236 np.ones(spike_counts.shape[1]))), 

237 np.concatenate((baseline_counts[i, :], spike_counts[i, :]))) 

238 

239 return auc_roc, cluster_ids 1h

240 

241 

242def roc_between_two_events(spike_times, spike_clusters, event_times, event_groups, 

243 pre_time=0, post_time=0.25): 

244 """ 

245 Calcluate area under the ROC curve that indicates how well the activity of the neuron 

246 distiguishes between two events (e.g. movement to the right vs left). A value of 0.5 indicates 

247 the neuron cannot distiguish between the two events. A value of 0 or 1 indicates maximum 

248 distinction. Significance is determined by bootstrapping the ROC curves. If 0.5 is not 

249 included in the 95th percentile of the bootstrapped distribution, the neuron is deemed 

250 to be significant. 

251 

252 Parameters 

253 ---------- 

254 spike_times : 1D array 

255 spike times (in seconds) 

256 spike_clusters : 1D array 

257 cluster ids corresponding to each event in `spikes` 

258 event_times : 1D array 

259 times (in seconds) of the events from the two groups 

260 event_groups : 1D array 

261 group identities of the events as either 0 or 1 

262 pre_time : float 

263 time (in seconds) to precede the event times 

264 post_time : float 

265 time (in seconds) to follow the event times 

266 

267 Returns 

268 ------- 

269 auc_roc : 1D array 

270 an array of the area under the ROC curve for every neuron 

271 cluster_ids : 1D array 

272 cluster ids of the AUC values 

273 """ 

274 

275 # Get spike counts 

276 times = np.column_stack(((event_times - pre_time), (event_times + post_time))) 1i

277 spike_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) 1i

278 

279 # Calculate area under the ROC curve per neuron 

280 auc_roc = np.empty(spike_counts.shape[0]) 1i

281 for i in range(spike_counts.shape[0]): 1i

282 auc_roc[i] = roc_auc_score(event_groups, spike_counts[i, :]) 1i

283 

284 return auc_roc, cluster_ids 1i

285 

286 

287def _get_biased_probs(n: int, idx: int = -1, prob: float = 0.5) -> list: 

288 n_1 = n - 1 1a

289 z = n_1 + prob 1a

290 p = [1 / z] * (n_1 + 1) 1a

291 p[idx] *= prob 1a

292 return p 1a

293 

294 

295def _draw_contrast( 

296 contrast_set: list, prob_type: str = "biased", idx: int = -1, idx_prob: float = 0.5 

297) -> float: 

298 if prob_type in ["non-uniform", "biased"]: 1ab

299 p = _get_biased_probs(len(contrast_set), idx=idx, prob=idx_prob) 1a

300 return np.random.choice(contrast_set, p=p) 1a

301 elif prob_type == "uniform": 1ab

302 return np.random.choice(contrast_set) 1ab

303 

304 

305def _draw_position(position_set, stim_probability_left): 

306 return int( 1ab

307 np.random.choice( 

308 position_set, p=[stim_probability_left, 1 - stim_probability_left] 

309 ) 

310 ) 

311 

312 

313def generate_pseudo_blocks(n_trials, factor=60, min_=20, max_=100, first5050=90): 

314 """ 

315 Generate a pseudo block structure 

316 

317 Parameters 

318 ---------- 

319 n_trials : int 

320 how many trials to generate 

321 factor : int 

322 factor of the exponential 

323 min_ : int 

324 minimum number of trials per block 

325 max_ : int 

326 maximum number of trials per block 

327 first5050 : int 

328 amount of trials with 50/50 left right probability at the beginning 

329 

330 Returns 

331 --------- 

332 probabilityLeft : 1D array 

333 array with probability left per trial 

334 """ 

335 

336 block_ids = [] 1gab

337 while len(block_ids) < n_trials: 1gab

338 x = np.random.exponential(factor) 1gab

339 while (x <= min_) | (x >= max_): 1gab

340 x = np.random.exponential(factor) 1gab

341 if (len(block_ids) == 0) & (np.random.randint(2) == 0): 1gab

342 block_ids += [0.2] * int(x) 1a

343 elif (len(block_ids) == 0): 1gab

344 block_ids += [0.8] * int(x) 1gab

345 elif block_ids[-1] == 0.2: 1gab

346 block_ids += [0.8] * int(x) 1ga

347 elif block_ids[-1] == 0.8: 1gab

348 block_ids += [0.2] * int(x) 1gab

349 return np.array([0.5] * first5050 + block_ids[:n_trials - first5050]) 1gab

350 

351 

352def generate_pseudo_stimuli(n_trials, contrast_set=[0, 0.06, 0.12, 0.25, 1], first5050=90): 

353 """ 

354 Generate a block structure with stimuli 

355 

356 Parameters 

357 ---------- 

358 n_trials : int 

359 number of trials to generate 

360 contrast_set : 1D array 

361 the contrasts that are presented. The default is [0.06, 0.12, 0.25, 1]. 

362 first5050 : int 

363 Number of 50/50 trials at the beginning of the session. The default is 90. 

364 

365 Returns 

366 ------- 

367 p_left : 1D array 

368 probability of left stimulus 

369 contrast_left : 1D array 

370 contrast on the left 

371 contrast_right : 1D array 

372 contrast on the right 

373 

374 """ 

375 

376 # Initialize vectors 

377 contrast_left = np.empty(n_trials) 1b

378 contrast_left[:] = np.nan 1b

379 contrast_right = np.empty(n_trials) 1b

380 contrast_right[:] = np.nan 1b

381 

382 # Generate block structure 

383 p_left = generate_pseudo_blocks(n_trials, first5050=first5050) 1b

384 

385 for i in range(n_trials): 1b

386 

387 # Draw position and contrast for this trial 

388 position = _draw_position([-1, 1], p_left[i]) 1b

389 contrast = _draw_contrast(contrast_set, 'uniform') 1b

390 

391 # Add to trials 

392 if position == -1: 1b

393 contrast_left[i] = contrast 1b

394 elif position == 1: 1b

395 contrast_right[i] = contrast 1b

396 

397 return p_left, contrast_left, contrast_right 1b

398 

399 

400def generate_pseudo_session(trials, generate_choices=True, contrast_distribution='non-uniform'): 

401 """ 

402 Generate a complete pseudo session with biased blocks, all stimulus contrasts, choices and 

403 rewards and omissions. Biased blocks and stimulus contrasts are generated using the same 

404 statistics as used in the actual task. The choices of the animal are generated using the 

405 actual psychometrics of the animal in the session. For each synthetic trial the choice is 

406 determined by drawing from a Bernoulli distribution that is biased according to the proportion 

407 of times the animal chose left for the stimulus contrast, side, and block probability. 

408 No-go trials are ignored in the generating of the synthetic choices. 

409 

410 Parameters 

411 ---------- 

412 trials : DataFrame 

413 Pandas dataframe with columns as trial vectors loaded using ONE 

414 generate_choices : bool 

415 whether to generate the choices (runs faster without) 

416 contrast_distribution: str ['uniform', 'non-uniform'] 

417 the absolute contrast distribution. 

418 If uniform, the zero contrast is as likely as other contrasts: BiasedChoiceWorld task 

419 If 'non-uniform', the zero contrast is half as likely to occur: EphysChoiceWorld task 

420 ('biased' is kept for compatibility, but is deprecated as it is confusing) 

421 

422 Returns 

423 ------- 

424 pseudo_trials : DataFrame 

425 a trials dataframe with synthetically generated trials 

426 """ 

427 # Get contrast set presented to the animal 

428 contrast_set = np.unique(trials['contrastLeft'][~np.isnan(trials['contrastLeft'])]) 1a

429 signed_contrast = trials['contrastRight'].copy() 1a

430 signed_contrast[np.isnan(signed_contrast)] = -trials['contrastLeft'][ 1a

431 ~np.isnan(trials['contrastLeft'])] 

432 

433 # Generate synthetic session 

434 pseudo_trials = pd.DataFrame() 1a

435 pseudo_trials['probabilityLeft'] = generate_pseudo_blocks(trials.shape[0]) 1a

436 

437 # For each trial draw stimulus contrast and side and generate a synthetic choice 

438 for i in range(pseudo_trials.shape[0]): 1a

439 

440 # Draw position and contrast for this trial 

441 position = _draw_position([-1, 1], pseudo_trials['probabilityLeft'][i]) 1a

442 contrast = _draw_contrast(contrast_set, prob_type=contrast_distribution, idx=np.where(contrast_set == 0)[0][0]) 1a

443 signed_stim = contrast * np.sign(position) 1a

444 

445 if generate_choices: 1a

446 # Generate synthetic choice by drawing from Bernoulli distribution 

447 trial_select = ((signed_contrast == signed_stim) & (trials['choice'] != 0) 

448 & (trials['probabilityLeft'] == pseudo_trials['probabilityLeft'][i])) 

449 p_right = (np.sum(trials['choice'][trial_select] == 1) 

450 / trials['choice'][trial_select].shape[0]) 

451 this_choice = [-1, 1][np.random.binomial(1, p_right)] 

452 

453 # Add to trials 

454 if position == -1: 

455 pseudo_trials.loc[i, 'contrastLeft'] = contrast 

456 if this_choice == -1: 

457 pseudo_trials.loc[i, 'feedbackType'] = -1 

458 elif this_choice == 1: 

459 pseudo_trials.loc[i, 'feedbackType'] = 1 

460 elif position == 1: 

461 pseudo_trials.loc[i, 'contrastRight'] = contrast 

462 if this_choice == -1: 

463 pseudo_trials.loc[i, 'feedbackType'] = 1 

464 elif this_choice == 1: 

465 pseudo_trials.loc[i, 'feedbackType'] = -1 

466 pseudo_trials.loc[i, 'choice'] = this_choice 

467 else: 

468 if position == -1: 1a

469 pseudo_trials.loc[i, 'contrastLeft'] = contrast 1a

470 elif position == 1: 1a

471 pseudo_trials.loc[i, 'contrastRight'] = contrast 1a

472 pseudo_trials.loc[i, 'stim_side'] = position 1a

473 pseudo_trials['signed_contrast'] = pseudo_trials['contrastRight'] 1a

474 pseudo_trials.loc[pseudo_trials['signed_contrast'].isnull(), 1a

475 'signed_contrast'] = -pseudo_trials['contrastLeft'] 

476 return pseudo_trials 1a

477 

478 

479def get_impostor_target(targets, labels, current_label=None, 

480 seed_idx=None, verbose=False): 

481 """ 

482 Generate impostor targets by selecting from a list of current targets of variable length. 

483 Targets are selected and stitched together to the length of the current labeled target, 

484 aka 'Frankenstein' targets, often used for evaluating a null distribution while decoding. 

485 

486 Parameters 

487 ---------- 

488 targets : list of all targets 

489 targets may be arrays of any dimension (a,b,...,z) 

490 but must have the same shape except for the last dimension, z. All targets must 

491 have z > 0. 

492 labels : numpy array of strings 

493 labels corresponding to each target e.g. session eid. 

494 only targets with unique labels are used to create impostor target. Typically, 

495 use eid as the label because each eid has a unique target. 

496 current_label : string 

497 targets with the current label are not used to create impostor 

498 target. Size of corresponding target is used to determine size of impostor 

499 target. If None, a random selection from the set of unique labels is used. 

500 

501 Returns 

502 -------- 

503 impostor_final : numpy array, same shape as all targets except last dimension 

504 

505 """ 

506 

507 np.random.seed(seed_idx) 1c

508 

509 unique_labels, unique_label_idxs = np.unique(labels, return_index=True) 1c

510 unique_targets = [targets[unique_label_idxs[i]] for i in range(len(unique_label_idxs))] 1c

511 if current_label is None: 1c

512 current_label = np.random.choice(unique_labels) 1c

513 avoid_same_label = ~(unique_labels == current_label) 1c

514 # current label must correspond to exactly one unique label 

515 assert len(np.nonzero(~avoid_same_label)[0]) == 1 1c

516 avoided_index = np.nonzero(~avoid_same_label)[0][0] 1c

517 nonavoided_indices = np.nonzero(avoid_same_label)[0] 1c

518 ntargets = len(nonavoided_indices) 1c

519 all_impostor_targets = [unique_targets[nonavoided_indices[i]] for i in range(ntargets)] 1c

520 all_impostor_sizes = np.array([all_impostor_targets[i].shape[-1] for i in range(ntargets)]) 1c

521 current_target_size = unique_targets[avoided_index].shape[-1] 1c

522 if verbose: 1c

523 print('impostor target has length %s' % (current_target_size)) 

524 assert np.min(all_impostor_sizes) > 0 # all targets must be nonzero in size 1c

525 max_needed_to_tile = int(np.max(all_impostor_sizes) / np.min(all_impostor_sizes)) + 1 1c

526 tile_indices = np.random.choice(np.arange(len(all_impostor_targets), dtype=int), 1c

527 size=max_needed_to_tile, 

528 replace=False) 

529 impostor_tiles = [all_impostor_targets[tile_indices[i]] for i in range(len(tile_indices))] 1c

530 impostor_tile_sizes = all_impostor_sizes[tile_indices] 1c

531 if verbose: 1c

532 print('Randomly chose %s targets to tile the impostor target' % (max_needed_to_tile)) 

533 print('with the following sizes:', impostor_tile_sizes) 

534 

535 number_of_tiles_needed = np.sum(np.cumsum(impostor_tile_sizes) < current_target_size) + 1 1c

536 impostor_tiles = impostor_tiles[:number_of_tiles_needed] 1c

537 if verbose: 1c

538 print('%s of %s needed to tile the entire impostor target' % (number_of_tiles_needed, 

539 max_needed_to_tile)) 

540 

541 impostor_stitch = np.concatenate(impostor_tiles, axis=-1) 1c

542 start_ind = np.random.randint((impostor_stitch.shape[-1] - current_target_size) + 1) 1c

543 impostor_final = impostor_stitch[..., start_ind:start_ind + current_target_size] 1c

544 if verbose: 1c

545 print('%s targets stitched together with shift of %s\n' % (number_of_tiles_needed, 

546 start_ind)) 

547 

548 np.random.seed(None) # reset numpy seed to None 1c

549 

550 return impostor_final 1c