Coverage for brainbox/task/trials.py: 51%

160 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1from iblutil.numerical import ismember, bincount2D 

2import numpy as np 

3 

4 

5def find_trial_ids(trials, side='all', choice='all', order='trial num', sort='idx', 

6 contrast=(1, 0.5, 0.25, 0.125, 0.0625, 0), event=None): 

7 """ 

8 Finds trials that match criterion 

9 :param trials: trials object. Must contain attributes contrastLeft, contrastRight and 

10 feedbackType 

11 :param side: stimulus side, options are 'all', 'left' or 'right' 

12 :param choice: trial choice, options are 'all', 'correct' or 'incorrect' 

13 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be 

14 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast 

15 :param order: how to order the trials, options are 'trial num' or 'reaction time' 

16 :param sort: how to sort the trials, options are 'side' (split left right trials), 'choice' 

17 (split correct incorrect trials), 'choice and side' (split left right and correct incorrect) 

18 :param event: trial event to align to (in order to remove nan trials for this event) 

19 :return: np.array of trial ids, list of dividers to indicate how trials are sorted 

20 """ 

21 if event: 1a

22 idx = ~np.isnan(trials[event]) 

23 nan_idx = np.where(idx)[0] 

24 else: 

25 idx = np.ones_like(trials['feedbackType'], dtype=bool) 1a

26 

27 # Find trials that have specified contrasts 

28 cont = np.bitwise_or(ismember(trials['contrastLeft'][idx], np.array(contrast))[0], 1a

29 ismember(trials['contrastRight'][idx], np.array(contrast))[0]) 

30 

31 # Find different permutations of trials 

32 # correct right 

33 cor_r = np.where( 1a

34 np.bitwise_and(cont, np.bitwise_and(trials['feedbackType'][idx] == 1, 

35 np.isfinite(trials['contrastRight'][idx]))))[0] 

36 # correct left 

37 cor_l = np.where( 1a

38 np.bitwise_and(cont, np.bitwise_and(trials['feedbackType'][idx] == 1, 

39 np.isfinite(trials['contrastLeft'][idx]))))[0] 

40 # incorrect right 

41 incor_r = np.where( 1a

42 np.bitwise_and(cont, np.bitwise_and(trials['feedbackType'][idx] == -1, 

43 np.isfinite(trials['contrastRight'][idx]))))[0] 

44 # incorrect left 

45 incor_l = np.where( 1a

46 np.bitwise_and(cont, np.bitwise_and(trials['feedbackType'][idx] == -1, 

47 np.isfinite(trials['contrastLeft'][idx]))))[0] 

48 

49 reaction_time = trials['response_times'][idx] - trials['goCue_times'][idx] 1a

50 

51 def _order_by(_trials, order): 1a

52 # Returns subset of trials either ordered by trial number or by reaction time 

53 sorted_trials = np.sort(_trials) 1a

54 if order == 'trial num': 1a

55 return sorted_trials 1a

56 elif order == 'reaction time': 1a

57 sorted_reaction = np.argsort(reaction_time[sorted_trials]) 1a

58 return sorted_trials[sorted_reaction] 1a

59 

60 dividers = [] 1a

61 

62 # Find the trial id for all possible combinations 

63 if side == 'all' and choice == 'all': 1a

64 if sort == 'idx': 1a

65 trial_id = _order_by(np.r_[cor_r, cor_l, incor_r, incor_l], order) 1a

66 elif sort == 'choice': 1a

67 trial_id = np.r_[_order_by(np.r_[cor_l, cor_r], order), 

68 _order_by(np.r_[incor_l, incor_r], order)] 

69 dividers.append(np.r_[cor_l, cor_r].shape[0]) 

70 elif sort == 'side': 1a

71 trial_id = np.r_[_order_by(np.r_[cor_l, incor_l], order), 

72 _order_by(np.r_[cor_r, incor_r], order)] 

73 dividers.append(np.r_[cor_l, incor_l].shape[0]) 

74 elif sort == 'choice and side': 1a

75 trial_id = np.r_[_order_by(cor_l, order), _order_by(incor_l, order), 1a

76 _order_by(cor_r, order), _order_by(incor_r, order)] 

