Coverage for brainbox/behavior/training.py: 64%

403 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1"""Computing and testing IBL training status criteria. 

2 

3For an in-depth description of each training status, see `Appendix 2`_ of the IBL Protocol For Mice 

4Training. 

5 

6.. _Appendix 2: https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_\ 

7measure_decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729 

8 

9Examples 

10-------- 

11Plot the psychometric curve for a given session. 

12 

13>>> trials = ONE().load_object(eid, 'trials') 

14>>> fix, ax = plot_psychometric(trials) 

15 

16Compute 'response times', defined as the duration of open-loop for each contrast. 

17 

18>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(trials) 

19 

20Compute 'reaction times', defined as the time between go cue and first detected movement. 

21NB: These may be negative! 

22 

23>>> reaction_time, contrasts, n_contrasts = compute_reaction_time( 

24... trials, stim_on_type='goCue_times', stim_off_type='firstMovement_times') 

25 

26Compute 'response times', defined as the time between first detected movement and response. 

27 

28>>> reaction_time, contrasts, n_contrasts = compute_reaction_time( 

29... trials, stim_on_type='firstMovement_times', stim_off_type='response_times') 

30 

31Compute 'movement times', defined as the time between last detected movement and response threshold. 

32 

33>>> import brainbox.behavior.wheel as wh 

34>>> wheel_moves = ONE().load_object(eid, 'wheeMoves') 

35>>> trials['lastMovement_times'] = wh.get_movement_onset(wheel_moves.intervals, trial_data.response_times) 

36>>> reaction_time, contrasts, n_contrasts = compute_reaction_time( 

37... trials, stim_on_type='lastMovement_times', stim_off_type='response_times') 

38 

39""" 

40import logging 

41import datetime 

42import re 

43from enum import IntFlag, auto, unique 

44 

45import numpy as np 

46import matplotlib 

47import matplotlib.pyplot as plt 

48import seaborn as sns 

49import pandas as pd 

50from scipy.stats import bootstrap 

51from iblutil.util import Bunch 

52from one.api import ONE 

53from one.alf.io import AlfBunch 

54from one.alf.exceptions import ALFObjectNotFound 

55import psychofit as psy 

56 

57_logger = logging.getLogger('ibllib') 

58 

59TRIALS_KEYS = ['contrastLeft', 

60 'contrastRight', 

61 'feedbackType', 

62 'probabilityLeft', 

63 'choice', 

64 'response_times', 

65 'stimOn_times'] 

66"""list of str: The required keys in the trials object for computing training status.""" 

67 

68 

69@unique 

70class TrainingStatus(IntFlag): 

71 """Standard IBL training criteria. 

72 

73 Enumeration allows for comparisons between training status. 

74 

75 Examples 

76 -------- 

77 >>> status = 'ready4delay' 

78 ... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY 

79 ... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training' 

80 ... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained' 

81 ... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained' 

82 ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained' 

83 ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY 

84 

85 Get the next training status 

86 

87 >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()]) 

88 <TrainingStatus.READY4RECORDING: 128> 

89 

90 Notes 

91 ----- 

92 - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b. 

93 - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it 

94 is possible to have skipped the TRAINED_1A session status. 

95 """ 

96 UNTRAINABLE = auto() 

97 UNBIASABLE = auto() 

98 IN_TRAINING = auto() 

99 TRAINED_1A = auto() 

100 TRAINED_1B = auto() 

101 READY4EPHYSRIG = auto() 

102 READY4DELAY = auto() 

103 READY4RECORDING = auto() 

104 # Compound training statuses for convenience 

105 FAILED = UNTRAINABLE | UNBIASABLE 

106 READY = READY4EPHYSRIG | READY4DELAY | READY4RECORDING 

107 TRAINED = TRAINED_1A | TRAINED_1B 

108 

109 

110def get_lab_training_status(lab, date=None, details=True, one=None): 

111 """ 

112 Computes the training status of all alive and water restricted subjects in a specified lab. 

113 

114 The response are printed to std out. 

115 

116 Parameters 

117 ---------- 

118 lab : str 

119 Lab name (must match the name registered on Alyx). 

120 date : str 

121 The ISO date from which to compute training status. If not specified will compute from the 

122 latest date with available data. Format should be 'YYYY-MM-DD'. 

123 details : bool 

124 Whether to display all information about training status computation e.g. performance, 

125 number of trials, psychometric fit parameters. 

126 one : one.api.OneAlyx 

127 An instance of ONE. 

128 

129 """ 

130 one = one or ONE() 

131 subj_lab = one.alyx.rest('subjects', 'list', lab=lab, alive=True, water_restricted=True) 

132 subjects = [subj['nickname'] for subj in subj_lab] 

133 for subj in subjects: 

134 get_subject_training_status(subj, date=date, details=details, one=one) 

135 

136 

137def get_subject_training_status(subj, date=None, details=True, one=None): 

138 """ 

139 Computes the training status of specified subject and prints results to std out. 

140 

141 Parameters 

142 ---------- 

143 subj : str 

144 Subject nickname (must match the name registered on Alyx). 

145 date : str 

146 The ISO date from which to compute training status. If not specified will compute from the 

147 latest date with available data. Format should be 'YYYY-MM-DD'. 

148 details : bool 

149 Whether to display all information about training status computation e.g. performance, 

150 number of trials, psychometric fit parameters. 

151 one : one.api.OneAlyx 

152 An instance of ONE. 

153 """ 

154 one = one or ONE() 

155 

156 trials, task_protocol, ephys_sess, n_delay = get_sessions(subj, date=date, one=one) 

157 if not trials: 

158 return 

159 sess_dates = list(trials.keys()) 

160 status, info, _ = get_training_status(trials, task_protocol, ephys_sess, n_delay) 

161 

162 if details: 

163 if np.any(info.get('psych')): 

164 display_status(subj, sess_dates, status, perf_easy=info.perf_easy, 

165 n_trials=info.n_trials, psych=info.psych, rt=info.rt) 

166 elif np.any(info.get('psych_20')): 

167 display_status(subj, sess_dates, status, perf_easy=info.perf_easy, 

168 n_trials=info.n_trials, psych_20=info.psych_20, psych_80=info.psych_80, 

169 rt=info.rt) 

170 else: 

171 display_status(subj, sess_dates, status) 

172 

173 

174def get_sessions(subj, date=None, one=None): 

