Coverage for brainbox/behavior/training.py: 58%

350 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Computing and testing IBL training status criteria. 

2 

3For an in-depth description of each training status, see `Appendix 2`_ of the IBL Protocol For Mice 

4Training. 

5 

6.. _Appendix 2: https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_\ 

7measure_decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729 

8 

9Examples 

10-------- 

11Plot the psychometric curve for a given session. 

12 

13>>> trials = ONE().load_object(eid, 'trials') 

14>>> fix, ax = plot_psychometric(trials) 

15 

16Compute 'response times', defined as the duration of open-loop for each contrast. 

17 

18>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(trials) 

19 

20Compute 'reaction times', defined as the time between go cue and first detected movement. 

21NB: These may be negative! 

22 

23>>> reaction_time, contrasts, n_contrasts = compute_reaction_time( 

24... trials, stim_on_type='goCue_times', stim_off_type='firstMovement_times') 

25 

26Compute 'response times', defined as the time between first detected movement and response. 

27 

28>>> reaction_time, contrasts, n_contrasts = compute_reaction_time( 

29... trials, stim_on_type='firstMovement_times', stim_off_type='response_times') 

30 

31Compute 'movement times', defined as the time between last detected movement and response threshold. 

32 

33>>> import brainbox.behavior.wheel as wh 

34>>> wheel_moves = ONE().load_object(eid, 'wheeMoves') 

35>>> trials['lastMovement_times'] = wh.get_movement_onset(wheel_moves.intervals, trial_data.response_times) 

36>>> reaction_time, contrasts, n_contrasts = compute_reaction_time( 

37... trials, stim_on_type='lastMovement_times', stim_off_type='response_times') 

38 

39""" 

40import logging 

41import datetime 

42import re 

43from enum import IntFlag, auto, unique 

44 

45import numpy as np 

46import matplotlib 

47import matplotlib.pyplot as plt 

48import seaborn as sns 

49import pandas as pd 

50from scipy.stats import bootstrap 

51from iblutil.util import Bunch 

52from one.api import ONE 

53from one.alf.io import AlfBunch 

54from one.alf.exceptions import ALFObjectNotFound 

55import psychofit as psy 

56 

57_logger = logging.getLogger('ibllib') 

58 

59TRIALS_KEYS = ['contrastLeft', 

60 'contrastRight', 

61 'feedbackType', 

62 'probabilityLeft', 

63 'choice', 

64 'response_times', 

65 'stimOn_times'] 

66"""list of str: The required keys in the trials object for computing training status.""" 

67 

68 

69@unique 

70class TrainingStatus(IntFlag): 

71 """Standard IBL training criteria. 

72 

73 Enumeration allows for comparisons between training status. 

74 

75 Examples 

76 -------- 

77 >>> status = 'ready4delay' 

78 ... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY 

79 ... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training' 

80 ... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained' 

81 ... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained' 

82 ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained' 

83 ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY 

84 

85 Get the next training status 

86 

87 >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()]) 

88 <TrainingStatus.READY4RECORDING: 128> 

89 

90 Notes 

91 ----- 

92 - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b. 

93 - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it 

94 is possible to have skipped the TRAINED_1A session status. 

95 """ 

96 UNTRAINABLE = auto() 

97 UNBIASABLE = auto() 

98 IN_TRAINING = auto() 

99 TRAINED_1A = auto() 

100 TRAINED_1B = auto() 

101 READY4EPHYSRIG = auto() 

102 READY4DELAY = auto() 

103 READY4RECORDING = auto() 

104 # Compound training statuses for convenience 

105 FAILED = UNTRAINABLE | UNBIASABLE 

106 READY = READY4EPHYSRIG | READY4DELAY | READY4RECORDING 

107 TRAINED = TRAINED_1A | TRAINED_1B 

108 

109 

110def get_lab_training_status(lab, date=None, details=True, one=None): 

111 """ 

112 Computes the training status of all alive and water restricted subjects in a specified lab. 

113 

114 The response are printed to std out. 

115 

116 Parameters 

117 ---------- 

118 lab : str 

119 Lab name (must match the name registered on Alyx). 

120 date : str 

121 The ISO date from which to compute training status. If not specified will compute from the 

122 latest date with available data. Format should be 'YYYY-MM-DD'. 

123 details : bool 

124 Whether to display all information about training status computation e.g. performance, 

125 number of trials, psychometric fit parameters. 

126 one : one.api.OneAlyx 

127 An instance of ONE. 

128 

129 """ 

130 one = one or ONE() 

131 subj_lab = one.alyx.rest('subjects', 'list', lab=lab, alive=True, water_restricted=True) 

132 subjects = [subj['nickname'] for subj in subj_lab] 

133 for subj in subjects: 

134 get_subject_training_status(subj, date=date, details=details, one=one) 

135 

136 

137def get_subject_training_status(subj, date=None, details=True, one=None): 

138 """ 

139 Computes the training status of specified subject and prints results to std out. 

140 

141 Parameters 

142 ---------- 

143 subj : str 

144 Subject nickname (must match the name registered on Alyx). 

145 date : str 

146 The ISO date from which to compute training status. If not specified will compute from the 

147 latest date with available data. Format should be 'YYYY-MM-DD'. 

148 details : bool 

149 Whether to display all information about training status computation e.g. performance, 

150 number of trials, psychometric fit parameters. 

151 one : one.api.OneAlyx 

152 An instance of ONE. 