77 dividers.append(cor_l.shape[0]) 1a

78 dividers.append(np.r_[cor_l, incor_l].shape[0]) 1a

79 dividers.append(np.r_[cor_l, incor_l, cor_r].shape[0]) 1a

80 

81 if side == 'left' and choice == 'all': 1a

82 if sort in ['idx', 'side']: 1a

83 trial_id = _order_by(np.r_[cor_l, incor_l], order) 1a

84 elif sort in ['choice', 'choice and side']: 1a

85 trial_id = np.r_[_order_by(cor_l, order), _order_by(incor_l, order)] 1a

86 dividers.append(cor_l.shape[0]) 1a

87 

88 if side == 'right' and choice == 'all': 1a

89 if sort in ['idx', 'side']: 1a

90 trial_id = _order_by(np.r_[cor_r, incor_r], order) 1a

91 elif sort in ['choice', 'choice and side']: 

92 trial_id = np.r_[_order_by(cor_r, order), _order_by(incor_r, order)] 

93 dividers.append(cor_r.shape[0]) 

94 

95 if side == 'all' and choice == 'correct': 1a

96 if sort in ['idx', 'choice']: 1a

97 trial_id = _order_by(np.r_[cor_l, cor_r], order) 1a

98 elif sort in ['side', 'choice and side']: 

99 trial_id = np.r_[_order_by(cor_l, order), _order_by(cor_r, order)] 

100 dividers.append(cor_l.shape[0]) 

101 

102 if side == 'all' and choice == 'incorrect': 1a

103 if sort in ['idx', 'choice']: 1a

104 trial_id = _order_by(np.r_[incor_l, incor_r], order) 1a

105 elif sort in ['side', 'choice and side']: 

106 trial_id = np.r_[_order_by(incor_l, order), _order_by(incor_r, order)] 

107 dividers.append(incor_l.shape[0]) 

108 

109 if side == 'left' and choice == 'correct': 1a

110 trial_id = _order_by(cor_l, order) 1a

111 

112 if side == 'left' and choice == 'incorrect': 1a

113 trial_id = _order_by(incor_l, order) 1a

114 

115 if side == 'right' and choice == 'correct': 1a

116 trial_id = _order_by(cor_r, order) 1a

117 

118 if side == 'right' and choice == 'incorrect': 1a

119 trial_id = _order_by(incor_r, order) 1a

120 

121 if event: 1a

122 trial_id = nan_idx[trial_id] 

123 

124 return trial_id, dividers 1a

125 

126 

127def get_event_aligned_raster(times, events, tbin=0.02, values=None, epoch=[-0.4, 1], bin=True): 

128 """ 

129 Get event aligned raster 

130 :param times: array of times e.g spike times or dlc points 

131 :param events: array of events to epoch around 

132 :param tbin: bin size to over which to count events 

133 :param values: values to scale counts by 

134 :param epoch: window around each event 

135 :param bin: whether to bin times in tbin windows or not 

136 :return: 

137 """ 

138 

139 if bin: 1b

140 vals, bin_times, _ = bincount2D(times, np.ones_like(times), xbin=tbin, weights=values) 1b

141 vals = vals[0] 1b

142 t = np.arange(epoch[0], epoch[1] + tbin, tbin) 1b

143 nbin = t.shape[0] 1b

144 else: 

145 vals = values 

146 bin_times = times 

147 tbin = np.mean(np.diff(bin_times)) 

148 t = np.arange(epoch[0], epoch[1], tbin) 

149 nbin = t.shape[0] 

150 

151 # remove nan trials 

152 non_nan_events = events[~np.isnan(events)] 1b

153 nan_idx = np.where(~np.isnan(events)) 1b

154 intervals = np.c_[non_nan_events + epoch[0], non_nan_events + epoch[1]] 1b

155 

156 # Remove any trials that are later than the last value in bin_times 

157 out_intervals = intervals[:, 1] > bin_times[-1] 1b

158 epoch_idx = np.searchsorted(bin_times, intervals)[np.invert(out_intervals)] 1b

159 

160 for ep in range(nbin): 1b

161 if ep == 0: 1b

