Coverage for ibllib/io/extractors/training_trials.py: 80%

375 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-02 18:55 +0100

1import logging 

2import numpy as np 

3from itertools import accumulate 

4from packaging import version 

5from one.alf.io import AlfBunch 

6 

7import ibllib.io.raw_data_loaders as raw 

8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

9from ibllib.io.extractors.training_wheel import Wheel 

10 

11 

12_logger = logging.getLogger(__name__) 

13__all__ = ['TrainingTrials'] 

14 

15 

16class FeedbackType(BaseBpodTrialsExtractor): 

17 """ 

18 Get the feedback that was delivered to subject. 

19 **Optional:** saves _ibl_trials.feedbackType.npy 

20 

21 Checks in raw datafile for error and reward state. 

22 Will raise an error if more than one of the mutually exclusive states have 

23 been triggered. 

24 

25 Sets feedbackType to -1 if error state was triggered (applies to no-go) 

26 Sets feedbackType to +1 if reward state was triggered 

27 """ 

28 save_names = '_ibl_trials.feedbackType.npy' 

29 var_names = 'feedbackType' 

30 

31 def _extract(self): 

32 feedbackType = np.zeros(len(self.bpod_trials), np.int64) 1awcjhdeqlmknoifbgp

33 for i, t in enumerate(self.bpod_trials): 1awcjhdeqlmknoifbgp

34 state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go'] 1awcjhdeqlmknoifbgp

35 outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.nan]])[0][0]) for sn in state_names} 1awcjhdeqlmknoifbgp

36 assert np.sum(list(outcome.values())) == 1 1awcjhdeqlmknoifbgp

37 outcome = next(k for k in outcome if outcome[k]) 1awcjhdeqlmknoifbgp

38 if outcome == 'correct': 1awcjhdeqlmknoifbgp

39 feedbackType[i] = 1 1awcjhdeqlmknoifbg

40 elif outcome in ['error', 'no_go']: 1awcjhdeqlmknoifbgp

41 feedbackType[i] = -1 1awcjhdeqlmknoifbgp

42 return feedbackType 1awcjhdeqlmknoifbgp

43 

44 

45class ContrastLR(BaseBpodTrialsExtractor): 

46 """ 

47 Get left and right contrasts from raw datafile. Optionally, saves 

48 _ibl_trials.contrastLeft.npy and _ibl_trials.contrastRight.npy to alf folder. 

49 

50 Uses signed_contrast to create left and right contrast vectors. 

51 """ 

52 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy') 

53 var_names = ('contrastLeft', 'contrastRight') 

54 

55 def _extract(self): 

56 # iblrigv8 has only flat values in the trial table so we can switch to parquet table when times come 

57 # and all the clutter here would fit in ~30 lines 

58 if isinstance(self.bpod_trials[0]['contrast'], float): 1Fcdefb

