Coverage for ibllib/io/extractors/training_trials.py: 92%

386 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1import logging 

2import numpy as np 

3from itertools import accumulate 

4from packaging import version 

5from one.alf.io import AlfBunch 

6 

7import ibllib.io.raw_data_loaders as raw 

8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

9from ibllib.io.extractors.training_wheel import Wheel 

10 

11 

12_logger = logging.getLogger(__name__) 

13__all__ = ['TrainingTrials', 'extract_all'] 

14 

15 

16class FeedbackType(BaseBpodTrialsExtractor): 

17 """ 

18 Get the feedback that was delivered to subject. 

19 **Optional:** saves _ibl_trials.feedbackType.npy 

20 

21 Checks in raw datafile for error and reward state. 

22 Will raise an error if more than one of the mutually exclusive states have 

23 been triggered. 

24 

25 Sets feedbackType to -1 if error state was triggered (applies to no-go) 

26 Sets feedbackType to +1 if reward state was triggered 

27 """ 

28 save_names = '_ibl_trials.feedbackType.npy' 

29 var_names = 'feedbackType' 

30 

31 def _extract(self): 

32 feedbackType = np.zeros(len(self.bpod_trials), np.int64) 1acEbpvlifgxqrosjkmntuedhw

33 for i, t in enumerate(self.bpod_trials): 1acEbpvlifgxqrosjkmntuedhw

34 state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go'] 1acEbpvlifgxqrosjkmntuedhw

35 outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.NaN]])[0][0]) for sn in state_names} 1acEbpvlifgxqrosjkmntuedhw

36 assert np.sum(list(outcome.values())) == 1 1acEbpvlifgxqrosjkmntuedhw

37 outcome = next(k for k in outcome if outcome[k]) 1acEbpvlifgxqrosjkmntuedhw

38 if outcome == 'correct': 1acEbpvlifgxqrosjkmntuedhw

39 feedbackType[i] = 1 1acEbpvlifgxqrosjkmntuedh

40 elif outcome in ['error', 'no_go']: 1acEbpvlifgxqrosjkmntuedhw

41 feedbackType[i] = -1 1acEbpvlifgxqrosjkmntuedhw

42 return feedbackType 1acEbpvlifgxqrosjkmntuedhw

43 

44 

45class ContrastLR(BaseBpodTrialsExtractor): 

46 """ 

47 Get left and right contrasts from raw datafile. Optionally, saves 

48 _ibl_trials.contrastLeft.npy and _ibl_trials.contrastRight.npy to alf folder. 

49 

50 Uses signed_contrast to create left and right contrast vectors. 

51 """ 

52 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy') 

53 var_names = ('contrastLeft', 'contrastRight') 

54 

55 def _extract(self): 

56 # iblrigv8 has only flat values in the trial table so we can switch to parquet table when times come 

57 # and all the clutter here would fit in ~30 lines 

58 if isinstance(self.bpod_trials[0]['contrast'], float): 1cNbfge