153 """ 

154 one = one or ONE() 

155 

156 trials, task_protocol, ephys_sess, n_delay = get_sessions(subj, date=date, one=one) 

157 if not trials: 

158 return 

159 sess_dates = list(trials.keys()) 

160 status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay) 

161 

162 if details: 

163 if np.any(info.get('psych')): 

164 display_status(subj, sess_dates, status, perf_easy=info.perf_easy, 

165 n_trials=info.n_trials, psych=info.psych, rt=info.rt) 

166 elif np.any(info.get('psych_20')): 

167 display_status(subj, sess_dates, status, perf_easy=info.perf_easy, 

168 n_trials=info.n_trials, psych_20=info.psych_20, psych_80=info.psych_80, 

169 rt=info.rt) 

170 else: 

171 display_status(subj, sess_dates, status) 

172 

173 

174def get_sessions(subj, date=None, one=None): 

175 """ 

176 Download and load in training data for a specified subject. If a date is given it will load 

177 data from the three (or as many as are available) previous sessions up to the specified date. 

178 If not it will load data from the last three training sessions that have data available. 

179 

180 Parameters 

181 ---------- 

182 subj : str 

183 Subject nickname (must match the name registered on Alyx). 

184 date : str 

185 The ISO date from which to compute training status. If not specified will compute from the 

186 latest date with available data. Format should be 'YYYY-MM-DD'. 

187 one : one.api.OneAlyx 

188 An instance of ONE. 

189 

190 Returns 

191 ------- 

192 iblutil.util.Bunch 

193 Dictionary of trials objects where each key is the ISO session date string. 

194 list of str 

195 List of the task protocol used for each of the sessions. 

196 list of str 

197 List of ISO date strings where training was conducted on ephys rig. Empty list if all 

198 sessions on training rig. 

199 n_delay : int 

200 Number of sessions on ephys rig that had delay prior to starting session > 15min. 

201 Returns 0 if no sessions detected. 

202 """ 

203 one = one or ONE() 

204 

205 if date is None: 

206 # compute from yesterday 

207 specified_date = (datetime.date.today() - datetime.timedelta(days=1)) 

208 latest_sess = specified_date.strftime("%Y-%m-%d") 

209 latest_minus_week = (datetime.date.today() - 

210 datetime.timedelta(days=8)).strftime("%Y-%m-%d") 

211 else: 

212 # compute from the date specified 

213 specified_date = datetime.datetime.strptime(date, '%Y-%m-%d') 

214 latest_minus_week = (specified_date - datetime.timedelta(days=7)).strftime("%Y-%m-%d") 

215 latest_sess = date 

216 

217 sessions = one.alyx.rest('sessions', 'list', subject=subj, date_range=[latest_minus_week, 

218 latest_sess], dataset_types='trials.goCueTrigger_times') 

219 

220 # If not enough sessions in the last week, then just fetch them all 

221 if len(sessions) < 3: 

222 specified_date_plus = (specified_date + datetime.timedelta(days=1)).strftime("%Y-%m-%d") 

223 django_query = 'start_time__lte,' + specified_date_plus 

224 sessions = one.alyx.rest('sessions', 'list', subject=subj, 

225 dataset_types='trials.goCueTrigger_times', django=django_query) 

226 

227 # If still 0 sessions then return with warning 

228 if len(sessions) == 0: 

229 _logger.warning(f"No training sessions detected for {subj}") 

230 return [None] * 4 

231 

232 trials = Bunch() 

233 task_protocol = [] 

234 sess_dates = [] 

235 if len(sessions) < 3: 

236 for n, _ in enumerate(sessions): 

237 try: 

238 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials') 

239 except ALFObjectNotFound: 

240 trials_ = None 

241 

242 if trials_: 

243 task_protocol.append(re.search('tasks_(.*)Choice', 

244 sessions[n]['task_protocol']).group(1)) 

245 sess_dates.append(sessions[n]['start_time'][:10]) 

246 trials[sessions[n]['start_time'][:10]] = trials_ 

247 

248 else: 

249 n = 0 

250 while len(trials) < 3: 

251 print(sessions[n]['url'].split('/')[-1]) 

252 try: 

253 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials') 

254 except ALFObjectNotFound: 

255 trials_ = None 

256 

257 if trials_: 

258 task_protocol.append(re.search('tasks_(.*)Choice', 

259 sessions[n]['task_protocol']).group(1)) 

260 sess_dates.append(sessions[n]['start_time'][:10]) 

261 trials[sessions[n]['start_time'][:10]] = trials_ 

262 

263 n += 1 

264 

265 if not np.any(np.array(task_protocol) == 'training'): 

266 ephys_sess = one.alyx.rest('sessions', 'list', subject=subj, 

267 date_range=[sess_dates[-1], sess_dates[0]], 

268 django='json__PYBPOD_BOARD__icontains,ephys') 

269 if len(ephys_sess) > 0: 

270 ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess] 

271 

272 n_delay = len(one.alyx.rest('sessions', 'list', subject=subj, 

273 date_range=[sess_dates[-1], sess_dates[0]], 

274 django='json__SESSION_START_DELAY_SEC__gte,900')) 

275 else: 

276 ephys_sess_dates = [] 

277 n_delay = 0 

278 else: 

279 ephys_sess_dates = [] 

280 n_delay = 0 

281 

282 return trials, task_protocol, ephys_sess_dates, n_delay 

283 

284 

285def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): 

286 """ 

287 Compute training status of a subject from consecutive training datasets. 

288 

289 For IBL, training status is calculated using trials from the last three consecutive sessions. 

290 

291 Parameters 

292 ---------- 

293 trials : dict of str 

294 Dictionary of trials objects where each key is the ISO session date string. 

295 task_protocol : list of str 

296 Task protocol used for each training session in `trials`, can be 'training', 'biased' or 

297 'ephys'. 

298 ephys_sess_dates : list of str 

299 List of ISO date strings where training was conducted on ephys rig. Empty list if all 

300 sessions on training rig. 

301 n_delay : int 