175 """ 

176 Download and load in training data for a specified subject. If a date is given it will load 

177 data from the three (or as many as are available) previous sessions up to the specified date. 

178 If not it will load data from the last three training sessions that have data available. 

179 

180 Parameters 

181 ---------- 

182 subj : str 

183 Subject nickname (must match the name registered on Alyx). 

184 date : str 

185 The ISO date from which to compute training status. If not specified will compute from the 

186 latest date with available data. Format should be 'YYYY-MM-DD'. 

187 one : one.api.OneAlyx 

188 An instance of ONE. 

189 

190 Returns 

191 ------- 

192 iblutil.util.Bunch 

193 Dictionary of trials objects where each key is the ISO session date string. 

194 list of str 

195 List of the task protocol used for each of the sessions. 

196 list of str 

197 List of ISO date strings where training was conducted on ephys rig. Empty list if all 

198 sessions on training rig. 

199 n_delay : int 

200 Number of sessions on ephys rig that had delay prior to starting session > 15min. 

201 Returns 0 if no sessions detected. 

202 """ 

203 one = one or ONE() 

204 

205 if date is None: 

206 # compute from yesterday 

207 specified_date = (datetime.date.today() - datetime.timedelta(days=1)) 

208 latest_sess = specified_date.strftime("%Y-%m-%d") 

209 latest_minus_week = (datetime.date.today() - 

210 datetime.timedelta(days=8)).strftime("%Y-%m-%d") 

211 else: 

212 # compute from the date specified 

213 specified_date = datetime.datetime.strptime(date, '%Y-%m-%d') 

214 latest_minus_week = (specified_date - datetime.timedelta(days=7)).strftime("%Y-%m-%d") 

215 latest_sess = date 

216 

217 sessions = one.alyx.rest('sessions', 'list', subject=subj, date_range=[latest_minus_week, 

218 latest_sess], dataset_types='trials.goCueTrigger_times') 

219 

220 # If not enough sessions in the last week, then just fetch them all 

221 if len(sessions) < 3: 

222 specified_date_plus = (specified_date + datetime.timedelta(days=1)).strftime("%Y-%m-%d") 

223 django_query = 'start_time__lte,' + specified_date_plus 

224 sessions = one.alyx.rest('sessions', 'list', subject=subj, 

225 dataset_types='trials.goCueTrigger_times', django=django_query) 

226 

227 # If still 0 sessions then return with warning 

228 if len(sessions) == 0: 

229 _logger.warning(f"No training sessions detected for {subj}") 

230 return [None] * 4 

231 

232 trials = Bunch() 

233 task_protocol = [] 

234 sess_dates = [] 

235 if len(sessions) < 3: 

236 for n, _ in enumerate(sessions): 

237 try: 

238 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials') 

239 except ALFObjectNotFound: 

240 trials_ = None 

241 

242 if trials_: 

243 task_protocol.append(re.search('tasks_(.*)Choice', 

244 sessions[n]['task_protocol']).group(1)) 

245 sess_dates.append(sessions[n]['start_time'][:10]) 

246 trials[sessions[n]['start_time'][:10]] = trials_ 

247 

248 else: 

249 n = 0 

250 while len(trials) < 3: 

251 print(sessions[n]['url'].split('/')[-1]) 

252 try: 

253 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials') 

254 except ALFObjectNotFound: 

255 trials_ = None 

256 

257 if trials_: 

258 task_protocol.append(re.search('tasks_(.*)Choice', 

259 sessions[n]['task_protocol']).group(1)) 

260 sess_dates.append(sessions[n]['start_time'][:10]) 

261 trials[sessions[n]['start_time'][:10]] = trials_ 

262 

263 n += 1 

264 

265 if not np.any(np.array(task_protocol) == 'training'): 

266 ephys_sess = one.alyx.rest('sessions', 'list', subject=subj, 

267 date_range=[sess_dates[-1], sess_dates[0]], 

268 django='location__name__icontains,ephys') 

269 if len(ephys_sess) > 0: 

270 ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess] 

271 

272 n_delay = len(one.alyx.rest('sessions', 'list', subject=subj, 

273 date_range=[sess_dates[-1], sess_dates[0]], 

274 django='json__SESSION_DELAY_START__gte,900')) 

275 else: 

276 ephys_sess_dates = [] 

277 n_delay = 0 

278 else: 

279 ephys_sess_dates = [] 

280 n_delay = 0 

281 

282 return trials, task_protocol, ephys_sess_dates, n_delay 

283 

284 

285def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): 

286 """ 

287 Compute training status of a subject from consecutive training datasets. 

288 

289 For IBL, training status is calculated using trials from the last three consecutive sessions. 

290 

291 Parameters 

292 ---------- 

293 trials : dict of str 

294 Dictionary of trials objects where each key is the ISO session date string. 

295 task_protocol : list of str 

296 Task protocol used for each training session in `trials`, can be 'training', 'biased' or 

297 'ephys'. 

298 ephys_sess_dates : list of str 

299 List of ISO date strings where training was conducted on ephys rig. Empty list if all 

300 sessions on training rig. 

301 n_delay : int 

302 Number of sessions on ephys rig that had delay prior to starting session > 15min. 

303 Returns 0 if no sessions detected. 

304 

305 Returns 

306 ------- 

307 str 

308 Training status of the subject. 

309 iblutil.util.Bunch 

310 Bunch containing performance metrics that decide training status i.e. performance on easy 

311 trials, number of trials, psychometric fit parameters, reaction time. 

312 """ 

313 

314 info = Bunch() 1edcgfhjba

315 trials_all = concatenate_trials(trials) 1edcgfhjba

316 info.session_dates = list(trials.keys()) 1edcgfhjba

317 info.protocols = [p for p in task_protocol] 1edcgfhjba

318 

319 # Case when all sessions are trainingChoiceWorld 

320 if np.all(np.array(task_protocol) == 'training'): 1edcgfhjba

321 signed_contrast = np.unique(get_signed_contrast(trials_all)) 1efha

322 (info.perf_easy, info.n_trials, 1efha

323 info.psych, info.rt) = compute_training_info(trials, trials_all) 

324 

325 pass_criteria, criteria = criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt, 1efha

326 signed_contrast) 

327 if pass_criteria: 1efha

328 failed_criteria = Bunch() 1ha

329 failed_criteria['NBiased'] = {'val': info.protocols, 'pass': False} 1ha

330 failed_criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} 1ha

331 status = 'trained 1b' 1ha

332 else: 

333 failed_criteria = criteria 1efa

334 pass_criteria, criteria = criterion_1a(info.psych, info.n_trials, info.perf_easy, signed_contrast) 1efa

335 if pass_criteria: 1efa