59 contrastLeft = np.array([t['contrast'] if np.sign( 1f

60 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

61 contrastRight = np.array([t['contrast'] if np.sign( 1f

62 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

63 else: 

64 contrastLeft = np.array([t['contrast']['value'] if np.sign( 1Fcdeb

65 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

66 contrastRight = np.array([t['contrast']['value'] if np.sign( 1Fcdeb

67 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

68 

69 return contrastLeft, contrastRight 1Fcdefb

70 

71 

72class ProbabilityLeft(BaseBpodTrialsExtractor): 

73 save_names = '_ibl_trials.probabilityLeft.npy' 

74 var_names = 'probabilityLeft' 

75 

76 def _extract(self, **kwargs): 

77 return np.array([t['stim_probability_left'] for t in self.bpod_trials]) 1Mcjdeqlifb

78 

79 

80class Choice(BaseBpodTrialsExtractor): 

81 """ 

82 Get the subject's choice in every trial. 

83 **Optional:** saves _ibl_trials.choice.npy to alf folder. 

84 

85 Uses signed_contrast and trial_correct. 

86 -1 is a CCW turn (towards the left) 

87 +1 is a CW turn (towards the right) 

88 0 is a no_go trial 

89 If a trial is correct the choice of the animal was the inverse of the sign 

90 of the position. 

91 

92 >>> choice[t] = -np.sign(position[t]) if trial_correct[t] 

93 """ 

94 save_names = '_ibl_trials.choice.npy' 

95 var_names = 'choice' 

96 

97 def _extract(self): 

98 sitm_side = np.array([np.sign(t['position']) for t in self.bpod_trials]) 1aBcrjhdeqlmknoifbgp

99 trial_correct = np.array([t['trial_correct'] for t in self.bpod_trials]) 1aBcrjhdeqlmknoifbgp

100 trial_nogo = np.array( 1aBcrjhdeqlmknoifbgp

101 [~np.isnan(t['behavior_data']['States timestamps']['no_go'][0][0]) 

102 for t in self.bpod_trials]) 

103 choice = sitm_side.copy() 1aBcrjhdeqlmknoifbgp

104 choice[trial_correct] = -choice[trial_correct] 1aBcrjhdeqlmknoifbgp

105 choice[trial_nogo] = 0 1aBcrjhdeqlmknoifbgp

106 choice = choice.astype(int) 1aBcrjhdeqlmknoifbgp

107 return choice 1aBcrjhdeqlmknoifbgp

108 

109 

110class RepNum(BaseBpodTrialsExtractor): 

111 """ 

112 Count the consecutive repeated trials. 

113 **Optional:** saves _ibl_trials.repNum.npy to alf folder. 

114 

115 Creates trial_repeated from trial['contrast']['type'] == 'RepeatContrast' 

116 

117 >>> trial_repeated = [0, 1, 1, 0, 1, 0, 1, 1, 1, 0] 

118 >>> repNum = [0, 1, 2, 0, 1, 0, 1, 2, 3, 0] 

119 """ 

120 save_names = '_ibl_trials.repNum.npy' 

121 var_names = 'repNum' 

122 

123 def _extract(self): 

124 def get_trial_repeat(trial): 1Ccdefb

125 if 'debias_trial' in trial: 1Ccdefb

126 return trial['debias_trial'] 1f

127 elif 'contrast' in trial and isinstance(trial['contrast'], dict): 1Ccdeb

128 return trial['contrast']['type'] == 'RepeatContrast' 1Ccdeb

129 else: 

130 # For advanced choice world and its subclasses before version 8.19.0 there was no 'debias_trial' field 

131 # and no debiasing protocol applied, so simply return False 

132 assert (self.settings['PYBPOD_PROTOCOL'].startswith('_iblrig_tasks_advancedChoiceWorld') or 

133 self.settings['PYBPOD_PROTOCOL'].startswith('ccu_neuromodulatorChoiceWorld')) 

134 return False 

135 

136 trial_repeated = np.fromiter(map(get_trial_repeat, self.bpod_trials), int) 1Ccdefb

137 repNum = np.fromiter(accumulate(trial_repeated, lambda x, y: x + y if y else 0), int) 1Ccdefb

138 return repNum 1Ccdefb

139 

140 

141class RewardVolume(BaseBpodTrialsExtractor): 

142 """ 

143 Load reward volume delivered for each trial. 

144 **Optional:** saves _ibl_trials.rewardVolume.npy 

145 

146 Uses reward_current to accumulate the amount of 

147 """ 

148 save_names = '_ibl_trials.rewardVolume.npy' 

149 var_names = 'rewardVolume' 

150 

151 def _extract(self): 

152 trial_volume = [x['reward_amount'] 1aGcjhdeqlmknoifbgp

153 if x['trial_correct'] else 0 for x in self.bpod_trials] 

154 reward_volume = np.array(trial_volume).astype(np.float64) 1aGcjhdeqlmknoifbgp

155 assert len(reward_volume) == len(self.bpod_trials) 1aGcjhdeqlmknoifbgp

156 return reward_volume 1aGcjhdeqlmknoifbgp

157 

158 

159class FeedbackTimes(BaseBpodTrialsExtractor): 

160 """ 

161 Get the times the water or error tone was delivered to the animal. 

162 **Optional:** saves _ibl_trials.feedback_times.npy 

163 

164 Gets reward and error state init times vectors, 

165 checks if the intersection of nans is empty, then 

166 merges the 2 vectors. 

167 """ 

168 save_names = '_ibl_trials.feedback_times.npy' 

169 var_names = 'feedback_times' 

170 

171 @staticmethod 

172 def get_feedback_times_lt5(session_path, task_collection='raw_behavior_data', data=False): 

173 if not data: 1azcqb

174 data = raw.load_data(session_path, task_collection=task_collection) 

175 rw_times = [tr['behavior_data']['States timestamps']['reward'][0][0] 1azcqb

176 for tr in data] 

177 err_times = [tr['behavior_data']['States timestamps']['error'][0][0] 1azcqb

178 for tr in data] 

179 nogo_times = [tr['behavior_data']['States timestamps']['no_go'][0][0] 1azcqb

180 for tr in data] 

181 assert sum(np.isnan(rw_times) & 1azcqb

182 np.isnan(err_times) & np.isnan(nogo_times)) == 0 

183 merge = np.array([np.array(times)[~np.isnan(times)] for times in 1azcqb

184 zip(rw_times, err_times, nogo_times)]).squeeze() 

185 

186 return np.array(merge) 1azcqb

187 

188 @staticmethod 

189 def get_feedback_times_ge5(session_path, task_collection='raw_behavior_data', data=False): 

190 # ger err and no go trig times -- look for BNC2High of trial -- verify 

191 # only 2 onset times go tone and noise, select 2nd/-1 OR select the one 

192 # that is grater than the nogo or err trial onset time 

193 if not data: 1ucjshdeltmknoifbgp

194 data = raw.load_data(session_path, task_collection=task_collection) 

195 missed_bnc2 = 0 1ucjshdeltmknoifbgp

196 rw_times, err_sound_times, merge = [np.zeros([len(data), ]) for _ in range(3)] 1ucjshdeltmknoifbgp

197 

198 for ind, tr in enumerate(data): 1ucjshdeltmknoifbgp

199 st = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1ucjshdeltmknoifbgp

200 if not st: 1ucjshdeltmknoifbgp

201 st = np.array([np.nan, np.nan]) 1bg

202 missed_bnc2 += 1 1bg

203 # xonar soundcard duplicates events, remove consecutive events too close together 

204 st = np.delete(st, np.where(np.diff(st) < 0.020)[0] + 1) 1ucjshdeltmknoifbgp

205 rw_times[ind] = tr['behavior_data']['States timestamps']['reward'][0][0] 1ucjshdeltmknoifbgp

206 # get the error sound only if the reward is nan 

207 err_sound_times[ind] = st[-1] if st.size >= 2 and np.isnan(rw_times[ind]) else np.nan 1ucjshdeltmknoifbgp

208 if missed_bnc2 == len(data): 1ucjshdeltmknoifbgp

209 _logger.warning('No BNC2 for feedback times, filling error trials NaNs') 

210 merge *= np.nan 1ucjshdeltmknoifbgp

211 merge[~np.isnan(rw_times)] = rw_times[~np.isnan(rw_times)] 1ucjshdeltmknoifbgp

212 merge[~np.isnan(err_sound_times)] = err_sound_times[~np.isnan(err_sound_times)] 1ucjshdeltmknoifbgp

213 

214 return merge 1ucjshdeltmknoifbgp

215 

216 def _extract(self): 

217 # Version check 

218 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1auzcjshdeqltmknoifbgp

219 merge = self.get_feedback_times_ge5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1ucjshdeltmknoifbgp

220 else: 

221 merge = self.get_feedback_times_lt5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1azcqb

222 return np.array(merge) 1auzcjshdeqltmknoifbgp

223 

224 

225class Intervals(BaseBpodTrialsExtractor): 

226 """ 

227 Trial start to trial end. Trial end includes 1 or 2 seconds after feedback, 

228 (depending on the feedback) and 0.5 seconds of iti. 

229 **Optional:** saves _ibl_trials.intervals.npy 

230 

231 Uses the corrected Trial start and Trial end timestamp values form PyBpod. 

232 """ 

233 save_names = '_ibl_trials.intervals.npy' 

234 var_names = 'intervals' 

235 

236 def _extract(self): 

237 starts = [t['behavior_data']['Trial start timestamp'] for t in self.bpod_trials] 1aHcjhdeqlmknoifbgp

238 ends = [t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials] 1aHcjhdeqlmknoifbgp

239 return np.array([starts, ends]).T 1aHcjhdeqlmknoifbgp

240 

241 

242class ResponseTimes(BaseBpodTrialsExtractor): 

243 """ 

244 Time (in absolute seconds from session start) when a response was recorded. 

245 **Optional:** saves _ibl_trials.response_times.npy 

246 

247 Uses the timestamp of the end of the closed_loop state. 

248 """ 

249 save_names = '_ibl_trials.response_times.npy' 

250 var_names = 'response_times' 

251 

252 def _extract(self): 

253 rt = np.array([tr['behavior_data']['States timestamps']['closed_loop'][0][1] 1aKcjhdeqlmknoifbgp

254 for tr in self.bpod_trials]) 

255 return rt 1aKcjhdeqlmknoifbgp

256 

257 

258class ItiDuration(BaseBpodTrialsExtractor): 

259 """ 

260 Calculate duration of iti from state timestamps. 

261 **Optional:** saves _ibl_trials.iti_duration.npy 

262 

263 Uses Trial end timestamp and get_response_times to calculate iti. 

264 """ 

265 save_names = '_ibl_trials.itiDuration.npy' 

266 var_names = 'iti_dur' 

267 

268 def _extract(self): 

269 rt, _ = ResponseTimes(self.session_path).extract( 

270 save=False, task_collection=self.task_collection, bpod_trials=self.bpod_trials, settings=self.settings) 

271 ends = np.array([t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials]) 

272 iti_dur = ends - rt 

273 return iti_dur 

274 

275 

276class GoCueTriggerTimes(BaseBpodTrialsExtractor): 

277 """ 

278 Get trigger times of goCue from state machine. 

279 

280 Current software solution for triggering sounds uses PyBpod soft codes. 

281 Delays can be in the order of 10's of ms. This is the time when the command 

282 to play the sound was executed. To measure accurate time, either getting the 

283 sound onset from xonar soundcard sync pulse (latencies may vary). 

284 """ 

285 save_names = '_ibl_trials.goCueTrigger_times.npy' 

286 var_names = 'goCueTrigger_times' 

287 

288 def _extract(self): 

289 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1aIcjhdeqlmknoifbgp

290 goCue = np.array([tr['behavior_data']['States timestamps'] 1Icjhdelmknoifbgp

291 ['play_tone'][0][0] for tr in self.bpod_trials]) 

292 else: 

293 goCue = np.array([tr['behavior_data']['States timestamps'] 1acqb

294 ['closed_loop'][0][0] for tr in self.bpod_trials]) 

295 return goCue 1aIcjhdeqlmknoifbgp

296 

297 

298class TrialType(BaseBpodTrialsExtractor): 

299 save_names = '_ibl_trials.type.npy' 

300 var_name = 'trial_type' 

301 

302 def _extract(self): 

303 trial_type = [] 

304 for tr in self.bpod_trials: 

305 if ~np.isnan(tr["behavior_data"]["States timestamps"]["reward"][0][0]): 

306 trial_type.append(1) 

307 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["error"][0][0]): 

308 trial_type.append(-1) 

309 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0][0]): 

310 trial_type.append(0) 

311 else: 

312 _logger.warning("Trial is not in set {-1, 0, 1}, appending NaN to trialType") 

313 trial_type.append(np.nan) 

314 return np.array(trial_type) 

315 

316 

317class GoCueTimes(BaseBpodTrialsExtractor): 

318 """ 

319 Get trigger times of goCue from state machine. 

320 

321 Current software solution for triggering sounds uses PyBpod soft codes. 

322 Delays can be in the order of 10-100s of ms. This is the time when the command 

323 to play the sound was executed. To measure accurate time, either getting the 

324 sound onset from the future microphone OR the new xonar soundcard and 

325 setup developed by Sanworks guarantees a set latency (in testing). 

326 """ 

327 save_names = '_ibl_trials.goCue_times.npy' 

328 var_names = 'goCue_times' 

329 

330 def _extract(self): 

331 go_cue_times = np.zeros([len(self.bpod_trials), ]) 1avcxjshdeqltmknoifybgp

332 for ind, tr in enumerate(self.bpod_trials): 1avcxjshdeqltmknoifybgp

333 if raw.get_port_events(tr, 'BNC2'): 1avcxjshdeqltmknoifybgp

334 bnchigh = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1vcxjshdeqltmknoifybgp

335 if bnchigh: 1vcxjshdeqltmknoifybgp

336 go_cue_times[ind] = bnchigh[0] 1vcxjshdeqltmknoifybgp

337 continue 1vcxjshdeqltmknoifybgp

338 bnclow = tr['behavior_data']['Events timestamps'].get('BNC2Low', None) 1bg

339 if bnclow: 1bg

340 go_cue_times[ind] = bnclow[0] - 0.1 1bg

341 continue 1bg

342 go_cue_times[ind] = np.nan 

343 else: 

344 go_cue_times[ind] = np.nan 1avcb

345 

346 nmissing = np.sum(np.isnan(go_cue_times)) 1avcxjshdeqltmknoifybgp

347 # Check if all stim_syncs have failed to be detected 

348 if np.all(np.isnan(go_cue_times)): 1avcxjshdeqltmknoifybgp

349 _logger.warning( 1avcb

350 f'{self.session_path}: Missing ALL !! BNC2 TTLs ({nmissing} trials)') 

351 # Check if any stim_sync has failed be detected for every trial 

352 elif np.any(np.isnan(go_cue_times)): 1vcxjshdeqltmknoifybgp

353 _logger.warning(f'{self.session_path}: Missing BNC2 TTLs on {nmissing} trials') 

354 

355 return go_cue_times 1avcxjshdeqltmknoifybgp

356 

357 

358class IncludedTrials(BaseBpodTrialsExtractor): 

359 save_names = '_ibl_trials.included.npy' 

360 var_names = 'included' 

361 

362 def _extract(self): 

363 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 1aADEcjhqlmknoibgp

364 trials_included = self.get_included_trials_ge5( 1ADcjhlmknoibgp

365 data=self.bpod_trials, settings=self.settings) 

366 else: 

367 trials_included = self.get_included_trials_lt5(data=self.bpod_trials) 1aAEcqb

368 return trials_included 1aADEcjhqlmknoibgp

369 

370 @staticmethod 

371 def get_included_trials_lt5(data=False): 

372 trials_included = np.ones(len(data), dtype=bool) 1aAEcqb

373 return trials_included 1aAEcqb

374 

375 @staticmethod 

376 def get_included_trials_ge5(data=False, settings=False): 

377 trials_included = np.array([True for t in data]) 1ADcjhlmknoibgp

378 if ('SUBJECT_DISENGAGED_TRIGGERED' in settings.keys() and settings[ 1ADcjhlmknoibgp

379 'SUBJECT_DISENGAGED_TRIGGERED'] is not False): 

380 idx = settings['SUBJECT_DISENGAGED_TRIALNUM'] - 1 1jhmknobg

381 trials_included[idx:] = False 1jhmknobg

382 return trials_included 1ADcjhlmknoibgp

383 

384 

385class ItiInTimes(BaseBpodTrialsExtractor): 

386 var_names = 'itiIn_times' 

387 

388 def _extract(self): 

389 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("5.0.0"): 1acjhdeqlmknoifbgp

390 iti_in = np.ones(len(self.bpod_trials)) * np.nan 1acqb

391 else: 

392 iti_in = np.array( 1cjhdelmknoifbgp

393 [tr["behavior_data"]["States timestamps"] 

394 ["exit_state"][0][0] for tr in self.bpod_trials] 

395 ) 

396 return iti_in 1acjhdeqlmknoifbgp

397 

398 

399class ErrorCueTriggerTimes(BaseBpodTrialsExtractor): 

400 var_names = 'errorCueTrigger_times' 

401 

402 def _extract(self): 

403 errorCueTrigger_times = np.zeros(len(self.bpod_trials)) * np.nan 1acjhdeqlmknoifbgp

404 for i, tr in enumerate(self.bpod_trials): 1acjhdeqlmknoifbgp

405 nogo = tr["behavior_data"]["States timestamps"]["no_go"][0][0] 1acjhdeqlmknoifbgp

406 error = tr["behavior_data"]["States timestamps"]["error"][0][0] 1acjhdeqlmknoifbgp

407 if np.all(~np.isnan(nogo)): 1acjhdeqlmknoifbgp

408 errorCueTrigger_times[i] = nogo 1chdemnoibg

409 elif np.all(~np.isnan(error)): 1acjhdeqlmknoifbgp

410 errorCueTrigger_times[i] = error 1acjhdeqlmknoifbgp

411 return errorCueTrigger_times 1acjhdeqlmknoifbgp

412 

413 

414class StimFreezeTriggerTimes(BaseBpodTrialsExtractor): 

415 var_names = 'stimFreezeTrigger_times' 

416 

417 def _extract(self): 

418 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("6.2.5"): 1acrjhdeqlmknoifbgp

419 return np.ones(len(self.bpod_trials)) * np.nan 1acrqmnobgp

420 freeze_reward = np.array( 1jhdelkifg

421 [ 

422 True 

423 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_reward"][0])) 

424 else False 

425 for tr in self.bpod_trials 

426 ] 

427 ) 

428 freeze_error = np.array( 1jhdelkifg

429 [ 

430 True 

431 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_error"][0])) 

432 else False 

433 for tr in self.bpod_trials 

434 ] 

435 ) 

436 no_go = np.array( 1jhdelkifg

437 [ 

438 True 

439 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0])) 

440 else False 

441 for tr in self.bpod_trials 

442 ] 

443 ) 

444 assert (np.sum(freeze_error) + np.sum(freeze_reward) + 1jhdelkifg

445 np.sum(no_go) == len(self.bpod_trials)) 

446 stimFreezeTrigger = np.array([]) 1jhdelkifg

447 for r, e, n, tr in zip(freeze_reward, freeze_error, no_go, self.bpod_trials): 1jhdelkifg

448 if n: 1jhdelkifg

449 stimFreezeTrigger = np.append(stimFreezeTrigger, np.nan) 1hdeig

450 continue 1hdeig

451 state = "freeze_reward" if r else "freeze_error" 1jhdelkifg

452 stimFreezeTrigger = np.append( 1jhdelkifg

453 stimFreezeTrigger, tr["behavior_data"]["States timestamps"][state][0][0] 

454 ) 

455 return stimFreezeTrigger 1jhdelkifg

456 

457 

458class StimOffTriggerTimes(BaseBpodTrialsExtractor): 

459 var_names = 'stimOffTrigger_times' 

460 save_names = '_ibl_trials.stimOnTrigger_times.npy' 

461 

462 def _extract(self): 

463 if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') >= version.parse("6.2.5"): 1acrjhdeqlmknoifbgp

464 stim_off_trigger_state = "hide_stim" 1jhdelkifg

465 elif version.parse(self.settings["IBLRIG_VERSION"]) >= version.parse("5.0.0"): 1acrqmnobgp

466 stim_off_trigger_state = "exit_state" 1crmnobgp

467 else: 

468 stim_off_trigger_state = "trial_start" 1acrqb

469 

470 stimOffTrigger_times = np.array( 1acrjhdeqlmknoifbgp

471 [tr["behavior_data"]["States timestamps"][stim_off_trigger_state][0][0] 

472 for tr in self.bpod_trials] 

473 ) 

474 # If pre version 5.0.0 no specific nogo Off trigger was given, just return trial_starts 

475 if stim_off_trigger_state == "trial_start": 1acrjhdeqlmknoifbgp

476 return stimOffTrigger_times 1acrqb

477 

478 no_goTrigger_times = np.array( 1crjhdelmknoifbgp

479 [tr["behavior_data"]["States timestamps"]["no_go"][0][0] for tr in self.bpod_trials] 

480 ) 

481 # Stim off trigs are either in their own state or in the no_go state if the 

482 # mouse did not move, if the stim_off_trigger_state always exist 

483 # (exit_state or trial_start) 

484 # no NaNs will happen, NaNs might happen in at last trial if 

485 # session was stopped after response 

486 # if stim_off_trigger_state == "hide_stim": 

487 # assert all(~np.isnan(no_goTrigger_times) == np.isnan(stimOffTrigger_times)) 

488 # Patch with the no_go states trig times 

489 stimOffTrigger_times[~np.isnan(no_goTrigger_times)] = no_goTrigger_times[ 1crjhdelmknoifbgp

490 ~np.isnan(no_goTrigger_times) 

491 ] 

492 return stimOffTrigger_times 1crjhdelmknoifbgp

493 

494 

495class StimOnTriggerTimes(BaseBpodTrialsExtractor): 

496 save_names = '_ibl_trials.stimOnTrigger_times.npy' 

497 var_names = 'stimOnTrigger_times' 

498 

499 def _extract(self): 

500 # Get the stim_on_state that triggers the onset of the stim 

501 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1aLcrjshdeqltmknoifbgp

502 ['stim_on'][0] for tr in self.bpod_trials]) 

503 return stim_on_state[:, 0].T 1aLcrjshdeqltmknoifbgp

504 

505 

506class StimOnTimes_deprecated(BaseBpodTrialsExtractor): 

507 save_names = '_ibl_trials.stimOn_times.npy' 

508 var_names = 'stimOn_times' 

509 

510 def _extract(self): 

511 """ 

512 Find the time of the state machine command to turn on the stim 

513 (state stim_on start or rotary_encoder_event2) 

514 Find the next frame change from the photodiode after that TS. 

515 Screen is not displaying anything until then. 

516 (Frame changes are in BNC1 High and BNC1 Low) 

517 """ 

518 # Version check 

519 _logger.warning("Deprecation Warning: this is an old version of stimOn extraction." 

520 "From version 5., use StimOnOffFreezeTimes") 

521 if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): 

522 stimOn_times = self.get_stimOn_times_ge5(self.session_path, data=self.bpod_trials, 

523 task_collection=self.task_collection) 

524 else: 

525 stimOn_times = self.get_stimOn_times_lt5(self.session_path, data=self.bpod_trials, 

526 task_collection=self.task_collection) 

527 return np.array(stimOn_times) 

528 

529 @staticmethod 

530 def get_stimOn_times_ge5(session_path, data=False, task_collection='raw_behavior_data'): 

531 """ 

532 Find first and last stim_sync pulse of the trial. 

533 stimOn_times should be the first after the stim_on state. 

534 (Stim updates are in BNC1High and BNC1Low - frame2TTL device) 

535 Check that all trials have frame changes. 

536 Find length of stim_on_state [start, stop]. 

537 If either check fails the HW device failed to detect the stim_sync square change 

538 Substitute that trial's missing or incorrect value with a NaN. 

539 return stimOn_times 

540 """ 

541 if not data: 

542 data = raw.load_data(session_path, task_collection=task_collection) 

543 # Get all stim_sync events detected 

544 stim_sync_all = [raw.get_port_events(tr, 'BNC1') for tr in data] 

545 stim_sync_all = [np.array(x) for x in stim_sync_all] 

546 # Get the stim_on_state that triggers the onset of the stim 

547 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 

548 ['stim_on'][0] for tr in data]) 

549 

550 stimOn_times = np.array([]) 

551 for sync, on, off in zip( 

552 stim_sync_all, stim_on_state[:, 0], stim_on_state[:, 1]): 

553 pulse = sync[np.where(np.bitwise_and((sync > on), (sync <= off)))] 

554 if pulse.size == 0: 

555 stimOn_times = np.append(stimOn_times, np.nan) 

556 else: 

557 stimOn_times = np.append(stimOn_times, pulse) 

558 

559 nmissing = np.sum(np.isnan(stimOn_times)) 

560 # Check if all stim_syncs have failed to be detected 

561 if np.all(np.isnan(stimOn_times)): 

562 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({nmissing} trials)') 

563 

564 # Check if any stim_sync has failed be detected for every trial 

565 if np.any(np.isnan(stimOn_times)): 

566 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {nmissing} trials') 

567 

568 return stimOn_times 

569 

570 @staticmethod 

571 def get_stimOn_times_lt5(session_path, data=False, task_collection='raw_behavior_data'): 

572 """ 

573 Find the time of the statemachine command to turn on the stim 

574 (state stim_on start or rotary_encoder_event2) 

575 Find the next frame change from the photodiode after that TS. 

576 Screen is not displaying anything until then. 

577 (Frame changes are in BNC1High and BNC1Low) 

578 """ 

579 if not data: 

580 data = raw.load_data(session_path, task_collection=task_collection) 

581 stim_on = [] 

582 bnc_h = [] 

583 bnc_l = [] 

584 for tr in data: 

585 stim_on.append(tr['behavior_data']['States timestamps']['stim_on'][0][0]) 

586 if 'BNC1High' in tr['behavior_data']['Events timestamps'].keys(): 

587 bnc_h.append(np.array(tr['behavior_data'] 

588 ['Events timestamps']['BNC1High'])) 

589 else: 

590 bnc_h.append(np.array([np.NINF])) 

591 if 'BNC1Low' in tr['behavior_data']['Events timestamps'].keys(): 

592 bnc_l.append(np.array(tr['behavior_data'] 

593 ['Events timestamps']['BNC1Low'])) 

594 else: 

595 bnc_l.append(np.array([np.NINF])) 

596 

597 stim_on = np.array(stim_on) 

598 bnc_h = np.array(bnc_h, dtype=object) 

599 bnc_l = np.array(bnc_l, dtype=object) 

600 

601 count_missing = 0 

602 stimOn_times = np.zeros_like(stim_on) 

603 for i in range(len(stim_on)): 

604 hl = np.sort(np.concatenate([bnc_h[i], bnc_l[i]])) 

605 stot = hl[hl > stim_on[i]] 

606 if np.size(stot) == 0: 

607 stot = np.array([np.nan]) 

608 count_missing += 1 

609 stimOn_times[i] = stot[0] 

610 

611 if np.all(np.isnan(stimOn_times)): 

612 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({count_missing} trials)') 

613 

614 if count_missing > 0: 

615 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {count_missing} trials') 

616 

617 return np.array(stimOn_times) 

618 

619 

620class StimOnOffFreezeTimes(BaseBpodTrialsExtractor): 

621 """ 

622 Extracts stim on / off and freeze times from Bpod BNC1 detected fronts. 

623 

624 Each stimulus event is the first detected front of the BNC1 signal after the trigger state, but before the next 

625 trigger state. 

626 """ 

627 save_names = ('_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOff_times.npy', None) 

628 var_names = ('stimOn_times', 'stimOff_times', 'stimFreeze_times') 

629 

630 def _extract(self): 

631 choice = Choice(self.session_path).extract( 1acrjhdeqlmknoifbgp

632 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

633 )[0] 

634 stimOnTrigger = StimOnTriggerTimes(self.session_path).extract( 1acrjhdeqlmknoifbgp

635 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

636 )[0] 

637 stimFreezeTrigger = StimFreezeTriggerTimes(self.session_path).extract( 1acrjhdeqlmknoifbgp

638 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

639 )[0] 

640 stimOffTrigger = StimOffTriggerTimes(self.session_path).extract( 1acrjhdeqlmknoifbgp

641 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

642 )[0] 

643 f2TTL = [raw.get_port_events(tr, name='BNC1') for tr in self.bpod_trials] 1acrjhdeqlmknoifbgp

644 assert stimOnTrigger.size == stimFreezeTrigger.size == stimOffTrigger.size == choice.size == len(f2TTL) 1acrjhdeqlmknoifbgp

645 assert all(stimOnTrigger < np.nan_to_num(stimFreezeTrigger, nan=np.inf)) and \ 1acrjhdeqlmknoifbgp

646 all(np.nan_to_num(stimFreezeTrigger, nan=-np.inf) < stimOffTrigger) 

647 

648 stimOn_times = np.array([]) 1acrjhdeqlmknoifbgp

649 stimOff_times = np.array([]) 1acrjhdeqlmknoifbgp

650 stimFreeze_times = np.array([]) 1acrjhdeqlmknoifbgp

651 has_freeze = version.parse(self.settings.get('IBLRIG_VERSION', '0')) >= version.parse('6.2.5') 1acrjhdeqlmknoifbgp

652 for tr, on, freeze, off, c in zip(f2TTL, stimOnTrigger, stimFreezeTrigger, stimOffTrigger, choice): 1acrjhdeqlmknoifbgp

653 tr = np.array(tr) 1acrjhdeqlmknoifbgp

654 # stim on 

655 lim = freeze if has_freeze else off 1acrjhdeqlmknoifbgp

656 idx, = np.where(np.logical_and(on < tr, tr < lim)) 1acrjhdeqlmknoifbgp

657 stimOn_times = np.append(stimOn_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acrjhdeqlmknoifbgp

658 # stim off 

659 idx, = np.where(off < tr) 1acrjhdeqlmknoifbgp

660 stimOff_times = np.append(stimOff_times, tr[idx[0]] if idx.size > 0 else np.nan) 1acrjhdeqlmknoifbgp

661 # stim freeze - take last event before off trigger 

662 if has_freeze: 1acrjhdeqlmknoifbgp

663 idx, = np.where(np.logical_and(freeze < tr, tr < off)) 1jhdelkifg

664 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1jhdelkifg

665 else: 

666 idx, = np.where(tr <= off) 1acrqmnobgp

667 stimFreeze_times = np.append(stimFreeze_times, tr[idx[-1]] if idx.size > 0 else np.nan) 1acrqmnobgp

668 # In no_go trials no stimFreeze happens just stim Off 

669 stimFreeze_times[choice == 0] = np.nan 1acrjhdeqlmknoifbgp

670 

671 return stimOn_times, stimOff_times, stimFreeze_times 1acrjhdeqlmknoifbgp

672 

673 

674class PhasePosQuiescence(BaseBpodTrialsExtractor): 

675 """Extract stimulus phase, position and quiescence from Bpod data. 

676 

677 For extraction of pre-generated events, use the ProbaContrasts extractor instead. 

678 """ 

679 save_names = (None, None, '_ibl_trials.quiescencePeriod.npy') 

680 var_names = ('phase', 'position', 'quiescence') 

681 

682 def _extract(self, **kwargs): 

683 phase = np.array([t['stim_phase'] for t in self.bpod_trials]) 1acjhdeqlmknoifbgp

684 position = np.array([t['position'] for t in self.bpod_trials]) 1acjhdeqlmknoifbgp

685 quiescence = np.array([t['quiescent_period'] for t in self.bpod_trials]) 1acjhdeqlmknoifbgp

686 return phase, position, quiescence 1acjhdeqlmknoifbgp

687 

688 

689class PauseDuration(BaseBpodTrialsExtractor): 

690 """Extract pause duration from raw trial data.""" 

691 save_names = None 

692 var_names = 'pause_duration' 

693 

694 def _extract(self, **kwargs): 

695 # pausing logic added in version 8.9.0 

696 ver = version.parse(self.settings.get('IBLRIG_VERSION') or '0') 1Jcdefb

697 default = 0. if ver < version.parse('8.9.0') else np.nan 1Jcdefb

698 return np.fromiter((t.get('pause_duration', default) for t in self.bpod_trials), dtype=float) 1Jcdefb

699 

700 

701class TrialsTable(BaseBpodTrialsExtractor): 

702 """ 

703 Extracts the following into a table from Bpod raw data: 

704 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

705 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

706 Additionally extracts the following wheel data: 

707 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude 

708 """ 

709 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

710 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) 

711 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 

712 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement') 

713 

714 def _extract(self, extractor_classes=None, **kwargs): 

715 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1cdefb

716 RewardVolume, ProbabilityLeft, Wheel] 

717 out, _ = run_extractor_classes( 1cdefb

718 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

719 task_collection=self.task_collection) 

720 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1cdefb

721 assert len(table.keys()) == 12 1cdefb

722 

723 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1cdefb

724 

725 

726class TrainingTrials(BaseBpodTrialsExtractor): 

727 save_names = ('_ibl_trials.repNum.npy', '_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, 

728 '_ibl_trials.stimOffTrigger_times.npy', None, None, '_ibl_trials.table.pqt', '_ibl_trials.stimOff_times.npy', 

729 None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

730 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, 

731 '_ibl_trials.quiescencePeriod.npy', None) 

732 var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 

733 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 

734 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 

735 'wheelMoves_peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence', 'pause_duration') 

736 

737 def _extract(self) -> dict: 

738 base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1cdefb

739 ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence, PauseDuration] 

740 out, _ = run_extractor_classes( 1cdefb

741 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

742 task_collection=self.task_collection) 

743 return {k: out[k] for k in self.var_names} 1cdefb