302 Number of sessions on ephys rig that had delay prior to starting session > 15min. 

303 Returns 0 if no sessions detected. 

304 

305 Returns 

306 ------- 

307 str 

308 Training status of the subject. 

309 iblutil.util.Bunch 

310 Bunch containing performance metrics that decide training status i.e. performance on easy 

311 trials, number of trials, psychometric fit parameters, reaction time. 

312 """ 

313 

314 info = Bunch() 1gdheijkcab

315 trials_all = concatenate_trials(trials) 1gdheijkcab

316 

317 # Case when all sessions are trainingChoiceWorld 

318 if np.all(np.array(task_protocol) == 'training'): 1gdheijkcab

319 signed_contrast = get_signed_contrast(trials_all) 1gijab

320 (info.perf_easy, info.n_trials, 1gijab

321 info.psych, info.rt) = compute_training_info(trials, trials_all) 

322 if not np.any(signed_contrast == 0): 1gijab

323 status = 'in training' 1ab

324 else: 

325 if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt): 1gijb

326 status = 'trained 1b' 1jb

327 elif criterion_1a(info.psych, info.n_trials, info.perf_easy): 1gib

328 status = 'trained 1a' 1ib

329 else: 

330 status = 'in training' 1g

331 

332 return status, info 1gijab

333 

334 # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion 

335 if ~np.all(np.array(task_protocol) == 'training') and \ 1dhekcab

336 np.any(np.array(task_protocol) == 'training'): 

337 status = 'trained 1b' 1kb

338 (info.perf_easy, info.n_trials, 1kb

339 info.psych, info.rt) = compute_training_info(trials, trials_all) 

340 

341 return status, info 1kb

342 

343 # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions 

344 if not np.any(np.array(task_protocol) == 'training'): 1dhecab

345 

346 (info.perf_easy, info.n_trials, 1dhecab

347 info.psych_20, info.psych_80, 

348 info.rt) = compute_bias_info(trials, trials_all) 

349 # We are still on training rig and so all sessions should be biased 

350 if len(ephys_sess_dates) == 0: 1dhecab

351 assert np.all(np.array(task_protocol) == 'biased') 1hcab

352 if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, 1hcab

353 info.rt): 

354 status = 'ready4ephysrig' 1hb

355 else: 

356 status = 'trained 1b' 1cab

357 

358 elif len(ephys_sess_dates) < 3: 1deab

359 assert all(date in trials for date in ephys_sess_dates) 1dab

360 perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in 1dab

361 ephys_sess_dates]) 

362 n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) 1dab

363 

364 if criterion_delay(n_ephys_trials, perf_ephys_easy): 1dab

365 status = 'ready4delay' 1db

366 else: 

367 status = 'ready4ephysrig' 1a

368 

369 elif len(ephys_sess_dates) >= 3: 1eb

370 if n_delay > 0 and \ 1eb

371 criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, 

372 info.rt): 

373 status = 'ready4recording' 1eb

374 elif criterion_delay(info.n_trials, info.perf_easy): 1b

375 status = 'ready4delay' 1b

376 else: 

377 status = 'ready4ephysrig' 

378 

379 return status, info 1dhecab

380 

381 

382def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None, 

383 psych_20=None, psych_80=None, rt=None): 

384 """ 

385 Display training status of subject to terminal. 

386 

387 Parameters 

388 ---------- 

389 subj : str 

390 Subject nickname (must match the name registered on Alyx). 

391 sess_dates : list of str 

392 ISO date strings of training sessions used to determine training status. 

393 status : str 

394 Training status of subject. 

395 perf_easy : numpy.array 

396 Proportion of correct high contrast trials for each training session. 

397 n_trials : numpy.array 

398 Total number of trials for each training session. 

399 psych : numpy.array 

400 Psychometric parameters fit to data from all training sessions - bias, threshold, lapse 

401 high, lapse low. 

402 psych_20 : numpy.array 

403 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. 

404 psych_80 : numpy.array 

405 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. 

406 rt : float 

407 The median response time for zero contrast trials across all training sessions. NaN 

408 indicates no zero contrast stimuli in training sessions. 

409 

410 """ 

411 

412 if perf_easy is None: 

413 print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, " 

414 f"{sess_dates[2]}]") 

415 elif psych_20 is None: 

416 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, " 

417 f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, " 

418 f"N trials={[nt for nt in n_trials]} " 

419 f"\nPsych fit over last 3 sessions: " 

420 f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, " 

421 f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} " 

422 f"\nMedian reaction time at 0 contrast over last 3 sessions = " 

423 f"{np.around(rt,2)}") 

424 

425 else: 

426 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, " 

427 f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, " 

428 f"N trials={[nt for nt in n_trials]} " 

429 f"\nPsych fit over last 3 sessions (20): " 

430 f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, " 

431 f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} " 

432 f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, " 

433 f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, " 

434 f"lapse_high={np.around(psych_80[3],2)} " 

435 f"\nMedian reaction time at 0 contrast over last 3 sessions = " 

436 f"{np.around(rt, 2)}") 

437 

438 

439def concatenate_trials(trials): 

440 """ 

441 Concatenate trials from different training sessions. 

442 

443 Parameters 

444 ---------- 

445 trials : dict of str 

446 Dictionary of trials objects where each key is the ISO session date string. 

447 

448 Returns 

449 ------- 

450 one.alf.io.AlfBunch 

451 Trials object with data concatenated over three training sessions. 

452 """ 

453 trials_all = AlfBunch() 1lgndheijkcaob

454 for k in TRIALS_KEYS: 1lgndheijkcaob

455 trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys())) 1lgndheijkcaob

456 

457 return trials_all 1lgndheijkcaob

458 

459 

460def compute_training_info(trials, trials_all): 

461 """ 