336 status = 'trained 1a' 1fa

337 else: 

338 failed_criteria = criteria 1ea

339 status = 'in training' 1ea

340 

341 return status, info, failed_criteria 1efha

342 

343 # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion 

344 if ~np.all(np.array(task_protocol) == 'training') and \ 1dcgjba

345 np.any(np.array(task_protocol) == 'training'): 

346 status = 'trained 1b' 1ja

347 (info.perf_easy, info.n_trials, 1ja

348 info.psych, info.rt) = compute_training_info(trials, trials_all) 

349 

350 criteria = Bunch() 1ja

351 criteria['NBiased'] = {'val': info.protocols, 'pass': False} 1ja

352 criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} 1ja

353 

354 return status, info, criteria 1ja

355 

356 # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions 

357 if not np.any(np.array(task_protocol) == 'training'): 1dcgba

358 

359 (info.perf_easy, info.n_trials, 1dcgba

360 info.psych_20, info.psych_80, 

361 info.rt) = compute_bias_info(trials, trials_all) 

362 

363 n_ephys = len(ephys_sess_dates) 1dcgba

364 info.n_ephys = n_ephys 1dcgba

365 info.n_delay = n_delay 1dcgba

366 

367 # Criterion recording 

368 pass_criteria, criteria = criteria_recording(n_ephys, n_delay, info.psych_20, info.psych_80, info.n_trials, 1dcgba

369 info.perf_easy, info.rt) 

370 if pass_criteria: 1dcgba

371 # Here the criteria doesn't actually fail but we have no other criteria to meet so we return this 

372 failed_criteria = criteria 1ga

373 status = 'ready4recording' 1ga

374 else: 

375 failed_criteria = criteria 1dcba

376 assert all(date in trials for date in ephys_sess_dates) 1dcba

377 perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in 1dcba

378 ephys_sess_dates]) 

379 n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) 1dcba

380 

381 pass_criteria, criteria = criterion_delay(n_ephys_trials, perf_ephys_easy, n_ephys=n_ephys) 1dcba

382 

383 if pass_criteria: 1dcba

384 status = 'ready4delay' 1da

385 else: 

386 failed_criteria = criteria 1cba

387 pass_criteria, criteria = criterion_ephys(info.psych_20, info.psych_80, info.n_trials, 1cba

388 info.perf_easy, info.rt) 

389 if pass_criteria: 1cba

390 status = 'ready4ephysrig' 1ca

391 else: 

392 failed_criteria = criteria 1ba

393 status = 'trained 1b' 1ba

394 

395 return status, info, failed_criteria 1dcgba

396 

397 

398def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None, 

399 psych_20=None, psych_80=None, rt=None): 

400 """ 

401 Display training status of subject to terminal. 

402 

403 Parameters 

404 ---------- 

405 subj : str 

406 Subject nickname (must match the name registered on Alyx). 

407 sess_dates : list of str 

408 ISO date strings of training sessions used to determine training status. 

409 status : str 

410 Training status of subject. 

411 perf_easy : numpy.array 

412 Proportion of correct high contrast trials for each training session. 

413 n_trials : numpy.array 

414 Total number of trials for each training session. 

415 psych : numpy.array 

416 Psychometric parameters fit to data from all training sessions - bias, threshold, lapse 

417 high, lapse low. 

418 psych_20 : numpy.array 

419 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. 

420 psych_80 : numpy.array 

421 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. 

422 rt : float 

423 The median response time for zero contrast trials across all training sessions. NaN 

424 indicates no zero contrast stimuli in training sessions. 

425 

426 """ 

427 

428 if perf_easy is None: 

429 print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, " 

430 f"{sess_dates[2]}]") 

431 elif psych_20 is None: 

432 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, " 

433 f"Perf easy={[np.around(pe, 2) for pe in perf_easy]}, " 

434 f"N trials={[nt for nt in n_trials]} " 

435 f"\nPsych fit over last 3 sessions: " 

436 f"bias={np.around(psych[0], 2)}, thres={np.around(psych[1], 2)}, " 

437 f"lapse_low={np.around(psych[2], 2)}, lapse_high={np.around(psych[3], 2)} " 

438 f"\nMedian reaction time at 0 contrast over last 3 sessions = " 

439 f"{np.around(rt, 2)}") 

440 

441 else: 

442 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, " 

443 f"Perf easy={[np.around(pe, 2) for pe in perf_easy]}, " 

444 f"N trials={[nt for nt in n_trials]} " 

445 f"\nPsych fit over last 3 sessions (20): " 

446 f"bias={np.around(psych_20[0], 2)}, thres={np.around(psych_20[1], 2)}, " 

447 f"lapse_low={np.around(psych_20[2], 2)}, lapse_high={np.around(psych_20[3], 2)} " 

448 f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0], 2)}, " 

449 f"thres={np.around(psych_80[1], 2)}, lapse_low={np.around(psych_80[2], 2)}, " 

450 f"lapse_high={np.around(psych_80[3], 2)} " 

451 f"\nMedian reaction time at 0 contrast over last 3 sessions = " 

452 f"{np.around(rt, 2)}") 

453 

454 

455def concatenate_trials(trials): 

456 """ 

457 Concatenate trials from different training sessions. 

458 

459 Parameters 

460 ---------- 

461 trials : dict of str 

462 Dictionary of trials objects where each key is the ISO session date string. 

463 

464 Returns 

465 ------- 

466 one.alf.io.AlfBunch 

467 Trials object with data concatenated over three training sessions. 

468 """ 

469 trials_all = AlfBunch() 1kemdcgfhjbna

470 for k in TRIALS_KEYS: 1kemdcgfhjbna

471 trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys())) 1kemdcgfhjbna

472 

473 return trials_all 1kemdcgfhjbna

474 

475 

476def compute_training_info(trials, trials_all): 

477 """ 

478 Compute all relevant performance metrics for when subject is on trainingChoiceWorld. 

479 

480 Parameters 

481 ---------- 

482 trials : dict of str 

483 Dictionary of trials objects where each key is the ISO session date string. 

484 trials_all : one.alf.io.AlfBunch 

485 Trials object with data concatenated over three training sessions. 

486 

487 Returns 

488 ------- 

489 numpy.array 

490 Proportion of correct high contrast trials for each session. 

491 numpy.array 

492 Total number of trials for each training session. 

493 numpy.array 

494 Array of psychometric parameters fit to `all_trials` - bias, threshold, lapse high, 

495 lapse low. 

496 float 

497 The median response time for all zero-contrast trials across all sessions. Returns NaN if 

498 no trials zero-contrast trials). 

499 """ 

