Coverage for ibllib/io/extractors/training_trials.py: 93%

377 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1import logging 

2import numpy as np 

3from pkg_resources import parse_version 

4from one.alf.io import AlfBunch 

5 

6import ibllib.io.raw_data_loaders as raw 

7from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

8from ibllib.io.extractors.training_wheel import Wheel 

9 

10 

11_logger = logging.getLogger(__name__) 

12__all__ = ['TrainingTrials', 'extract_all'] 

13 

14 

15class FeedbackType(BaseBpodTrialsExtractor): 

16 """ 

17 Get the feedback that was delivered to subject. 

18 **Optional:** saves _ibl_trials.feedbackType.npy 

19 

20 Checks in raw datafile for error and reward state. 

21 Will raise an error if more than one of the mutually exclusive states have 

22 been triggered. 

23 

24 Sets feedbackType to -1 if error state was triggered (applies to no-go) 

25 Sets feedbackType to +1 if reward state was triggered 

26 """ 

27 save_names = '_ibl_trials.feedbackType.npy' 

28 var_names = 'feedbackType' 

29 

30 def _extract(self): 

31 feedbackType = np.zeros(len(self.bpod_trials), np.int64) 1adFcotjkfgiqrsuplmnwxbehv

32 for i, t in enumerate(self.bpod_trials): 1adFcotjkfgiqrsuplmnwxbehv

33 state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go'] 1adFcotjkfgiqrsuplmnwxbehv

34 outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.NaN]])[0][0]) for sn in state_names} 1adFcotjkfgiqrsuplmnwxbehv

35 assert np.sum(list(outcome.values())) == 1 1adFcotjkfgiqrsuplmnwxbehv

36 outcome = next(k for k in outcome if outcome[k]) 1adFcotjkfgiqrsuplmnwxbehv

37 if outcome == 'correct': 1adFcotjkfgiqrsuplmnwxbehv

38 feedbackType[i] = 1 1adFcotjkfgiqrsuplmnwxbeh

39 elif outcome in ['error', 'no_go']: 1adFcotjkfgiqrsuplmnwxbehv

40 feedbackType[i] = -1 1adFcotjkfgiqrsuplmnwxbehv

41 return feedbackType 1adFcotjkfgiqrsuplmnwxbehv

42 

43 

44class ContrastLR(BaseBpodTrialsExtractor): 

45 """ 

46 Get left and right contrasts from raw datafile. Optionally, saves 

47 _ibl_trials.contrastLeft.npy and _ibl_trials.contrastRight.npy to alf folder. 

48 

49 Uses signed_contrast to create left and right contrast vectors. 

50 """ 

51 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy') 

52 var_names = ('contrastLeft', 'contrastRight') 

53 

54 def _extract(self): 

55 # iblrigv8 has only flat values in the trial table so we can switch to parquet table when times come 

56 # and all the clutter here would fit in ~30 lines 

57 if isinstance(self.bpod_trials[0]['contrast'], float): 1dNcfgb