462 Compute all relevant performance metrics for when subject is on trainingChoiceWorld. 

463 

464 Parameters 

465 ---------- 

466 trials : dict of str 

467 Dictionary of trials objects where each key is the ISO session date string. 

468 trials_all : one.alf.io.AlfBunch 

469 Trials object with data concatenated over three training sessions. 

470 

471 Returns 

472 ------- 

473 numpy.array 

474 Proportion of correct high contrast trials for each session. 

475 numpy.array 

476 Total number of trials for each training session. 

477 numpy.array 

478 Array of psychometric parameters fit to `all_trials` - bias, threshold, lapse high, 

479 lapse low. 

480 float 

481 The median response time for all zero-contrast trials across all sessions. Returns NaN if 

482 no trials zero-contrast trials). 

483 """ 

484 

485 signed_contrast = get_signed_contrast(trials_all) 1gijkab

486 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1gijkab

487 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1gijkab

488 psych = compute_psychometric(trials_all, signed_contrast=signed_contrast) 1gijkab

489 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1gijkab

490 

491 return perf_easy, n_trials, psych, rt 1gijkab

492 

493 

494def compute_bias_info(trials, trials_all): 

495 """ 

496 Compute all relevant performance metrics for when subject is on biasedChoiceWorld 

497 

498 :param trials: dict containing trials objects from three consecutive training sessions, 

499 keys are session dates 

500 :type trials: Bunch 

501 :param trials_all: trials object with data concatenated over three training sessions 

502 :type trials_all: Bunch 

503 :returns: 

504 - perf_easy - performance of easy trials for each session 

505 - n_trials - number of trials in each session 

506 - psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions 

507 - psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions 

508 - rt - median reaction time for zero contrast stimuli over all sessions 

509 """ 

510 

511 signed_contrast = get_signed_contrast(trials_all) 1dhecab

512 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1dhecab

513 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1dhecab

514 psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2) 1dhecab

515 psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8) 1dhecab

516 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1dhecab

517 

518 return perf_easy, n_trials, psych_20, psych_80, rt 1dhecab

519 

520 

521def get_signed_contrast(trials): 

522 """ 

523 Compute signed contrast from trials object 

524 

525 :param trials: trials object that must contain contrastLeft and contrastRight keys 

526 :type trials: dict 

527 returns: array of signed contrasts in percent, where -ve values are on the left 

528 """ 

529 # Replace NaNs with zeros, stack and take the difference 

530 contrast = np.nan_to_num(np.c_[trials['contrastLeft'], trials['contrastRight']]) 1lgndheijkcab

531 return np.diff(contrast).flatten() * 100 1lgndheijkcab

532 

533 

534def compute_performance_easy(trials): 

535 """ 

536 Compute performance on easy trials (stimulus >= 50 %) from trials object 

537 

538 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType 

539 keys 

540 :type trials: dict 

541 returns: float containing performance on easy contrast trials 

542 """ 

543 signed_contrast = get_signed_contrast(trials) 1lgdheijkcab

544 easy_trials = np.where(np.abs(signed_contrast) >= 50)[0] 1lgdheijkcab

545 return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0] 1lgdheijkcab

546 

547 

548def compute_performance(trials, signed_contrast=None, block=None, prob_right=False): 

549 """ 

550 Compute performance on all trials at each contrast level from trials object 

551 

552 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType 

553 keys 

554 :type trials: dict 

555 returns: float containing performance on easy contrast trials 

556 """ 

557 if signed_contrast is None: 1lgdheijkcab

558 signed_contrast = get_signed_contrast(trials) 1cab

559 

560 if block is None: 1lgdheijkcab

561 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1lgijkcab

562 else: 

563 block_idx = trials.probabilityLeft == block 1dhecab

564 

565 if not np.any(block_idx): 1lgdheijkcab

566 return np.nan * np.zeros(3) 1ca

567 

568 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 1lgdheijkcab

569 

570 if not prob_right: 1lgdheijkcab

571 correct = trials.feedbackType == 1 

572 performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts) 

573 else: 

574 rightward = trials.choice == -1 1lgdheijkcab

575 # Calculate the proportion rightward for each contrast type 

576 performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts) 1lgdheijkcab

577 

578 return performance, contrasts, n_contrasts 1lgdheijkcab

579 

580 

581def compute_n_trials(trials): 

582 """ 

583 Compute number of trials in trials object 

584 

585 :param trials: trials object 

586 :type trials: dict 

587 returns: int containing number of trials in session 

588 """ 

589 return trials['choice'].shape[0] 1lgdheijkcab

590 

591 

592def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False, compute_ci=False, alpha=.032): 

593 """ 

594 Compute psychometric fit parameters for trials object. 

595 

596 Parameters 

597 ---------- 

598 trials : one.alf.io.AlfBunch 

599 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

600 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

601 signed_contrast : numpy.array 

602 An array of signed contrasts in percent the length of trials, where left contrasts are -ve. 

603 If None, these are computed from the trials object. 

604 block : float 

605 The block type to compute. If None, all trials are included, otherwise only trials where 

606 probabilityLeft matches this value are included. For biasedChoiceWorld, the 

607 probabilityLeft set is {0.5, 0.2, 0.8}. 

608 plotting : bool 

609 Which set of psychofit model parameters to use (see notes). 

610 compute_ci : bool 

611 If true, computes and returns the confidence intervals for response at each contrast. 

612 alpha : float, default=0.032 

613 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, 

614 this value is ignored. 

615 

616 Returns 

617 ------- 

618 numpy.array 

619 Array of psychometric fit parameters - bias, threshold, lapse high, lapse low. 

620 (tuple of numpy.array) 

621 If `compute_ci` is true, a tuple of 