500 

501 signed_contrast = get_signed_contrast(trials_all) 1efhja

502 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1efhja

503 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1efhja

504 psych = compute_psychometric(trials_all, signed_contrast=signed_contrast) 1efhja

505 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1efhja

506 

507 return perf_easy, n_trials, psych, rt 1efhja

508 

509 

510def compute_bias_info(trials, trials_all): 

511 """ 

512 Compute all relevant performance metrics for when subject is on biasedChoiceWorld 

513 

514 :param trials: dict containing trials objects from three consecutive training sessions, 

515 keys are session dates 

516 :type trials: Bunch 

517 :param trials_all: trials object with data concatenated over three training sessions 

518 :type trials_all: Bunch 

519 :returns: 

520 - perf_easy - performance of easy trials for each session 

521 - n_trials - number of trials in each session 

522 - psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions 

523 - psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions 

524 - rt - median reaction time for zero contrast stimuli over all sessions 

525 """ 

526 

527 signed_contrast = get_signed_contrast(trials_all) 1dcgba

528 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1dcgba

529 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1dcgba

530 psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2) 1dcgba

531 psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8) 1dcgba

532 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1dcgba

533 

534 return perf_easy, n_trials, psych_20, psych_80, rt 1dcgba

535 

536 

537def get_signed_contrast(trials): 

538 """ 

539 Compute signed contrast from trials object 

540 

541 :param trials: trials object that must contain contrastLeft and contrastRight keys 

542 :type trials: dict 

543 returns: array of signed contrasts in percent, where -ve values are on the left 

544 """ 

545 # Replace NaNs with zeros, stack and take the difference 

546 contrast = np.nan_to_num(np.c_[trials['contrastLeft'], trials['contrastRight']]) 1kemdcgfhjba

547 return np.diff(contrast).flatten() * 100 1kemdcgfhjba

548 

549 

550def compute_performance_easy(trials): 

551 """ 

552 Compute performance on easy trials (stimulus >= 50 %) from trials object 

553 

554 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType 

555 keys 

556 :type trials: dict 

557 returns: float containing performance on easy contrast trials 

558 """ 

559 signed_contrast = get_signed_contrast(trials) 1kedcgfhjba

560 easy_trials = np.where(np.abs(signed_contrast) >= 50)[0] 1kedcgfhjba

561 return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0] 1kedcgfhjba

562 

563 

564def compute_performance(trials, signed_contrast=None, block=None, prob_right=False): 

565 """ 

566 Compute performance on all trials at each contrast level from trials object 

567 

568 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType 

569 keys 

570 :type trials: dict 

571 returns: float containing performance on easy contrast trials 

572 """ 

573 if signed_contrast is None: 1kedcgfhjba

574 signed_contrast = get_signed_contrast(trials) 1ba

575 

576 if block is None: 1kedcgfhjba

577 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1kefhjba

578 else: 

579 block_idx = trials.probabilityLeft == block 1dcgba

580 

581 if not np.any(block_idx): 1kedcgfhjba

582 return np.nan * np.zeros(3) 1b

583 

584 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 1kedcgfhjba

585 

586 if not prob_right: 1kedcgfhjba

587 correct = trials.feedbackType == 1 

588 performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts) 

589 else: 

590 rightward = trials.choice == -1 1kedcgfhjba

591 # Calculate the proportion rightward for each contrast type 

592 performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts) 1kedcgfhjba

593 

594 return performance, contrasts, n_contrasts 1kedcgfhjba

595 

596 

597def compute_n_trials(trials): 

598 """ 

599 Compute number of trials in trials object 

600 

601 :param trials: trials object 

602 :type trials: dict 

603 returns: int containing number of trials in session 

604 """ 

605 return trials['choice'].shape[0] 1kedcgfhjba

606 

607 

608def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False, compute_ci=False, alpha=.032): 

609 """ 

610 Compute psychometric fit parameters for trials object. 

611 

612 Parameters 

613 ---------- 

614 trials : one.alf.io.AlfBunch 

615 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

616 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

617 signed_contrast : numpy.array 

618 An array of signed contrasts in percent the length of trials, where left contrasts are -ve. 

619 If None, these are computed from the trials object. 

620 block : float 

621 The block type to compute. If None, all trials are included, otherwise only trials where 

622 probabilityLeft matches this value are included. For biasedChoiceWorld, the 

623 probabilityLeft set is {0.5, 0.2, 0.8}. 

624 plotting : bool 

625 Which set of psychofit model parameters to use (see notes). 

626 compute_ci : bool 

627 If true, computes and returns the confidence intervals for response at each contrast. 

628 alpha : float, default=0.032 

629 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, 

630 this value is ignored. 

631 

632 Returns 

633 ------- 

634 numpy.array 

635 Array of psychometric fit parameters - bias, threshold, lapse high, lapse low. 

636 (tuple of numpy.array) 

637 If `compute_ci` is true, a tuple of 

638 

639 See Also 

640 -------- 

641 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence 

642 interval. 

643 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. 

644 

645 Notes 

646 ----- 

647 The psychofit starting parameters and model constraints used for the fit when computing the 

648 training status (e.g. trained_1a, etc.) are sub-optimal and can produce a poor fit. To keep 

649 the precise criteria the same for all subjects, these parameters have not changed. To produce a 

650 better fit for plotting purposes, or to calculate the training status in a manner inconsistent 

651 with the IBL training pipeline, use plotting=True. 

652 """ 

653 

654 if signed_contrast is None: 1kemdcgfhjba

655 signed_contrast = get_signed_contrast(trials) 1kmba

656 

657 if block is None: 1kemdcgfhjba

658 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1kefhja

659 else: 

660 block_idx = trials.probabilityLeft == block 1mdcgba

661 

662 if not np.any(block_idx): 1kemdcgfhjba

663 return np.nan * np.zeros(4) 1mba

664 

665 prob_choose_right, contrasts, n_contrasts = compute_performance( 1kedcgfhjba

666 trials, signed_contrast=signed_contrast, block=block, prob_right=True) 

667 

668 if plotting: 1kedcgfhjba

669 # These starting parameters and constraints tend to produce a better fit, and are therefore 

670 # used for plotting. 

671 psych, _ = psy.mle_fit_psycho( 1b

672 np.vstack([contrasts, n_contrasts, prob_choose_right]), 

673 P_model='erf_psycho_2gammas', 

674 parstart=np.array([0., 40., 0.1, 0.1]), 

675 parmin=np.array([-50., 10., 0., 0.]), 

676 parmax=np.array([50., 50., 0.2, 0.2]), 

677 nfits=10) 