58 contrastLeft = np.array([t['contrast'] if np.sign( 

59 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

60 contrastRight = np.array([t['contrast'] if np.sign( 

61 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

62 else: 

63 contrastLeft = np.array([t['contrast']['value'] if np.sign( 1dNcfgb

64 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

65 contrastRight = np.array([t['contrast']['value'] if np.sign( 1dNcfgb

66 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

67 

68 return contrastLeft, contrastRight 1dNcfgb

69 

70 

71class ProbabilityLeft(BaseBpodTrialsExtractor): 

72 save_names = '_ibl_trials.probabilityLeft.npy' 

73 var_names = 'probabilityLeft' 

74 

75 def _extract(self, **kwargs): 

76 return np.array([t['stim_probability_left'] for t in self.bpod_trials]) 1adUcojkfgiqrsuplmnwxbehv

77 

78 

79class Choice(BaseBpodTrialsExtractor): 

80 """ 

81 Get the subject's choice in every trial. 

82 **Optional:** saves _ibl_trials.choice.npy to alf folder. 

83 

84 Uses signed_contrast and trial_correct. 

85 -1 is a CCW turn (towards the left) 

86 +1 is a CW turn (towards the right) 

87 0 is a no_go trial 

88 If a trial is correct the choice of the animal was the inverse of the sign 

89 of the position. 

90 

91 >>> choice[t] = -np.sign(position[t]) if trial_correct[t] 

92 """ 

93 save_names = '_ibl_trials.choice.npy' 

94 var_names = 'choice' 

95 

96 def _extract(self): 

97 sitm_side = np.array([np.sign(t['position']) for t in self.bpod_trials]) 1adKcAotjkfgiqrsuplmnwxbehv

98 trial_correct = np.array([t['trial_correct'] for t in self.bpod_trials]) 1adKcAotjkfgiqrsuplmnwxbehv

99 trial_nogo = np.array( 1adKcAotjkfgiqrsuplmnwxbehv

100 [~np.isnan(t['behavior_data']['States timestamps']['no_go'][0][0]) 

101 for t in self.bpod_trials]) 

102 choice = sitm_side.copy() 1adKcAotjkfgiqrsuplmnwxbehv

103 choice[trial_correct] = -choice[trial_correct] 1adKcAotjkfgiqrsuplmnwxbehv

104 choice[trial_nogo] = 0 1adKcAotjkfgiqrsuplmnwxbehv

105 choice = choice.astype(int) 1adKcAotjkfgiqrsuplmnwxbehv

106 return choice 1adKcAotjkfgiqrsuplmnwxbehv

107 

108 

109class RepNum(BaseBpodTrialsExtractor): 

110 """ 

111 Count the consecutive repeated trials. 

112 **Optional:** saves _ibl_trials.repNum.npy to alf folder. 

113 

114 Creates trial_repeated from trial['contrast']['type'] == 'RepeatContrast' 

115 

116 >>> trial_repeated = [0, 1, 1, 0, 1, 0, 1, 1, 1, 0] 

117 >>> repNum = [0, 1, 2, 0, 1, 0, 1, 2, 3, 0] 

118 """ 

119 save_names = '_ibl_trials.repNum.npy' 

120 var_names = 'repNum' 

121 

122 def _extract(self): 

123 def get_trial_repeat(trial): 1dDcfgb

124 if 'debias_trial' in trial: 1dDcfgb

125 return trial['debias_trial'] 

126 else: 

127 return trial['contrast']['type'] == 'RepeatContrast' 1dDcfgb

128 

129 trial_repeated = np.array(list(map(get_trial_repeat, self.bpod_trials))).astype(int) 1dDcfgb

130 repNum = trial_repeated.copy() 1dDcfgb

131 c = 0 1dDcfgb

132 for i in range(len(trial_repeated)): 1dDcfgb

133 if trial_repeated[i] == 0: 1dDcfgb

134 c = 0 1dDcfgb

135 repNum[i] = 0 1dDcfgb

136 continue 1dDcfgb

137 c += 1 1dDcfgb

138 repNum[i] = c 1dDcfgb

139 return repNum 1dDcfgb

140 

141 

142class RewardVolume(BaseBpodTrialsExtractor): 

143 """ 

144 Load reward volume delivered for each trial. 

145 **Optional:** saves _ibl_trials.rewardVolume.npy 

146 

147 Uses reward_current to accumulate the amount of 

148 """ 

149 save_names = '_ibl_trials.rewardVolume.npy' 

150 var_names = 'rewardVolume' 

151 

152 def _extract(self): 

153 trial_volume = [x['reward_amount'] 1adOcotjkfgiqrsuplmnwxbehv

154 if x['trial_correct'] else 0 for x in self.bpod_trials] 

155 reward_volume = np.array(trial_volume).astype(np.float64) 1adOcotjkfgiqrsuplmnwxbehv

156 assert len(reward_volume) == len(self.bpod_trials) 1adOcotjkfgiqrsuplmnwxbehv

157 return reward_volume 1adOcotjkfgiqrsuplmnwxbehv

158 

159 

160class FeedbackTimes(BaseBpodTrialsExtractor): 

161 """ 

162 Get the times the water or error tone was delivered to the animal. 

163 **Optional:** saves _ibl_trials.feedback_times.npy 

164 

165 Gets reward and error state init times vectors, 

166 checks if theintersection of nans is empty, then 

167 merges the 2 vectors. 

168 """ 

169 save_names = '_ibl_trials.feedback_times.npy' 

170 var_names = 'feedback_times' 

171 

172 @staticmethod 

173 def get_feedback_times_lt5(session_path, task_collection='raw_behavior_data', data=False): 

174 if not data: 1adIciwxbe

175 data = raw.load_data(session_path, task_collection=task_collection) 

176 rw_times = [tr['behavior_data']['States timestamps']['reward'][0][0] 1adIciwxbe

177 for tr in data] 

178 err_times = [tr['behavior_data']['States timestamps']['error'][0][0] 1adIciwxbe

179 for tr in data] 

180 nogo_times = [tr['behavior_data']['States timestamps']['no_go'][0][0] 1adIciwxbe

181 for tr in data] 

182 assert sum(np.isnan(rw_times) & 1adIciwxbe

183 np.isnan(err_times) & np.isnan(nogo_times)) == 0 

184 merge = np.array([np.array(times)[~np.isnan(times)] for times in 1adIciwxbe

185 zip(rw_times, err_times, nogo_times)]).squeeze() 

186 

187 return np.array(merge) 1adIciwxbe

188 

189 @staticmethod 

190 def get_feedback_times_ge5(session_path, task_collection='raw_behavior_data', data=False): 

191 # ger err and no go trig times -- look for BNC2High of trial -- verify 

192 # only 2 onset times go tone and noise, select 2nd/-1 OR select the one 

193 # that is grater than the nogo or err trial onset time 

194 if not data: 1dCcotjykfgqrsuplmnbehv

195 data = raw.load_data(session_path, task_collection=task_collection) 

196 missed_bnc2 = 0 1dCcotjykfgqrsuplmnbehv

197 rw_times, err_sound_times, merge = [np.zeros([len(data), ]) for _ in range(3)] 1dCcotjykfgqrsuplmnbehv

198 

199 for ind, tr in enumerate(data): 1dCcotjykfgqrsuplmnbehv

200 st = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1dCcotjykfgqrsuplmnbehv

201 if not st: 1dCcotjykfgqrsuplmnbehv

202 st = np.array([np.nan, np.nan]) 1eh

203 missed_bnc2 += 1 1eh

204 # xonar soundcard duplicates events, remove consecutive events too close together 

205 st = np.delete(st, np.where(np.diff(st) < 0.020)[0] + 1) 1dCcotjykfgqrsuplmnbehv

206 rw_times[ind] = tr['behavior_data']['States timestamps']['reward'][0][0] 1dCcotjykfgqrsuplmnbehv

207 # get the error sound only if the reward is nan 

208 err_sound_times[ind] = st[-1] if st.size >= 2 and np.isnan(rw_times[ind]) else np.nan 1dCcotjykfgqrsuplmnbehv

209 if missed_bnc2 == len(data): 1dCcotjykfgqrsuplmnbehv

210 _logger.warning('No BNC2 for feedback times, filling error trials NaNs') 

211 merge *= np.nan 1dCcotjykfgqrsuplmnbehv

212 merge[~np.isnan(rw_times)] = rw_times[~np.isnan(rw_times)] 1dCcotjykfgqrsuplmnbehv

213 merge[~np.isnan(err_sound_times)] = err_sound_times[~np.isnan(err_sound_times)] 1dCcotjykfgqrsuplmnbehv

214 

215 return merge 1dCcotjykfgqrsuplmnbehv

216 

217 def _extract(self): 

218 # Version check 

219 if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1adCIcotjykfgiqrsuplmnwxbehv

220 merge = self.get_feedback_times_ge5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1dCcotjykfgqrsuplmnbehv

221 else: 

222 merge = self.get_feedback_times_lt5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) 1adIciwxbe

223 return np.array(merge) 1adCIcotjykfgiqrsuplmnwxbehv

224 

225 

226class Intervals(BaseBpodTrialsExtractor): 

227 """ 

228 Trial start to trial end. Trial end includes 1 or 2 seconds after feedback, 

229 (depending on the feedback) and 0.5 seconds of iti. 

230 **Optional:** saves _ibl_trials.intervals.npy 

231 

232 Uses the corrected Trial start and Trial end timestamp values form PyBpod. 

233 """ 

234 save_names = '_ibl_trials.intervals.npy' 

235 var_names = 'intervals' 

236 

237 def _extract(self): 

238 starts = [t['behavior_data']['Trial start timestamp'] for t in self.bpod_trials] 1adPcotjykfgiqrsuplmnwxbehv

239 ends = [t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials] 1adPcotjykfgiqrsuplmnwxbehv

240 return np.array([starts, ends]).T 1adPcotjykfgiqrsuplmnwxbehv

241 

242 

243class ResponseTimes(BaseBpodTrialsExtractor): 

244 """ 

245 Time (in absolute seconds from session start) when a response was recorded. 

246 **Optional:** saves _ibl_trials.response_times.npy 

247 

248 Uses the timestamp of the end of the closed_loop state. 

249 """ 

250 save_names = '_ibl_trials.response_times.npy' 

251 var_names = 'response_times' 

252 

253 def _extract(self): 

254 rt = np.array([tr['behavior_data']['States timestamps']['closed_loop'][0][1] 1adScotjkfgiqrsuplmnwxbehv

255 for tr in self.bpod_trials]) 

256 return rt 1adScotjkfgiqrsuplmnwxbehv

257 

258 

259class ItiDuration(BaseBpodTrialsExtractor): 

260 """ 

261 Calculate duration of iti from state timestamps. 

262 **Optional:** saves _ibl_trials.iti_duration.npy 

263 

264 Uses Trial end timestamp and get_response_times to calculate iti. 

265 """ 

266 save_names = '_ibl_trials.itiDuration.npy' 

267 var_names = 'iti_dur' 

268 

269 def _extract(self): 

270 rt, _ = ResponseTimes(self.session_path).extract( 

271 save=False, task_collection=self.task_collection, bpod_trials=self.bpod_trials, settings=self.settings) 

272 ends = np.array([t['behavior_data']['Trial end timestamp'] for t in self.bpod_trials]) 

273 iti_dur = ends - rt 

274 return iti_dur 

275 

276 

277class GoCueTriggerTimes(BaseBpodTrialsExtractor): 

278 """ 

279 Get trigger times of goCue from state machine. 

280 

281 Current software solution for triggering sounds uses PyBpod soft codes. 

282 Delays can be in the order of 10's of ms. This is the time when the command 

283 to play the sound was executed. To measure accurate time, either getting the 

284 sound onset from xonar soundcard sync pulse (latencies may vary). 

285 """ 

286 save_names = '_ibl_trials.goCueTrigger_times.npy' 

287 var_names = 'goCueTrigger_times' 

288 

289 def _extract(self): 

290 if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1adQcotjkfgiqRrsuplmnwxbehv

291 goCue = np.array([tr['behavior_data']['States timestamps'] 1dQcotjkfgqRrsuplmnbehv

292 ['play_tone'][0][0] for tr in self.bpod_trials]) 

293 else: 

294 goCue = np.array([tr['behavior_data']['States timestamps'] 1adciwxbe

295 ['closed_loop'][0][0] for tr in self.bpod_trials]) 

296 return goCue 1adQcotjkfgiqRrsuplmnwxbehv

297 

298 

299class TrialType(BaseBpodTrialsExtractor): 

300 save_names = '_ibl_trials.type.npy' 

301 var_name = 'trial_type' 

302 

303 def _extract(self): 

304 trial_type = [] 

305 for tr in self.bpod_trials: 

306 if ~np.isnan(tr["behavior_data"]["States timestamps"]["reward"][0][0]): 

307 trial_type.append(1) 

308 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["error"][0][0]): 

309 trial_type.append(-1) 

310 elif ~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0][0]): 

311 trial_type.append(0) 

312 else: 

313 _logger.warning("Trial is not in set {-1, 0, 1}, appending NaN to trialType") 

314 trial_type.append(np.nan) 

315 return np.array(trial_type) 

316 

317 

318class GoCueTimes(BaseBpodTrialsExtractor): 

319 """ 

320 Get trigger times of goCue from state machine. 

321 

322 Current software solution for triggering sounds uses PyBpod soft codes. 

323 Delays can be in the order of 10-100s of ms. This is the time when the command 

324 to play the sound was executed. To measure accurate time, either getting the 

325 sound onset from the future microphone OR the new xonar soundcard and 

326 setup developed by Sanworks guarantees a set latency (in testing). 

327 """ 

328 save_names = '_ibl_trials.goCue_times.npy' 

329 var_names = 'goCue_times' 

330 

331 def _extract(self): 

332 go_cue_times = np.zeros([len(self.bpod_trials), ]) 1adEcotGjykfgiqrsuplmnwxHbehv

333 for ind, tr in enumerate(self.bpod_trials): 1adEcotGjykfgiqrsuplmnwxHbehv

334 if raw.get_port_events(tr, 'BNC2'): 1adEcotGjykfgiqrsuplmnwxHbehv

335 bnchigh = tr['behavior_data']['Events timestamps'].get('BNC2High', None) 1dEcotGjykfgiqrsuplmnHbehv

336 if bnchigh: 1dEcotGjykfgiqrsuplmnHbehv

337 go_cue_times[ind] = bnchigh[0] 1dEcotGjykfgiqrsuplmnHbehv

338 continue 1dEcotGjykfgiqrsuplmnHbehv

339 bnclow = tr['behavior_data']['Events timestamps'].get('BNC2Low', None) 1eh

340 if bnclow: 1eh

341 go_cue_times[ind] = bnclow[0] - 0.1 1eh

342 continue 1eh

343 go_cue_times[ind] = np.nan 

344 else: 

345 go_cue_times[ind] = np.nan 1adEcwxbe

346 

347 nmissing = np.sum(np.isnan(go_cue_times)) 1adEcotGjykfgiqrsuplmnwxHbehv

348 # Check if all stim_syncs have failed to be detected 

349 if np.all(np.isnan(go_cue_times)): 1adEcotGjykfgiqrsuplmnwxHbehv

350 _logger.warning( 1adEcwxbe

351 f'{self.session_path}: Missing ALL !! BNC2 TTLs ({nmissing} trials)') 

352 # Check if any stim_sync has failed be detected for every trial 

353 elif np.any(np.isnan(go_cue_times)): 1dEcotGjykfgiqrsuplmnHbehv

354 _logger.warning(f'{self.session_path}: Missing BNC2 TTLs on {nmissing} trials') 

355 

356 return go_cue_times 1adEcotGjykfgiqrsuplmnwxHbehv

357 

358 

359class IncludedTrials(BaseBpodTrialsExtractor): 

360 save_names = '_ibl_trials.included.npy' 

361 var_names = 'included' 

362 

363 def _extract(self): 

364 if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1dJLMcojkiqrsuplmnbehv

365 trials_included = self.get_included_trials_ge5( 1dJLcojkqrsuplmnbehv

366 data=self.bpod_trials, settings=self.settings) 

367 else: 

368 trials_included = self.get_included_trials_lt5(data=self.bpod_trials) 1JMcib

369 return trials_included 1dJLMcojkiqrsuplmnbehv

370 

371 @staticmethod 

372 def get_included_trials_lt5(data=False): 

373 trials_included = np.array([True for t in data]) 1JMcib

374 return trials_included 1JMcib

375 

376 @staticmethod 

377 def get_included_trials_ge5(data=False, settings=False): 

378 trials_included = np.array([True for t in data]) 1dJLcojkqrsuplmnbehv

379 if ('SUBJECT_DISENGAGED_TRIGGERED' in settings.keys() and settings[ 1dJLcojkqrsuplmnbehv

380 'SUBJECT_DISENGAGED_TRIGGERED'] is not False): 

381 idx = settings['SUBJECT_DISENGAGED_TRIALNUM'] - 1 1jkupmneh

382 trials_included[idx:] = False 1jkupmneh

383 return trials_included 1dJLcojkqrsuplmnbehv

384 

385 

386class ItiInTimes(BaseBpodTrialsExtractor): 

387 var_names = 'itiIn_times' 

388 

389 def _extract(self): 

390 if parse_version(self.settings["IBLRIG_VERSION_TAG"]) < parse_version("5.0.0"): 1dcotjkfgiqrsuplmnbehv

391 iti_in = np.ones(len(self.bpod_trials)) * np.nan 1ib

392 else: 

393 iti_in = np.array( 1dcotjkfgqrsuplmnbehv

394 [tr["behavior_data"]["States timestamps"] 

395 ["exit_state"][0][0] for tr in self.bpod_trials] 

396 ) 

397 return iti_in 1dcotjkfgiqrsuplmnbehv

398 

399 

400class ErrorCueTriggerTimes(BaseBpodTrialsExtractor): 

401 var_names = 'errorCueTrigger_times' 

402 

403 def _extract(self): 

404 errorCueTrigger_times = np.zeros(len(self.bpod_trials)) * np.nan 1dcotjkfgiqrsuplmnbehv

405 for i, tr in enumerate(self.bpod_trials): 1dcotjkfgiqrsuplmnbehv

406 nogo = tr["behavior_data"]["States timestamps"]["no_go"][0][0] 1dcotjkfgiqrsuplmnbehv

407 error = tr["behavior_data"]["States timestamps"]["error"][0][0] 1dcotjkfgiqrsuplmnbehv

408 if np.all(~np.isnan(nogo)): 1dcotjkfgiqrsuplmnbehv

409 errorCueTrigger_times[i] = nogo 1dctkfgrsulbeh

410 elif np.all(~np.isnan(error)): 1dcotjkfgiqrsuplmnbehv

411 errorCueTrigger_times[i] = error 1dcotjkfgiqrsuplmnbehv

412 return errorCueTrigger_times 1dcotjkfgiqrsuplmnbehv

413 

414 

415class StimFreezeTriggerTimes(BaseBpodTrialsExtractor): 

416 var_names = 'stimFreezeTrigger_times' 

417 

418 def _extract(self): 

419 if parse_version(self.settings["IBLRIG_VERSION_TAG"]) < parse_version("6.2.5"): 1dcotjkfgiqrsuplmnbehv

420 return np.ones(len(self.bpod_trials)) * np.nan 1dcirsubehv

421 freeze_reward = np.array( 1otjkfgqplmnh

422 [ 

423 True 

424 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_reward"][0])) 

425 else False 

426 for tr in self.bpod_trials 

427 ] 

428 ) 

429 freeze_error = np.array( 1otjkfgqplmnh

430 [ 

431 True 

432 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["freeze_error"][0])) 

433 else False 

434 for tr in self.bpod_trials 

435 ] 

436 ) 

437 no_go = np.array( 1otjkfgqplmnh

438 [ 

439 True 

440 if np.all(~np.isnan(tr["behavior_data"]["States timestamps"]["no_go"][0])) 

441 else False 

442 for tr in self.bpod_trials 

443 ] 

444 ) 

445 assert (np.sum(freeze_error) + np.sum(freeze_reward) + 1otjkfgqplmnh

446 np.sum(no_go) == len(self.bpod_trials)) 

447 stimFreezeTrigger = np.array([]) 1otjkfgqplmnh

448 for r, e, n, tr in zip(freeze_reward, freeze_error, no_go, self.bpod_trials): 1otjkfgqplmnh

449 if n: 1otjkfgqplmnh

450 stimFreezeTrigger = np.append(stimFreezeTrigger, np.nan) 1tkfglh

451 continue 1tkfglh

452 state = "freeze_reward" if r else "freeze_error" 1otjkfgqplmnh

453 stimFreezeTrigger = np.append( 1otjkfgqplmnh

454 stimFreezeTrigger, tr["behavior_data"]["States timestamps"][state][0][0] 

455 ) 

456 return stimFreezeTrigger 1otjkfgqplmnh

457 

458 

459class StimOffTriggerTimes(BaseBpodTrialsExtractor): 

460 var_names = 'stimOffTrigger_times' 

461 

462 def _extract(self): 

463 if parse_version(self.settings["IBLRIG_VERSION_TAG"]) >= parse_version("6.2.5"): 1dcotjkfgiqrsuplmnbehv

464 stim_off_trigger_state = "hide_stim" 1otjkfgqplmnh

465 elif parse_version(self.settings["IBLRIG_VERSION_TAG"]) >= parse_version("5.0.0"): 1dcirsubehv

466 stim_off_trigger_state = "exit_state" 1dcrsubehv

467 else: 

468 stim_off_trigger_state = "trial_start" 1ib

469 

470 stimOffTrigger_times = np.array( 1dcotjkfgiqrsuplmnbehv

471 [tr["behavior_data"]["States timestamps"][stim_off_trigger_state][0][0] 

472 for tr in self.bpod_trials] 

473 ) 

474 # If pre version 5.0.0 no specific nogo Off trigger was given, just return trial_starts 

475 if stim_off_trigger_state == "trial_start": 1dcotjkfgiqrsuplmnbehv

476 return stimOffTrigger_times 1ib

477 

478 no_goTrigger_times = np.array( 1dcotjkfgqrsuplmnbehv

479 [tr["behavior_data"]["States timestamps"]["no_go"][0][0] for tr in self.bpod_trials] 

480 ) 

481 # Stim off trigs are either in their own state or in the no_go state if the 

482 # mouse did not move, if the stim_off_trigger_state always exist 

483 # (exit_state or trial_start) 

484 # no NaNs will happen, NaNs might happen in at last trial if 

485 # session was stopped after response 

486 # if stim_off_trigger_state == "hide_stim": 

487 # assert all(~np.isnan(no_goTrigger_times) == np.isnan(stimOffTrigger_times)) 

488 # Patch with the no_go states trig times 

489 stimOffTrigger_times[~np.isnan(no_goTrigger_times)] = no_goTrigger_times[ 1dcotjkfgqrsuplmnbehv

490 ~np.isnan(no_goTrigger_times) 

491 ] 

492 return stimOffTrigger_times 1dcotjkfgqrsuplmnbehv

493 

494 

495class StimOnTriggerTimes(BaseBpodTrialsExtractor): 

496 save_names = '_ibl_trials.stimOnTrigger_times.npy' 

497 var_names = 'stimOnTrigger_times' 

498 

499 def _extract(self): 

500 # Get the stim_on_state that triggers the onset of the stim 

501 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1dTcotjykfgiqrsuplmnbehv

502 ['stim_on'][0] for tr in self.bpod_trials]) 

503 return stim_on_state[:, 0].T 1dTcotjykfgiqrsuplmnbehv

504 

505 

506class StimOnTimes_deprecated(BaseBpodTrialsExtractor): 

507 save_names = '_ibl_trials.stimOn_times.npy' 

508 var_names = 'stimOn_times' 

509 

510 def _extract(self): 

511 """ 

512 Find the time of the state machine command to turn on the stim 

513 (state stim_on start or rotary_encoder_event2) 

514 Find the next frame change from the photodiode after that TS. 

515 Screen is not displaying anything until then. 

516 (Frame changes are in BNC1 High and BNC1 Low) 

517 """ 

518 # Version check 

519 _logger.warning("Deprecation Warning: this is an old version of stimOn extraction." 1adBzciwxbe

520 "From version 5., use StimOnOffFreezeTimes") 

521 if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1adBzciwxbe

522 stimOn_times = self.get_stimOn_times_ge5(self.session_path, data=self.bpod_trials, 1B

523 task_collection=self.task_collection) 

524 else: 

525 stimOn_times = self.get_stimOn_times_lt5(self.session_path, data=self.bpod_trials, 1adzciwxbe

526 task_collection=self.task_collection) 

527 return np.array(stimOn_times) 1adBzciwxbe

528 

529 @staticmethod 

530 def get_stimOn_times_ge5(session_path, data=False, task_collection='raw_behavior_data'): 

531 """ 

532 Find first and last stim_sync pulse of the trial. 

533 stimOn_times should be the first after the stim_on state. 

534 (Stim updates are in BNC1High and BNC1Low - frame2TTL device) 

535 Check that all trials have frame changes. 

536 Find length of stim_on_state [start, stop]. 

537 If either check fails the HW device failed to detect the stim_sync square change 

538 Substitute that trial's missing or incorrect value with a NaN. 

539 return stimOn_times 

540 """ 

541 if not data: 1B

542 data = raw.load_data(session_path, task_collection=task_collection) 

543 # Get all stim_sync events detected 

544 stim_sync_all = [raw.get_port_events(tr, 'BNC1') for tr in data] 1B

545 stim_sync_all = [np.array(x) for x in stim_sync_all] 1B

546 # Get the stim_on_state that triggers the onset of the stim 

547 stim_on_state = np.array([tr['behavior_data']['States timestamps'] 1B

548 ['stim_on'][0] for tr in data]) 

549 

550 stimOn_times = np.array([]) 1B

551 for sync, on, off in zip( 1B

552 stim_sync_all, stim_on_state[:, 0], stim_on_state[:, 1]): 

553 pulse = sync[np.where(np.bitwise_and((sync > on), (sync <= off)))] 1B

554 if pulse.size == 0: 1B

555 stimOn_times = np.append(stimOn_times, np.nan) 1B

556 else: 

557 stimOn_times = np.append(stimOn_times, pulse) 1B

558 

559 nmissing = np.sum(np.isnan(stimOn_times)) 1B

560 # Check if all stim_syncs have failed to be detected 

561 if np.all(np.isnan(stimOn_times)): 1B

562 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({nmissing} trials)') 

563 

564 # Check if any stim_sync has failed be detected for every trial 

565 if np.any(np.isnan(stimOn_times)): 1B

566 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {nmissing} trials') 1B

567 

568 return stimOn_times 1B

569 

570 @staticmethod 

571 def get_stimOn_times_lt5(session_path, data=False, task_collection='raw_behavior_data'): 

572 """ 

573 Find the time of the statemachine command to turn on hte stim 

574 (state stim_on start or rotary_encoder_event2) 

575 Find the next frame change from the photodiodeafter that TS. 

576 Screen is not displaying anything until then. 

577 (Frame changes are in BNC1High and BNC1Low) 

578 """ 

579 if not data: 1adzciwxbe

580 data = raw.load_data(session_path, task_collection=task_collection) 

581 stim_on = [] 1adzciwxbe

582 bnc_h = [] 1adzciwxbe

583 bnc_l = [] 1adzciwxbe

584 for tr in data: 1adzciwxbe

585 stim_on.append(tr['behavior_data']['States timestamps']['stim_on'][0][0]) 1adzciwxbe

586 if 'BNC1High' in tr['behavior_data']['Events timestamps'].keys(): 1adzciwxbe

587 bnc_h.append(np.array(tr['behavior_data'] 1dzcibe

588 ['Events timestamps']['BNC1High'])) 

589 else: 

590 bnc_h.append(np.array([np.NINF])) 1azcwxe

591 if 'BNC1Low' in tr['behavior_data']['Events timestamps'].keys(): 1adzciwxbe

592 bnc_l.append(np.array(tr['behavior_data'] 1dzcibe

593 ['Events timestamps']['BNC1Low'])) 

594 else: 

595 bnc_l.append(np.array([np.NINF])) 1azcwxe

596 

597 stim_on = np.array(stim_on) 1adzciwxbe

598 bnc_h = np.array(bnc_h, dtype=object) 1adzciwxbe

599 bnc_l = np.array(bnc_l, dtype=object) 1adzciwxbe

600 

601 count_missing = 0 1adzciwxbe

602 stimOn_times = np.zeros_like(stim_on) 1adzciwxbe

603 for i in range(len(stim_on)): 1adzciwxbe

604 hl = np.sort(np.concatenate([bnc_h[i], bnc_l[i]])) 1adzciwxbe

605 stot = hl[hl > stim_on[i]] 1adzciwxbe

606 if np.size(stot) == 0: 1adzciwxbe

607 stot = np.array([np.nan]) 1azcwxe

608 count_missing += 1 1azcwxe

609 stimOn_times[i] = stot[0] 1adzciwxbe

610 

611 if np.all(np.isnan(stimOn_times)): 1adzciwxbe

612 _logger.error(f'{session_path}: Missing ALL BNC1 TTLs ({count_missing} trials)') 1awxe

613 

614 if count_missing > 0: 1adzciwxbe

615 _logger.warning(f'{session_path}: Missing BNC1 TTLs on {count_missing} trials') 1azcwxe

616 

617 return np.array(stimOn_times) 1adzciwxbe

618 

619 

620class StimOnOffFreezeTimes(BaseBpodTrialsExtractor): 

621 """ 

622 Extracts stim on / off and freeze times from Bpod BNC1 detected fronts 

623 """ 

624 save_names = ('_ibl_trials.stimOn_times.npy', None, None) 

625 var_names = ('stimOn_times', 'stimOff_times', 'stimFreeze_times') 

626 

627 def _extract(self): 

628 choice = Choice(self.session_path).extract( 1adcAotjkfgiqrsuplmnwxbehv

629 bpod_trials=self.bpod_trials, task_collection=self.task_collection, settings=self.settings, save=False 

630 )[0] 

631 f2TTL = [raw.get_port_events(tr, name='BNC1') for tr in self.bpod_trials] 1adcAotjkfgiqrsuplmnwxbehv

632 

633 stimOn_times = np.array([]) 1adcAotjkfgiqrsuplmnwxbehv

634 stimOff_times = np.array([]) 1adcAotjkfgiqrsuplmnwxbehv

635 stimFreeze_times = np.array([]) 1adcAotjkfgiqrsuplmnwxbehv

636 for tr in f2TTL: 1adcAotjkfgiqrsuplmnwxbehv

637 if tr and len(tr) == 2: 1adcAotjkfgiqrsuplmnwxbehv

638 stimOn_times = np.append(stimOn_times, tr[0]) 1Aotjfgirslmnbh

639 stimOff_times = np.append(stimOff_times, tr[-1]) 1Aotjfgirslmnbh

640 stimFreeze_times = np.append(stimFreeze_times, np.nan) 1Aotjfgirslmnbh

641 elif tr and len(tr) >= 3: 1adcAotjkfgiqrsuplmnwxbehv

642 stimOn_times = np.append(stimOn_times, tr[0]) 1dcAotjkfgiqrsuplmnbehv

643 stimOff_times = np.append(stimOff_times, tr[-1]) 1dcAotjkfgiqrsuplmnbehv

644 stimFreeze_times = np.append(stimFreeze_times, tr[-2]) 1dcAotjkfgiqrsuplmnbehv

645 else: 

646 stimOn_times = np.append(stimOn_times, np.nan) 1aAjkwxe

647 stimOff_times = np.append(stimOff_times, np.nan) 1aAjkwxe

648 stimFreeze_times = np.append(stimFreeze_times, np.nan) 1aAjkwxe

649 

650 # In no_go trials no stimFreeze happens just stim Off 

651 stimFreeze_times[choice == 0] = np.nan 1adcAotjkfgiqrsuplmnwxbehv

652 # Check for trigger times 

653 # 2nd order criteria: 

654 # stimOn -> Closest one to stimOnTrigger? 

655 # stimOff -> Closest one to stimOffTrigger? 

656 # stimFreeze -> Closest one to stimFreezeTrigger? 

657 

658 return stimOn_times, stimOff_times, stimFreeze_times 1adcAotjkfgiqrsuplmnwxbehv

659 

660 

661class PhasePosQuiescence(BaseBpodTrialsExtractor): 

662 """Extracts stimulus phase, position and quiescence from Bpod data. 

663 For extraction of pre-generated events, use the ProbaContrasts extractor instead. 

664 """ 

665 save_names = (None, None, '_ibl_trials.quiescencePeriod.npy') 

666 var_names = ('phase', 'position', 'quiescence') 

667 

668 def _extract(self, **kwargs): 

669 phase = np.array([t['stim_phase'] for t in self.bpod_trials]) 1adcojkfgiqrsuplmnwxbehv

670 position = np.array([t['position'] for t in self.bpod_trials]) 1adcojkfgiqrsuplmnwxbehv

671 quiescence = np.array([t['quiescent_period'] for t in self.bpod_trials]) 1adcojkfgiqrsuplmnwxbehv

672 return phase, position, quiescence 1adcojkfgiqrsuplmnwxbehv

673 

674 

675class TrialsTable(BaseBpodTrialsExtractor): 

676 """ 

677 Extracts the following into a table from Bpod raw data: 

678 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

679 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

680 Additionally extracts the following wheel data: 

681 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude 

682 """ 

683 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

684 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) 

685 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', 

686 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement') 

687 

688 def _extract(self, extractor_classes=None, **kwargs): 

689 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1dcfgb

690 RewardVolume, ProbabilityLeft, Wheel] 

691 out, _ = run_extractor_classes( 1dcfgb

692 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

693 task_collection=self.task_collection) 

694 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1dcfgb

695 assert len(table.keys()) == 12 1dcfgb

696 

697 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1dcfgb

698 

699 

700class TrainingTrials(BaseBpodTrialsExtractor): 

701 save_names = ('_ibl_trials.repNum.npy', '_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, 

702 None, None, None, '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

703 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, None) 

704 var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 

705 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 

706 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 

707 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence') 

708 

709 def _extract(self): 

710 base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1dcfgb

711 ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence] 

712 out, _ = run_extractor_classes( 1dcfgb

713 base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, 

714 task_collection=self.task_collection) 

715 return tuple(out.pop(x) for x in self.var_names) 1dcfgb

716 

717 

718def extract_all(session_path, save=False, bpod_trials=None, settings=None, task_collection='raw_behavior_data', save_path=None): 

719 """Extract trials and wheel data. 

720 

721 For task versions >= 5.0.0, outputs wheel data and trials.table dataset (+ some extra datasets) 

722 

723 Parameters 

724 ---------- 

725 session_path : str, pathlib.Path 

726 The path to the session 

727 save : bool 

728 If true save the data files to ALF 

729 bpod_trials : list of dicts 

730 The Bpod trial dicts loaded from the _iblrig_taskData.raw dataset 

731 settings : dict 

732 The Bpod settings loaded from the _iblrig_taskSettings.raw dataset 

733 

734 Returns 

735 ------- 

736 A list of extracted data and a list of file paths if save is True (otherwise None) 

737 """ 

738 if not bpod_trials: 1dcb

739 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 1d

740 if not settings: 1dcb

741 settings = raw.load_settings(session_path, task_collection=task_collection) 1d

742 if settings is None or settings['IBLRIG_VERSION_TAG'] == '': 1dcb

743 settings = {'IBLRIG_VERSION_TAG': '100.0.0'} 

744 

745 # Version check 

746 if parse_version(settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1dcb

747 # We now extract a single trials table 

748 base = [TrainingTrials] 1dcb

749 else: 

750 base = [ 1dcb

751 RepNum, GoCueTriggerTimes, Intervals, Wheel, FeedbackType, ContrastLR, ProbabilityLeft, Choice, IncludedTrials, 

752 StimOnTimes_deprecated, RewardVolume, FeedbackTimes, ResponseTimes, GoCueTimes, PhasePosQuiescence 

753 ] 

754 

755 out, fil = run_extractor_classes(base, save=save, session_path=session_path, bpod_trials=bpod_trials, settings=settings, 1dcb

756 task_collection=task_collection, path_out=save_path) 

757 return out, fil 1dcb