622 

623 See Also 

624 -------- 

625 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence 

626 interval. 

627 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. 

628 

629 Notes 

630 ----- 

631 The psychofit starting parameters and model constraints used for the fit when computing the 

632 training status (e.g. trained_1a, etc.) are sub-optimal and can produce a poor fit. To keep 

633 the precise criteria the same for all subjects, these parameters have not changed. To produce a 

634 better fit for plotting purposes, or to calculate the training status in a manner inconsistent 

635 with the IBL training pipeline, use plotting=True. 

636 """ 

637 

638 if signed_contrast is None: 1lgndheijkcab

639 signed_contrast = get_signed_contrast(trials) 1lncab

640 

641 if block is None: 1lgndheijkcab

642 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1lgijkab

643 else: 

644 block_idx = trials.probabilityLeft == block 1ndhecab

645 

646 if not np.any(block_idx): 1lgndheijkcab

647 return np.nan * np.zeros(4) 1ncab

648 

649 prob_choose_right, contrasts, n_contrasts = compute_performance( 1lgdheijkcab

650 trials, signed_contrast=signed_contrast, block=block, prob_right=True) 

651 

652 if plotting: 1lgdheijkcab

653 # These starting parameters and constraints tend to produce a better fit, and are therefore 

654 # used for plotting. 

655 psych, _ = psy.mle_fit_psycho( 1ca

656 np.vstack([contrasts, n_contrasts, prob_choose_right]), 

657 P_model='erf_psycho_2gammas', 

658 parstart=np.array([0., 40., 0.1, 0.1]), 

659 parmin=np.array([-50., 10., 0., 0.]), 

660 parmax=np.array([50., 50., 0.2, 0.2]), 

661 nfits=10) 

662 else: 

663 # These starting parameters and constraints are not ideal but are still used for computing 

664 # the training status for consistency. 

665 psych, _ = psy.mle_fit_psycho( 1lgdheijkcab

666 np.vstack([contrasts, n_contrasts, prob_choose_right]), 

667 P_model='erf_psycho_2gammas', 

668 parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]), 

669 parmin=np.array([np.min(contrasts), 0., 0., 0.]), 

670 parmax=np.array([np.max(contrasts), 100., 1, 1])) 

671 

672 if compute_ci: 1lgdheijkcab

673 import statsmodels.stats.proportion as smp # noqa 

674 # choice == -1 means contrast on right hand side 

675 n_right = np.vectorize(lambda x: np.sum(trials['choice'][(x == signed_contrast) & block_idx] == -1))(contrasts) 

676 ci = smp.proportion_confint(n_right, n_contrasts, alpha=alpha, method='normal') - prob_choose_right 

677 

678 return psych, ci 

679 else: 

680 return psych 1lgdheijkcab

681 

682 

683def compute_median_reaction_time(trials, stim_on_type='stimOn_times', contrast=None, signed_contrast=None): 

684 """ 

685 Compute median response time on zero contrast trials from trials object 

686 

687 Parameters 

688 ---------- 

689 trials : one.alf.io.AlfBunch 

690 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

691 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

692 stim_on_type : str, default='stimOn_times' 

693 The trials key to use when calculating the response times. The difference between this and 

694 'feedback_times' is used (see notes). 

695 contrast : float 

696 If None, the median response time is calculated for all trials, regardless of contrast, 

697 otherwise only trials where the matching signed percent contrast was presented are used. 

698 signed_contrast : numpy.array 

699 An array of signed contrasts in percent the length of trials, where left contrasts are -ve. 

700 If None, these are computed from the trials object. 

701 

702 Returns 

703 ------- 

704 float 

705 The median response time for trials with `contrast` (returns NaN if no trials matching 

706 `contrast` in trials object). 

707 

708 Notes 

709 ----- 

710 - The `stim_on_type` is 'stimOn_times' by default, however for IBL rig data, the photodiode is 

711 sometimes not calibrated properly which can lead to inaccurate (or absent, i.e. NaN) stim on 

712 times. Therefore, it is sometimes more accurate to use the 'stimOnTrigger_times' (the time of 

713 the stimulus onset command), if available, or the 'goCue_times' (the time of the soundcard 

714 output TTL when the audio go cue is played) or the 'goCueTrigger_times' (the time of the 

715 audio go cue command). 

716 - The response/reaction time here is defined as the time between stim on and feedback, i.e. the 

717 entire open-loop trial duration. 

718 """ 

719 if signed_contrast is None: 1lgdheijkcab

720 signed_contrast = get_signed_contrast(trials) 1lcab

721 

722 if contrast is None: 1lgdheijkcab

723 contrast_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1cab

724 else: 

725 contrast_idx = signed_contrast == contrast 1lgdheijkcab

726 

727 if np.any(contrast_idx): 1lgdheijkcab

728 reaction_time = np.nanmedian((trials.response_times - trials[stim_on_type]) 1lgdheijkcab

729 [contrast_idx]) 

730 else: 

731 reaction_time = np.nan 1ab

732 

733 return reaction_time 1lgdheijkcab

734 

735 

736def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='response_times', signed_contrast=None, block=None, 

737 compute_ci=False, alpha=0.32): 

738 """ 

739 Compute median response time for all contrasts. 

740 

741 Parameters 

742 ---------- 

743 trials : one.alf.io.AlfBunch 

744 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

745 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

746 stim_on_type : str, default='stimOn_times' 

747 The trials key to use when calculating the response times. The difference between this and 

748 `stim_off_type` is used (see notes). 

749 stim_off_type : str, default='response_times' 

750 The trials key to use when calculating the response times. The difference between this and 

751 `stim_on_type` is used (see notes). 

752 signed_contrast : numpy.array 