678 else: 

679 # These starting parameters and constraints are not ideal but are still used for computing 

680 # the training status for consistency. 

681 psych, _ = psy.mle_fit_psycho( 1kedcgfhjba

682 np.vstack([contrasts, n_contrasts, prob_choose_right]), 

683 P_model='erf_psycho_2gammas', 

684 parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]), 

685 parmin=np.array([np.min(contrasts), 0., 0., 0.]), 

686 parmax=np.array([np.max(contrasts), 100., 1, 1])) 

687 

688 if compute_ci: 1kedcgfhjba

689 import statsmodels.stats.proportion as smp # noqa 

690 # choice == -1 means contrast on right hand side 

691 n_right = np.vectorize(lambda x: np.sum(trials['choice'][(x == signed_contrast) & block_idx] == -1))(contrasts) 

692 ci = smp.proportion_confint(n_right, n_contrasts, alpha=alpha, method='normal') - prob_choose_right 

693 

694 return psych, ci 

695 else: 

696 return psych 1kedcgfhjba

697 

698 

699def compute_median_reaction_time(trials, stim_on_type='stimOn_times', contrast=None, signed_contrast=None): 

700 """ 

701 Compute median response time on zero contrast trials from trials object 

702 

703 Parameters 

704 ---------- 

705 trials : one.alf.io.AlfBunch 

706 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

707 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

708 stim_on_type : str, default='stimOn_times' 

709 The trials key to use when calculating the response times. The difference between this and 

710 'feedback_times' is used (see notes). 

711 contrast : float 

712 If None, the median response time is calculated for all trials, regardless of contrast, 

713 otherwise only trials where the matching signed percent contrast was presented are used. 

714 signed_contrast : numpy.array 

715 An array of signed contrasts in percent the length of trials, where left contrasts are -ve. 

716 If None, these are computed from the trials object. 

717 

718 Returns 

719 ------- 

720 float 

721 The median response time for trials with `contrast` (returns NaN if no trials matching 

722 `contrast` in trials object). 

723 

724 Notes 

725 ----- 

726 - The `stim_on_type` is 'stimOn_times' by default, however for IBL rig data, the photodiode is 

727 sometimes not calibrated properly which can lead to inaccurate (or absent, i.e. NaN) stim on 

728 times. Therefore, it is sometimes more accurate to use the 'stimOnTrigger_times' (the time of 

729 the stimulus onset command), if available, or the 'goCue_times' (the time of the soundcard 

730 output TTL when the audio go cue is played) or the 'goCueTrigger_times' (the time of the 

731 audio go cue command). 

732 - The response/reaction time here is defined as the time between stim on and feedback, i.e. the 

733 entire open-loop trial duration. 

734 """ 

735 if signed_contrast is None: 1kedcgfhjba

736 signed_contrast = get_signed_contrast(trials) 1kba

737 

738 if contrast is None: 1kedcgfhjba

739 contrast_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1ba

740 else: 

741 contrast_idx = signed_contrast == contrast 1kedcgfhjba

742 

743 if np.any(contrast_idx): 1kedcgfhjba

744 reaction_time = np.nanmedian((trials.response_times - trials[stim_on_type]) 1kedcgfhjba

745 [contrast_idx]) 

746 else: 

747 reaction_time = np.nan 1a

748 

749 return reaction_time 1kedcgfhjba

750 

751 

752def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='response_times', signed_contrast=None, block=None, 

753 compute_ci=False, alpha=0.32): 

754 """ 

755 Compute median response time for all contrasts. 

756 

757 Parameters 

758 ---------- 

759 trials : one.alf.io.AlfBunch 

760 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

761 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

762 stim_on_type : str, default='stimOn_times' 

763 The trials key to use when calculating the response times. The difference between this and 

764 `stim_off_type` is used (see notes). 

765 stim_off_type : str, default='response_times' 

766 The trials key to use when calculating the response times. The difference between this and 

767 `stim_on_type` is used (see notes). 

768 signed_contrast : numpy.array 

769 An array of signed contrasts in percent the length of trials, where left contrasts are -ve. 

770 If None, these are computed from the trials object. 

771 block : float 

772 The block type to compute. If None, all trials are included, otherwise only trials where 

773 probabilityLeft matches this value are included. For biasedChoiceWorld, the 

774 probabilityLeft set is {0.5, 0.2, 0.8}. 

775 compute_ci : bool 

776 If true, computes and returns the confidence intervals for response time at each contrast. 

777 alpha : float, default=0.32 

778 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, 

779 this value is ignored. 

780 

781 Returns 

782 ------- 

783 numpy.array 

784 The median response times for each unique signed contrast. 

785 numpy.array 

786 The set of unique signed contrasts. 

787 numpy.array 

788 The number of trials for each unique signed contrast. 

789 (numpy.array) 

790 If `compute_ci` is true, an array of confidence intervals is return in the shape (n_trials, 

791 2). 

792 

793 Notes 

794 ----- 

795 - The response/reaction time by default is the time between stim on and response, i.e. the 

796 entire open-loop trial duration. One could use 'stimOn_times' and 'firstMovement_times' to 

797 get the true reaction time, or 'firstMovement_times' and 'response_times' to get the true 

798 response times, or calculate the last movement onset times and calculate the true movement 

799 times. See module examples for how to calculate this. 

800 

801 See Also 

802 -------- 

803 scipy.stats.bootstrap - the function used to compute the confidence interval. 

804 """ 

805 

806 if signed_contrast is None: 

807 signed_contrast = get_signed_contrast(trials) 

808 

809 if block is None: 

810 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 

811 else: 

812 block_idx = trials.probabilityLeft == block 

813 

814 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 

815 reaction_time = np.vectorize( 

816 lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]), 

817 otypes=[float] 

818 )(contrasts) 

819 

820 if compute_ci: 

821 ci = np.full((contrasts.size, 2), np.nan) 

822 for i, x in enumerate(contrasts): 

823 data = (trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx] 

824 bt = bootstrap((data,), np.nanmedian, confidence_level=1 - alpha) 

825 ci[i, 0] = bt.confidence_interval.low 

826 ci[i, 1] = bt.confidence_interval.high 

827 

828 return reaction_time, contrasts, n_contrasts, ci 

829 else: 

830 return reaction_time, contrasts, n_contrasts, 

831 

832 

