Coverage for brainbox/task/closed_loop.py: 84%
179 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1'''
2Computes task related output
3'''
5import numpy as np
6from scipy.stats import ranksums, wilcoxon, ttest_ind, ttest_rel
7from ._statsmodels import multipletests
8from sklearn.metrics import roc_auc_score
9import pandas as pd
10from brainbox.population.decode import get_spike_counts_in_bins
13def responsive_units(spike_times, spike_clusters, event_times, pre_time=[0.5, 0],
14 post_time=[0, 0.5], alpha=0.05, fdr_corr=False, use_fr=False):
15 """
16 Determine responsive neurons by doing a Wilcoxon Signed-Rank test between a baseline period
17 before a certain task event (e.g. stimulus onset) and a period after the task event.
19 Parameters
20 ----------
21 spike_times : 1D array
22 spike times (in seconds)
23 spike_clusters : 1D array
24 cluster ids corresponding to each event in `spikes`
25 event_times : 1D array
26 times (in seconds) of the events from the two groups
27 pre_time : two-element array
28 time (in seconds) preceding the event to get the baseline (e.g. [0.5, 0.2] would be a
29 window starting 0.5 seconds before the event and ending at 0.2 seconds before the event)
30 post_time : two-element array
31 time (in seconds) to follow the event times
32 alpha : float
33 alpha to use for statistical significance
34 fdr_corr : boolean
35 whether to use an FDR correction (Benjamin-Hochmann) to correct for multiple testing
36 use_fr : bool
37 whether to use the firing rate instead of total spike count
39 Returns
40 -------
41 significant_units : ndarray
42 an array with the indices of clusters that are significatly modulated
43 stats : 1D array
44 the statistic of the test that was performed
45 p_values : ndarray
46 the p-values of all the clusters
47 cluster_ids : ndarray
48 cluster ids of the p-values
49 """
51 # Get spike counts for baseline and event timewindow
52 baseline_times = np.column_stack(((event_times - pre_time[0]), (event_times - pre_time[1]))) 1e
53 baseline_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, 1e
54 baseline_times)
55 times = np.column_stack(((event_times + post_time[0]), (event_times + post_time[1]))) 1e
56 spike_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) 1e
58 if use_fr: 1e
59 baseline_counts = baseline_counts / (pre_time[0] - pre_time[1])
60 spike_counts = spike_counts / (post_time[1] - post_time[0])
62 # Do statistics
63 sig_units, stats, p_values = compute_comparison_statistics(baseline_counts, spike_counts, test='signrank', alpha=alpha) 1e
64 significant_units = cluster_ids[sig_units] 1e
66 return significant_units, stats, p_values, cluster_ids 1e
69def differentiate_units(spike_times, spike_clusters, event_times, event_groups,
70 pre_time=0, post_time=0.5, test='ranksums', alpha=0.05, fdr_corr=False):
71 """
72 Determine units which significantly differentiate between two task events
73 (e.g. stimulus left/right) by performing a statistical test between the spike rates
74 elicited by the two events. Default is a Wilcoxon Rank Sum test.
76 Parameters
77 ----------
78 spike_times : 1D array
79 spike times (in seconds)
80 spike_clusters : 1D array
81 cluster ids corresponding to each event in `spikes`
82 event_times : 1D array
83 times (in seconds) of the events from the two groups
84 event_groups : 1D array
85 group identities of the events as either 0 or 1
86 pre_time : float
87 time (in seconds) to precede the event times to get the baseline
88 post_time : float
89 time (in seconds) to follow the event times
90 test : string
91 which statistical test to use, options are:
92 'ranksums' Wilcoxon Rank Sums test
93 'signrank' Wilcoxon Signed Rank test (for paired observations)
94 'ttest' independent samples t-test
95 'paired_ttest' paired t-test
96 alpha : float
97 alpha to use for statistical significance
98 fdr_corr : boolean
99 whether to use an FDR correction (Benjamin-Hochmann) to correct for multiple testing
101 Returns
102 -------
103 significant_units : 1D array
104 an array with the indices of clusters that are significatly modulated
105 stats : 1D array
106 the statistic of the test that was performed
107 p_values : 1D array
108 the p-values of all the clusters
109 cluster_ids : ndarray
110 cluster ids of the p-values
111 """
113 # Check input
114 assert test in ['ranksums', 'signrank', 'ttest', 'paired_ttest'] 1d
115 if (test == 'signrank') or (test == 'paired_ttest'): 1d
116 assert np.sum(event_groups == 0) == np.sum(event_groups == 1), \
117 'For paired tests the number of events in both groups needs to be the same'
119 # Get spike counts for the two events
120 times_1 = np.column_stack(((event_times[event_groups == 0] - pre_time), 1d
121 (event_times[event_groups == 0] + post_time)))
122 counts_1, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times_1) 1d
123 times_2 = np.column_stack(((event_times[event_groups == 1] - pre_time), 1d
124 (event_times[event_groups == 1] + post_time)))
125 counts_2, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times_2) 1d
127 # Do statistics
128 sig_units, stats, p_values = compute_comparison_statistics(counts_1, counts_2, test=test, alpha=alpha) 1d
129 significant_units = cluster_ids[sig_units] 1d
131 return significant_units, stats, p_values, cluster_ids 1d
134def compute_comparison_statistics(value1, value2, test='ranksums', alpha=0.05, fdr_corr=False):
135 """
136 Compute statistical test between two arrays
138 Parameters
139 ----------
140 value1 : 1D array
141 first array of values to compare
142 value2 : 1D array
143 second array of values to compare
144 test : string
145 which statistical test to use, options are:
146 'ranksums' Wilcoxon Rank Sums test
147 'signrank' Wilcoxon Signed Rank test (for paired observations)
148 'ttest' independent samples t-test
149 'paired_ttest' paired t-test
150 alpha : float
151 alpha to use for statistical significance
152 fdr_corr : boolean
153 whether to use an FDR correction (Benjamin-Hochmann) to correct for multiple testing
155 Returns
156 -------
157 significant_units : 1D array
158 an array with the indices of values that are significatly modulated
159 stats : 1D array
160 the statistic of the test that was performed
161 p_values : 1D array
162 the p-values of all the values
163 """
165 p_values = np.empty(len(value1)) 1de
166 stats = np.empty(len(value1)) 1de
167 for i in range(len(value1)): 1de
168 if test == 'signrank': 1de
169 if np.sum(value1[i, :] - value2[i, :]) == 0: 1e
170 p_values[i] = 1 1e
171 stats[i] = 0 1e
172 else:
173 stats[i], p_values[i] = wilcoxon(value1[i, :], value2[i, :]) 1e
174 else:
175 if (np.sum(value1[i, :]) == 0) and (np.sum(value2[i, :]) == 0): 1d
176 p_values[i] = 1 1d
177 stats[i] = 0 1d
178 else:
179 if test == 'ranksums': 1d
180 stats[i], p_values[i] = ranksums(value1[i, :], value2[i, :]) 1d
181 elif test == 'ttest':
182 stats[i], p_values[i] = ttest_ind(value1[i, :], value2[i, :])
183 elif test == 'paired_ttest':
184 stats[i], p_values[i] = ttest_rel(value1[i, :], value2[i, :])
186 # Perform Benjamin-Hochmann FDR correction for multiple testing
187 if fdr_corr: 1de
188 sig_units, p_values, _, _ = multipletests(p_values, alpha, method='fdr_bh')
189 else:
190 sig_units = p_values < alpha 1de
192 return sig_units, stats, p_values 1de
195def roc_single_event(spike_times, spike_clusters, event_times,
196 pre_time=[0.5, 0], post_time=[0, 0.5]):
197 """
198 Determine how well neurons respond to a certain task event by calculating the area under the
199 ROC curve between a baseline period before the event and a period after the event.
200 Values of > 0.5 indicate the neuron respons positively to the event and < 0.5 indicate
201 a negative response.
203 Parameters
204 ----------
205 spike_times : 1D array
206 spike times (in seconds)
207 spike_clusters : 1D array
208 cluster ids corresponding to each event in `spikes`
209 event_times : 1D array
210 times (in seconds) of the events from the two groups
211 pre_time : two-element array
212 time (in seconds) preceding the event to get the baseline (e.g. [0.5, 0.2] would be a
213 window starting 0.5 seconds before the event and ending at 0.2 seconds before the event)
214 post_time : two-element array
215 time (in seconds) to follow the event times
217 Returns
218 -------
219 auc_roc : 1D array
220 the area under the ROC curve
221 cluster_ids : 1D array
222 cluster ids of the p-values
223 """
225 # Get spike counts for baseline and event timewindow
226 baseline_times = np.column_stack(((event_times - pre_time[0]), (event_times - pre_time[1]))) 1h
227 baseline_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, 1h
228 baseline_times)
229 times = np.column_stack(((event_times + post_time[0]), (event_times + post_time[1]))) 1h
230 spike_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) 1h
232 # Calculate area under the ROC curve per neuron
233 auc_roc = np.empty(spike_counts.shape[0]) 1h
234 for i in range(spike_counts.shape[0]): 1h
235 auc_roc[i] = roc_auc_score(np.concatenate((np.zeros(baseline_counts.shape[1]), 1h
236 np.ones(spike_counts.shape[1]))),
237 np.concatenate((baseline_counts[i, :], spike_counts[i, :])))
239 return auc_roc, cluster_ids 1h
242def roc_between_two_events(spike_times, spike_clusters, event_times, event_groups,
243 pre_time=0, post_time=0.25):
244 """
245 Calcluate area under the ROC curve that indicates how well the activity of the neuron
246 distiguishes between two events (e.g. movement to the right vs left). A value of 0.5 indicates
247 the neuron cannot distiguish between the two events. A value of 0 or 1 indicates maximum
248 distinction. Significance is determined by bootstrapping the ROC curves. If 0.5 is not
249 included in the 95th percentile of the bootstrapped distribution, the neuron is deemed
250 to be significant.
252 Parameters
253 ----------
254 spike_times : 1D array
255 spike times (in seconds)
256 spike_clusters : 1D array
257 cluster ids corresponding to each event in `spikes`
258 event_times : 1D array
259 times (in seconds) of the events from the two groups
260 event_groups : 1D array
261 group identities of the events as either 0 or 1
262 pre_time : float
263 time (in seconds) to precede the event times
264 post_time : float
265 time (in seconds) to follow the event times
267 Returns
268 -------
269 auc_roc : 1D array
270 an array of the area under the ROC curve for every neuron
271 cluster_ids : 1D array
272 cluster ids of the AUC values
273 """
275 # Get spike counts
276 times = np.column_stack(((event_times - pre_time), (event_times + post_time))) 1i
277 spike_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) 1i
279 # Calculate area under the ROC curve per neuron
280 auc_roc = np.empty(spike_counts.shape[0]) 1i
281 for i in range(spike_counts.shape[0]): 1i
282 auc_roc[i] = roc_auc_score(event_groups, spike_counts[i, :]) 1i
284 return auc_roc, cluster_ids 1i
287def _get_biased_probs(n: int, idx: int = -1, prob: float = 0.5) -> list:
288 n_1 = n - 1 1a
289 z = n_1 + prob 1a
290 p = [1 / z] * (n_1 + 1) 1a
291 p[idx] *= prob 1a
292 return p 1a
295def _draw_contrast(
296 contrast_set: list, prob_type: str = "biased", idx: int = -1, idx_prob: float = 0.5
297) -> float:
298 if prob_type in ["non-uniform", "biased"]: 1ab
299 p = _get_biased_probs(len(contrast_set), idx=idx, prob=idx_prob) 1a
300 return np.random.choice(contrast_set, p=p) 1a
301 elif prob_type == "uniform": 1ab
302 return np.random.choice(contrast_set) 1ab
305def _draw_position(position_set, stim_probability_left):
306 return int( 1ab
307 np.random.choice(
308 position_set, p=[stim_probability_left, 1 - stim_probability_left]
309 )
310 )
313def generate_pseudo_blocks(n_trials, factor=60, min_=20, max_=100, first5050=90):
314 """
315 Generate a pseudo block structure
317 Parameters
318 ----------
319 n_trials : int
320 how many trials to generate
321 factor : int
322 factor of the exponential
323 min_ : int
324 minimum number of trials per block
325 max_ : int
326 maximum number of trials per block
327 first5050 : int
328 amount of trials with 50/50 left right probability at the beginning
330 Returns
331 ---------
332 probabilityLeft : 1D array
333 array with probability left per trial
334 """
336 block_ids = [] 1gab
337 while len(block_ids) < n_trials: 1gab
338 x = np.random.exponential(factor) 1gab
339 while (x <= min_) | (x >= max_): 1gab
340 x = np.random.exponential(factor) 1gab
341 if (len(block_ids) == 0) & (np.random.randint(2) == 0): 1gab
342 block_ids += [0.2] * int(x) 1a
343 elif (len(block_ids) == 0): 1gab
344 block_ids += [0.8] * int(x) 1gab
345 elif block_ids[-1] == 0.2: 1gab
346 block_ids += [0.8] * int(x) 1ga
347 elif block_ids[-1] == 0.8: 1gab
348 block_ids += [0.2] * int(x) 1gab
349 return np.array([0.5] * first5050 + block_ids[:n_trials - first5050]) 1gab
352def generate_pseudo_stimuli(n_trials, contrast_set=[0, 0.06, 0.12, 0.25, 1], first5050=90):
353 """
354 Generate a block structure with stimuli
356 Parameters
357 ----------
358 n_trials : int
359 number of trials to generate
360 contrast_set : 1D array
361 the contrasts that are presented. The default is [0.06, 0.12, 0.25, 1].
362 first5050 : int
363 Number of 50/50 trials at the beginning of the session. The default is 90.
365 Returns
366 -------
367 p_left : 1D array
368 probability of left stimulus
369 contrast_left : 1D array
370 contrast on the left
371 contrast_right : 1D array
372 contrast on the right
374 """
376 # Initialize vectors
377 contrast_left = np.empty(n_trials) 1b
378 contrast_left[:] = np.nan 1b
379 contrast_right = np.empty(n_trials) 1b
380 contrast_right[:] = np.nan 1b
382 # Generate block structure
383 p_left = generate_pseudo_blocks(n_trials, first5050=first5050) 1b
385 for i in range(n_trials): 1b
387 # Draw position and contrast for this trial
388 position = _draw_position([-1, 1], p_left[i]) 1b
389 contrast = _draw_contrast(contrast_set, 'uniform') 1b
391 # Add to trials
392 if position == -1: 1b
393 contrast_left[i] = contrast 1b
394 elif position == 1: 1b
395 contrast_right[i] = contrast 1b
397 return p_left, contrast_left, contrast_right 1b
400def generate_pseudo_session(trials, generate_choices=True, contrast_distribution='non-uniform'):
401 """
402 Generate a complete pseudo session with biased blocks, all stimulus contrasts, choices and
403 rewards and omissions. Biased blocks and stimulus contrasts are generated using the same
404 statistics as used in the actual task. The choices of the animal are generated using the
405 actual psychometrics of the animal in the session. For each synthetic trial the choice is
406 determined by drawing from a Bernoulli distribution that is biased according to the proportion
407 of times the animal chose left for the stimulus contrast, side, and block probability.
408 No-go trials are ignored in the generating of the synthetic choices.
410 Parameters
411 ----------
412 trials : DataFrame
413 Pandas dataframe with columns as trial vectors loaded using ONE
414 generate_choices : bool
415 whether to generate the choices (runs faster without)
416 contrast_distribution: str ['uniform', 'non-uniform']
417 the absolute contrast distribution.
418 If uniform, the zero contrast is as likely as other contrasts: BiasedChoiceWorld task
419 If 'non-uniform', the zero contrast is half as likely to occur: EphysChoiceWorld task
420 ('biased' is kept for compatibility, but is deprecated as it is confusing)
422 Returns
423 -------
424 pseudo_trials : DataFrame
425 a trials dataframe with synthetically generated trials
426 """
427 # Get contrast set presented to the animal
428 contrast_set = np.unique(trials['contrastLeft'][~np.isnan(trials['contrastLeft'])]) 1a
429 signed_contrast = trials['contrastRight'].copy() 1a
430 signed_contrast[np.isnan(signed_contrast)] = -trials['contrastLeft'][ 1a
431 ~np.isnan(trials['contrastLeft'])]
433 # Generate synthetic session
434 pseudo_trials = pd.DataFrame() 1a
435 pseudo_trials['probabilityLeft'] = generate_pseudo_blocks(trials.shape[0]) 1a
437 # For each trial draw stimulus contrast and side and generate a synthetic choice
438 for i in range(pseudo_trials.shape[0]): 1a
440 # Draw position and contrast for this trial
441 position = _draw_position([-1, 1], pseudo_trials['probabilityLeft'][i]) 1a
442 contrast = _draw_contrast(contrast_set, prob_type=contrast_distribution, idx=np.where(contrast_set == 0)[0][0]) 1a
443 signed_stim = contrast * np.sign(position) 1a
445 if generate_choices: 1a
446 # Generate synthetic choice by drawing from Bernoulli distribution
447 trial_select = ((signed_contrast == signed_stim) & (trials['choice'] != 0)
448 & (trials['probabilityLeft'] == pseudo_trials['probabilityLeft'][i]))
449 p_right = (np.sum(trials['choice'][trial_select] == 1)
450 / trials['choice'][trial_select].shape[0])
451 this_choice = [-1, 1][np.random.binomial(1, p_right)]
453 # Add to trials
454 if position == -1:
455 pseudo_trials.loc[i, 'contrastLeft'] = contrast
456 if this_choice == -1:
457 pseudo_trials.loc[i, 'feedbackType'] = -1
458 elif this_choice == 1:
459 pseudo_trials.loc[i, 'feedbackType'] = 1
460 elif position == 1:
461 pseudo_trials.loc[i, 'contrastRight'] = contrast
462 if this_choice == -1:
463 pseudo_trials.loc[i, 'feedbackType'] = 1
464 elif this_choice == 1:
465 pseudo_trials.loc[i, 'feedbackType'] = -1
466 pseudo_trials.loc[i, 'choice'] = this_choice
467 else:
468 if position == -1: 1a
469 pseudo_trials.loc[i, 'contrastLeft'] = contrast 1a
470 elif position == 1: 1a
471 pseudo_trials.loc[i, 'contrastRight'] = contrast 1a
472 pseudo_trials.loc[i, 'stim_side'] = position 1a
473 pseudo_trials['signed_contrast'] = pseudo_trials['contrastRight'] 1a
474 pseudo_trials.loc[pseudo_trials['signed_contrast'].isnull(), 1a
475 'signed_contrast'] = -pseudo_trials['contrastLeft']
476 return pseudo_trials 1a
479def get_impostor_target(targets, labels, current_label=None,
480 seed_idx=None, verbose=False):
481 """
482 Generate impostor targets by selecting from a list of current targets of variable length.
483 Targets are selected and stitched together to the length of the current labeled target,
484 aka 'Frankenstein' targets, often used for evaluating a null distribution while decoding.
486 Parameters
487 ----------
488 targets : list of all targets
489 targets may be arrays of any dimension (a,b,...,z)
490 but must have the same shape except for the last dimension, z. All targets must
491 have z > 0.
492 labels : numpy array of strings
493 labels corresponding to each target e.g. session eid.
494 only targets with unique labels are used to create impostor target. Typically,
495 use eid as the label because each eid has a unique target.
496 current_label : string
497 targets with the current label are not used to create impostor
498 target. Size of corresponding target is used to determine size of impostor
499 target. If None, a random selection from the set of unique labels is used.
501 Returns
502 --------
503 impostor_final : numpy array, same shape as all targets except last dimension
505 """
507 np.random.seed(seed_idx) 1c
509 unique_labels, unique_label_idxs = np.unique(labels, return_index=True) 1c
510 unique_targets = [targets[unique_label_idxs[i]] for i in range(len(unique_label_idxs))] 1c
511 if current_label is None: 1c
512 current_label = np.random.choice(unique_labels) 1c
513 avoid_same_label = ~(unique_labels == current_label) 1c
514 # current label must correspond to exactly one unique label
515 assert len(np.nonzero(~avoid_same_label)[0]) == 1 1c
516 avoided_index = np.nonzero(~avoid_same_label)[0][0] 1c
517 nonavoided_indices = np.nonzero(avoid_same_label)[0] 1c
518 ntargets = len(nonavoided_indices) 1c
519 all_impostor_targets = [unique_targets[nonavoided_indices[i]] for i in range(ntargets)] 1c
520 all_impostor_sizes = np.array([all_impostor_targets[i].shape[-1] for i in range(ntargets)]) 1c
521 current_target_size = unique_targets[avoided_index].shape[-1] 1c
522 if verbose: 1c
523 print('impostor target has length %s' % (current_target_size))
524 assert np.min(all_impostor_sizes) > 0 # all targets must be nonzero in size 1c
525 max_needed_to_tile = int(np.max(all_impostor_sizes) / np.min(all_impostor_sizes)) + 1 1c
526 tile_indices = np.random.choice(np.arange(len(all_impostor_targets), dtype=int), 1c
527 size=max_needed_to_tile,
528 replace=False)
529 impostor_tiles = [all_impostor_targets[tile_indices[i]] for i in range(len(tile_indices))] 1c
530 impostor_tile_sizes = all_impostor_sizes[tile_indices] 1c
531 if verbose: 1c
532 print('Randomly chose %s targets to tile the impostor target' % (max_needed_to_tile))
533 print('with the following sizes:', impostor_tile_sizes)
535 number_of_tiles_needed = np.sum(np.cumsum(impostor_tile_sizes) < current_target_size) + 1 1c
536 impostor_tiles = impostor_tiles[:number_of_tiles_needed] 1c
537 if verbose: 1c
538 print('%s of %s needed to tile the entire impostor target' % (number_of_tiles_needed,
539 max_needed_to_tile))
541 impostor_stitch = np.concatenate(impostor_tiles, axis=-1) 1c
542 start_ind = np.random.randint((impostor_stitch.shape[-1] - current_target_size) + 1) 1c
543 impostor_final = impostor_stitch[..., start_ind:start_ind + current_target_size] 1c
544 if verbose: 1c
545 print('%s targets stitched together with shift of %s\n' % (number_of_tiles_needed,
546 start_ind))
548 np.random.seed(None) # reset numpy seed to None 1c
550 return impostor_final 1c