753 An array of signed contrasts in percent the length of trials, where left contrasts are -ve. 

754 If None, these are computed from the trials object. 

755 block : float 

756 The block type to compute. If None, all trials are included, otherwise only trials where 

757 probabilityLeft matches this value are included. For biasedChoiceWorld, the 

758 probabilityLeft set is {0.5, 0.2, 0.8}. 

759 compute_ci : bool 

760 If true, computes and returns the confidence intervals for response time at each contrast. 

761 alpha : float, default=0.32 

762 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, 

763 this value is ignored. 

764 

765 Returns 

766 ------- 

767 numpy.array 

768 The median response times for each unique signed contrast. 

769 numpy.array 

770 The set of unique signed contrasts. 

771 numpy.array 

772 The number of trials for each unique signed contrast. 

773 (numpy.array) 

774 If `compute_ci` is true, an array of confidence intervals is return in the shape (n_trials, 

775 2). 

776 

777 Notes 

778 ----- 

779 - The response/reaction time by default is the time between stim on and response, i.e. the 

780 entire open-loop trial duration. One could use 'stimOn_times' and 'firstMovement_times' to 

781 get the true reaction time, or 'firstMovement_times' and 'response_times' to get the true 

782 response times, or calculate the last movement onset times and calculate the true movement 

783 times. See module examples for how to calculate this. 

784 

785 See Also 

786 -------- 

787 scipy.stats.bootstrap - the function used to compute the confidence interval. 

788 """ 

789 

790 if signed_contrast is None: 

791 signed_contrast = get_signed_contrast(trials) 

792 

793 if block is None: 

794 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 

795 else: 

796 block_idx = trials.probabilityLeft == block 

797 

798 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 

799 reaction_time = np.vectorize( 

800 lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]), 

801 otypes=[float] 

802 )(contrasts) 

803 

804 if compute_ci: 

805 ci = np.full((contrasts.size, 2), np.nan) 

806 for i, x in enumerate(contrasts): 

807 data = (trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx] 

808 bt = bootstrap((data,), np.nanmedian, confidence_level=1 - alpha) 

809 ci[i, 0] = bt.confidence_interval.low 

810 ci[i, 1] = bt.confidence_interval.high 

811 

812 return reaction_time, contrasts, n_contrasts, ci 

813 else: 

814 return reaction_time, contrasts, n_contrasts, 

815 

816 

817def criterion_1a(psych, n_trials, perf_easy): 

818 """ 

819 Returns bool indicating whether criteria for status 'trained_1a' are met. 

820 

821 Criteria 

822 -------- 

823 - Bias is less than 16 

824 - Threshold is less than 19 

825 - Lapse rate on both sides is less than 0.2 

826 - The total number of trials is greater than 200 for each session 

827 - Performance on easy contrasts > 80% for all sessions 

828 

829 Parameters 

830 ---------- 

831 psych : numpy.array 

832 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, 

833 lapse high, lapse low. 

834 n_trials : numpy.array of int 

835 The number for trials for each session. 

836 perf_easy : numpy.array of float 

837 The proportion of correct high contrast trials for each session. 

838 

839 Returns 

840 ------- 

841 bool 

842 True if the criteria are met for 'trained_1a'. 

843 

844 Notes 

845 ----- 

846 The parameter thresholds chosen here were originally determined by averaging the parameter fits 

847 for a number of sessions determined to be of 'good' performance by an experimenter. 

848 """ 

849 

850 criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and 1gib

851 np.all(n_trials > 200) and np.all(perf_easy > 0.8)) 

852 return criterion 1gib

853 

854 

855def criterion_1b(psych, n_trials, perf_easy, rt): 

856 """ 

857 Returns bool indicating whether criteria for trained_1b are met. 

858 

859 Criteria 

860 -------- 

861 - Bias is less than 10 

862 - Threshold is less than 20 (see notes) 

863 - Lapse rate on both sides is less than 0.1 

864 - The total number of trials is greater than 400 for each session 

865 - Performance on easy contrasts > 90% for all sessions 

866 - The median response time across all zero contrast trials is less than 2 seconds 

867 

868 Parameters 

869 ---------- 

870 psych : numpy.array 

871 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, 

872 lapse high, lapse low. 

873 n_trials : numpy.array of int 

874 The number for trials for each session. 

875 perf_easy : numpy.array of float 

876 The proportion of correct high contrast trials for each session. 

877 rt : float 

878 The median response time for zero contrast trials. 

879 

880 Returns 

881 ------- 

882 bool 

883 True if the criteria are met for 'trained_1b'. 

884 

885 Notes 

886 ----- 

887 The parameter thresholds chosen here were originally chosen to be slightly stricter than 1a, 

888 however it was decided to use round numbers so that readers would not assume a level of 

889 precision that isn't there (remember, these parameters were not chosen with any rigor). This 

890 regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the 

891 slope of the psychometric curve may be slightly less steep than 1a. 

892 """ 

893 criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and 1gijb

894 np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2) 

895 return criterion 1gijb

896 

897 

898def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): 

899 """ 

900 Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met. 

901 

902 NB: The difference between these two is whether the sessions were acquired ot a recording rig 

903 with a delay before the first trial. Neither of these two things are tested here. 

904 

905 Criteria 

906 -------- 

907 - Lapse on both sides < 0.1 for both bias blocks 

908 - Bias shift between blocks > 5 

909 - Total number of trials > 400 for all sessions 

910 - Performance on easy contrasts > 90% for all sessions 

911 - Median response time for zero contrast stimuli < 2 seconds 

912 

913 Parameters 

914 ---------- 

915 psych_20 : numpy.array 

916 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. 

917 Parameters are bias, threshold, lapse high, lapse low. 

918 psych_80 : numpy.array 

919 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. 