833def criterion_1a(psych, n_trials, perf_easy, signed_contrast): 

834 """ 

835 Returns bool indicating whether criteria for status 'trained_1a' are met. 

836 

837 Criteria 

838 -------- 

839 - Bias is less than 16 

840 - Threshold is less than 19 

841 - Lapse rate on both sides is less than 0.2 

842 - The total number of trials is greater than 200 for each session 

843 - Performance on easy contrasts > 80% for all sessions 

844 - Zero contrast trials must be present 

845 

846 Parameters 

847 ---------- 

848 psych : numpy.array 

849 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, 

850 lapse high, lapse low. 

851 n_trials : numpy.array of int 

852 The number for trials for each session. 

853 perf_easy : numpy.array of float 

854 The proportion of correct high contrast trials for each session. 

855 signed_contrast: numpy.array 

856 Unique list of contrasts displayed 

857 

858 Returns 

859 ------- 

860 bool 

861 True if the criteria are met for 'trained_1a'. 

862 Bunch 

863 Bunch containing breakdown of the passing/ failing critieria 

864 

865 Notes 

866 ----- 

867 The parameter thresholds chosen here were originally determined by averaging the parameter fits 

868 for a number of sessions determined to be of 'good' performance by an experimenter. 

869 """ 

870 

871 criteria = Bunch() 1efa

872 criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} 1efa

873 criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2} 1efa

874 criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.2} 1efa

875 criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 16} 1efa

876 criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 19} 1efa

877 criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 200)} 1efa

878 criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.8)} 1efa

879 

880 passing = np.all([v['pass'] for k, v in criteria.items()]) 1efa

881 

882 criteria['Criteria'] = {'val': 'trained_1a', 'pass': passing} 1efa

883 

884 return passing, criteria 1efa

885 

886 

887def criterion_1b(psych, n_trials, perf_easy, rt, signed_contrast): 

888 """ 

889 Returns bool indicating whether criteria for trained_1b are met. 

890 

891 Criteria 

892 -------- 

893 - Bias is less than 10 

894 - Threshold is less than 20 (see notes) 

895 - Lapse rate on both sides is less than 0.1 

896 - The total number of trials is greater than 400 for each session 

897 - Performance on easy contrasts > 90% for all sessions 

898 - The median response time across all zero contrast trials is less than 2 seconds 

899 - Zero contrast trials must be present 

900 

901 Parameters 

902 ---------- 

903 psych : numpy.array 

904 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, 

905 lapse high, lapse low. 

906 n_trials : numpy.array of int 

907 The number for trials for each session. 

908 perf_easy : numpy.array of float 

909 The proportion of correct high contrast trials for each session. 

910 rt : float 

911 The median response time for zero contrast trials. 

912 signed_contrast: numpy.array 

913 Unique list of contrasts displayed 

914 

915 Returns 

916 ------- 

917 bool 

918 True if the criteria are met for 'trained_1b'. 

919 Bunch 

920 Bunch containing breakdown of the passing/ failing critieria 

921 

922 Notes 

923 ----- 

924 The parameter thresholds chosen here were originally chosen to be slightly stricter than 1a, 

925 however it was decided to use round numbers so that readers would not assume a level of 

926 precision that isn't there (remember, these parameters were not chosen with any rigor). This 

927 regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the 

928 slope of the psychometric curve may be slightly less steep than 1a. 

929 """ 

930 

931 criteria = Bunch() 1efha

932 criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} 1efha

933 criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.1} 1efha

934 criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.1} 1efha

935 criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 10} 1efha

936 criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 20} 1efha

937 criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} 1efha

938 criteria['Perf_tasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} 1efha

939 criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} 1efha

940 

941 passing = np.all([v['pass'] for k, v in criteria.items()]) 1efha

942 

943 criteria['Criteria'] = {'val': 'trained_1b', 'pass': passing} 1efha

944 

945 return passing, criteria 1efha

946 

947 

948def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): 

949 """ 

950 Returns bool indicating whether criteria for ready4ephysrig are met. 

951 

952 Criteria 

953 -------- 

954 - Lapse on both sides < 0.1 for both bias blocks 

955 - Bias shift between blocks > 5 

956 - Total number of trials > 400 for all sessions 

957 - Performance on easy contrasts > 90% for all sessions 

958 - Median response time for zero contrast stimuli < 2 seconds 

959 

960 Parameters 

961 ---------- 

962 psych_20 : numpy.array 

963 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. 

964 Parameters are bias, threshold, lapse high, lapse low. 

965 psych_80 : numpy.array 

966 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. 

967 Parameters are bias, threshold, lapse high, lapse low. 

968 n_trials : numpy.array 

969 The number of trials for each session (typically three consecutive sessions). 

970 perf_easy : numpy.array 

971 The proportion of correct high contrast trials for each session (typically three 

972 consecutive sessions). 

973 rt : float 

974 The median response time for zero contrast trials. 

975 

976 Returns 

977 ------- 

978 bool 

979 True if subject passes the ready4ephysrig criteria. 

980 Bunch 

981 Bunch containing breakdown of the passing/ failing critieria 

982 """ 

983 criteria = Bunch() 1dcgba

984 criteria['LapseLow_80'] = {'val': psych_80[2], 'pass': psych_80[2] < 0.1} 1dcgba

985 criteria['LapseHigh_80'] = {'val': psych_80[3], 'pass': psych_80[3] < 0.1} 1dcgba

986 criteria['LapseLow_20'] = {'val': psych_20[2], 'pass': psych_20[2] < 0.1} 1dcgba

987 criteria['LapseHigh_20'] = {'val': psych_20[3], 'pass': psych_20[3] < 0.1} 1dcgba

988 criteria['Bias_shift'] = {'val': psych_80[0] - psych_20[0], 'pass': psych_80[0] - psych_20[0] > 5} 1dcgba

989 criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} 1dcgba

990 criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} 1dcgba

991 criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} 1dcgba

992 

993 passing = np.all([v['pass'] for k, v in criteria.items()]) 1dcgba

994 

995 criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': passing} 1dcgba

996 

997 return passing, criteria 1dcgba

998 

999 

1000def criterion_delay(n_trials, perf_easy, n_ephys=1): 

