Coverage for brainbox/task/trials.py: 51%
160 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1from iblutil.numerical import ismember, bincount2D
2import numpy as np
5def find_trial_ids(trials, side='all', choice='all', order='trial num', sort='idx',
6 contrast=(1, 0.5, 0.25, 0.125, 0.0625, 0), event=None):
7 """
8 Finds trials that match criterion
9 :param trials: trials object. Must contain attributes contrastLeft, contrastRight and
10 feedbackType
11 :param side: stimulus side, options are 'all', 'left' or 'right'
12 :param choice: trial choice, options are 'all', 'correct' or 'incorrect'
13 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be
14 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast
15 :param order: how to order the trials, options are 'trial num' or 'reaction time'
16 :param sort: how to sort the trials, options are 'side' (split left right trials), 'choice'
17 (split correct incorrect trials), 'choice and side' (split left right and correct incorrect)
18 :param event: trial event to align to (in order to remove nan trials for this event)
19 :return: np.array of trial ids, list of dividers to indicate how trials are sorted
20 """
21 if event: 1a
22 idx = ~np.isnan(trials[event])
23 nan_idx = np.where(idx)[0]
24 else:
25 idx = np.ones_like(trials['feedbackType'], dtype=bool) 1a
27 # Find trials that have specified contrasts
28 cont = np.bitwise_or(ismember(trials['contrastLeft'][idx], np.array(contrast))[0], 1a
29 ismember(trials['contrastRight'][idx], np.array(contrast))[0])
31 # Find different permutations of trials
32 # correct right
33 cor_r = np.where( 1a
34 np.bitwise_and(cont, np.bitwise_and(trials['feedbackType'][idx] == 1,
35 np.isfinite(trials['contrastRight'][idx]))))[0]
36 # correct left
37 cor_l = np.where( 1a
38 np.bitwise_and(cont, np.bitwise_and(trials['feedbackType'][idx] == 1,
39 np.isfinite(trials['contrastLeft'][idx]))))[0]
40 # incorrect right
41 incor_r = np.where( 1a
42 np.bitwise_and(cont, np.bitwise_and(trials['feedbackType'][idx] == -1,
43 np.isfinite(trials['contrastRight'][idx]))))[0]
44 # incorrect left
45 incor_l = np.where( 1a
46 np.bitwise_and(cont, np.bitwise_and(trials['feedbackType'][idx] == -1,
47 np.isfinite(trials['contrastLeft'][idx]))))[0]
49 reaction_time = trials['response_times'][idx] - trials['goCue_times'][idx] 1a
51 def _order_by(_trials, order): 1a
52 # Returns subset of trials either ordered by trial number or by reaction time
53 sorted_trials = np.sort(_trials) 1a
54 if order == 'trial num': 1a
55 return sorted_trials 1a
56 elif order == 'reaction time': 1a
57 sorted_reaction = np.argsort(reaction_time[sorted_trials]) 1a
58 return sorted_trials[sorted_reaction] 1a
60 dividers = [] 1a
62 # Find the trial id for all possible combinations
63 if side == 'all' and choice == 'all': 1a
64 if sort == 'idx': 1a
65 trial_id = _order_by(np.r_[cor_r, cor_l, incor_r, incor_l], order) 1a
66 elif sort == 'choice': 1a
67 trial_id = np.r_[_order_by(np.r_[cor_l, cor_r], order),
68 _order_by(np.r_[incor_l, incor_r], order)]
69 dividers.append(np.r_[cor_l, cor_r].shape[0])
70 elif sort == 'side': 1a
71 trial_id = np.r_[_order_by(np.r_[cor_l, incor_l], order),
72 _order_by(np.r_[cor_r, incor_r], order)]
73 dividers.append(np.r_[cor_l, incor_l].shape[0])
74 elif sort == 'choice and side': 1a
75 trial_id = np.r_[_order_by(cor_l, order), _order_by(incor_l, order), 1a
76 _order_by(cor_r, order), _order_by(incor_r, order)]
77 dividers.append(cor_l.shape[0]) 1a
78 dividers.append(np.r_[cor_l, incor_l].shape[0]) 1a
79 dividers.append(np.r_[cor_l, incor_l, cor_r].shape[0]) 1a
81 if side == 'left' and choice == 'all': 1a
82 if sort in ['idx', 'side']: 1a
83 trial_id = _order_by(np.r_[cor_l, incor_l], order) 1a
84 elif sort in ['choice', 'choice and side']: 1a
85 trial_id = np.r_[_order_by(cor_l, order), _order_by(incor_l, order)] 1a
86 dividers.append(cor_l.shape[0]) 1a
88 if side == 'right' and choice == 'all': 1a
89 if sort in ['idx', 'side']: 1a
90 trial_id = _order_by(np.r_[cor_r, incor_r], order) 1a
91 elif sort in ['choice', 'choice and side']:
92 trial_id = np.r_[_order_by(cor_r, order), _order_by(incor_r, order)]
93 dividers.append(cor_r.shape[0])
95 if side == 'all' and choice == 'correct': 1a
96 if sort in ['idx', 'choice']: 1a
97 trial_id = _order_by(np.r_[cor_l, cor_r], order) 1a
98 elif sort in ['side', 'choice and side']:
99 trial_id = np.r_[_order_by(cor_l, order), _order_by(cor_r, order)]
100 dividers.append(cor_l.shape[0])
102 if side == 'all' and choice == 'incorrect': 1a
103 if sort in ['idx', 'choice']: 1a
104 trial_id = _order_by(np.r_[incor_l, incor_r], order) 1a
105 elif sort in ['side', 'choice and side']:
106 trial_id = np.r_[_order_by(incor_l, order), _order_by(incor_r, order)]
107 dividers.append(incor_l.shape[0])
109 if side == 'left' and choice == 'correct': 1a
110 trial_id = _order_by(cor_l, order) 1a
112 if side == 'left' and choice == 'incorrect': 1a
113 trial_id = _order_by(incor_l, order) 1a
115 if side == 'right' and choice == 'correct': 1a
116 trial_id = _order_by(cor_r, order) 1a
118 if side == 'right' and choice == 'incorrect': 1a
119 trial_id = _order_by(incor_r, order) 1a
121 if event: 1a
122 trial_id = nan_idx[trial_id]
124 return trial_id, dividers 1a
127def get_event_aligned_raster(times, events, tbin=0.02, values=None, epoch=[-0.4, 1], bin=True):
128 """
129 Get event aligned raster
130 :param times: array of times e.g spike times or dlc points
131 :param events: array of events to epoch around
132 :param tbin: bin size to over which to count events
133 :param values: values to scale counts by
134 :param epoch: window around each event
135 :param bin: whether to bin times in tbin windows or not
136 :return:
137 """
139 if bin: 1b
140 vals, bin_times, _ = bincount2D(times, np.ones_like(times), xbin=tbin, weights=values) 1b
141 vals = vals[0] 1b
142 t = np.arange(epoch[0], epoch[1] + tbin, tbin) 1b
143 nbin = t.shape[0] 1b
144 else:
145 vals = values
146 bin_times = times
147 tbin = np.mean(np.diff(bin_times))
148 t = np.arange(epoch[0], epoch[1], tbin)
149 nbin = t.shape[0]
151 # remove nan trials
152 non_nan_events = events[~np.isnan(events)] 1b
153 nan_idx = np.where(~np.isnan(events)) 1b
154 intervals = np.c_[non_nan_events + epoch[0], non_nan_events + epoch[1]] 1b
156 # Remove any trials that are later than the last value in bin_times
157 out_intervals = intervals[:, 1] > bin_times[-1] 1b
158 epoch_idx = np.searchsorted(bin_times, intervals)[np.invert(out_intervals)] 1b
160 for ep in range(nbin): 1b
161 if ep == 0: 1b
162 event_raster = (vals[epoch_idx[:, 0] + ep]).astype(float) 1b
163 else:
164 event_raster = np.c_[event_raster, vals[epoch_idx[:, 0] + ep]] 1b
166 # Find any trials that are less than the first value time and fill with nans (case for example
167 # where spiking of cluster doesn't start till after start of first trial due to settling of
168 # brain)
169 event_raster[intervals[np.invert(out_intervals), 0] < bin_times[0]] = np.nan 1b
171 # Add back in the trials that were later than last value with nans
172 if np.sum(out_intervals) > 0: 1b
173 event_raster = np.r_[event_raster, np.full((np.sum(out_intervals), 1b
174 event_raster.shape[1]), np.nan)]
175 assert event_raster.shape[0] == intervals.shape[0] 1b
177 # Reindex if we have removed any nan values
178 all_event_raster = np.full((events.shape[0], event_raster.shape[1]), np.nan) 1b
179 all_event_raster[nan_idx, :] = event_raster 1b
181 return all_event_raster, t 1b
184def get_psth(raster, trial_ids=None):
185 """
186 Compute psth averaged over chosen trials
187 :param raster: output from event aligned raster, window of activity around event
188 :param trial_ids: the trials from the raster to average over
189 :return:
190 """
191 if trial_ids is None:
192 mean = np.nanmean(raster, axis=0)
193 err = np.nanstd(raster, axis=0) / np.sqrt(raster.shape[0])
194 else:
195 raster = filter_by_trial(raster, trial_ids)
196 mean = np.nanmean(raster, axis=0)
197 err = np.nanstd(raster, axis=0) / np.sqrt(raster.shape[0])
199 return mean, err
202def filter_by_trial(raster, trial_id):
203 """
204 Select trials of interest for raster
205 :param raster:
206 :param trial_id:
207 :return:
208 """
209 return raster[trial_id, :]
212def filter_correct_incorrect_left_right(trials, event_raster, event, contrast, order='trial num'):
213 """
214 Return psth for left correct, left incorrect, right correct, right incorrect and raster
215 sorted by these trials
216 :param trials: trials object
217 :param event_raster: output from get_event_aligned_activity
218 :param event: event to align to e.g 'goCue_times', 'stimOn_times'
219 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be
220 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast
221 :param order: order to sort trials by either 'trial num' or 'reaction time'
222 :return:
223 """
224 trials_sorted, div = find_trial_ids(trials, sort='choice and side', event=event, order=order, contrast=contrast)
225 trials_lc, _ = find_trial_ids(trials, side='left', choice='correct', event=event, order=order, contrast=contrast)
226 trials_li, _ = find_trial_ids(trials, side='left', choice='incorrect', event=event,
227 order=order, contrast=contrast)
228 trials_rc, _ = find_trial_ids(trials, side='right', choice='correct', event=event, order=order, contrast=contrast)
229 trials_ri, _ = find_trial_ids(trials, side='right', choice='incorrect', event=event,
230 order=order, contrast=contrast)
232 psth = dict()
233 mean, err = get_psth(event_raster, trials_lc)
234 psth['left_correct'] = {'vals': mean, 'err': err,
235 'linestyle': {'color': 'r'}}
236 mean, err = get_psth(event_raster, trials_li)
237 psth['left_incorrect'] = {'vals': mean, 'err': err,
238 'linestyle': {'color': 'r', 'linestyle': 'dashed'}}
239 mean, err = get_psth(event_raster, trials_rc)
240 psth['right_correct'] = {'vals': mean, 'err': err,
241 'linestyle': {'color': 'b'}}
242 mean, err = get_psth(event_raster, trials_ri)
243 psth['right_incorrect'] = {'vals': mean, 'err': err,
244 'linestyle': {'color': 'b', 'linestyle': 'dashed'}}
246 raster = {}
247 raster['vals'] = filter_by_trial(event_raster, trials_sorted)
248 raster['dividers'] = div
250 return raster, psth
253def filter_correct_incorrect(trials, event_raster, event, contrast, order='trial num'):
254 """
255 Return psth for correct and incorrect trials and raster sorted by correct incorrect
256 :param trials: trials object
257 :param event_raster: output from get_event_aligned_activity
258 :param event: event to align to e.g 'goCue_times', 'stimOn_times'
259 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be
260 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast
261 :param order: order to sort trials by either 'trial num' or 'reaction time'
262 :return:
263 """
264 trials_sorted, div = find_trial_ids(trials, sort='choice', event=event, order=order, contrast=contrast)
265 trials_c, _ = find_trial_ids(trials, side='all', choice='correct', event=event, order=order, contrast=contrast)
266 trials_i, _ = find_trial_ids(trials, side='all', choice='incorrect', event=event, order=order, contrast=contrast)
268 psth = dict()
269 mean, err = get_psth(event_raster, trials_c)
270 psth['correct'] = {'vals': mean, 'err': err, 'linestyle': {'color': 'r'}}
271 mean, err = get_psth(event_raster, trials_i)
272 psth['incorrect'] = {'vals': mean, 'err': err, 'linestyle': {'color': 'b'}}
274 raster = {}
275 raster['vals'] = filter_by_trial(event_raster, trials_sorted)
276 raster['dividers'] = div
278 return raster, psth
281def filter_left_right(trials, event_raster, event, contrast, order='trial num'):
282 """
283 Return psth for left and right trials and raster sorted by left right
284 :param trials: trials object
285 :param event_raster: output from get_event_aligned_activity
286 :param event: event to align to e.g 'goCue_times', 'stimOn_times'
287 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be
288 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast
289 :param order: order to sort trials by either 'trial num' or 'reaction time'
290 :return:
291 """
292 trials_sorted, div = find_trial_ids(trials, sort='side', event=event, order=order, contrast=contrast)
293 trials_l, _ = find_trial_ids(trials, side='left', choice='all', event=event, order=order, contrast=contrast)
294 trials_r, _ = find_trial_ids(trials, side='right', choice='all', event=event, order=order, contrast=contrast)
296 psth = dict()
297 mean, err = get_psth(event_raster, trials_l)
298 psth['left'] = {'vals': mean, 'err': err, 'linestyle': {'color': 'r'}}
299 mean, err = get_psth(event_raster, trials_r)
300 psth['right'] = {'vals': mean, 'err': err, 'linestyle': {'color': 'b'}}
302 raster = {}
303 raster['vals'] = filter_by_trial(event_raster, trials_sorted)
304 raster['dividers'] = div
306 return raster, psth
309def filter_trials(trials, event_raster, event, contrast=(1, 0.5, 0.25, 0.125, 0.0625, 0), order='trial num', sort='choice'):
310 """
311 Wrapper to get out psth and raster for trial choice
312 :param trials: trials object
313 :param event_raster: output from get_event_aligned_activity
314 :param event: event to align to e.g 'goCue_times', 'stimOn_times'
315 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be
316 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast
317 :param order: order to sort trials by either 'trial num' or 'reaction time'
318 :param sort: how to divide trials options are 'choice' (e.g correct vs incorrect), 'side'
319 (e.g left vs right') and 'choice and side' (e.g correct vs incorrect and left vs right)
320 :return:
321 """
322 if sort == 'choice':
323 raster, psth = filter_correct_incorrect(trials, event_raster, event, contrast, order)
324 elif sort == 'side':
325 raster, psth = filter_left_right(trials, event_raster, event, contrast, order)
326 elif sort == 'choice and side':
327 raster, psth = filter_correct_incorrect_left_right(trials, event_raster, event, contrast, order)
329 return raster, psth