920 Parameters are bias, threshold, lapse high, lapse low. 

921 n_trials : numpy.array 

922 The number of trials for each session (typically three consecutive sessions). 

923 perf_easy : numpy.array 

924 The proportion of correct high contrast trials for each session (typically three 

925 consecutive sessions). 

926 rt : float 

927 The median response time for zero contrast trials. 

928 

929 Returns 

930 ------- 

931 bool 

932 True if subject passes the ready4ephysrig or ready4recording criteria. 

933 """ 

934 

935 criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse 1hecab

936 psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials 

937 np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times 

938 return criterion 1hecab

939 

940 

941def criterion_delay(n_trials, perf_easy): 

942 """ 

943 Returns bool indicating whether criteria for 'ready4delay' is met. 

944 

945 Criteria 

946 -------- 

947 - Total number of trials for any of the sessions is greater than 400 

948 - Performance on easy contrasts is greater than 90% for any of the sessions 

949 

950 Parameters 

951 ---------- 

952 n_trials : numpy.array of int 

953 The number of trials for each session (typically three consecutive sessions). 

954 perf_easy : numpy.array 

955 The proportion of correct high contrast trials for each session (typically three 

956 consecutive sessions). 

957 

958 Returns 

959 ------- 

960 bool 

961 True if subject passes the 'ready4delay' criteria. 

962 """ 

963 criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9) 1dab

964 return criterion 1dab

965 

966 

967def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs): 

968 """ 

969 Function to plot psychometric curve plots a la datajoint webpage. 

970 

971 Parameters 

972 ---------- 

973 trials : one.alf.io.AlfBunch 

974 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

975 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

976 ax : matplotlib.pyplot.Axes 

977 An axis object to plot onto. 

978 title : str 

979 An optional plot title. 

980 plot_ci : bool 

981 If true, computes and plots the confidence intervals for response at each contrast. 

982 ci_alpha : float, default=0.032 

983 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, 

984 this value is ignored. 

985 **kwargs 

986 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. 

987 

988 Returns 

989 ------- 

990 matplotlib.pyplot.Figure 

991 The figure handle containing the plot. 

992 matplotlib.pyplot.Axes 

993 The plotted axes. 

994 

995 See Also 

996 -------- 

997 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence 

998 interval. 

999 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. 

1000 psychofit.erf_psycho_2gammas - The function used to transform contrast to response probability 

1001 using the fit parameters. 

1002 """ 

1003 

1004 signed_contrast = get_signed_contrast(trials) 1ca

1005 contrasts_fit = np.arange(-100, 100) 1ca

1006 

1007 prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True) 1ca

1008 out_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True, 1ca

1009 compute_ci=plot_ci, alpha=ci_alpha) 

1010 pars_50 = out_50[0] if plot_ci else out_50 1ca

1011 prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit) 1ca

1012 

1013 prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True) 1ca

1014 out_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True, 1ca

1015 compute_ci=plot_ci, alpha=ci_alpha) 

1016 pars_20 = out_20[0] if plot_ci else out_20 1ca

1017 prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit) 1ca

1018 

1019 prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True) 1ca

1020 out_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True, 1ca

1021 compute_ci=plot_ci, alpha=ci_alpha) 

1022 pars_80 = out_80[0] if plot_ci else out_80 1ca

1023 prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit) 1ca

1024 

1025 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 1ca

1026 

1027 if not ax: 1ca

1028 fig, ax = plt.subplots(**kwargs) 

1029 else: 

1030 fig = plt.gcf() 1ca

1031 

1032 fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1]) 1ca

1033 data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1]) 1ca

1034 fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0]) 1ca

1035 data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0]) 1ca

1036 fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2]) 1ca

1037 data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2]) 1ca

1038 

1039 if plot_ci: 1ca

1040 errbar_50 = np.c_[np.abs(out_50[1][0]), np.abs(out_50[1][1])].T 

1041 errbar_20 = np.c_[np.abs(out_20[1][0]), np.abs(out_20[1][1])].T 

1042 errbar_80 = np.c_[np.abs(out_80[1][0]), np.abs(out_80[1][1])].T 

1043 

1044 ax.errorbar(contrasts_50, prob_right_50, yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4) 

1045 ax.errorbar(contrasts_20, prob_right_20, yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4) 

1046 ax.errorbar(contrasts_80, prob_right_80, yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4) 

1047 

1048 ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80], 1ca

1049 ['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'], 

1050 loc='upper left') 

1051 ax.set_ylim(-0.05, 1.05) 1ca

1052 ax.set_ylabel('Probability choosing right') 1ca

1053 ax.set_xlabel('Contrasts') 1ca

1054 if title: 1ca

1055 ax.set_title(title) 1ca

1056 

1057 return fig, ax 1ca

1058 

1059 

1060def plot_reaction_time(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.32, **kwargs): 

1061 """ 

1062 Function to plot reaction time against contrast a la datajoint webpage. 

1063 

1064 The reaction times are plotted individually for the following three blocks: {0.5, 0.2, 0.8}. 

1065 

1066 Parameters 

1067 ---------- 

1068 trials : one.alf.io.AlfBunch 

1069 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

1070 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

1071 ax : matplotlib.pyplot.Axes 

1072 An axis object to plot onto. 

1073 title : str 

1074 An optional plot title. 

1075 plot_ci : bool 

1076 If true, computes and plots the confidence intervals for response at each contrast. 

1077 ci_alpha : float, default=0.32 

1078 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, 

1079 this value is ignored. 

1080 **kwargs 

1081 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. 

1082 

1083 Returns 

1084 ------- 

1085 matplotlib.pyplot.Figure 

1086 The figure handle containing the plot. 

1087 matplotlib.pyplot.Axes 