1001 """ 

1002 Returns bool indicating whether criteria for 'ready4delay' is met. 

1003 

1004 Criteria 

1005 -------- 

1006 - At least one session on an ephys rig 

1007 - Total number of trials for any of the sessions is greater than 400 

1008 - Performance on easy contrasts is greater than 90% for any of the sessions 

1009 

1010 Parameters 

1011 ---------- 

1012 n_trials : numpy.array of int 

1013 The number of trials for each session (typically three consecutive sessions). 

1014 perf_easy : numpy.array 

1015 The proportion of correct high contrast trials for each session (typically three 

1016 consecutive sessions). 

1017 

1018 Returns 

1019 ------- 

1020 bool 

1021 True if subject passes the 'ready4delay' criteria. 

1022 Bunch 

1023 Bunch containing breakdown of the passing/ failing critieria 

1024 """ 

1025 

1026 criteria = Bunch() 1dcba

1027 criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys > 0} 1dcba

1028 criteria['N_trials'] = {'val': n_trials, 'pass': np.any(n_trials > 400)} 1dcba

1029 criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.any(perf_easy > 0.9)} 1dcba

1030 

1031 passing = np.all([v['pass'] for k, v in criteria.items()]) 1dcba

1032 

1033 criteria['Criteria'] = {'val': 'ready4delay', 'pass': passing} 1dcba

1034 

1035 return passing, criteria 1dcba

1036 

1037 

1038def criteria_recording(n_ephys, delay, psych_20, psych_80, n_trials, perf_easy, rt): 

1039 """ 

1040 Returns bool indicating whether criteria for ready4recording are met. 

1041 

1042 Criteria 

1043 -------- 

1044 - At least 3 ephys sessions 

1045 - Delay on any session > 0 

1046 - Lapse on both sides < 0.1 for both bias blocks 

1047 - Bias shift between blocks > 5 

1048 - Total number of trials > 400 for all sessions 

1049 - Performance on easy contrasts > 90% for all sessions 

1050 - Median response time for zero contrast stimuli < 2 seconds 

1051 

1052 Parameters 

1053 ---------- 

1054 psych_20 : numpy.array 

1055 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. 

1056 Parameters are bias, threshold, lapse high, lapse low. 

1057 psych_80 : numpy.array 

1058 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. 

1059 Parameters are bias, threshold, lapse high, lapse low. 

1060 n_trials : numpy.array 

1061 The number of trials for each session (typically three consecutive sessions). 

1062 perf_easy : numpy.array 

1063 The proportion of correct high contrast trials for each session (typically three 

1064 consecutive sessions). 

1065 rt : float 

1066 The median response time for zero contrast trials. 

1067 

1068 Returns 

1069 ------- 

1070 bool 

1071 True if subject passes the ready4recording criteria. 

1072 Bunch 

1073 Bunch containing breakdown of the passing/ failing critieria 

1074 """ 

1075 

1076 _, criteria = criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt) 1dcgba

1077 criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys >= 3} 1dcgba

1078 criteria['N_delay'] = {'val': delay, 'pass': delay > 0} 1dcgba

1079 

1080 passing = np.all([v['pass'] for k, v in criteria.items()]) 1dcgba

1081 

1082 criteria['Criteria'] = {'val': 'ready4recording', 'pass': passing} 1dcgba

1083 

1084 return passing, criteria 1dcgba

1085 

1086 

1087def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs): 

1088 """ 

1089 Function to plot psychometric curve plots a la datajoint webpage. 

1090 

1091 Parameters 

1092 ---------- 

1093 trials : one.alf.io.AlfBunch 

1094 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

1095 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

1096 ax : matplotlib.pyplot.Axes 

1097 An axis object to plot onto. 

1098 title : str 

1099 An optional plot title. 

1100 plot_ci : bool 

1101 If true, computes and plots the confidence intervals for response at each contrast. 

1102 ci_alpha : float, default=0.032 

1103 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, 

1104 this value is ignored. 

1105 **kwargs 

1106 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. 

1107 

1108 Returns 

1109 ------- 

1110 matplotlib.pyplot.Figure 

1111 The figure handle containing the plot. 

1112 matplotlib.pyplot.Axes 

1113 The plotted axes. 

1114 

1115 See Also 

1116 -------- 

1117 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence 

1118 interval. 

1119 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. 

1120 psychofit.erf_psycho_2gammas - The function used to transform contrast to response probability 

1121 using the fit parameters. 

1122 """ 

1123 

1124 signed_contrast = get_signed_contrast(trials) 1b

1125 contrasts_fit = np.arange(-100, 100) 1b

1126 

1127 prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True) 1b

1128 out_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True, 1b

1129 compute_ci=plot_ci, alpha=ci_alpha) 

1130 pars_50 = out_50[0] if plot_ci else out_50 1b

1131 prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit) 1b

1132 

1133 prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True) 1b

1134 out_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True, 1b

1135 compute_ci=plot_ci, alpha=ci_alpha) 

1136 pars_20 = out_20[0] if plot_ci else out_20 1b

1137 prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit) 1b

1138 

1139 prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True) 1b

1140 out_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True, 1b

1141 compute_ci=plot_ci, alpha=ci_alpha) 

1142 pars_80 = out_80[0] if plot_ci else out_80 1b

1143 prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit) 1b

1144 

1145 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 1b

1146 

1147 if not ax: 1b

1148 fig, ax = plt.subplots(**kwargs) 

1149 else: 

1150 fig = plt.gcf() 1b

1151 

1152 fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1]) 1b

1153 data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1]) 1b

1154 fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0]) 1b

1155 data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0]) 1b

1156 fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2]) 1b

1157 data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2]) 1b

1158 

1159 if plot_ci: 1b

1160 errbar_50 = np.c_[np.abs(out_50[1][0]), np.abs(out_50[1][1])].T 

1161 errbar_20 = np.c_[np.abs(out_20[1][0]), np.abs(out_20[1][1])].T 

1162 errbar_80 = np.c_[np.abs(out_80[1][0]), np.abs(out_80[1][1])].T 

1163 

1164 ax.errorbar(contrasts_50, prob_right_50, yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4) 

1165 ax.errorbar(contrasts_20, prob_right_20, yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4) 

1166 ax.errorbar(contrasts_80, prob_right_80, yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4) 

1167 

1168 ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80], 1b

1169 ['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'], 

1170 loc='upper left') 

1171 ax.set_ylim(-0.05, 1.05) 1b

1172 ax.set_ylabel('Probability choosing right') 1b

1173 ax.set_xlabel('Contrasts') 1b

1174 if title: 1b

1175 ax.set_title(title) 1b

1176 

1177 return fig, ax 1b

1178 

1179 

1180def plot_reaction_time(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.32, **kwargs): 

