Coverage for ibllib/io/extractors/training_trials.py: 79%

375 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-07 14:26 +0100

1import logging 

2import numpy as np 

3from itertools import accumulate 

4from packaging import version 

5from one.alf.io import AlfBunch 

6 

7import ibllib.io.raw_data_loaders as raw 

8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

9from ibllib.io.extractors.training_wheel import Wheel 

10 

11 

12_logger = logging.getLogger(__name__) 

13__all__ = ['TrainingTrials'] 

14 

15 

16class FeedbackType(BaseBpodTrialsExtractor): 

17 """ 

18 Get the feedback that was delivered to subject. 

19 **Optional:** saves _ibl_trials.feedbackType.npy 

20 

21 Checks in raw datafile for error and reward state. 

22 Will raise an error if more than one of the mutually exclusive states have 

23 been triggered. 

24 

25 Sets feedbackType to -1 if error state was triggered (applies to no-go) 

26 Sets feedbackType to +1 if reward state was triggered 

27 """ 

28 save_names = '_ibl_trials.feedbackType.npy' 

29 var_names = 'feedbackType' 

30 

31 def _extract(self): 

32 feedbackType = np.zeros(len(self.bpod_trials), np.int64) 1avcigdepkljmnhbfo

33 for i, t in enumerate(self.bpod_trials): 1avcigdepkljmnhbfo

34 state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go'] 1avcigdepkljmnhbfo

35 outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.nan]])[0][0]) for sn in state_names} 1avcigdepkljmnhbfo

36 assert np.sum(list(outcome.values())) == 1 1avcigdepkljmnhbfo

37 outcome = next(k for k in outcome if outcome[k]) 1avcigdepkljmnhbfo

38 if outcome == 'correct': 1avcigdepkljmnhbfo

39 feedbackType[i] = 1 1avcigdepkljmnhbf

40 elif outcome in ['error', 'no_go']: 1avcigdepkljmnhbfo

41 feedbackType[i] = -1 1avcigdepkljmnhbfo

42 return feedbackType 1avcigdepkljmnhbfo

43 

44 

45class ContrastLR(BaseBpodTrialsExtractor): 

46 """ 

47 Get left and right contrasts from raw datafile. Optionally, saves 

48 _ibl_trials.contrastLeft.npy and _ibl_trials.contrastRight.npy to alf folder. 

49 

50 Uses signed_contrast to create left and right contrast vectors. 

51 """ 

52 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy') 

53 var_names = ('contrastLeft', 'contrastRight') 

54 

55 def _extract(self): 

56 # iblrigv8 has only flat values in the trial table so we can switch to parquet table when times come 

57 # and all the clutter here would fit in ~30 lines 

58 if isinstance(self.bpod_trials[0]['contrast'], float): 1Ecdeb