59 contrastLeft = np.array([t['contrast'] if np.sign( 

60 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

61 contrastRight = np.array([t['contrast'] if np.sign( 

62 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

63 else: 

64 contrastLeft = np.array([t['contrast']['value'] if np.sign( 1cNbfge

65 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

66 contrastRight = np.array([t['contrast']['value'] if np.sign( 1cNbfge

67 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

68 

69 return contrastLeft, contrastRight 1cNbfge

70 

71 

72class ProbabilityLeft(BaseBpodTrialsExtractor): 

73 save_names = '_ibl_trials.probabilityLeft.npy' 

74 var_names = 'probabilityLeft' 

75 

76 def _extract(self, **kwargs): 

77 return np.array([t['stim_probability_left'] for t in self.bpod_trials]) 1acVbplfgxqjkmntued

78 

79 

80class Choice(BaseBpodTrialsExtractor): 

81 """ 

82 Get the subject's choice in every trial. 

83 **Optional:** saves _ibl_trials.choice.npy to alf folder. 

84 

85 Uses signed_contrast and trial_correct. 

86 -1 is a CCW turn (towards the left) 

87 +1 is a CW turn (towards the right) 

88 0 is a no_go trial 

89 If a trial is correct the choice of the animal was the inverse of the sign 

90 of the position. 

91 

92 >>> choice[t] = -np.sign(position[t]) if trial_correct[t] 

93 """ 

94 save_names = '_ibl_trials.choice.npy' 

95 var_names = 'choice' 

96 

97 def _extract(self): 

98 sitm_side = np.array([np.sign(t['position']) for t in self.bpod_trials]) 1acJbypvlifgxqrosjkmntuedhw

99 trial_correct = np.array([t['trial_correct'] for t in self.bpod_trials]) 1acJbypvlifgxqrosjkmntuedhw

100 trial_nogo = np.array( 1acJbypvlifgxqrosjkmntuedhw

101 [~np.isnan(t['behavior_data']['States timestamps']['no_go'][0][0]) 

102 for t in self.bpod_trials]) 

103 choice = sitm_side.copy() 1acJbypvlifgxqrosjkmntuedhw

104 choice[trial_correct] = -choice[trial_correct] 1acJbypvlifgxqrosjkmntuedhw

105 choice[trial_nogo] = 0 1acJbypvlifgxqrosjkmntuedhw

106 choice = choice.astype(int) 1acJbypvlifgxqrosjkmntuedhw

107 return choice 1acJbypvlifgxqrosjkmntuedhw

108 

109 

110class RepNum(BaseBpodTrialsExtractor): 

111 """ 

112 Count the consecutive repeated trials. 

113 **Optional:** saves _ibl_trials.repNum.npy to alf folder. 

114 

115 Creates trial_repeated from trial['contrast']['type'] == 'RepeatContrast' 

116 

117 >>> trial_repeated = [0, 1, 1, 0, 1, 0, 1, 1, 1, 0] 

118 >>> repNum = [0, 1, 2, 0, 1, 0, 1, 2, 3, 0] 

119 """ 

120 save_names = '_ibl_trials.repNum.npy' 

121 var_names = 'repNum' 

122 

123 def _extract(self): 

124 def get_trial_repeat(trial): 1cKbfge

125 if 'debias_trial' in trial: 1cKbfge

126 return trial['debias_trial'] 

127 elif 'contrast' in trial and isinstance(trial['contrast'], dict): 1cKbfge

128 return trial['contrast']['type'] == 'RepeatContrast' 1cKbfge

129 else: 

130 # For advanced choice world before version 8.19.0 there was no 'debias_trial' field 

131 # and no debiasing protocol applied, so simply return False 

132 assert self.settings['PYBPOD_PROTOCOL'].startswith('_iblrig_tasks_advancedChoiceWorld') 

133 return False 

134 

135 trial_repeated = np.fromiter(map(get_trial_repeat, self.bpod_trials), int) 1cKbfge

136 repNum = np.fromiter(accumulate(trial_repeated, lambda x, y: x + y if y else 0), int) 1cKbfge

137 return repNum 1cKbfge

138 

139 

140class RewardVolume(BaseBpodTrialsExtractor): 

141 """ 

142 Load reward volume delivered for each trial. 

143 **Optional:** saves _ibl_trials.rewardVolume.npy 

144 

145 Uses reward_current to accumulate the amount of 

146 """ 

147 save_names = '_ibl_trials.rewardVolume.npy' 

148 var_names = 'rewardVolume' 

149 

150 def _extract(self): 

151 trial_volume = [x['reward_amount'] 1acObpvlifgxqrosjkmntuedhw

152 if x['trial_correct'] else 0 for x in self.bpod_trials] 

153 reward_volume = np.array(trial_volume).astype(np.float64) 1acObpvlifgxqrosjkmntuedhw

154 assert len(reward_volume) == len(self.bpod_trials) 1acObpvlifgxqrosjkmntuedhw

155 return reward_volume 1acObpvlifgxqrosjkmntuedhw

156 

157 

158class FeedbackTimes(BaseBpodTrialsExtractor): 

159 """ 

160 Get the times the water or error tone was delivered to the animal. 

161 **Optional:** saves _ibl_trials.feedback_times.npy 

162 

163 Gets reward and error state init times vectors, 

164 checks if the intersection of nans is empty, then 

165 merges the 2 vectors. 

166 """ 

167 save_names = '_ibl_trials.feedback_times.npy' 

168 var_names = 'feedback_times' 

169 

170 @staticmethod 

171 def get_feedback_times_lt5(session_path, task_collection='raw_behavior_data', data=False): 

172 if not data: 1acHbxtued

173 data = raw.load_data(session_path, task_collection=task_collection) 

174 rw_times = [tr['behavior_data']['States timestamps']['reward'][0][0] 1acHbxtued

175 for tr in data] 

176 err_times = [tr['behavior_data']['States timestamps']['error'][0][0] 1acHbxtued

177 for tr in data] 

178 nogo_times = [tr['behavior_data']['States timestamps']['no_go'][0][0] 1acHbxtued

179 for tr in data] 

180 assert sum(np.isnan(rw_times) & 1acHbxtued

181 np.isnan(err_times) & np.isnan(nogo_times)) == 0 

182 merge = np.array([np.array(times)[~np.isnan(times)] for times in 1acHbxtued

183 zip(rw_times, err_times, nogo_times)]).squeeze() 

184 

185 return np.array(merge) 1acHbxtued

186 

187 @staticmethod 

188 def get_feedback_times_ge5(session_path, task_collection='raw_behavior_data', data=False): 

189 # ger err and no go trig times -- look for BNC2High of trial -- verify 

190 # only 2 onset times go tone and noise, select 2nd/-1 OR select the one 

191 # that is grater than the nogo or err trial onset time 

192 if not data: 1cCbpvlAifgqrosjkmnedhw

193 data = raw.load_data(session_path, task_collection=task_collection) 

194 missed_bnc2 = 0 1cCbpvlAifgqrosjkmnedhw

195 rw_times, err_sound_times, merge = [np.zeros([len(data), ]) for _ in range(3)] 1cCbpvlAifgqrosjkmnedhw

196 

197 for ind, tr in enumerate(data): 1cCbpvlAifgqrosjkmnedhw

198 st = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1cCbpvlAifgqrosjkmnedhw

199 if not st: 1cCbpvlAifgqrosjkmnedhw

200 st = np.array([np.nan, np.nan]) 1dh

201 missed_bnc2 += 1 1dh

202 # xonar soundcard duplicates events, remove consecutive events too close together 

203 st = np.delete(st, np.where(np.diff(st) < 0.020)[0] + 1) 1cCbpvlAifgqrosjkmnedhw

204 rw_times[ind] = tr['behavior_data']['States timestamps']['reward'][0][0] 1cCbpvlAifgqrosjkmnedhw

205 # get the error sound only if the reward is nan 

206 err_sound_times[ind] = st[-1] if st.size >= 2 and np.isnan(rw_times[ind]) else np.nan 1cCbpvlAifgqrosjkmnedhw

207 if missed_bnc2 == len(data): 1cCbpvlAifgqrosjkmnedhw

208 _logger.warning('No BNC2 for feedback times, filling error trials NaNs') 

209 merge *= np.nan 1cCbpvlAifgqrosjkmnedhw

210 merge[~np.isnan(rw_times)] = rw_times[~np.isnan(rw_times)] 1cCbpvlAifgqrosjkmnedhw

211 merge[~np.isnan(err_sound_times)] = err_sound_times[~np.isnan(err_sound_times)] 1cCbpvlAifgqrosjkmnedhw

212 

213 return merge 1cCbpvlAifgqrosjkmnedhw

214 

215 def _extract(self): 

216 # Version check 

217 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1acCHbpvlAifgxqrosjkmntuedhw

218 merge = self.get_feedback_times_ge5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1cCbpvlAifgqrosjkmnedhw

219 else: 

220 merge = self.get_feedback_times_lt5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1acHbxtued

221 return np.array(merge) 1acCHbpvlAifgxqrosjkmntuedhw

222 

223 

224class Intervals(BaseBpodTrialsExtractor): 

225 """ 

226 Trial start to trial end. Trial end includes 1 or 2 seconds after feedback, 

227 (depending on the feedback) and 0.5 seconds of iti. 

228 **Optional:** saves _ibl_trials.intervals.npy 

229 

230 Uses the corrected Trial start and Trial end timestamp values form PyBpod. 

231 """ 

232 save_names = '_ibl_trials.intervals.npy' 

233 var_names = 'intervals' 

234 

235 def _extract(self): 

236 starts = [t['behavior_data']['Trial start timestamp'] for t in self.bpod_trials] 1acPbpvlifgxqrosjkmntuedhw

237 ends = [t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials] 1acPbpvlifgxqrosjkmntuedhw

238 return np.array([starts, ends]).T 1acPbpvlifgxqrosjkmntuedhw

239 

240 

241class ResponseTimes(BaseBpodTrialsExtractor): 

242 """ 

243 Time (in absolute seconds from session start) when a response was recorded. 

244 **Optional:** saves _ibl_trials.response_times.npy 

245 

246 Uses the timestamp of the end of the closed_loop state. 

247 """ 

248 save_names = '_ibl_trials.response_times.npy' 

249 var_names = 'response_times' 

250 

251 def _extract(self): 

252 rt = np.array([tr['behavior_data']['States timestamps']['closed_loop'][0][1] 1acTbpvlifgxqrosjkmntuedhw

253 for tr in self.bpod_trials]) 

254 return rt 1acTbpvlifgxqrosjkmntuedhw

255 

256 

257class ItiDuration(BaseBpodTrialsExtractor): 

258 """ 

259 Calculate duration of iti from state timestamps. 

260 **Optional:** saves _ibl_trials.iti_duration.npy 

261 

262 Uses Trial end timestamp and get_response_times to calculate iti. 

263 """ 

264 save_names = '_ibl_trials.itiDuration.npy' 

265 var_names = 'iti_dur' 

266 

267 def _extract(self): 

268 rt, _ = ResponseTimes(self.session_path).extract( 

269 save=False, task_collection=self.task_collection, bpod_trials=self.bpod_trials, settings=self.settings) 

270 ends = np.array([t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials]) 

271 iti_dur = ends - rt 

272 return iti_dur 

273 

274 

275class GoCueTriggerTimes(BaseBpodTrialsExtractor): 

276 """ 

277 Get trigger times of goCue from state machine. 

278 

279 Current software solution for triggering sounds uses PyBpod soft codes. 

280 Delays can be in the order of 10's of ms. This is the time when the command 

281 to play the sound was executed. To measure accurate time, either getting the 

282 sound onset from xonar soundcard sync pulse (latencies may vary). 

283 """ 

284 save_names = '_ibl_trials.goCueTrigger_times.npy' 

285 var_names = 'goCueTrigger_times' 

286 

287 def _extract(self): 

288 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1acQbpvlifgxqRrosjkmntuedhw

289 goCue = np.array([tr['behavior_data']['States timestamps'] 1cQbpvlifgqRrosjkmnedhw

290 ['play_tone'][0][0] for tr in self.bpod_trials]) 

291 else: 

292 goCue = np.array([tr['behavior_data']['States timestamps'] 1acbxtued

293 ['closed_loop'][0][0] for tr in self.bpod_trials]) 

294 return goCue 1acQbpvlifgxqRrosjkmntuedhw

295 

296 

297class TrialType(BaseBpodTrialsExtractor): 

298 save_names = '_ibl_trials.type.npy' 

299 var_name = 'trial_type' 

300 

301 def _extract(self): 

302 trial_type = [] 

303 for tr in self.bpod_trials: 

304 if ~np.isnan(tr["behavior_data"]["States timestamps"]["reward"][0][0]): 

305 trial_type.append(1) 

306 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["error"][0][0]): 

307 trial_type.append(-1) 

308 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0][0]): 

309 trial_type.append(0) 

310 else: 

311 _logger.warning("Trial is not in set {-1, 0, 1}, appending NaN to trialType") 

312 trial_type.append(np.nan) 

313 return np.array(trial_type) 

314 

315 

316class GoCueTimes(BaseBpodTrialsExtractor): 

317 """ 

318 Get trigger times of goCue from state machine. 

319 

320 Current software solution for triggering sounds uses PyBpod soft codes. 

321 Delays can be in the order of 10-100s of ms. This is the time when the command 

322 to play the sound was executed. To measure accurate time, either getting the 

323 sound onset from the future microphone OR the new xonar soundcard and 

324 setup developed by Sanworks guarantees a set latency (in testing). 

325 """ 

326 save_names = '_ibl_trials.goCue_times.npy' 

327 var_names = 'goCue_times' 

328 

329 def _extract(self): 

330 go_cue_times = np.zeros([len(self.bpod_trials), ]) 1acDbpvFlAifgxqrosjkmntuGedhw

331 for ind, tr in enumerate(self.bpod_trials): 1acDbpvFlAifgxqrosjkmntuGedhw

332 if raw.get_port_events(tr, 'BNC2'): 1acDbpvFlAifgxqrosjkmntuGedhw

333 bnchigh = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1cDbpvFlAifgxqrosjkmnGedhw

334 if bnchigh: 1cDbpvFlAifgxqrosjkmnGedhw

335 go_cue_times[ind] = bnchigh[0] 1cDbpvFlAifgxqrosjkmnGedhw

336 continue 1cDbpvFlAifgxqrosjkmnGedhw

337 bnclow = tr['behavior_data']['Events timestamps'].get('BNC2Low', None) 1dh

338 if bnclow: 1dh

339 go_cue_times[ind] = bnclow[0] - 0.1 1dh

340 continue 1dh

341 go_cue_times[ind] = np.nan 

342 else: 

343 go_cue_times[ind] = np.nan 1acDbtued

344 

345 nmissing = np.sum(np.isnan(go_cue_times)) 1acDbpvFlAifgxqrosjkmntuGedhw

346 # Check if all stim_syncs have failed to be detected 

347 if np.all(np.isnan(go_cue_times)): 1acDbpvFlAifgxqrosjkmntuGedhw

348 _logger.warning( 1acDbtued

349 f'{self.session_path}: Missing ALL !! BNC2 TTLs ({nmissing} trials)') 

350 # Check if any stim_sync has failed be detected for every trial 

351 elif np.any(np.isnan(go_cue_times)): 1cDbpvFlAifgxqrosjkmnGedhw

352 _logger.warning(f'{self.session_path}: Missing BNC2 TTLs on {nmissing} trials') 

353 

354 return go_cue_times 1acDbpvFlAifgxqrosjkmntuGedhw

355 

356 

357class IncludedTrials(BaseBpodTrialsExtractor): 

358 save_names = '_ibl_trials.included.npy' 

359 var_names = 'included' 

360 

361 def _extract(self): 

362 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1cILMbplixqrosjkmnedhw

363 trials_included = self.get_included_trials_ge5( 1cILbpliqrosjkmnedhw

364 data=self.bpod_trials, settings=self.settings) 

365 else: 

366 trials_included = self.get_included_trials_lt5(data=self.bpod_trials) 1IMbxe

367 return trials_included 1cILMbplixqrosjkmnedhw

368 

369 @staticmethod 

370 def get_included_trials_lt5(data=False): 

371 trials_included = np.ones(len(data), dtype=bool) 1IMbxe

372 return trials_included 1IMbxe

373 

374 @staticmethod 

375 def get_included_trials_ge5(data=False, settings=False): 

376 trials_included = np.array([True for t in data]) 1cILbpliqrosjkmnedhw

377 if ('SUBJECT_DISENGAGED_TRIGGERED' in settings.keys() and settings[ 1cILbpliqrosjkmnedhw

378 'SUBJECT_DISENGAGED_TRIGGERED'] is not False): 

379 idx = settings['SUBJECT_DISENGAGED_TRIALNUM'] - 1 1lirosmndh

380 trials_included[idx:] = False 1lirosmndh

381 return trials_included 1cILbpliqrosjkmnedhw

382 

383 

384class ItiInTimes(BaseBpodTrialsExtractor): 

385 var_names = 'itiIn_times' 

386 

387 def _extract(self): 

388 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("5.0.0"): 1cbpvlifgxqrosjkmnedhw

389 iti_in = np.ones(len(self.bpod_trials)) * np.nan 1xe

390 else: 

391 iti_in = np.array( 1cbpvlifgqrosjkmnedhw

392 [tr["behavior_data"]["States timestamps"] 

393 ["exit_state"][0][0] for tr in self.bpod_trials] 

394 ) 

395 return iti_in 1cbpvlifgxqrosjkmnedhw

396 

397 

398class ErrorCueTriggerTimes(BaseBpodTrialsExtractor): 

399 var_names = 'errorCueTrigger_times' 

400 

401 def _extract(self): 

402 errorCueTrigger_times = np.zeros(len(self.bpod_trials)) * np.nan 1cbpvlifgxqrosjkmnedhw

403 for i, tr in enumerate(self.bpod_trials): 1cbpvlifgxqrosjkmnedhw

404 nogo = tr["behavior_data"]["States timestamps"]["no_go"][0][0] 1cbpvlifgxqrosjkmnedhw

405 error = tr["behavior_data"]["States timestamps"]["error"][0][0] 1cbpvlifgxqrosjkmnedhw

406 if np.all(~np.isnan(nogo)): 1cbpvlifgxqrosjkmnedhw

407 errorCueTrigger_times[i] = nogo 1cbvifgrsjkedh

408 elif np.all(~np.isnan(error)): 1cbpvlifgxqrosjkmnedhw

409 errorCueTrigger_times[i] = error 1cbpvlifgxqrosjkmnedhw

410 return errorCueTrigger_times 1cbpvlifgxqrosjkmnedhw

411 

412 

413class StimFreezeTriggerTimes(BaseBpodTrialsExtractor): 

414 var_names = 'stimFreezeTrigger_times' 

415 

416 def _extract(self): 

417 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("6.2.5"): 1acbypvlifgxqrosjkmntuedhw

418 return np.ones(len(self.bpod_trials)) * np.nan 1acbyxrstuedhw

419 freeze_reward = np.array( 1pvlifgqojkmnh

420 [ 

421 True 

422 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_reward"][0])) 

423 else False 

424 for tr in self.bpod_trials 

425 ] 

426 ) 

427 freeze_error = np.array( 1pvlifgqojkmnh

428 [ 

429 True 

430 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_error"][0])) 

431 else False 

432 for tr in self.bpod_trials 

433 ] 

434 ) 

435 no_go = np.array( 1pvlifgqojkmnh

436 [ 

437 True 

438 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0])) 

439 else False 

440 for tr in self.bpod_trials 

441 ] 

442 ) 

443 assert (np.sum(freeze_error) + np.sum(freeze_reward) + 1pvlifgqojkmnh

444 np.sum(no_go) == len(self.bpod_trials)) 

445 stimFreezeTrigger = np.array([]) 1pvlifgqojkmnh

446 for r, e, n, tr in zip(freeze_reward, freeze_error, no_go, self.bpod_trials): 1pvlifgqojkmnh

447 if n: 1pvlifgqojkmnh

448 stimFreezeTrigger = np.append(stimFreezeTrigger, np.nan) 1vifgjkh

449 continue 1vifgjkh

450 state = "freeze_reward" if r else "freeze_error" 1pvlifgqojkmnh

451 stimFreezeTrigger = np.append( 1pvlifgqojkmnh

452 stimFreezeTrigger, tr["behavior_data"]["States timestamps"][state][0][0] 

453 ) 

454 return stimFreezeTrigger 1pvlifgqojkmnh

455 

456 

457class StimOffTriggerTimes(BaseBpodTrialsExtractor): 

458 var_names = 'stimOffTrigger_times' 

459 

460 def _extract(self): 

461 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') >= version.parse("6.2.5"): 1acbypvlifgxqrosjkmntuedhw

462 stim_off_trigger_state = "hide_stim" 1pvlifgqojkmnh

463 elif version.parse(self.settings["IBLRIG_VERSION"]) >= version.parse("5.0.0"): 1acbyxrstuedhw

464 stim_off_trigger_state = "exit_state" 1cbyrsedhw

465 else: 

466 stim_off_trigger_state = "trial_start" 1ayxtued

467 

468 stimOffTrigger_times = np.array( 1acbypvlifgxqrosjkmntuedhw

469 [tr["behavior_data"]["States timestamps"][stim_off_trigger_state][0][0] 

470 for tr in self.bpod_trials] 

471 ) 

472 # If pre version 5.0.0 no specific nogo Off trigger was given, just return trial_starts 

473 if stim_off_trigger_state == "trial_start": 1acbypvlifgxqrosjkmntuedhw

474 return stimOffTrigger_times 1ayxtued

475 

476 no_goTrigger_times = np.array( 1cbypvlifgqrosjkmnedhw

477 [tr["behavior_data"]["States timestamps"]["no_go"][0][0] for tr in self.bpod_trials] 

478 ) 

479 # Stim off trigs are either in their own state or in the no_go state if the 

480 # mouse did not move, if the stim_off_trigger_state always exist 

481 # (exit_state or trial_start) 

482 # no NaNs will happen, NaNs might happen in at last trial if 

483 # session was stopped after response 

484 # if stim_off_trigger_state == "hide_stim": 

485 # assert all(~np.isnan(no_goTrigger_times) == np.isnan(stimOffTrigger_times)) 

486 # Patch with the no_go states trig times 

487 stimOffTrigger_times[~np.isnan(no_goTrigger_times)] = no_goTrigger_times[ 1cbypvlifgqrosjkmnedhw

488 ~np.isnan(no_goTrigger_times) 

489 ] 

490 return stimOffTrigger_times 1cbypvlifgqrosjkmnedhw

491 

492 

493class StimOnTriggerTimes(BaseBpodTrialsExtractor): 

494 save_names = '_ibl_trials.stimOnTrigger_times.npy' 

495 var_names = 'stimOnTrigger_times' 

496 

497 def _extract(self): 

498 # Get the stim_on_state that triggers the onset of the stim 

499 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1acUbypvlAifgxqrosjkmntuedhw

500 ['stim_on'][0] for tr in self.bpod_trials]) 

501 return stim_on_state[:, 0].T 1acUbypvlAifgxqrosjkmntuedhw

502 

503 

504class StimOnTimes_deprecated(BaseBpodTrialsExtractor): 

505 save_names = '_ibl_trials.stimOn_times.npy' 

506 var_names = 'stimOn_times' 

507 

508 def _extract(self): 

509 """ 

510 Find the time of the state machine command to turn on the stim 

511 (state stim_on start or rotary_encoder_event2) 

512 Find the next frame change from the photodiode after that TS. 

513 Screen is not displaying anything until then. 

514 (Frame changes are in BNC1 High and BNC1 Low) 

515 """ 

516 # Version check 

517 _logger.warning("Deprecation Warning: this is an old version of stimOn extraction." 1acBzbtud

518 "From version 5., use StimOnOffFreezeTimes") 

519 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1acBzbtud

520 stimOn_times = self.get_stimOn_times_ge5(self.session_path, data=self.bpod_trials, 1B

521 task_collection=self.task_collection) 

522 else: 

523 stimOn_times = self.get_stimOn_times_lt5(self.session_path, data=self.bpod_trials, 1aczbtud

524 task_collection=self.task_collection) 

525 return np.array(stimOn_times) 1acBzbtud

526 

527 @staticmethod 

528 def get_stimOn_times_ge5(session_path, data=False, task_collection='raw_behavior_data'): 

529 """ 

530 Find first and last stim_sync pulse of the trial. 

531 stimOn_times should be the first after the stim_on state. 

532 (Stim updates are in BNC1High and BNC1Low - frame2TTL device) 

533 Check that all trials have frame changes. 

534 Find length of stim_on_state [start, stop]. 

535 If either check fails the HW device failed to detect the stim_sync square change 

536 Substitute that trial's missing or incorrect value with a NaN. 

537 return stimOn_times 

538 """ 

539 if not data: 1B

540 data = raw.load_data(session_path, task_collection=task_collection) 

541 # Get all stim_sync events detected 

542 stim_sync_all = [raw.get_port_events(tr, 'BNC1') for tr in data] 1B

543 stim_sync_all = [np.array(x) for x in stim_sync_all] 1B

544 # Get the stim_on_state that triggers the onset of the stim 

545 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1B

546 ['stim_on'][0] for tr in data]) 

547 

548 stimOn_times = np.array([]) 1B

549 for sync, on, off in zip( 1B

550 stim_sync_all, stim_on_state[:, 0], stim_on_state[:, 1]): 

551 pulse = sync[np.where(np.bitwise_and((sync > on), (sync <= off)))] 1B

552 if pulse.size == 0: 1B

553 stimOn_times = np.append(stimOn_times, np.nan) 1B

554 else: 

555 stimOn_times = np.append(stimOn_times, pulse) 1B

556 

557 nmissing = np.sum(np.isnan(stimOn_times)) 1B

558 # Check if all stim_syncs have failed to be detected 

559 if np.all(np.isnan(stimOn_times)): 1B

560 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({nmissing} trials)') 

561 

562 # Check if any stim_sync has failed be detected for every trial 

563 if np.any(np.isnan(stimOn_times)): 1B

564 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {nmissing} trials') 1B

565 

566 return stimOn_times 1B

567 

568 @staticmethod 

569 def get_stimOn_times_lt5(session_path, data=False, task_collection='raw_behavior_data'): 

570 """ 

571 Find the time of the statemachine command to turn on the stim 

572 (state stim_on start or rotary_encoder_event2) 

573 Find the next frame change from the photodiode after that TS. 

574 Screen is not displaying anything until then. 

575 (Frame changes are in BNC1High and BNC1Low) 

576 """ 

577 if not data: 1aczbtud

578 data = raw.load_data(session_path, task_collection=task_collection) 

579 stim_on = [] 1aczbtud

580 bnc_h = [] 1aczbtud

581 bnc_l = [] 1aczbtud

582 for tr in data: 1aczbtud

583 stim_on.append(tr['behavior_data']['States timestamps']['stim_on'][0][0]) 1aczbtud

584 if 'BNC1High' in tr['behavior_data']['Events timestamps'].keys(): 1aczbtud

585 bnc_h.append(np.array(tr['behavior_data'] 1czbd

586 ['Events timestamps']['BNC1High'])) 

587 else: 

588 bnc_h.append(np.array([np.NINF])) 1azbtud

589 if 'BNC1Low' in tr['behavior_data']['Events timestamps'].keys(): 1aczbtud

590 bnc_l.append(np.array(tr['behavior_data'] 1czbd

591 ['Events timestamps']['BNC1Low'])) 

592 else: 

593 bnc_l.append(np.array([np.NINF])) 1azbtud

594 

595 stim_on = np.array(stim_on) 1aczbtud

596 bnc_h = np.array(bnc_h, dtype=object) 1aczbtud

597 bnc_l = np.array(bnc_l, dtype=object) 1aczbtud

598 

599 count_missing = 0 1aczbtud

600 stimOn_times = np.zeros_like(stim_on) 1aczbtud

601 for i in range(len(stim_on)): 1aczbtud

602 hl = np.sort(np.concatenate([bnc_h[i], bnc_l[i]])) 1aczbtud

603 stot = hl[hl > stim_on[i]] 1aczbtud

604 if np.size(stot) == 0: 1aczbtud

605 stot = np.array([np.nan]) 1azbtud

606 count_missing += 1 1azbtud

607 stimOn_times[i] = stot[0] 1aczbtud

608 

609 if np.all(np.isnan(stimOn_times)): 1aczbtud

610 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({count_missing} trials)') 1atud

611 

612 if count_missing > 0: 1aczbtud

613 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {count_missing} trials') 1azbtud

614 

615 return np.array(stimOn_times) 1aczbtud

616 

617 

618class StimOnOffFreezeTimes(BaseBpodTrialsExtractor): 

619 """ 

620 Extracts stim on / off and freeze times from Bpod BNC1 detected fronts. 

621 

622 Each stimulus event is the first detected front of the BNC1 signal after the trigger state, but before the next 

623 trigger state. 

624 """ 

625 save_names = ('_ibl_trials.stimOn_times.npy', None, None) 

626 var_names = ('stimOn_times', 'stimOff_times', 'stimFreeze_times') 

627 

628 def _extract(self): 

629 choice = Choice(self.session_path).extract( 1acbypvlifgxqrosjkmntuedhw

630 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

631 )[0] 

632 stimOnTrigger = StimOnTriggerTimes(self.session_path).extract( 1acbypvlifgxqrosjkmntuedhw

633 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

634 )[0] 

635 stimFreezeTrigger = StimFreezeTriggerTimes(self.session_path).extract( 1acbypvlifgxqrosjkmntuedhw

636 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

637 )[0] 

638 stimOffTrigger = StimOffTriggerTimes(self.session_path).extract( 1acbypvlifgxqrosjkmntuedhw

639 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

640 )[0] 

641 f2TTL = [raw.get_port_events(tr, name='BNC1') for tr in self.bpod_trials] 1acbypvlifgxqrosjkmntuedhw

642 assert stimOnTrigger.size == stimFreezeTrigger.size == stimOffTrigger.size == choice.size == len(f2TTL) 1acbypvlifgxqrosjkmntuedhw

643 assert all(stimOnTrigger < np.nan_to_num(stimFreezeTrigger, nan=np.inf)) and \ 1acbypvlifgxqrosjkmntuedhw

644 all(np.nan_to_num(stimFreezeTrigger, nan=-np.inf) < stimOffTrigger) 

645 

646 stimOn_times = np.array([]) 1acbypvlifgxqrosjkmntuedhw

647 stimOff_times = np.array([]) 1acbypvlifgxqrosjkmntuedhw

648 stimFreeze_times = np.array([]) 1acbypvlifgxqrosjkmntuedhw

649 has_freeze = version.parse(self.settings.get('IBLRIG_VERSION', '0')) >= version.parse('6.2.5') 1acbypvlifgxqrosjkmntuedhw

650 for tr, on, freeze, off, c in zip(f2TTL, stimOnTrigger, stimFreezeTrigger, stimOffTrigger, choice): 1acbypvlifgxqrosjkmntuedhw

651 tr = np.array(tr) 1acbypvlifgxqrosjkmntuedhw

652 # stim on 

653 lim = freeze if has_freeze else off 1acbypvlifgxqrosjkmntuedhw

654 idx, = np.where(np.logical_and(on < tr, tr < lim)) 1acbypvlifgxqrosjkmntuedhw

655 stimOn_times = np.append(stimOn_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acbypvlifgxqrosjkmntuedhw

656 # stim off 

657 idx, = np.where(off < tr) 1acbypvlifgxqrosjkmntuedhw

658 stimOff_times = np.append(stimOff_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acbypvlifgxqrosjkmntuedhw

659 # stim freeze - take last event before off trigger 

660 if has_freeze: 1acbypvlifgxqrosjkmntuedhw

661 idx, = np.where(np.logical_and(freeze < tr, tr < off)) 1pvlifgqojkmnh

662 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1pvlifgqojkmnh

663 else: 

664 idx, = np.where(tr <= off) 1acbyxrstuedhw

665 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1acbyxrstuedhw

666 # In no_go trials no stimFreeze happens just stim Off 

667 stimFreeze_times[choice == 0] = np.nan 1acbypvlifgxqrosjkmntuedhw

668 

669 return stimOn_times, stimOff_times, stimFreeze_times 1acbypvlifgxqrosjkmntuedhw

670 

671 

672class PhasePosQuiescence(BaseBpodTrialsExtractor): 

673 """Extract stimulus phase, position and quiescence from Bpod data. 

674 

675 For extraction of pre-generated events, use the ProbaContrasts extractor instead. 

676 """ 

677 save_names = (None, None, '_ibl_trials.quiescencePeriod.npy') 

678 var_names = ('phase', 'position', 'quiescence') 

679 

680 def _extract(self, **kwargs): 

681 phase = np.array([t['stim_phase'] for t in self.bpod_trials]) 1acbplifgxqrosjkmntuedhw

682 position = np.array([t['position'] for t in self.bpod_trials]) 1acbplifgxqrosjkmntuedhw

683 quiescence = np.array([t['quiescent_period'] for t in self.bpod_trials]) 1acbplifgxqrosjkmntuedhw

684 return phase, position, quiescence 1acbplifgxqrosjkmntuedhw

685 

686 

687class PauseDuration(BaseBpodTrialsExtractor): 

688 """Extract pause duration from raw trial data.""" 

689 save_names = None 

690 var_names = 'pause_duration' 

691 

692 def _extract(self, **kwargs): 

693 # pausing logic added in version 8.9.0 

694 ver = version.parse(self.settings.get('IBLRIG_VERSION') or '0') 1cSbfge

695 default = 0. if ver < version.parse('8.9.0') else np.nan 1cSbfge

696 return np.fromiter((t.get('pause_duration', default) for t in self.bpod_trials), dtype=float) 1cSbfge

697 

698 

699class TrialsTable(BaseBpodTrialsExtractor): 

700 """ 

701 Extracts the following into a table from Bpod raw data: 

702 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

703 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

704 Additionally extracts the following wheel data: 

705 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude 

706 """ 

707 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

708 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) 

709 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 

710 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement') 

711 

712 def _extract(self, extractor_classes=None, **kwargs): 

713 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1cbfge

714 RewardVolume, ProbabilityLeft, Wheel] 

715 out, _ = run_extractor_classes( 1cbfge

716 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

717 task_collection=self.task_collection) 

718 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1cbfge

719 assert len(table.keys()) == 12 1cbfge

720 

721 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1cbfge

722 

723 

724class TrainingTrials(BaseBpodTrialsExtractor): 

725 save_names = ('_ibl_trials.repNum.npy', '_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, 

726 None, None, None, '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

727 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, None, None) 

728 var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 

729 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 

730 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 

731 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence', 'pause_duration') 

732 

733 def _extract(self) -> dict: 

734 base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1cbfge

735 ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence, PauseDuration] 

736 out, _ = run_extractor_classes( 1cbfge

737 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

738 task_collection=self.task_collection) 

739 return {k: out[k] for k in self.var_names} 1cbfge

740 

741 

742def extract_all(session_path, save=False, bpod_trials=None, settings=None, task_collection='raw_behavior_data', save_path=None): 

743 """Extract trials and wheel data. 

744 

745 For task versions >= 5.0.0, outputs wheel data and trials.table dataset (+ some extra datasets) 

746 

747 Parameters 

748 ---------- 

749 session_path : str, pathlib.Path 

750 The path to the session 

751 save : bool 

752 If true save the data files to ALF 

753 bpod_trials : list of dicts 

754 The Bpod trial dicts loaded from the _iblrig_taskData.raw dataset 

755 settings : dict 

756 The Bpod settings loaded from the _iblrig_taskSettings.raw dataset 

757 

758 Returns 

759 ------- 

760 A list of extracted data and a list of file paths if save is True (otherwise None) 

761 """ 

762 if not bpod_trials: 1cb

763 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 1c

764 if not settings: 1cb

765 settings = raw.load_settings(session_path, task_collection=task_collection) 1c

766 if settings is None or settings['IBLRIG_VERSION'] == '': 1cb

767 settings = {'IBLRIG_VERSION': '100.0.0'} 

768 

769 # Version check 

770 if version.parse(settings['IBLRIG_VERSION']) >= version.parse('5.0.0'): 1cb

771 # We now extract a single trials table 

772 base = [TrainingTrials] 1cb

773 else: 

774 base = [ 1cb

775 RepNum, GoCueTriggerTimes, Intervals, Wheel, FeedbackType, ContrastLR, ProbabilityLeft, Choice, IncludedTrials, 

776 StimOnTimes_deprecated, RewardVolume, FeedbackTimes, ResponseTimes, GoCueTimes, PhasePosQuiescence 

777 ] 

778 

779 out, fil = run_extractor_classes(base, save=save, session_path=session_path, bpod_trials=bpod_trials, settings=settings, 1cb

780 task_collection=task_collection, path_out=save_path) 

781 return out, fil 1cb