1181 """ 

1182 Function to plot reaction time against contrast a la datajoint webpage. 

1183 

1184 The reaction times are plotted individually for the following three blocks: {0.5, 0.2, 0.8}. 

1185 

1186 Parameters 

1187 ---------- 

1188 trials : one.alf.io.AlfBunch 

1189 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

1190 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

1191 ax : matplotlib.pyplot.Axes 

1192 An axis object to plot onto. 

1193 title : str 

1194 An optional plot title. 

1195 plot_ci : bool 

1196 If true, computes and plots the confidence intervals for response at each contrast. 

1197 ci_alpha : float, default=0.32 

1198 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, 

1199 this value is ignored. 

1200 **kwargs 

1201 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. 

1202 

1203 Returns 

1204 ------- 

1205 matplotlib.pyplot.Figure 

1206 The figure handle containing the plot. 

1207 matplotlib.pyplot.Axes 

1208 The plotted axes. 

1209 

1210 See Also 

1211 -------- 

1212 scipy.stats.bootstrap - the function used to compute the confidence interval. 

1213 """ 

1214 

1215 signed_contrast = get_signed_contrast(trials) 

1216 out_50 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5, compute_ci=plot_ci, alpha=ci_alpha) 

1217 out_20 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2, compute_ci=plot_ci, alpha=ci_alpha) 

1218 out_80 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8, compute_ci=plot_ci, alpha=ci_alpha) 

1219 

1220 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 

1221 

1222 if not ax: 

1223 fig, ax = plt.subplots(**kwargs) 

1224 else: 

1225 fig = plt.gcf() 

1226 

1227 data_50 = ax.plot(out_50[1], out_50[0], '-o', color=cmap[1]) 

1228 data_20 = ax.plot(out_20[1], out_20[0], '-o', color=cmap[0]) 

1229 data_80 = ax.plot(out_80[1], out_80[0], '-o', color=cmap[2]) 

1230 

1231 if plot_ci: 

1232 errbar_50 = np.c_[out_50[0] - out_50[3][:, 0], out_50[3][:, 1] - out_50[0]].T 

1233 errbar_20 = np.c_[out_20[0] - out_20[3][:, 0], out_20[3][:, 1] - out_20[0]].T 

1234 errbar_80 = np.c_[out_80[0] - out_80[3][:, 0], out_80[3][:, 1] - out_80[0]].T 

1235 

1236 ax.errorbar(out_50[1], out_50[0], yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4) 

1237 ax.errorbar(out_20[1], out_20[0], yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4) 

1238 ax.errorbar(out_80[1], out_80[0], yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4) 

1239 

1240 ax.legend([data_50[0], data_20[0], data_80[0]], 

1241 ['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'], 

1242 loc='upper left') 

1243 ax.set_ylabel('Reaction time (s)') 

1244 ax.set_xlabel('Contrasts') 

1245 

1246 if title: 

1247 ax.set_title(title) 

1248 

1249 return fig, ax 

1250 

1251 

1252def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs): 

1253 """ 

1254 Function to plot reaction time with trial number a la datajoint webpage. 

1255 

1256 Parameters 

1257 ---------- 

1258 trials : one.alf.io.AlfBunch 

1259 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

1260 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

1261 stim_on_type : str, default='stimOn_times' 

1262 The trials key to use when calculating the response times. The difference between this and 

1263 'feedback_times' is used (see notes for `compute_median_reaction_time`). 

1264 ax : matplotlib.pyplot.Axes 

1265 An axis object to plot onto. 

1266 title : str 

1267 An optional plot title. 

1268 **kwargs 

1269 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. 

1270 

1271 Returns 

1272 ------- 

1273 matplotlib.pyplot.Figure 

1274 The figure handle containing the plot. 

1275 matplotlib.pyplot.Axes 

1276 The plotted axes. 

1277 """ 

1278 

1279 reaction_time = pd.DataFrame() 

1280 reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type] 

1281 reaction_time.index = reaction_time.index + 1 

1282 reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median() 

1283 reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None) 

1284 reaction_time = reaction_time.where((pd.notnull(reaction_time)), None) 

1285 

1286 if not ax: 

1287 fig, ax = plt.subplots(**kwargs) 

1288 else: 

1289 fig = plt.gcf() 

1290 

1291 ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray') 

1292 ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2) 

1293 ax.set_yscale('log') 

1294 ax.set_ylim(0.1, 100) 

1295 ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) 

1296 ax.set_ylabel('Reaction time (s)') 

1297 ax.set_xlabel('Trial number') 

1298 if title: 

1299 ax.set_title(title) 

1300 

1301 return fig, ax 

1302 

1303 

1304def query_criterion(subject, status, from_status=None, one=None, validate=True): 

1305 """Get the session for which a given training criterion was met. 

1306 

1307 Parameters 

1308 ---------- 

1309 subject : str 

1310 The subject name. 

1311 status : str 

1312 The training status to query for. 

1313 from_status : str, optional 

1314 Count number of sessions and days from reaching `from_status` to `status`. 

1315 one : one.api.OneAlyx, optional 

1316 An instance of ONE. 

1317 validate : bool 

1318 If true, check if status in TrainingStatus enumeration. Set to false for non-standard 

1319 training pipelines. 

1320 

1321 Returns 

1322 ------- 

1323 str 

1324 The eID of the first session where this training status was reached. 

1325 int 

1326 The number of sessions it took to reach `status` (optionally from reaching `from_status`). 

1327 int 

1328 The number of days it tool to reach `status` (optionally from reaching `from_status`). 

1329 """ 

1330 if validate: 1l

1331 status = status.lower().replace(' ', '_') 1l

1332 try: 1l

1333 status = TrainingStatus[status.upper().replace(' ', '_')].name.lower() 1l

1334 except KeyError as ex: 1l

1335 raise ValueError( 1l

1336 f'Unknown status "{status}". For non-standard training protocols set validate=False' 

1337 ) from ex 

1338 one = one or ONE() 1l

1339 subject_json = one.alyx.rest('subjects', 'read', id=subject)['json'] 1l

1340 if not (criteria := subject_json.get('trained_criteria')) or status not in criteria: 1l

1341 return None, None, None 1l

1342 to_date, eid = criteria[status] 1l

1343 from_date, _ = criteria.get(from_status, (None, None)) 1l

1344 eids, det = one.search(subject=subject, date_range=[from_date, to_date], details=True) 1l

1345 if len(eids) == 0: 1l

1346 return eid, None, None 1l

1347 delta_date = det[0]['date'] - det[-1]['date'] 1l

1348 return eid, len(eids), delta_date.days 1l