Coverage for ibllib/io/extractors/training_trials.py: 79%

375 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1import logging 

2import numpy as np 

3from itertools import accumulate 

4from packaging import version 

5from one.alf.io import AlfBunch 

6 

7import ibllib.io.raw_data_loaders as raw 

8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

9from ibllib.io.extractors.training_wheel import Wheel 

10 

11 

12_logger = logging.getLogger(__name__) 

13__all__ = ['TrainingTrials'] 

14 

15 

16class FeedbackType(BaseBpodTrialsExtractor): 

17 """ 

18 Get the feedback that was delivered to subject. 

19 **Optional:** saves _ibl_trials.feedbackType.npy 

20 

21 Checks in raw datafile for error and reward state. 

22 Will raise an error if more than one of the mutually exclusive states have 

23 been triggered. 

24 

25 Sets feedbackType to -1 if error state was triggered (applies to no-go) 

26 Sets feedbackType to +1 if reward state was triggered 

27 """ 

28 save_names = '_ibl_trials.feedbackType.npy' 

29 var_names = 'feedbackType' 

30 

31 def _extract(self): 

32 feedbackType = np.zeros(len(self.bpod_trials), np.int64) 1aucigdepkljmnhbfo

33 for i, t in enumerate(self.bpod_trials): 1aucigdepkljmnhbfo

34 state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go'] 1aucigdepkljmnhbfo

35 outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.nan]])[0][0]) for sn in state_names} 1aucigdepkljmnhbfo

36 assert np.sum(list(outcome.values())) == 1 1aucigdepkljmnhbfo

37 outcome = next(k for k in outcome if outcome[k]) 1aucigdepkljmnhbfo

38 if outcome == 'correct': 1aucigdepkljmnhbfo

39 feedbackType[i] = 1 1aucigdepkljmnhbf

40 elif outcome in ['error', 'no_go']: 1aucigdepkljmnhbfo

41 feedbackType[i] = -1 1aucigdepkljmnhbfo

42 return feedbackType 1aucigdepkljmnhbfo

43 

44 

45class ContrastLR(BaseBpodTrialsExtractor): 

46 """ 

47 Get left and right contrasts from raw datafile. Optionally, saves 

48 _ibl_trials.contrastLeft.npy and _ibl_trials.contrastRight.npy to alf folder. 

49 

50 Uses signed_contrast to create left and right contrast vectors. 

51 """ 

52 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy') 

53 var_names = ('contrastLeft', 'contrastRight') 

54 

55 def _extract(self): 

56 # iblrigv8 has only flat values in the trial table so we can switch to parquet table when times come 

57 # and all the clutter here would fit in ~30 lines 

58 if isinstance(self.bpod_trials[0]['contrast'], float): 1Dcdeb