162 event_raster = (vals[epoch_idx[:, 0] + ep]).astype(float) 1b

163 else: 

164 event_raster = np.c_[event_raster, vals[epoch_idx[:, 0] + ep]] 1b

165 

166 # Find any trials that are less than the first value time and fill with nans (case for example 

167 # where spiking of cluster doesn't start till after start of first trial due to settling of 

168 # brain) 

169 event_raster[intervals[np.invert(out_intervals), 0] < bin_times[0]] = np.nan 1b

170 

171 # Add back in the trials that were later than last value with nans 

172 if np.sum(out_intervals) > 0: 1b

173 event_raster = np.r_[event_raster, np.full((np.sum(out_intervals), 1b

174 event_raster.shape[1]), np.nan)] 

175 assert event_raster.shape[0] == intervals.shape[0] 1b

176 

177 # Reindex if we have removed any nan values 

178 all_event_raster = np.full((events.shape[0], event_raster.shape[1]), np.nan) 1b

179 all_event_raster[nan_idx, :] = event_raster 1b

180 

181 return all_event_raster, t 1b

182 

183 

184def get_psth(raster, trial_ids=None): 

185 """ 

186 Compute psth averaged over chosen trials 

187 :param raster: output from event aligned raster, window of activity around event 

188 :param trial_ids: the trials from the raster to average over 

189 :return: 

190 """ 

191 if trial_ids is None: 

192 mean = np.nanmean(raster, axis=0) 

193 err = np.nanstd(raster, axis=0) / np.sqrt(raster.shape[0]) 

194 else: 

195 raster = filter_by_trial(raster, trial_ids) 

196 mean = np.nanmean(raster, axis=0) 

197 err = np.nanstd(raster, axis=0) / np.sqrt(raster.shape[0]) 

198 

199 return mean, err 

200 

201 

202def filter_by_trial(raster, trial_id): 

203 """ 

204 Select trials of interest for raster 

205 :param raster: 

206 :param trial_id: 

207 :return: 

208 """ 

209 return raster[trial_id, :] 

210 

211 

212def filter_correct_incorrect_left_right(trials, event_raster, event, contrast, order='trial num'): 

213 """ 

214 Return psth for left correct, left incorrect, right correct, right incorrect and raster 

215 sorted by these trials 

216 :param trials: trials object 

217 :param event_raster: output from get_event_aligned_activity 

218 :param event: event to align to e.g 'goCue_times', 'stimOn_times' 

219 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be 

220 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast 

221 :param order: order to sort trials by either 'trial num' or 'reaction time' 

222 :return: 

223 """ 

224 trials_sorted, div = find_trial_ids(trials, sort='choice and side', event=event, order=order, contrast=contrast) 

225 trials_lc, _ = find_trial_ids(trials, side='left', choice='correct', event=event, order=order, contrast=contrast) 

226 trials_li, _ = find_trial_ids(trials, side='left', choice='incorrect', event=event, 

227 order=order, contrast=contrast) 

228 trials_rc, _ = find_trial_ids(trials, side='right', choice='correct', event=event, order=order, contrast=contrast) 

229 trials_ri, _ = find_trial_ids(trials, side='right', choice='incorrect', event=event, 

230 order=order, contrast=contrast) 

231 

232 psth = dict() 

233 mean, err = get_psth(event_raster, trials_lc) 

234 psth['left_correct'] = {'vals': mean, 'err': err, 

235 'linestyle': {'color': 'r'}} 

236 mean, err = get_psth(event_raster, trials_li) 

237 psth['left_incorrect'] = {'vals': mean, 'err': err, 

238 'linestyle': {'color': 'r', 'linestyle': 'dashed'}} 

239 mean, err = get_psth(event_raster, trials_rc) 

240 psth['right_correct'] = {'vals': mean, 'err': err, 

241 'linestyle': {'color': 'b'}} 

242 mean, err = get_psth(event_raster, trials_ri) 

243 psth['right_incorrect'] = {'vals': mean, 'err': err, 

244 'linestyle': {'color': 'b', 'linestyle': 'dashed'}} 

245 

246 raster = {} 

247 raster['vals'] = filter_by_trial(event_raster, trials_sorted) 

248 raster['dividers'] = div 