1088 The plotted axes. 

1089 

1090 See Also 

1091 -------- 

1092 scipy.stats.bootstrap - the function used to compute the confidence interval. 

1093 """ 

1094 

1095 signed_contrast = get_signed_contrast(trials) 

1096 out_50 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5, compute_ci=plot_ci, alpha=ci_alpha) 

1097 out_20 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2, compute_ci=plot_ci, alpha=ci_alpha) 

1098 out_80 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8, compute_ci=plot_ci, alpha=ci_alpha) 

1099 

1100 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 

1101 

1102 if not ax: 

1103 fig, ax = plt.subplots(**kwargs) 

1104 else: 

1105 fig = plt.gcf() 

1106 

1107 data_50 = ax.plot(out_50[1], out_50[0], '-o', color=cmap[1]) 

1108 data_20 = ax.plot(out_20[1], out_20[0], '-o', color=cmap[0]) 

1109 data_80 = ax.plot(out_80[1], out_80[0], '-o', color=cmap[2]) 

1110 

1111 if plot_ci: 

1112 errbar_50 = np.c_[out_50[0] - out_50[3][:, 0], out_50[3][:, 1] - out_50[0]].T 

1113 errbar_20 = np.c_[out_20[0] - out_20[3][:, 0], out_20[3][:, 1] - out_20[0]].T 

1114 errbar_80 = np.c_[out_80[0] - out_80[3][:, 0], out_80[3][:, 1] - out_80[0]].T 

1115 

1116 ax.errorbar(out_50[1], out_50[0], yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4) 

1117 ax.errorbar(out_20[1], out_20[0], yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4) 

1118 ax.errorbar(out_80[1], out_80[0], yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4) 

1119 

1120 ax.legend([data_50[0], data_20[0], data_80[0]], 

1121 ['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'], 

1122 loc='upper left') 

1123 ax.set_ylabel('Reaction time (s)') 

1124 ax.set_xlabel('Contrasts') 

1125 

1126 if title: 

1127 ax.set_title(title) 

1128 

1129 return fig, ax 

1130 

1131 

1132def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs): 

1133 """ 

1134 Function to plot reaction time with trial number a la datajoint webpage. 

1135 

1136 Parameters 

1137 ---------- 

1138 trials : one.alf.io.AlfBunch 

1139 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

1140 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

1141 stim_on_type : str, default='stimOn_times' 

1142 The trials key to use when calculating the response times. The difference between this and 

1143 'feedback_times' is used (see notes for `compute_median_reaction_time`). 

1144 ax : matplotlib.pyplot.Axes 

1145 An axis object to plot onto. 

1146 title : str 

1147 An optional plot title. 

1148 **kwargs 

1149 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. 

1150 

1151 Returns 

1152 ------- 

1153 matplotlib.pyplot.Figure 

1154 The figure handle containing the plot. 

1155 matplotlib.pyplot.Axes 

1156 The plotted axes. 

1157 """ 

1158 

1159 reaction_time = pd.DataFrame() 

1160 reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type] 

1161 reaction_time.index = reaction_time.index + 1 

1162 reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median() 

1163 reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None) 

1164 reaction_time = reaction_time.where((pd.notnull(reaction_time)), None) 

1165 

1166 if not ax: 

1167 fig, ax = plt.subplots(**kwargs) 

1168 else: 

1169 fig = plt.gcf() 

1170 

1171 ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray') 

1172 ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2) 

1173 ax.set_yscale('log') 

1174 ax.set_ylim(0.1, 100) 

1175 ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) 

1176 ax.set_ylabel('Reaction time (s)') 

1177 ax.set_xlabel('Trial number') 

1178 if title: 

1179 ax.set_title(title) 

1180 

1181 return fig, ax 

1182 

1183 

1184def query_criterion(subject, status, from_status=None, one=None, validate=True): 

1185 """Get the session for which a given training criterion was met. 

1186 

1187 Parameters 

1188 ---------- 

1189 subject : str 

1190 The subject name. 

1191 status : str 

1192 The training status to query for. 

1193 from_status : str, optional 

1194 Count number of sessions and days from reaching `from_status` to `status`. 

1195 one : one.api.OneAlyx, optional 

1196 An instance of ONE. 

1197 validate : bool 

1198 If true, check if status in TrainingStatus enumeration. Set to false for non-standard 

1199 training pipelines. 

1200 

1201 Returns 

1202 ------- 

1203 str 

1204 The eID of the first session where this training status was reached. 

1205 int 

1206 The number of sessions it took to reach `status` (optionally from reaching `from_status`). 

1207 int 

1208 The number of days it tool to reach `status` (optionally from reaching `from_status`). 

1209 """ 

1210 if validate: 1m

1211 status = status.lower().replace(' ', '_') 1m

1212 try: 1m

1213 status = TrainingStatus[status.upper().replace(' ', '_')].name.lower() 1m

1214 except KeyError as ex: 1m

1215 raise ValueError( 1m

1216 f'Unknown status "{status}". For non-standard training protocols set validate=False' 

1217 ) from ex 

1218 one = one or ONE() 1m

1219 subject_json = one.alyx.rest('subjects', 'read', id=subject)['json'] 1m

1220 if not (criteria := subject_json.get('trained_criteria')) or status not in criteria: 1m

1221 return None, None, None 1m

1222 to_date, eid = criteria[status] 1m

1223 from_date, _ = criteria.get(from_status, (None, None)) 1m

1224 eids, det = one.search(subject=subject, date_range=[from_date, to_date], details=True) 1m

1225 if len(eids) == 0: 1m

1226 return eid, None, None 1m

1227 delta_date = det[0]['date'] - det[-1]['date'] 1m

1228 return eid, len(eids), delta_date.days 1m