59 contrastLeft = np.array([t['contrast'] if np.sign( 

60 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

61 contrastRight = np.array([t['contrast'] if np.sign( 

62 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

63 else: 

64 contrastLeft = np.array([t['contrast']['value'] if np.sign( 1Dcdeb

65 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

66 contrastRight = np.array([t['contrast']['value'] if np.sign( 1Dcdeb

67 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

68 

69 return contrastLeft, contrastRight 1Dcdeb

70 

71 

72class ProbabilityLeft(BaseBpodTrialsExtractor): 

73 save_names = '_ibl_trials.probabilityLeft.npy' 

74 var_names = 'probabilityLeft' 

75 

76 def _extract(self, **kwargs): 

77 return np.array([t['stim_probability_left'] for t in self.bpod_trials]) 1Kcidepkhb

78 

79 

80class Choice(BaseBpodTrialsExtractor): 

81 """ 

82 Get the subject's choice in every trial. 

83 **Optional:** saves _ibl_trials.choice.npy to alf folder. 

84 

85 Uses signed_contrast and trial_correct. 

86 -1 is a CCW turn (towards the left) 

87 +1 is a CW turn (towards the right) 

88 0 is a no_go trial 

89 If a trial is correct the choice of the animal was the inverse of the sign 

90 of the position. 

91 

92 >>> choice[t] = -np.sign(position[t]) if trial_correct[t] 

93 """ 

94 save_names = '_ibl_trials.choice.npy' 

95 var_names = 'choice' 

96 

97 def _extract(self): 

98 sitm_side = np.array([np.sign(t['position']) for t in self.bpod_trials]) 1azcqigdepkljmnhbfo

99 trial_correct = np.array([t['trial_correct'] for t in self.bpod_trials]) 1azcqigdepkljmnhbfo

100 trial_nogo = np.array( 1azcqigdepkljmnhbfo

101 [~np.isnan(t['behavior_data']['States timestamps']['no_go'][0][0]) 

102 for t in self.bpod_trials]) 

103 choice = sitm_side.copy() 1azcqigdepkljmnhbfo

104 choice[trial_correct] = -choice[trial_correct] 1azcqigdepkljmnhbfo

105 choice[trial_nogo] = 0 1azcqigdepkljmnhbfo

106 choice = choice.astype(int) 1azcqigdepkljmnhbfo

107 return choice 1azcqigdepkljmnhbfo

108 

109 

110class RepNum(BaseBpodTrialsExtractor): 

111 """ 

112 Count the consecutive repeated trials. 

113 **Optional:** saves _ibl_trials.repNum.npy to alf folder. 

114 

115 Creates trial_repeated from trial['contrast']['type'] == 'RepeatContrast' 

116 

117 >>> trial_repeated = [0, 1, 1, 0, 1, 0, 1, 1, 1, 0] 

118 >>> repNum = [0, 1, 2, 0, 1, 0, 1, 2, 3, 0] 

119 """ 

120 save_names = '_ibl_trials.repNum.npy' 

121 var_names = 'repNum' 

122 

123 def _extract(self): 

124 def get_trial_repeat(trial): 1Acdeb

125 if 'debias_trial' in trial: 1Acdeb

126 return trial['debias_trial'] 

127 elif 'contrast' in trial and isinstance(trial['contrast'], dict): 1Acdeb

128 return trial['contrast']['type'] == 'RepeatContrast' 1Acdeb

129 else: 

130 # For advanced choice world and its subclasses before version 8.19.0 there was no 'debias_trial' field 

131 # and no debiasing protocol applied, so simply return False 

132 assert (self.settings['PYBPOD_PROTOCOL'].startswith('_iblrig_tasks_advancedChoiceWorld') or 

133 self.settings['PYBPOD_PROTOCOL'].startswith('ccu_neuromodulatorChoiceWorld')) 

134 return False 

135 

136 trial_repeated = np.fromiter(map(get_trial_repeat, self.bpod_trials), int) 1Acdeb

137 repNum = np.fromiter(accumulate(trial_repeated, lambda x, y: x + y if y else 0), int) 1Acdeb

138 return repNum 1Acdeb

139 

140 

141class RewardVolume(BaseBpodTrialsExtractor): 

142 """ 

143 Load reward volume delivered for each trial. 

144 **Optional:** saves _ibl_trials.rewardVolume.npy 

145 

146 Uses reward_current to accumulate the amount of 

147 """ 

148 save_names = '_ibl_trials.rewardVolume.npy' 

149 var_names = 'rewardVolume' 

150 

151 def _extract(self): 

152 trial_volume = [x['reward_amount'] 1aEcigdepkljmnhbfo

153 if x['trial_correct'] else 0 for x in self.bpod_trials] 

154 reward_volume = np.array(trial_volume).astype(np.float64) 1aEcigdepkljmnhbfo

155 assert len(reward_volume) == len(self.bpod_trials) 1aEcigdepkljmnhbfo

156 return reward_volume 1aEcigdepkljmnhbfo

157 

158 

159class FeedbackTimes(BaseBpodTrialsExtractor): 

160 """ 

161 Get the times the water or error tone was delivered to the animal. 

162 **Optional:** saves _ibl_trials.feedback_times.npy 

163 

164 Gets reward and error state init times vectors, 

165 checks if the intersection of nans is empty, then 

166 merges the 2 vectors. 

167 """ 

168 save_names = '_ibl_trials.feedback_times.npy' 

169 var_names = 'feedback_times' 

170 

171 @staticmethod 

172 def get_feedback_times_lt5(session_path, task_collection='raw_behavior_data', data=False): 

173 if not data: 1axcpb

174 data = raw.load_data(session_path, task_collection=task_collection) 

175 rw_times = [tr['behavior_data']['States timestamps']['reward'][0][0] 1axcpb

176 for tr in data] 

177 err_times = [tr['behavior_data']['States timestamps']['error'][0][0] 1axcpb

178 for tr in data] 

179 nogo_times = [tr['behavior_data']['States timestamps']['no_go'][0][0] 1axcpb

180 for tr in data] 

181 assert sum(np.isnan(rw_times) & 1axcpb

182 np.isnan(err_times) & np.isnan(nogo_times)) == 0 

183 merge = np.array([np.array(times)[~np.isnan(times)] for times in 1axcpb

184 zip(rw_times, err_times, nogo_times)]).squeeze() 

185 

186 return np.array(merge) 1axcpb

187 

188 @staticmethod 

189 def get_feedback_times_ge5(session_path, task_collection='raw_behavior_data', data=False): 

190 # ger err and no go trig times -- look for BNC2High of trial -- verify 

191 # only 2 onset times go tone and noise, select 2nd/-1 OR select the one 

192 # that is grater than the nogo or err trial onset time 

193 if not data: 1scirgdekljmnhbfo

194 data = raw.load_data(session_path, task_collection=task_collection) 

195 missed_bnc2 = 0 1scirgdekljmnhbfo

196 rw_times, err_sound_times, merge = [np.zeros([len(data), ]) for _ in range(3)] 1scirgdekljmnhbfo

197 

198 for ind, tr in enumerate(data): 1scirgdekljmnhbfo

199 st = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1scirgdekljmnhbfo

200 if not st: 1scirgdekljmnhbfo

201 st = np.array([np.nan, np.nan]) 1bf

202 missed_bnc2 += 1 1bf

203 # xonar soundcard duplicates events, remove consecutive events too close together 

204 st = np.delete(st, np.where(np.diff(st) < 0.020)[0] + 1) 1scirgdekljmnhbfo

205 rw_times[ind] = tr['behavior_data']['States timestamps']['reward'][0][0] 1scirgdekljmnhbfo

206 # get the error sound only if the reward is nan 

207 err_sound_times[ind] = st[-1] if st.size >= 2 and np.isnan(rw_times[ind]) else np.nan 1scirgdekljmnhbfo

208 if missed_bnc2 == len(data): 1scirgdekljmnhbfo

209 _logger.warning('No BNC2 for feedback times, filling error trials NaNs') 

210 merge *= np.nan 1scirgdekljmnhbfo

211 merge[~np.isnan(rw_times)] = rw_times[~np.isnan(rw_times)] 1scirgdekljmnhbfo

212 merge[~np.isnan(err_sound_times)] = err_sound_times[~np.isnan(err_sound_times)] 1scirgdekljmnhbfo

213 

214 return merge 1scirgdekljmnhbfo

215 

216 def _extract(self): 

217 # Version check 

218 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1asxcirgdepkljmnhbfo

219 merge = self.get_feedback_times_ge5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1scirgdekljmnhbfo

220 else: 

221 merge = self.get_feedback_times_lt5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1axcpb

222 return np.array(merge) 1asxcirgdepkljmnhbfo

223 

224 

225class Intervals(BaseBpodTrialsExtractor): 

226 """ 

227 Trial start to trial end. Trial end includes 1 or 2 seconds after feedback, 

228 (depending on the feedback) and 0.5 seconds of iti. 

229 **Optional:** saves _ibl_trials.intervals.npy 

230 

231 Uses the corrected Trial start and Trial end timestamp values form PyBpod. 

232 """ 

233 save_names = '_ibl_trials.intervals.npy' 

234 var_names = 'intervals' 

235 

236 def _extract(self): 

237 starts = [t['behavior_data']['Trial start timestamp'] for t in self.bpod_trials] 1aFcigdepkljmnhbfo

238 ends = [t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials] 1aFcigdepkljmnhbfo

239 return np.array([starts, ends]).T 1aFcigdepkljmnhbfo

240 

241 

242class ResponseTimes(BaseBpodTrialsExtractor): 

243 """ 

244 Time (in absolute seconds from session start) when a response was recorded. 

245 **Optional:** saves _ibl_trials.response_times.npy 

246 

247 Uses the timestamp of the end of the closed_loop state. 

248 """ 

249 save_names = '_ibl_trials.response_times.npy' 

250 var_names = 'response_times' 

251 

252 def _extract(self): 

253 rt = np.array([tr['behavior_data']['States timestamps']['closed_loop'][0][1] 1aIcigdepkljmnhbfo

254 for tr in self.bpod_trials]) 

255 return rt 1aIcigdepkljmnhbfo

256 

257 

258class ItiDuration(BaseBpodTrialsExtractor): 

259 """ 

260 Calculate duration of iti from state timestamps. 

261 **Optional:** saves _ibl_trials.iti_duration.npy 

262 

263 Uses Trial end timestamp and get_response_times to calculate iti. 

264 """ 

265 save_names = '_ibl_trials.itiDuration.npy' 

266 var_names = 'iti_dur' 

267 

268 def _extract(self): 

269 rt, _ = ResponseTimes(self.session_path).extract( 

270 save=False, task_collection=self.task_collection, bpod_trials=self.bpod_trials, settings=self.settings) 

271 ends = np.array([t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials]) 

272 iti_dur = ends - rt 

273 return iti_dur 

274 

275 

276class GoCueTriggerTimes(BaseBpodTrialsExtractor): 

277 """ 

278 Get trigger times of goCue from state machine. 

279 

280 Current software solution for triggering sounds uses PyBpod soft codes. 

281 Delays can be in the order of 10's of ms. This is the time when the command 

282 to play the sound was executed. To measure accurate time, either getting the 

283 sound onset from xonar soundcard sync pulse (latencies may vary). 

284 """ 

285 save_names = '_ibl_trials.goCueTrigger_times.npy' 

286 var_names = 'goCueTrigger_times' 

287 

288 def _extract(self): 

289 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1aGcigdepkljmnhbfo

290 goCue = np.array([tr['behavior_data']['States timestamps'] 1Gcigdekljmnhbfo

291 ['play_tone'][0][0] for tr in self.bpod_trials]) 

292 else: 

293 goCue = np.array([tr['behavior_data']['States timestamps'] 1acpb

294 ['closed_loop'][0][0] for tr in self.bpod_trials]) 

295 return goCue 1aGcigdepkljmnhbfo

296 

297 

298class TrialType(BaseBpodTrialsExtractor): 

299 save_names = '_ibl_trials.type.npy' 

300 var_name = 'trial_type' 

301 

302 def _extract(self): 

303 trial_type = [] 

304 for tr in self.bpod_trials: 

305 if ~np.isnan(tr["behavior_data"]["States timestamps"]["reward"][0][0]): 

306 trial_type.append(1) 

307 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["error"][0][0]): 

308 trial_type.append(-1) 

309 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0][0]): 

310 trial_type.append(0) 

311 else: 

312 _logger.warning("Trial is not in set {-1, 0, 1}, appending NaN to trialType") 

313 trial_type.append(np.nan) 

314 return np.array(trial_type) 

315 

316 

317class GoCueTimes(BaseBpodTrialsExtractor): 

318 """ 

319 Get trigger times of goCue from state machine. 

320 

321 Current software solution for triggering sounds uses PyBpod soft codes. 

322 Delays can be in the order of 10-100s of ms. This is the time when the command 

323 to play the sound was executed. To measure accurate time, either getting the 

324 sound onset from the future microphone OR the new xonar soundcard and 

325 setup developed by Sanworks guarantees a set latency (in testing). 

326 """ 

327 save_names = '_ibl_trials.goCue_times.npy' 

328 var_names = 'goCue_times' 

329 

330 def _extract(self): 

331 go_cue_times = np.zeros([len(self.bpod_trials), ]) 1atcvirgdepkljmnhwbfo

332 for ind, tr in enumerate(self.bpod_trials): 1atcvirgdepkljmnhwbfo

333 if raw.get_port_events(tr, 'BNC2'): 1atcvirgdepkljmnhwbfo

334 bnchigh = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1tcvirgdepkljmnhwbfo

335 if bnchigh: 1tcvirgdepkljmnhwbfo

336 go_cue_times[ind] = bnchigh[0] 1tcvirgdepkljmnhwbfo

337 continue 1tcvirgdepkljmnhwbfo

338 bnclow = tr['behavior_data']['Events timestamps'].get('BNC2Low', None) 1bf

339 if bnclow: 1bf

340 go_cue_times[ind] = bnclow[0] - 0.1 1bf

341 continue 1bf

342 go_cue_times[ind] = np.nan 

343 else: 

344 go_cue_times[ind] = np.nan 1atcb

345 

346 nmissing = np.sum(np.isnan(go_cue_times)) 1atcvirgdepkljmnhwbfo

347 # Check if all stim_syncs have failed to be detected 

348 if np.all(np.isnan(go_cue_times)): 1atcvirgdepkljmnhwbfo

349 _logger.warning( 1atcb

350 f'{self.session_path}: Missing ALL !! BNC2 TTLs ({nmissing} trials)') 

351 # Check if any stim_sync has failed be detected for every trial 

352 elif np.any(np.isnan(go_cue_times)): 1tcvirgdepkljmnhwbfo

353 _logger.warning(f'{self.session_path}: Missing BNC2 TTLs on {nmissing} trials') 

354 

355 return go_cue_times 1atcvirgdepkljmnhwbfo

356 

357 

358class IncludedTrials(BaseBpodTrialsExtractor): 

359 save_names = '_ibl_trials.included.npy' 

360 var_names = 'included' 

361 

362 def _extract(self): 

363 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1ayBCcigpkljmnhbfo

364 trials_included = self.get_included_trials_ge5( 1yBcigkljmnhbfo

365 data=self.bpod_trials, settings=self.settings) 

366 else: 

367 trials_included = self.get_included_trials_lt5(data=self.bpod_trials) 1ayCcpb

368 return trials_included 1ayBCcigpkljmnhbfo

369 

370 @staticmethod 

371 def get_included_trials_lt5(data=False): 

372 trials_included = np.ones(len(data), dtype=bool) 1ayCcpb

373 return trials_included 1ayCcpb

374 

375 @staticmethod 

376 def get_included_trials_ge5(data=False, settings=False): 

377 trials_included = np.array([True for t in data]) 1yBcigkljmnhbfo

378 if ('SUBJECT_DISENGAGED_TRIGGERED' in settings.keys() and settings[ 1yBcigkljmnhbfo

379 'SUBJECT_DISENGAGED_TRIGGERED'] is not False): 

380 idx = settings['SUBJECT_DISENGAGED_TRIALNUM'] - 1 1igljmnbf

381 trials_included[idx:] = False 1igljmnbf

382 return trials_included 1yBcigkljmnhbfo

383 

384 

385class ItiInTimes(BaseBpodTrialsExtractor): 

386 var_names = 'itiIn_times' 

387 

388 def _extract(self): 

389 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("5.0.0"): 1acigdepkljmnhbfo

390 iti_in = np.ones(len(self.bpod_trials)) * np.nan 1acpb

391 else: 

392 iti_in = np.array( 1cigdekljmnhbfo

393 [tr["behavior_data"]["States timestamps"] 

394 ["exit_state"][0][0] for tr in self.bpod_trials] 

395 ) 

396 return iti_in 1acigdepkljmnhbfo

397 

398 

399class ErrorCueTriggerTimes(BaseBpodTrialsExtractor): 

400 var_names = 'errorCueTrigger_times' 

401 

402 def _extract(self): 

403 errorCueTrigger_times = np.zeros(len(self.bpod_trials)) * np.nan 1acigdepkljmnhbfo

404 for i, tr in enumerate(self.bpod_trials): 1acigdepkljmnhbfo

405 nogo = tr["behavior_data"]["States timestamps"]["no_go"][0][0] 1acigdepkljmnhbfo

406 error = tr["behavior_data"]["States timestamps"]["error"][0][0] 1acigdepkljmnhbfo

407 if np.all(~np.isnan(nogo)): 1acigdepkljmnhbfo

408 errorCueTrigger_times[i] = nogo 1cgdelmnhbf

409 elif np.all(~np.isnan(error)): 1acigdepkljmnhbfo

410 errorCueTrigger_times[i] = error 1acigdepkljmnhbfo

411 return errorCueTrigger_times 1acigdepkljmnhbfo

412 

413 

414class StimFreezeTriggerTimes(BaseBpodTrialsExtractor): 

415 var_names = 'stimFreezeTrigger_times' 

416 

417 def _extract(self): 

418 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("6.2.5"): 1acqigdepkljmnhbfo

419 return np.ones(len(self.bpod_trials)) * np.nan 1acqplmnbfo

420 freeze_reward = np.array( 1igdekjhf

421 [ 

422 True 

423 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_reward"][0])) 

424 else False 

425 for tr in self.bpod_trials 

426 ] 

427 ) 

428 freeze_error = np.array( 1igdekjhf

429 [ 

430 True 

431 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_error"][0])) 

432 else False 

433 for tr in self.bpod_trials 

434 ] 

435 ) 

436 no_go = np.array( 1igdekjhf

437 [ 

438 True 

439 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0])) 

440 else False 

441 for tr in self.bpod_trials 

442 ] 

443 ) 

444 assert (np.sum(freeze_error) + np.sum(freeze_reward) + 1igdekjhf

445 np.sum(no_go) == len(self.bpod_trials)) 

446 stimFreezeTrigger = np.array([]) 1igdekjhf

447 for r, e, n, tr in zip(freeze_reward, freeze_error, no_go, self.bpod_trials): 1igdekjhf

448 if n: 1igdekjhf

449 stimFreezeTrigger = np.append(stimFreezeTrigger, np.nan) 1gdehf

450 continue 1gdehf

451 state = "freeze_reward" if r else "freeze_error" 1igdekjhf

452 stimFreezeTrigger = np.append( 1igdekjhf

453 stimFreezeTrigger, tr["behavior_data"]["States timestamps"][state][0][0] 

454 ) 

455 return stimFreezeTrigger 1igdekjhf

456 

457 

458class StimOffTriggerTimes(BaseBpodTrialsExtractor): 

459 var_names = 'stimOffTrigger_times' 

460 save_names = '_ibl_trials.stimOnTrigger_times.npy' 

461 

462 def _extract(self): 

463 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') >= version.parse("6.2.5"): 1acqigdepkljmnhbfo

464 stim_off_trigger_state = "hide_stim" 1igdekjhf

465 elif version.parse(self.settings["IBLRIG_VERSION"]) >= version.parse("5.0.0"): 1acqplmnbfo

466 stim_off_trigger_state = "exit_state" 1cqlmnbfo

467 else: 

468 stim_off_trigger_state = "trial_start" 1acqpb

469 

470 stimOffTrigger_times = np.array( 1acqigdepkljmnhbfo

471 [tr["behavior_data"]["States timestamps"][stim_off_trigger_state][0][0] 

472 for tr in self.bpod_trials] 

473 ) 

474 # If pre version 5.0.0 no specific nogo Off trigger was given, just return trial_starts 

475 if stim_off_trigger_state == "trial_start": 1acqigdepkljmnhbfo

476 return stimOffTrigger_times 1acqpb

477 

478 no_goTrigger_times = np.array( 1cqigdekljmnhbfo

479 [tr["behavior_data"]["States timestamps"]["no_go"][0][0] for tr in self.bpod_trials] 

480 ) 

481 # Stim off trigs are either in their own state or in the no_go state if the 

482 # mouse did not move, if the stim_off_trigger_state always exist 

483 # (exit_state or trial_start) 

484 # no NaNs will happen, NaNs might happen in at last trial if 

485 # session was stopped after response 

486 # if stim_off_trigger_state == "hide_stim": 

487 # assert all(~np.isnan(no_goTrigger_times) == np.isnan(stimOffTrigger_times)) 

488 # Patch with the no_go states trig times 

489 stimOffTrigger_times[~np.isnan(no_goTrigger_times)] = no_goTrigger_times[ 1cqigdekljmnhbfo

490 ~np.isnan(no_goTrigger_times) 

491 ] 

492 return stimOffTrigger_times 1cqigdekljmnhbfo

493 

494 

495class StimOnTriggerTimes(BaseBpodTrialsExtractor): 

496 save_names = '_ibl_trials.stimOnTrigger_times.npy' 

497 var_names = 'stimOnTrigger_times' 

498 

499 def _extract(self): 

500 # Get the stim_on_state that triggers the onset of the stim 

501 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1aJcqirgdepkljmnhbfo

502 ['stim_on'][0] for tr in self.bpod_trials]) 

503 return stim_on_state[:, 0].T 1aJcqirgdepkljmnhbfo

504 

505 

506class StimOnTimes_deprecated(BaseBpodTrialsExtractor): 

507 save_names = '_ibl_trials.stimOn_times.npy' 

508 var_names = 'stimOn_times' 

509 

510 def _extract(self): 

511 """ 

512 Find the time of the state machine command to turn on the stim 

513 (state stim_on start or rotary_encoder_event2) 

514 Find the next frame change from the photodiode after that TS. 

515 Screen is not displaying anything until then. 

516 (Frame changes are in BNC1 High and BNC1 Low) 

517 """ 

518 # Version check 

519 _logger.warning("Deprecation Warning: this is an old version of stimOn extraction." 

520 "From version 5., use StimOnOffFreezeTimes") 

521 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 

522 stimOn_times = self.get_stimOn_times_ge5(self.session_path, data=self.bpod_trials, 

523 task_collection=self.task_collection) 

524 else: 

525 stimOn_times = self.get_stimOn_times_lt5(self.session_path, data=self.bpod_trials, 

526 task_collection=self.task_collection) 

527 return np.array(stimOn_times) 

528 

529 @staticmethod 

530 def get_stimOn_times_ge5(session_path, data=False, task_collection='raw_behavior_data'): 

531 """ 

532 Find first and last stim_sync pulse of the trial. 

533 stimOn_times should be the first after the stim_on state. 

534 (Stim updates are in BNC1High and BNC1Low - frame2TTL device) 

535 Check that all trials have frame changes. 

536 Find length of stim_on_state [start, stop]. 

537 If either check fails the HW device failed to detect the stim_sync square change 

538 Substitute that trial's missing or incorrect value with a NaN. 

539 return stimOn_times 

540 """ 

541 if not data: 

542 data = raw.load_data(session_path, task_collection=task_collection) 

543 # Get all stim_sync events detected 

544 stim_sync_all = [raw.get_port_events(tr, 'BNC1') for tr in data] 

545 stim_sync_all = [np.array(x) for x in stim_sync_all] 

546 # Get the stim_on_state that triggers the onset of the stim 

547 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 

548 ['stim_on'][0] for tr in data]) 

549 

550 stimOn_times = np.array([]) 

551 for sync, on, off in zip( 

552 stim_sync_all, stim_on_state[:, 0], stim_on_state[:, 1]): 

553 pulse = sync[np.where(np.bitwise_and((sync > on), (sync <= off)))] 

554 if pulse.size == 0: 

555 stimOn_times = np.append(stimOn_times, np.nan) 

556 else: 

557 stimOn_times = np.append(stimOn_times, pulse) 

558 

559 nmissing = np.sum(np.isnan(stimOn_times)) 

560 # Check if all stim_syncs have failed to be detected 

561 if np.all(np.isnan(stimOn_times)): 

562 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({nmissing} trials)') 

563 

564 # Check if any stim_sync has failed be detected for every trial 

565 if np.any(np.isnan(stimOn_times)): 

566 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {nmissing} trials') 

567 

568 return stimOn_times 

569 

570 @staticmethod 

571 def get_stimOn_times_lt5(session_path, data=False, task_collection='raw_behavior_data'): 

572 """ 

573 Find the time of the statemachine command to turn on the stim 

574 (state stim_on start or rotary_encoder_event2) 

575 Find the next frame change from the photodiode after that TS. 

576 Screen is not displaying anything until then. 

577 (Frame changes are in BNC1High and BNC1Low) 

578 """ 

579 if not data: 

580 data = raw.load_data(session_path, task_collection=task_collection) 

581 stim_on = [] 

582 bnc_h = [] 

583 bnc_l = [] 

584 for tr in data: 

585 stim_on.append(tr['behavior_data']['States timestamps']['stim_on'][0][0]) 

586 if 'BNC1High' in tr['behavior_data']['Events timestamps'].keys(): 

587 bnc_h.append(np.array(tr['behavior_data'] 

588 ['Events timestamps']['BNC1High'])) 

589 else: 

590 bnc_h.append(np.array([np.NINF])) 

591 if 'BNC1Low' in tr['behavior_data']['Events timestamps'].keys(): 

592 bnc_l.append(np.array(tr['behavior_data'] 

593 ['Events timestamps']['BNC1Low'])) 

594 else: 

595 bnc_l.append(np.array([np.NINF])) 

596 

597 stim_on = np.array(stim_on) 

598 bnc_h = np.array(bnc_h, dtype=object) 

599 bnc_l = np.array(bnc_l, dtype=object) 

600 

601 count_missing = 0 

602 stimOn_times = np.zeros_like(stim_on) 

603 for i in range(len(stim_on)): 

604 hl = np.sort(np.concatenate([bnc_h[i], bnc_l[i]])) 

605 stot = hl[hl > stim_on[i]] 

606 if np.size(stot) == 0: 

607 stot = np.array([np.nan]) 

608 count_missing += 1 

609 stimOn_times[i] = stot[0] 

610 

611 if np.all(np.isnan(stimOn_times)): 

612 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({count_missing} trials)') 

613 

614 if count_missing > 0: 

615 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {count_missing} trials') 

616 

617 return np.array(stimOn_times) 

618 

619 

620class StimOnOffFreezeTimes(BaseBpodTrialsExtractor): 

621 """ 

622 Extracts stim on / off and freeze times from Bpod BNC1 detected fronts. 

623 

624 Each stimulus event is the first detected front of the BNC1 signal after the trigger state, but before the next 

625 trigger state. 

626 """ 

627 save_names = ('_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOff_times.npy', None) 

628 var_names = ('stimOn_times', 'stimOff_times', 'stimFreeze_times') 

629 

630 def _extract(self): 

631 choice = Choice(self.session_path).extract( 1acqigdepkljmnhbfo

632 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

633 )[0] 

634 stimOnTrigger = StimOnTriggerTimes(self.session_path).extract( 1acqigdepkljmnhbfo

635 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

636 )[0] 

637 stimFreezeTrigger = StimFreezeTriggerTimes(self.session_path).extract( 1acqigdepkljmnhbfo

638 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

639 )[0] 

640 stimOffTrigger = StimOffTriggerTimes(self.session_path).extract( 1acqigdepkljmnhbfo

641 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

642 )[0] 

643 f2TTL = [raw.get_port_events(tr, name='BNC1') for tr in self.bpod_trials] 1acqigdepkljmnhbfo

644 assert stimOnTrigger.size == stimFreezeTrigger.size == stimOffTrigger.size == choice.size == len(f2TTL) 1acqigdepkljmnhbfo

645 assert all(stimOnTrigger < np.nan_to_num(stimFreezeTrigger, nan=np.inf)) and \ 1acqigdepkljmnhbfo

646 all(np.nan_to_num(stimFreezeTrigger, nan=-np.inf) < stimOffTrigger) 

647 

648 stimOn_times = np.array([]) 1acqigdepkljmnhbfo

649 stimOff_times = np.array([]) 1acqigdepkljmnhbfo

650 stimFreeze_times = np.array([]) 1acqigdepkljmnhbfo

651 has_freeze = version.parse(self.settings.get('IBLRIG_VERSION', '0')) >= version.parse('6.2.5') 1acqigdepkljmnhbfo

652 for tr, on, freeze, off, c in zip(f2TTL, stimOnTrigger, stimFreezeTrigger, stimOffTrigger, choice): 1acqigdepkljmnhbfo

653 tr = np.array(tr) 1acqigdepkljmnhbfo

654 # stim on 

655 lim = freeze if has_freeze else off 1acqigdepkljmnhbfo

656 idx, = np.where(np.logical_and(on < tr, tr < lim)) 1acqigdepkljmnhbfo

657 stimOn_times = np.append(stimOn_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acqigdepkljmnhbfo

658 # stim off 

659 idx, = np.where(off < tr) 1acqigdepkljmnhbfo

660 stimOff_times = np.append(stimOff_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acqigdepkljmnhbfo

661 # stim freeze - take last event before off trigger 

662 if has_freeze: 1acqigdepkljmnhbfo

663 idx, = np.where(np.logical_and(freeze < tr, tr < off)) 1igdekjhf

664 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1igdekjhf

665 else: 

666 idx, = np.where(tr <= off) 1acqplmnbfo

667 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1acqplmnbfo

668 # In no_go trials no stimFreeze happens just stim Off 

669 stimFreeze_times[choice == 0] = np.nan 1acqigdepkljmnhbfo

670 

671 return stimOn_times, stimOff_times, stimFreeze_times 1acqigdepkljmnhbfo

672 

673 

674class PhasePosQuiescence(BaseBpodTrialsExtractor): 

675 """Extract stimulus phase, position and quiescence from Bpod data. 

676 

677 For extraction of pre-generated events, use the ProbaContrasts extractor instead. 

678 """ 

679 save_names = (None, None, '_ibl_trials.quiescencePeriod.npy') 

680 var_names = ('phase', 'position', 'quiescence') 

681 

682 def _extract(self, **kwargs): 

683 phase = np.array([t['stim_phase'] for t in self.bpod_trials]) 1acigdepkljmnhbfo

684 position = np.array([t['position'] for t in self.bpod_trials]) 1acigdepkljmnhbfo

685 quiescence = np.array([t['quiescent_period'] for t in self.bpod_trials]) 1acigdepkljmnhbfo

686 return phase, position, quiescence 1acigdepkljmnhbfo

687 

688 

689class PauseDuration(BaseBpodTrialsExtractor): 

690 """Extract pause duration from raw trial data.""" 

691 save_names = None 

692 var_names = 'pause_duration' 

693 

694 def _extract(self, **kwargs): 

695 # pausing logic added in version 8.9.0 

696 ver = version.parse(self.settings.get('IBLRIG_VERSION') or '0') 1Hcdeb

697 default = 0. if ver < version.parse('8.9.0') else np.nan 1Hcdeb

698 return np.fromiter((t.get('pause_duration', default) for t in self.bpod_trials), dtype=float) 1Hcdeb

699 

700 

701class TrialsTable(BaseBpodTrialsExtractor): 

702 """ 

703 Extracts the following into a table from Bpod raw data: 

704 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

705 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

706 Additionally extracts the following wheel data: 

707 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude 

708 """ 

709 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

710 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) 

711 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 

712 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement') 

713 

714 def _extract(self, extractor_classes=None, **kwargs): 

715 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1cdeb

716 RewardVolume, ProbabilityLeft, Wheel] 

717 out, _ = run_extractor_classes( 1cdeb

718 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

719 task_collection=self.task_collection) 

720 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1cdeb

721 assert len(table.keys()) == 12 1cdeb

722 

723 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1cdeb

724 

725 

726class TrainingTrials(BaseBpodTrialsExtractor): 

727 save_names = ('_ibl_trials.repNum.npy', '_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, 

728 '_ibl_trials.stimOffTrigger_times.npy', None, None, '_ibl_trials.table.pqt', '_ibl_trials.stimOff_times.npy', 

729 None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

730 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, 

731 '_ibl_trials.quiescencePeriod.npy', None) 

732 var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 

733 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 

734 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 

735 'wheelMoves_peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence', 'pause_duration') 

736 

737 def _extract(self) -> dict: 

738 base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1cdeb

739 ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence, PauseDuration] 

740 out, _ = run_extractor_classes( 1cdeb

741 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

742 task_collection=self.task_collection) 

743 return {k: out[k] for k in self.var_names} 1cdeb