249 

250 return raster, psth 

251 

252 

253def filter_correct_incorrect(trials, event_raster, event, contrast, order='trial num'): 

254 """ 

255 Return psth for correct and incorrect trials and raster sorted by correct incorrect 

256 :param trials: trials object 

257 :param event_raster: output from get_event_aligned_activity 

258 :param event: event to align to e.g 'goCue_times', 'stimOn_times' 

259 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be 

260 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast 

261 :param order: order to sort trials by either 'trial num' or 'reaction time' 

262 :return: 

263 """ 

264 trials_sorted, div = find_trial_ids(trials, sort='choice', event=event, order=order, contrast=contrast) 

265 trials_c, _ = find_trial_ids(trials, side='all', choice='correct', event=event, order=order, contrast=contrast) 

266 trials_i, _ = find_trial_ids(trials, side='all', choice='incorrect', event=event, order=order, contrast=contrast) 

267 

268 psth = dict() 

269 mean, err = get_psth(event_raster, trials_c) 

270 psth['correct'] = {'vals': mean, 'err': err, 'linestyle': {'color': 'r'}} 

271 mean, err = get_psth(event_raster, trials_i) 

272 psth['incorrect'] = {'vals': mean, 'err': err, 'linestyle': {'color': 'b'}} 

273 

274 raster = {} 

275 raster['vals'] = filter_by_trial(event_raster, trials_sorted) 

276 raster['dividers'] = div 

277 

278 return raster, psth 

279 

280 

281def filter_left_right(trials, event_raster, event, contrast, order='trial num'): 

282 """ 

283 Return psth for left and right trials and raster sorted by left right 

284 :param trials: trials object 

285 :param event_raster: output from get_event_aligned_activity 

286 :param event: event to align to e.g 'goCue_times', 'stimOn_times' 

287 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be 

288 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast 

289 :param order: order to sort trials by either 'trial num' or 'reaction time' 

290 :return: 

291 """ 

292 trials_sorted, div = find_trial_ids(trials, sort='side', event=event, order=order, contrast=contrast) 

293 trials_l, _ = find_trial_ids(trials, side='left', choice='all', event=event, order=order, contrast=contrast) 

294 trials_r, _ = find_trial_ids(trials, side='right', choice='all', event=event, order=order, contrast=contrast) 

295 

296 psth = dict() 

297 mean, err = get_psth(event_raster, trials_l) 

298 psth['left'] = {'vals': mean, 'err': err, 'linestyle': {'color': 'r'}} 

299 mean, err = get_psth(event_raster, trials_r) 

300 psth['right'] = {'vals': mean, 'err': err, 'linestyle': {'color': 'b'}} 

301 

302 raster = {} 

303 raster['vals'] = filter_by_trial(event_raster, trials_sorted) 

304 raster['dividers'] = div 

305 

306 return raster, psth 

307 

308 

309def filter_trials(trials, event_raster, event, contrast=(1, 0.5, 0.25, 0.125, 0.0625, 0), order='trial num', sort='choice'): 

310 """ 

311 Wrapper to get out psth and raster for trial choice 

312 :param trials: trials object 

313 :param event_raster: output from get_event_aligned_activity 

314 :param event: event to align to e.g 'goCue_times', 'stimOn_times' 

315 :param contrast: contrast of stimulus, pass in list/tuple of all contrasts that want to be 

316 considered e.g [1, 0.5] would only look for trials with 100 % and 50 % contrast 

317 :param order: order to sort trials by either 'trial num' or 'reaction time' 

318 :param sort: how to divide trials options are 'choice' (e.g correct vs incorrect), 'side' 

319 (e.g left vs right') and 'choice and side' (e.g correct vs incorrect and left vs right) 

320 :return: 

321 """ 

322 if sort == 'choice': 

323 raster, psth = filter_correct_incorrect(trials, event_raster, event, contrast, order) 

324 elif sort == 'side': 

325 raster, psth = filter_left_right(trials, event_raster, event, contrast, order) 

326 elif sort == 'choice and side': 

327 raster, psth = filter_correct_incorrect_left_right(trials, event_raster, event, contrast, order) 

328 

329 return raster, psth