59 contrastLeft = np.array([t['contrast'] if np.sign( 

60 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

61 contrastRight = np.array([t['contrast'] if np.sign( 

62 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

63 else: 

64 contrastLeft = np.array([t['contrast']['value'] if np.sign( 1Ecdeb

65 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

66 contrastRight = np.array([t['contrast']['value'] if np.sign( 1Ecdeb

67 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

68 

69 return contrastLeft, contrastRight 1Ecdeb

70 

71 

72class ProbabilityLeft(BaseBpodTrialsExtractor): 

73 save_names = '_ibl_trials.probabilityLeft.npy' 

74 var_names = 'probabilityLeft' 

75 

76 def _extract(self, **kwargs): 

77 return np.array([t['stim_probability_left'] for t in self.bpod_trials]) 1Lcidepkhb

78 

79 

80class Choice(BaseBpodTrialsExtractor): 

81 """ 

82 Get the subject's choice in every trial. 

83 **Optional:** saves _ibl_trials.choice.npy to alf folder. 

84 

85 Uses signed_contrast and trial_correct. 

86 -1 is a CCW turn (towards the left) 

87 +1 is a CW turn (towards the right) 

88 0 is a no_go trial 

89 If a trial is correct the choice of the animal was the inverse of the sign 

90 of the position. 

91 

92 >>> choice[t] = -np.sign(position[t]) if trial_correct[t] 

93 """ 

94 save_names = '_ibl_trials.choice.npy' 

95 var_names = 'choice' 

96 

97 def _extract(self): 

98 sitm_side = np.array([np.sign(t['position']) for t in self.bpod_trials]) 1aAcqigdepkljmnhbfo

99 trial_correct = np.array([t['trial_correct'] for t in self.bpod_trials]) 1aAcqigdepkljmnhbfo

100 trial_nogo = np.array( 1aAcqigdepkljmnhbfo

101 [~np.isnan(t['behavior_data']['States timestamps']['no_go'][0][0]) 

102 for t in self.bpod_trials]) 

103 choice = sitm_side.copy() 1aAcqigdepkljmnhbfo

104 choice[trial_correct] = -choice[trial_correct] 1aAcqigdepkljmnhbfo

105 choice[trial_nogo] = 0 1aAcqigdepkljmnhbfo

106 choice = choice.astype(int) 1aAcqigdepkljmnhbfo

107 return choice 1aAcqigdepkljmnhbfo

108 

109 

110class RepNum(BaseBpodTrialsExtractor): 

111 """ 

112 Count the consecutive repeated trials. 

113 **Optional:** saves _ibl_trials.repNum.npy to alf folder. 

114 

115 Creates trial_repeated from trial['contrast']['type'] == 'RepeatContrast' 

116 

117 >>> trial_repeated = [0, 1, 1, 0, 1, 0, 1, 1, 1, 0] 

118 >>> repNum = [0, 1, 2, 0, 1, 0, 1, 2, 3, 0] 

119 """ 

120 save_names = '_ibl_trials.repNum.npy' 

121 var_names = 'repNum' 

122 

123 def _extract(self): 

124 def get_trial_repeat(trial): 1Bcdeb

125 if 'debias_trial' in trial: 1Bcdeb

126 return trial['debias_trial'] 

127 elif 'contrast' in trial and isinstance(trial['contrast'], dict): 1Bcdeb

128 return trial['contrast']['type'] == 'RepeatContrast' 1Bcdeb

129 else: 

130 # For advanced choice world and its subclasses before version 8.19.0 there was no 'debias_trial' field 

131 # and no debiasing protocol applied, so simply return False 

132 assert (self.settings['PYBPOD_PROTOCOL'].startswith('_iblrig_tasks_advancedChoiceWorld') or 

133 self.settings['PYBPOD_PROTOCOL'].startswith('ccu_neuromodulatorChoiceWorld')) 

134 return False 

135 

136 trial_repeated = np.fromiter(map(get_trial_repeat, self.bpod_trials), int) 1Bcdeb

137 repNum = np.fromiter(accumulate(trial_repeated, lambda x, y: x + y if y else 0), int) 1Bcdeb

138 return repNum 1Bcdeb

139 

140 

141class RewardVolume(BaseBpodTrialsExtractor): 

142 """ 

143 Load reward volume delivered for each trial. 

144 **Optional:** saves _ibl_trials.rewardVolume.npy 

145 

146 Uses reward_current to accumulate the amount of 

147 """ 

148 save_names = '_ibl_trials.rewardVolume.npy' 

149 var_names = 'rewardVolume' 

150 

151 def _extract(self): 

152 trial_volume = [x['reward_amount'] 1aFcigdepkljmnhbfo

153 if x['trial_correct'] else 0 for x in self.bpod_trials] 

154 reward_volume = np.array(trial_volume).astype(np.float64) 1aFcigdepkljmnhbfo

155 assert len(reward_volume) == len(self.bpod_trials) 1aFcigdepkljmnhbfo

156 return reward_volume 1aFcigdepkljmnhbfo

157 

158 

159class FeedbackTimes(BaseBpodTrialsExtractor): 

160 """ 

161 Get the times the water or error tone was delivered to the animal. 

162 **Optional:** saves _ibl_trials.feedback_times.npy 

163 

164 Gets reward and error state init times vectors, 

165 checks if the intersection of nans is empty, then 

166 merges the 2 vectors. 

167 """ 

168 save_names = '_ibl_trials.feedback_times.npy' 

169 var_names = 'feedback_times' 

170 

171 @staticmethod 

172 def get_feedback_times_lt5(session_path, task_collection='raw_behavior_data', data=False): 

173 if not data: 1aycpb

174 data = raw.load_data(session_path, task_collection=task_collection) 

175 rw_times = [tr['behavior_data']['States timestamps']['reward'][0][0] 1aycpb

176 for tr in data] 

177 err_times = [tr['behavior_data']['States timestamps']['error'][0][0] 1aycpb

178 for tr in data] 

179 nogo_times = [tr['behavior_data']['States timestamps']['no_go'][0][0] 1aycpb

180 for tr in data] 

181 assert sum(np.isnan(rw_times) & 1aycpb

182 np.isnan(err_times) & np.isnan(nogo_times)) == 0 

183 merge = np.array([np.array(times)[~np.isnan(times)] for times in 1aycpb

184 zip(rw_times, err_times, nogo_times)]).squeeze() 

185 

186 return np.array(merge) 1aycpb

187 

188 @staticmethod 

189 def get_feedback_times_ge5(session_path, task_collection='raw_behavior_data', data=False): 

190 # ger err and no go trig times -- look for BNC2High of trial -- verify 

191 # only 2 onset times go tone and noise, select 2nd/-1 OR select the one 

192 # that is grater than the nogo or err trial onset time 

193 if not data: 1tcirgdeksljmnhbfo

194 data = raw.load_data(session_path, task_collection=task_collection) 

195 missed_bnc2 = 0 1tcirgdeksljmnhbfo

196 rw_times, err_sound_times, merge = [np.zeros([len(data), ]) for _ in range(3)] 1tcirgdeksljmnhbfo

197 

198 for ind, tr in enumerate(data): 1tcirgdeksljmnhbfo

199 st = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1tcirgdeksljmnhbfo

200 if not st: 1tcirgdeksljmnhbfo

201 st = np.array([np.nan, np.nan]) 1bf

202 missed_bnc2 += 1 1bf

203 # xonar soundcard duplicates events, remove consecutive events too close together 

204 st = np.delete(st, np.where(np.diff(st) < 0.020)[0] + 1) 1tcirgdeksljmnhbfo

205 rw_times[ind] = tr['behavior_data']['States timestamps']['reward'][0][0] 1tcirgdeksljmnhbfo

206 # get the error sound only if the reward is nan 

207 err_sound_times[ind] = st[-1] if st.size >= 2 and np.isnan(rw_times[ind]) else np.nan 1tcirgdeksljmnhbfo

208 if missed_bnc2 == len(data): 1tcirgdeksljmnhbfo

209 _logger.warning('No BNC2 for feedback times, filling error trials NaNs') 

210 merge *= np.nan 1tcirgdeksljmnhbfo

211 merge[~np.isnan(rw_times)] = rw_times[~np.isnan(rw_times)] 1tcirgdeksljmnhbfo

212 merge[~np.isnan(err_sound_times)] = err_sound_times[~np.isnan(err_sound_times)] 1tcirgdeksljmnhbfo

213 

214 return merge 1tcirgdeksljmnhbfo

215 

216 def _extract(self): 

217 # Version check 

218 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1atycirgdepksljmnhbfo

219 merge = self.get_feedback_times_ge5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1tcirgdeksljmnhbfo

220 else: 

221 merge = self.get_feedback_times_lt5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1aycpb

222 return np.array(merge) 1atycirgdepksljmnhbfo

223 

224 

225class Intervals(BaseBpodTrialsExtractor): 

226 """ 

227 Trial start to trial end. Trial end includes 1 or 2 seconds after feedback, 

228 (depending on the feedback) and 0.5 seconds of iti. 

229 **Optional:** saves _ibl_trials.intervals.npy 

230 

231 Uses the corrected Trial start and Trial end timestamp values form PyBpod. 

232 """ 

233 save_names = '_ibl_trials.intervals.npy' 

234 var_names = 'intervals' 

235 

236 def _extract(self): 

237 starts = [t['behavior_data']['Trial start timestamp'] for t in self.bpod_trials] 1aGcigdepkljmnhbfo

238 ends = [t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials] 1aGcigdepkljmnhbfo

239 return np.array([starts, ends]).T 1aGcigdepkljmnhbfo

240 

241 

242class ResponseTimes(BaseBpodTrialsExtractor): 

243 """ 

244 Time (in absolute seconds from session start) when a response was recorded. 

245 **Optional:** saves _ibl_trials.response_times.npy 

246 

247 Uses the timestamp of the end of the closed_loop state. 

248 """ 

249 save_names = '_ibl_trials.response_times.npy' 

250 var_names = 'response_times' 

251 

252 def _extract(self): 

253 rt = np.array([tr['behavior_data']['States timestamps']['closed_loop'][0][1] 1aJcigdepkljmnhbfo

254 for tr in self.bpod_trials]) 

255 return rt 1aJcigdepkljmnhbfo

256 

257 

258class ItiDuration(BaseBpodTrialsExtractor): 

259 """ 

260 Calculate duration of iti from state timestamps. 

261 **Optional:** saves _ibl_trials.iti_duration.npy 

262 

263 Uses Trial end timestamp and get_response_times to calculate iti. 

264 """ 

265 save_names = '_ibl_trials.itiDuration.npy' 

266 var_names = 'iti_dur' 

267 

268 def _extract(self): 

269 rt, _ = ResponseTimes(self.session_path).extract( 

270 save=False, task_collection=self.task_collection, bpod_trials=self.bpod_trials, settings=self.settings) 

271 ends = np.array([t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials]) 

272 iti_dur = ends - rt 

273 return iti_dur 

274 

275 

276class GoCueTriggerTimes(BaseBpodTrialsExtractor): 

277 """ 

278 Get trigger times of goCue from state machine. 

279 

280 Current software solution for triggering sounds uses PyBpod soft codes. 

281 Delays can be in the order of 10's of ms. This is the time when the command 

282 to play the sound was executed. To measure accurate time, either getting the 

283 sound onset from xonar soundcard sync pulse (latencies may vary). 

284 """ 

285 save_names = '_ibl_trials.goCueTrigger_times.npy' 

286 var_names = 'goCueTrigger_times' 

287 

288 def _extract(self): 

289 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1aHcigdepkljmnhbfo

290 goCue = np.array([tr['behavior_data']['States timestamps'] 1Hcigdekljmnhbfo

291 ['play_tone'][0][0] for tr in self.bpod_trials]) 

292 else: 

293 goCue = np.array([tr['behavior_data']['States timestamps'] 1acpb

294 ['closed_loop'][0][0] for tr in self.bpod_trials]) 

295 return goCue 1aHcigdepkljmnhbfo

296 

297 

298class TrialType(BaseBpodTrialsExtractor): 

299 save_names = '_ibl_trials.type.npy' 

300 var_name = 'trial_type' 

301 

302 def _extract(self): 

303 trial_type = [] 

304 for tr in self.bpod_trials: 

305 if ~np.isnan(tr["behavior_data"]["States timestamps"]["reward"][0][0]): 

306 trial_type.append(1) 

307 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["error"][0][0]): 

308 trial_type.append(-1) 

309 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0][0]): 

310 trial_type.append(0) 

311 else: 

312 _logger.warning("Trial is not in set {-1, 0, 1}, appending NaN to trialType") 

313 trial_type.append(np.nan) 

314 return np.array(trial_type) 

315 

316 

317class GoCueTimes(BaseBpodTrialsExtractor): 

318 """ 

319 Get trigger times of goCue from state machine. 

320 

321 Current software solution for triggering sounds uses PyBpod soft codes. 

322 Delays can be in the order of 10-100s of ms. This is the time when the command 

323 to play the sound was executed. To measure accurate time, either getting the 

324 sound onset from the future microphone OR the new xonar soundcard and 

325 setup developed by Sanworks guarantees a set latency (in testing). 

326 """ 

327 save_names = '_ibl_trials.goCue_times.npy' 

328 var_names = 'goCue_times' 

329 

330 def _extract(self): 

331 go_cue_times = np.zeros([len(self.bpod_trials), ]) 1aucwirgdepksljmnhxbfo

332 for ind, tr in enumerate(self.bpod_trials): 1aucwirgdepksljmnhxbfo

333 if raw.get_port_events(tr, 'BNC2'): 1aucwirgdepksljmnhxbfo

334 bnchigh = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1ucwirgdepksljmnhxbfo

335 if bnchigh: 1ucwirgdepksljmnhxbfo

336 go_cue_times[ind] = bnchigh[0] 1ucwirgdepksljmnhxbfo

337 continue 1ucwirgdepksljmnhxbfo

338 bnclow = tr['behavior_data']['Events timestamps'].get('BNC2Low', None) 1bf

339 if bnclow: 1bf

340 go_cue_times[ind] = bnclow[0] - 0.1 1bf

341 continue 1bf

342 go_cue_times[ind] = np.nan 

343 else: 

344 go_cue_times[ind] = np.nan 1aucb

345 

346 nmissing = np.sum(np.isnan(go_cue_times)) 1aucwirgdepksljmnhxbfo

347 # Check if all stim_syncs have failed to be detected 

348 if np.all(np.isnan(go_cue_times)): 1aucwirgdepksljmnhxbfo

349 _logger.warning( 1aucb

350 f'{self.session_path}: Missing ALL !! BNC2 TTLs ({nmissing} trials)') 

351 # Check if any stim_sync has failed be detected for every trial 

352 elif np.any(np.isnan(go_cue_times)): 1ucwirgdepksljmnhxbfo

353 _logger.warning(f'{self.session_path}: Missing BNC2 TTLs on {nmissing} trials') 

354 

355 return go_cue_times 1aucwirgdepksljmnhxbfo

356 

357 

358class IncludedTrials(BaseBpodTrialsExtractor): 

359 save_names = '_ibl_trials.included.npy' 

360 var_names = 'included' 

361 

362 def _extract(self): 

363 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1azCDcigpkljmnhbfo

364 trials_included = self.get_included_trials_ge5( 1zCcigkljmnhbfo

365 data=self.bpod_trials, settings=self.settings) 

366 else: 

367 trials_included = self.get_included_trials_lt5(data=self.bpod_trials) 1azDcpb

368 return trials_included 1azCDcigpkljmnhbfo

369 

370 @staticmethod 

371 def get_included_trials_lt5(data=False): 

372 trials_included = np.ones(len(data), dtype=bool) 1azDcpb

373 return trials_included 1azDcpb

374 

375 @staticmethod 

376 def get_included_trials_ge5(data=False, settings=False): 

377 trials_included = np.array([True for t in data]) 1zCcigkljmnhbfo

378 if ('SUBJECT_DISENGAGED_TRIGGERED' in settings.keys() and settings[ 1zCcigkljmnhbfo

379 'SUBJECT_DISENGAGED_TRIGGERED'] is not False): 

380 idx = settings['SUBJECT_DISENGAGED_TRIALNUM'] - 1 1igljmnbf

381 trials_included[idx:] = False 1igljmnbf

382 return trials_included 1zCcigkljmnhbfo

383 

384 

385class ItiInTimes(BaseBpodTrialsExtractor): 

386 var_names = 'itiIn_times' 

387 

388 def _extract(self): 

389 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("5.0.0"): 1acigdepkljmnhbfo

390 iti_in = np.ones(len(self.bpod_trials)) * np.nan 1acpb

391 else: 

392 iti_in = np.array( 1cigdekljmnhbfo

393 [tr["behavior_data"]["States timestamps"] 

394 ["exit_state"][0][0] for tr in self.bpod_trials] 

395 ) 

396 return iti_in 1acigdepkljmnhbfo

397 

398 

399class ErrorCueTriggerTimes(BaseBpodTrialsExtractor): 

400 var_names = 'errorCueTrigger_times' 

401 

402 def _extract(self): 

403 errorCueTrigger_times = np.zeros(len(self.bpod_trials)) * np.nan 1acigdepkljmnhbfo

404 for i, tr in enumerate(self.bpod_trials): 1acigdepkljmnhbfo

405 nogo = tr["behavior_data"]["States timestamps"]["no_go"][0][0] 1acigdepkljmnhbfo

406 error = tr["behavior_data"]["States timestamps"]["error"][0][0] 1acigdepkljmnhbfo

407 if np.all(~np.isnan(nogo)): 1acigdepkljmnhbfo

408 errorCueTrigger_times[i] = nogo 1cgdelmnhbf

409 elif np.all(~np.isnan(error)): 1acigdepkljmnhbfo

410 errorCueTrigger_times[i] = error 1acigdepkljmnhbfo

411 return errorCueTrigger_times 1acigdepkljmnhbfo

412 

413 

414class StimFreezeTriggerTimes(BaseBpodTrialsExtractor): 

415 var_names = 'stimFreezeTrigger_times' 

416 

417 def _extract(self): 

418 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("6.2.5"): 1acqigdepkljmnhbfo

419 return np.ones(len(self.bpod_trials)) * np.nan 1acqplmnbfo

420 freeze_reward = np.array( 1igdekjhf

421 [ 

422 True 

423 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_reward"][0])) 

424 else False 

425 for tr in self.bpod_trials 

426 ] 

427 ) 

428 freeze_error = np.array( 1igdekjhf

429 [ 

430 True 

431 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_error"][0])) 

432 else False 

433 for tr in self.bpod_trials 

434 ] 

435 ) 

436 no_go = np.array( 1igdekjhf

437 [ 

438 True 

439 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0])) 

440 else False 

441 for tr in self.bpod_trials 

442 ] 

443 ) 

444 assert (np.sum(freeze_error) + np.sum(freeze_reward) + 1igdekjhf

445 np.sum(no_go) == len(self.bpod_trials)) 

446 stimFreezeTrigger = np.array([]) 1igdekjhf

447 for r, e, n, tr in zip(freeze_reward, freeze_error, no_go, self.bpod_trials): 1igdekjhf

448 if n: 1igdekjhf

449 stimFreezeTrigger = np.append(stimFreezeTrigger, np.nan) 1gdehf

450 continue 1gdehf

451 state = "freeze_reward" if r else "freeze_error" 1igdekjhf

452 stimFreezeTrigger = np.append( 1igdekjhf

453 stimFreezeTrigger, tr["behavior_data"]["States timestamps"][state][0][0] 

454 ) 

455 return stimFreezeTrigger 1igdekjhf

456 

457 

458class StimOffTriggerTimes(BaseBpodTrialsExtractor): 

459 var_names = 'stimOffTrigger_times' 

460 save_names = '_ibl_trials.stimOnTrigger_times.npy' 

461 

462 def _extract(self): 

463 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') >= version.parse("6.2.5"): 1acqigdepkljmnhbfo

464 stim_off_trigger_state = "hide_stim" 1igdekjhf

465 elif version.parse(self.settings["IBLRIG_VERSION"]) >= version.parse("5.0.0"): 1acqplmnbfo

466 stim_off_trigger_state = "exit_state" 1cqlmnbfo

467 else: 

468 stim_off_trigger_state = "trial_start" 1acqpb

469 

470 stimOffTrigger_times = np.array( 1acqigdepkljmnhbfo

471 [tr["behavior_data"]["States timestamps"][stim_off_trigger_state][0][0] 

472 for tr in self.bpod_trials] 

473 ) 

474 # If pre version 5.0.0 no specific nogo Off trigger was given, just return trial_starts 

475 if stim_off_trigger_state == "trial_start": 1acqigdepkljmnhbfo

476 return stimOffTrigger_times 1acqpb

477 

478 no_goTrigger_times = np.array( 1cqigdekljmnhbfo

479 [tr["behavior_data"]["States timestamps"]["no_go"][0][0] for tr in self.bpod_trials] 

480 ) 

481 # Stim off trigs are either in their own state or in the no_go state if the 

482 # mouse did not move, if the stim_off_trigger_state always exist 

483 # (exit_state or trial_start) 

484 # no NaNs will happen, NaNs might happen in at last trial if 

485 # session was stopped after response 

486 # if stim_off_trigger_state == "hide_stim": 

487 # assert all(~np.isnan(no_goTrigger_times) == np.isnan(stimOffTrigger_times)) 

488 # Patch with the no_go states trig times 

489 stimOffTrigger_times[~np.isnan(no_goTrigger_times)] = no_goTrigger_times[ 1cqigdekljmnhbfo

490 ~np.isnan(no_goTrigger_times) 

491 ] 

492 return stimOffTrigger_times 1cqigdekljmnhbfo

493 

494 

495class StimOnTriggerTimes(BaseBpodTrialsExtractor): 

496 save_names = '_ibl_trials.stimOnTrigger_times.npy' 

497 var_names = 'stimOnTrigger_times' 

498 

499 def _extract(self): 

500 # Get the stim_on_state that triggers the onset of the stim 

501 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1aKcqirgdepksljmnhbfo

502 ['stim_on'][0] for tr in self.bpod_trials]) 

503 return stim_on_state[:, 0].T 1aKcqirgdepksljmnhbfo

504 

505 

506class StimOnTimes_deprecated(BaseBpodTrialsExtractor): 

507 save_names = '_ibl_trials.stimOn_times.npy' 

508 var_names = 'stimOn_times' 

509 

510 def _extract(self): 

511 """ 

512 Find the time of the state machine command to turn on the stim 

513 (state stim_on start or rotary_encoder_event2) 

514 Find the next frame change from the photodiode after that TS. 

515 Screen is not displaying anything until then. 

516 (Frame changes are in BNC1 High and BNC1 Low) 

517 """ 

518 # Version check 

519 _logger.warning("Deprecation Warning: this is an old version of stimOn extraction." 

520 "From version 5., use StimOnOffFreezeTimes") 

521 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 

522 stimOn_times = self.get_stimOn_times_ge5(self.session_path, data=self.bpod_trials, 

523 task_collection=self.task_collection) 

524 else: 

525 stimOn_times = self.get_stimOn_times_lt5(self.session_path, data=self.bpod_trials, 

526 task_collection=self.task_collection) 

527 return np.array(stimOn_times) 

528 

529 @staticmethod 

530 def get_stimOn_times_ge5(session_path, data=False, task_collection='raw_behavior_data'): 

531 """ 

532 Find first and last stim_sync pulse of the trial. 

533 stimOn_times should be the first after the stim_on state. 

534 (Stim updates are in BNC1High and BNC1Low - frame2TTL device) 

535 Check that all trials have frame changes. 

536 Find length of stim_on_state [start, stop]. 

537 If either check fails the HW device failed to detect the stim_sync square change 

538 Substitute that trial's missing or incorrect value with a NaN. 

539 return stimOn_times 

540 """ 

541 if not data: 

542 data = raw.load_data(session_path, task_collection=task_collection) 

543 # Get all stim_sync events detected 

544 stim_sync_all = [raw.get_port_events(tr, 'BNC1') for tr in data] 

545 stim_sync_all = [np.array(x) for x in stim_sync_all] 

546 # Get the stim_on_state that triggers the onset of the stim 

547 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 

548 ['stim_on'][0] for tr in data]) 

549 

550 stimOn_times = np.array([]) 

551 for sync, on, off in zip( 

552 stim_sync_all, stim_on_state[:, 0], stim_on_state[:, 1]): 

553 pulse = sync[np.where(np.bitwise_and((sync > on), (sync <= off)))] 

554 if pulse.size == 0: 

555 stimOn_times = np.append(stimOn_times, np.nan) 

556 else: 

557 stimOn_times = np.append(stimOn_times, pulse) 

558 

559 nmissing = np.sum(np.isnan(stimOn_times)) 

560 # Check if all stim_syncs have failed to be detected 

561 if np.all(np.isnan(stimOn_times)): 

562 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({nmissing} trials)') 

563 

564 # Check if any stim_sync has failed be detected for every trial 

565 if np.any(np.isnan(stimOn_times)): 

566 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {nmissing} trials') 

567 

568 return stimOn_times 

569 

570 @staticmethod 

571 def get_stimOn_times_lt5(session_path, data=False, task_collection='raw_behavior_data'): 

572 """ 

573 Find the time of the statemachine command to turn on the stim 

574 (state stim_on start or rotary_encoder_event2) 

575 Find the next frame change from the photodiode after that TS. 

576 Screen is not displaying anything until then. 

577 (Frame changes are in BNC1High and BNC1Low) 

578 """ 

579 if not data: 

580 data = raw.load_data(session_path, task_collection=task_collection) 

581 stim_on = [] 

582 bnc_h = [] 

583 bnc_l = [] 

584 for tr in data: 

585 stim_on.append(tr['behavior_data']['States timestamps']['stim_on'][0][0]) 

586 if 'BNC1High' in tr['behavior_data']['Events timestamps'].keys(): 

587 bnc_h.append(np.array(tr['behavior_data'] 

588 ['Events timestamps']['BNC1High'])) 

589 else: 

590 bnc_h.append(np.array([np.NINF])) 

591 if 'BNC1Low' in tr['behavior_data']['Events timestamps'].keys(): 

592 bnc_l.append(np.array(tr['behavior_data'] 

593 ['Events timestamps']['BNC1Low'])) 

594 else: 

595 bnc_l.append(np.array([np.NINF])) 

596 

597 stim_on = np.array(stim_on) 

598 bnc_h = np.array(bnc_h, dtype=object) 

599 bnc_l = np.array(bnc_l, dtype=object) 

600 

601 count_missing = 0 

602 stimOn_times = np.zeros_like(stim_on) 

603 for i in range(len(stim_on)): 

604 hl = np.sort(np.concatenate([bnc_h[i], bnc_l[i]])) 

605 stot = hl[hl > stim_on[i]] 

606 if np.size(stot) == 0: 

607 stot = np.array([np.nan]) 

608 count_missing += 1 

609 stimOn_times[i] = stot[0] 

610 

611 if np.all(np.isnan(stimOn_times)): 

612 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({count_missing} trials)') 

613 

614 if count_missing > 0: 

615 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {count_missing} trials') 

616 

617 return np.array(stimOn_times) 

618 

619 

620class StimOnOffFreezeTimes(BaseBpodTrialsExtractor): 

621 """ 

622 Extracts stim on / off and freeze times from Bpod BNC1 detected fronts. 

623 

624 Each stimulus event is the first detected front of the BNC1 signal after the trigger state, but before the next 

625 trigger state. 

626 """ 

627 save_names = ('_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOff_times.npy', None) 

628 var_names = ('stimOn_times', 'stimOff_times', 'stimFreeze_times') 

629 

630 def _extract(self): 

631 choice = Choice(self.session_path).extract( 1acqigdepkljmnhbfo

632 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

633 )[0] 

634 stimOnTrigger = StimOnTriggerTimes(self.session_path).extract( 1acqigdepkljmnhbfo

635 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

636 )[0] 

637 stimFreezeTrigger = StimFreezeTriggerTimes(self.session_path).extract( 1acqigdepkljmnhbfo

638 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

639 )[0] 

640 stimOffTrigger = StimOffTriggerTimes(self.session_path).extract( 1acqigdepkljmnhbfo

641 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

642 )[0] 

643 f2TTL = [raw.get_port_events(tr, name='BNC1') for tr in self.bpod_trials] 1acqigdepkljmnhbfo

644 assert stimOnTrigger.size == stimFreezeTrigger.size == stimOffTrigger.size == choice.size == len(f2TTL) 1acqigdepkljmnhbfo

645 assert all(stimOnTrigger < np.nan_to_num(stimFreezeTrigger, nan=np.inf)) and \ 1acqigdepkljmnhbfo

646 all(np.nan_to_num(stimFreezeTrigger, nan=-np.inf) < stimOffTrigger) 

647 

648 stimOn_times = np.array([]) 1acqigdepkljmnhbfo

649 stimOff_times = np.array([]) 1acqigdepkljmnhbfo

650 stimFreeze_times = np.array([]) 1acqigdepkljmnhbfo

651 has_freeze = version.parse(self.settings.get('IBLRIG_VERSION', '0')) >= version.parse('6.2.5') 1acqigdepkljmnhbfo

652 for tr, on, freeze, off, c in zip(f2TTL, stimOnTrigger, stimFreezeTrigger, stimOffTrigger, choice): 1acqigdepkljmnhbfo

653 tr = np.array(tr) 1acqigdepkljmnhbfo

654 # stim on 

655 lim = freeze if has_freeze else off 1acqigdepkljmnhbfo

656 idx, = np.where(np.logical_and(on < tr, tr < lim)) 1acqigdepkljmnhbfo

657 stimOn_times = np.append(stimOn_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acqigdepkljmnhbfo

658 # stim off 

659 idx, = np.where(off < tr) 1acqigdepkljmnhbfo

660 stimOff_times = np.append(stimOff_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acqigdepkljmnhbfo

661 # stim freeze - take last event before off trigger 

662 if has_freeze: 1acqigdepkljmnhbfo

663 idx, = np.where(np.logical_and(freeze < tr, tr < off)) 1igdekjhf

664 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1igdekjhf

665 else: 

666 idx, = np.where(tr <= off) 1acqplmnbfo

667 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1acqplmnbfo

668 # In no_go trials no stimFreeze happens just stim Off 

669 stimFreeze_times[choice == 0] = np.nan 1acqigdepkljmnhbfo

670 

671 return stimOn_times, stimOff_times, stimFreeze_times 1acqigdepkljmnhbfo

672 

673 

674class PhasePosQuiescence(BaseBpodTrialsExtractor): 

675 """Extract stimulus phase, position and quiescence from Bpod data. 

676 

677 For extraction of pre-generated events, use the ProbaContrasts extractor instead. 

678 """ 

679 save_names = (None, None, '_ibl_trials.quiescencePeriod.npy') 

680 var_names = ('phase', 'position', 'quiescence') 

681 

682 def _extract(self, **kwargs): 

683 phase = np.array([t['stim_phase'] for t in self.bpod_trials]) 1acigdepkljmnhbfo

684 position = np.array([t['position'] for t in self.bpod_trials]) 1acigdepkljmnhbfo

685 quiescence = np.array([t['quiescent_period'] for t in self.bpod_trials]) 1acigdepkljmnhbfo

686 return phase, position, quiescence 1acigdepkljmnhbfo

687 

688 

689class PauseDuration(BaseBpodTrialsExtractor): 

690 """Extract pause duration from raw trial data.""" 

691 save_names = None 

692 var_names = 'pause_duration' 

693 

694 def _extract(self, **kwargs): 

695 # pausing logic added in version 8.9.0 

696 ver = version.parse(self.settings.get('IBLRIG_VERSION') or '0') 1Icdeb

697 default = 0. if ver < version.parse('8.9.0') else np.nan 1Icdeb

698 return np.fromiter((t.get('pause_duration', default) for t in self.bpod_trials), dtype=float) 1Icdeb

699 

700 

701class TrialsTable(BaseBpodTrialsExtractor): 

702 """ 

703 Extracts the following into a table from Bpod raw data: 

704 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

705 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

706 Additionally extracts the following wheel data: 

707 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude 

708 """ 

709 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

710 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) 

711 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 

712 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement') 

713 

714 def _extract(self, extractor_classes=None, **kwargs): 

715 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1cdeb

716 RewardVolume, ProbabilityLeft, Wheel] 

717 out, _ = run_extractor_classes( 1cdeb

718 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

719 task_collection=self.task_collection) 

720 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1cdeb

721 assert len(table.keys()) == 12 1cdeb

722 

723 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1cdeb

724 

725 

726class TrainingTrials(BaseBpodTrialsExtractor): 

727 save_names = ('_ibl_trials.repNum.npy', '_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, 

728 '_ibl_trials.stimOffTrigger_times.npy', None, None, '_ibl_trials.table.pqt', '_ibl_trials.stimOff_times.npy', 

729 None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

730 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, 

731 '_ibl_trials.quiescencePeriod.npy', None) 

732 var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 

733 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 

734 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 

735 'wheelMoves_peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence', 'pause_duration') 

736 

737 def _extract(self) -> dict: 

738 base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1cdeb

739 ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence, PauseDuration] 

740 out, _ = run_extractor_classes( 1cdeb

741 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

742 task_collection=self.task_collection) 

743 return {k: out[k] for k in self.var_names} 1cdeb