Coverage for brainbox/behavior/training.py: 70%

350 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1"""Computing and testing IBL training status criteria. 

2 

3For an in-depth description of each training status, see `Appendix 2`_ of the IBL Protocol For Mice 

4Training. 

5 

6.. _Appendix 2: https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_\ 

7measure_decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729 

8 

9Examples 

10-------- 

11Plot the psychometric curve for a given session. 

12 

13>>> trials = ONE().load_object(eid, 'trials') 

14>>> fix, ax = plot_psychometric(trials) 

15 

16Compute 'response times', defined as the duration of open-loop for each contrast. 

17 

18>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(trials) 

19 

20Compute 'reaction times', defined as the time between go cue and first detected movement. 

21NB: These may be negative! 

22 

23>>> reaction_time, contrasts, n_contrasts = compute_reaction_time( 

24... trials, stim_on_type='goCue_times', stim_off_type='firstMovement_times') 

25 

26Compute 'response times', defined as the time between first detected movement and response. 

27 

28>>> reaction_time, contrasts, n_contrasts = compute_reaction_time( 

29... trials, stim_on_type='firstMovement_times', stim_off_type='response_times') 

30 

31Compute 'movement times', defined as the time between last detected movement and response threshold. 

32 

33>>> import brainbox.behavior.wheel as wh 

34>>> wheel_moves = ONE().load_object(eid, 'wheeMoves') 

35>>> trials['lastMovement_times'] = wh.get_movement_onset(wheel_moves.intervals, trial_data.response_times) 

36>>> reaction_time, contrasts, n_contrasts = compute_reaction_time( 

37... trials, stim_on_type='lastMovement_times', stim_off_type='response_times') 

38 

39""" 

40import logging 

41import datetime 

42import re 

43from enum import IntFlag, auto, unique 

44 

45import numpy as np 

46import matplotlib 

47import matplotlib.pyplot as plt 

48import seaborn as sns 

49import pandas as pd 

50from scipy.stats import bootstrap 

51from iblutil.util import Bunch 

52from one.api import ONE 

53from one.alf.io import AlfBunch 

54from one.alf.exceptions import ALFObjectNotFound 

55 

56import psychofit as psy 

57 

58_logger = logging.getLogger('ibllib') 

59 

60TRIALS_KEYS = ['contrastLeft', 

61 'contrastRight', 

62 'feedbackType', 

63 'probabilityLeft', 

64 'choice', 

65 'response_times', 

66 'stimOn_times'] 

67"""list of str: The required keys in the trials object for computing training status.""" 

68 

69 

70@unique 

71class TrainingStatus(IntFlag): 

72 """Standard IBL training criteria. 

73 

74 Enumeration allows for comparisons between training status. 

75 

76 Examples 

77 -------- 

78 >>> status = 'ready4delay' 

79 ... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY 

80 ... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training' 

81 ... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained' 

82 ... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained' 

83 ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained' 

84 ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY 

85 

86 # Get the next training status 

87 >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()]) 

88 <TrainingStatus.READY4RECORDING: 128> 

89 

90 Notes 

91 ----- 

92 - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b. 

93 - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it 

94 is possible to have skipped the TRAINED_1A session status. 

95 """ 

96 UNTRAINABLE = auto() 

97 UNBIASABLE = auto() 

98 IN_TRAINING = auto() 

99 TRAINED_1A = auto() 

100 TRAINED_1B = auto() 

101 READY4EPHYSRIG = auto() 

102 READY4DELAY = auto() 

103 READY4RECORDING = auto() 

104 # Compound training statuses for convenience 

105 FAILED = UNTRAINABLE | UNBIASABLE 

106 READY = READY4EPHYSRIG | READY4DELAY | READY4RECORDING 

107 TRAINED = TRAINED_1A | TRAINED_1B 

108 

109 

110def get_lab_training_status(lab, date=None, details=True, one=None): 

111 """ 

112 Computes the training status of all alive and water restricted subjects in a specified lab. 

113 

114 The response are printed to std out. 

115 

116 Parameters 

117 ---------- 

118 lab : str 

119 Lab name (must match the name registered on Alyx). 

120 date : str 

121 The ISO date from which to compute training status. If not specified will compute from the 

122 latest date with available data. Format should be 'YYYY-MM-DD'. 

123 details : bool 

124 Whether to display all information about training status computation e.g. performance, 

125 number of trials, psychometric fit parameters. 

126 one : one.api.OneAlyx 

127 An instance of ONE. 

128 

129 """ 

130 one = one or ONE() 

131 subj_lab = one.alyx.rest('subjects', 'list', lab=lab, alive=True, water_restricted=True) 

132 subjects = [subj['nickname'] for subj in subj_lab] 

133 for subj in subjects: 

134 get_subject_training_status(subj, date=date, details=details, one=one) 

135 

136 

137def get_subject_training_status(subj, date=None, details=True, one=None): 

138 """ 

139 Computes the training status of specified subject and prints results to std out. 

140 

141 Parameters 

142 ---------- 

143 subj : str 

144 Subject nickname (must match the name registered on Alyx). 

145 date : str 

146 The ISO date from which to compute training status. If not specified will compute from the 

147 latest date with available data. Format should be 'YYYY-MM-DD'. 

148 details : bool 

149 Whether to display all information about training status computation e.g. performance, 

150 number of trials, psychometric fit parameters. 

151 one : one.api.OneAlyx 

152 An instance of ONE. 

153 """ 

154 one = one or ONE() 

155 

156 trials, task_protocol, ephys_sess, n_delay = get_sessions(subj, date=date, one=one) 

157 if not trials: 

158 return 

159 sess_dates = list(trials.keys()) 

160 status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay) 

161 

162 if details: 

163 if np.any(info.get('psych')): 

164 display_status(subj, sess_dates, status, perf_easy=info.perf_easy, 

165 n_trials=info.n_trials, psych=info.psych, rt=info.rt) 

166 elif np.any(info.get('psych_20')): 

167 display_status(subj, sess_dates, status, perf_easy=info.perf_easy, 

168 n_trials=info.n_trials, psych_20=info.psych_20, psych_80=info.psych_80, 

169 rt=info.rt) 

170 else: 

171 display_status(subj, sess_dates, status) 

172 

173 

174def get_sessions(subj, date=None, one=None): 

175 """ 

176 Download and load in training data for a specified subject. If a date is given it will load 

177 data from the three (or as many as are available) previous sessions up to the specified date. 

178 If not it will load data from the last three training sessions that have data available. 

179 

180 Parameters 

181 ---------- 

182 subj : str 

183 Subject nickname (must match the name registered on Alyx). 

184 date : str 

185 The ISO date from which to compute training status. If not specified will compute from the 

186 latest date with available data. Format should be 'YYYY-MM-DD'. 

187 one : one.api.OneAlyx 

188 An instance of ONE. 

189 

190 Returns 

191 ------- 

192 iblutil.util.Bunch 

193 Dictionary of trials objects where each key is the ISO session date string. 

194 list of str 

195 List of the task protocol used for each of the sessions. 

196 list of str 

197 List of ISO date strings where training was conducted on ephys rig. Empty list if all 

198 sessions on training rig. 

199 n_delay : int 

200 Number of sessions on ephys rig that had delay prior to starting session > 15min. 

201 Returns 0 if no sessions detected. 

202 """ 

203 one = one or ONE() 

204 

205 if date is None: 

206 # compute from yesterday 

207 specified_date = (datetime.date.today() - datetime.timedelta(days=1)) 

208 latest_sess = specified_date.strftime("%Y-%m-%d") 

209 latest_minus_week = (datetime.date.today() - 

210 datetime.timedelta(days=8)).strftime("%Y-%m-%d") 

211 else: 

212 # compute from the date specified 

213 specified_date = datetime.datetime.strptime(date, '%Y-%m-%d') 

214 latest_minus_week = (specified_date - datetime.timedelta(days=7)).strftime("%Y-%m-%d") 

215 latest_sess = date 

216 

217 sessions = one.alyx.rest('sessions', 'list', subject=subj, date_range=[latest_minus_week, 

218 latest_sess], dataset_types='trials.goCueTrigger_times') 

219 

220 # If not enough sessions in the last week, then just fetch them all 

221 if len(sessions) < 3: 

222 specified_date_plus = (specified_date + datetime.timedelta(days=1)).strftime("%Y-%m-%d") 

223 django_query = 'start_time__lte,' + specified_date_plus 

224 sessions = one.alyx.rest('sessions', 'list', subject=subj, 

225 dataset_types='trials.goCueTrigger_times', django=django_query) 

226 

227 # If still 0 sessions then return with warning 

228 if len(sessions) == 0: 

229 _logger.warning(f"No training sessions detected for {subj}") 

230 return [None] * 4 

231 

232 trials = Bunch() 

233 task_protocol = [] 

234 sess_dates = [] 

235 if len(sessions) < 3: 

236 for n, _ in enumerate(sessions): 

237 try: 

238 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials') 

239 except ALFObjectNotFound: 

240 trials_ = None 

241 

242 if trials_: 

243 task_protocol.append(re.search('tasks_(.*)Choice', 

244 sessions[n]['task_protocol']).group(1)) 

245 sess_dates.append(sessions[n]['start_time'][:10]) 

246 trials[sessions[n]['start_time'][:10]] = trials_ 

247 

248 else: 

249 n = 0 

250 while len(trials) < 3: 

251 print(sessions[n]['url'].split('/')[-1]) 

252 try: 

253 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials') 

254 except ALFObjectNotFound: 

255 trials_ = None 

256 

257 if trials_: 

258 task_protocol.append(re.search('tasks_(.*)Choice', 

259 sessions[n]['task_protocol']).group(1)) 

260 sess_dates.append(sessions[n]['start_time'][:10]) 

261 trials[sessions[n]['start_time'][:10]] = trials_ 

262 

263 n += 1 

264 

265 if not np.any(np.array(task_protocol) == 'training'): 

266 ephys_sess = one.alyx.rest('sessions', 'list', subject=subj, 

267 date_range=[sess_dates[-1], sess_dates[0]], 

268 django='json__PYBPOD_BOARD__icontains,ephys') 

269 if len(ephys_sess) > 0: 

270 ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess] 

271 

272 n_delay = len(one.alyx.rest('sessions', 'list', subject=subj, 

273 date_range=[sess_dates[-1], sess_dates[0]], 

274 django='json__SESSION_START_DELAY_SEC__gte,900')) 

275 else: 

276 ephys_sess_dates = [] 

277 n_delay = 0 

278 else: 

279 ephys_sess_dates = [] 

280 n_delay = 0 

281 

282 return trials, task_protocol, ephys_sess_dates, n_delay 

283 

284 

285def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): 

286 """ 

287 Compute training status of a subject from consecutive training datasets. 

288 

289 For IBL, training status is calculated using trials from the last three consecutive sessions. 

290 

291 Parameters 

292 ---------- 

293 trials : dict of str 

294 Dictionary of trials objects where each key is the ISO session date string. 

295 task_protocol : list of str 

296 Task protocol used for each training session in `trials`, can be 'training', 'biased' or 

297 'ephys'. 

298 ephys_sess_dates : list of str 

299 List of ISO date strings where training was conducted on ephys rig. Empty list if all 

300 sessions on training rig. 

301 n_delay : int 

302 Number of sessions on ephys rig that had delay prior to starting session > 15min. 

303 Returns 0 if no sessions detected. 

304 

305 Returns 

306 ------- 

307 str 

308 Training status of the subject. 

309 iblutil.util.Bunch 

310 Bunch containing performance metrics that decide training status i.e. performance on easy 

311 trials, number of trials, psychometric fit parameters, reaction time. 

312 """ 

313 

314 info = Bunch() 1heifjkldabc

315 trials_all = concatenate_trials(trials) 1heifjkldabc

316 

317 # Case when all sessions are trainingChoiceWorld 

318 if np.all(np.array(task_protocol) == 'training'): 1heifjkldabc

319 signed_contrast = get_signed_contrast(trials_all) 1hjkbc

320 (info.perf_easy, info.n_trials, 1hjkbc

321 info.psych, info.rt) = compute_training_info(trials, trials_all) 

322 if not np.any(signed_contrast == 0): 1hjkbc

323 status = 'in training' 1bc

324 else: 

325 if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt): 1hjkc

326 status = 'trained 1b' 1kc

327 elif criterion_1a(info.psych, info.n_trials, info.perf_easy): 1hjc

328 status = 'trained 1a' 1jc

329 else: 

330 status = 'in training' 1h

331 

332 return status, info 1hjkbc

333 

334 # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion 

335 if ~np.all(np.array(task_protocol) == 'training') and \ 1eifldabc

336 np.any(np.array(task_protocol) == 'training'): 

337 status = 'trained 1b' 1lbc

338 (info.perf_easy, info.n_trials, 1lbc

339 info.psych, info.rt) = compute_training_info(trials, trials_all) 

340 

341 return status, info 1lbc

342 

343 # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions 

344 if not np.any(np.array(task_protocol) == 'training'): 1eifdabc

345 

346 (info.perf_easy, info.n_trials, 1eifdabc

347 info.psych_20, info.psych_80, 

348 info.rt) = compute_bias_info(trials, trials_all) 

349 # We are still on training rig and so all sessions should be biased 

350 if len(ephys_sess_dates) == 0: 1eifdabc

351 assert np.all(np.array(task_protocol) == 'biased') 1idbc

352 if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, 1idbc

353 info.rt): 

354 status = 'ready4ephysrig' 1ibc

355 else: 

356 status = 'trained 1b' 1dbc

357 

358 elif len(ephys_sess_dates) < 3: 1efac

359 assert all(date in trials for date in ephys_sess_dates) 1eac

360 perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in 1eac

361 ephys_sess_dates]) 

362 n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) 1eac

363 

364 if criterion_delay(n_ephys_trials, perf_ephys_easy): 1eac

365 status = 'ready4delay' 1eac

366 else: 

367 status = 'ready4ephysrig' 

368 

369 elif len(ephys_sess_dates) >= 3: 1fc

370 if n_delay > 0 and \ 1fc

371 criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, 

372 info.rt): 

373 status = 'ready4recording' 1fc

374 elif criterion_delay(info.n_trials, info.perf_easy): 1c

375 status = 'ready4delay' 1c

376 else: 

377 status = 'ready4ephysrig' 

378 

379 return status, info 1eifdabc

380 

381 

382def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None, 

383 psych_20=None, psych_80=None, rt=None): 

384 """ 

385 Display training status of subject to terminal. 

386 

387 Parameters 

388 ---------- 

389 subj : str 

390 Subject nickname (must match the name registered on Alyx). 

391 sess_dates : list of str 

392 ISO date strings of training sessions used to determine training status. 

393 status : str 

394 Training status of subject. 

395 perf_easy : numpy.array 

396 Proportion of correct high contrast trials for each training session. 

397 n_trials : numpy.array 

398 Total number of trials for each training session. 

399 psych : numpy.array 

400 Psychometric parameters fit to data from all training sessions - bias, threshold, lapse 

401 high, lapse low. 

402 psych_20 : numpy.array 

403 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. 

404 psych_80 : numpy.array 

405 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. 

406 rt : float 

407 The median response time for zero contrast trials across all training sessions. NaN 

408 indicates no zero contrast stimuli in training sessions. 

409 

410 """ 

411 

412 if perf_easy is None: 

413 print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, " 

414 f"{sess_dates[2]}]") 

415 elif psych_20 is None: 

416 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, " 

417 f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, " 

418 f"N trials={[nt for nt in n_trials]} " 

419 f"\nPsych fit over last 3 sessions: " 

420 f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, " 

421 f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} " 

422 f"\nMedian reaction time at 0 contrast over last 3 sessions = " 

423 f"{np.around(rt,2)}") 

424 

425 else: 

426 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, " 

427 f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, " 

428 f"N trials={[nt for nt in n_trials]} " 

429 f"\nPsych fit over last 3 sessions (20): " 

430 f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, " 

431 f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} " 

432 f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, " 

433 f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, " 

434 f"lapse_high={np.around(psych_80[3],2)} " 

435 f"\nMedian reaction time at 0 contrast over last 3 sessions = " 

436 f"{np.around(rt, 2)}") 

437 

438 

439def concatenate_trials(trials): 

440 """ 

441 Concatenate trials from different training sessions. 

442 

443 Parameters 

444 ---------- 

445 trials : dict of str 

446 Dictionary of trials objects where each key is the ISO session date string. 

447 

448 Returns 

449 ------- 

450 one.alf.io.AlfBunch 

451 Trials object with data concatenated over three training sessions. 

452 """ 

453 trials_all = AlfBunch() 1mhoeifjkldabpc

454 for k in TRIALS_KEYS: 1mhoeifjkldabpc

455 trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys())) 1mhoeifjkldabpc

456 

457 return trials_all 1mhoeifjkldabpc

458 

459 

460def compute_training_info(trials, trials_all): 

461 """ 

462 Compute all relevant performance metrics for when subject is on trainingChoiceWorld. 

463 

464 Parameters 

465 ---------- 

466 trials : dict of str 

467 Dictionary of trials objects where each key is the ISO session date string. 

468 trials_all : one.alf.io.AlfBunch 

469 Trials object with data concatenated over three training sessions. 

470 

471 Returns 

472 ------- 

473 numpy.array 

474 Proportion of correct high contrast trials for each session. 

475 numpy.array 

476 Total number of trials for each training session. 

477 numpy.array 

478 Array of psychometric parameters fit to `all_trials` - bias, threshold, lapse high, 

479 lapse low. 

480 float 

481 The median response time for all zero-contrast trials across all sessions. Returns NaN if 

482 no trials zero-contrast trials). 

483 """ 

484 

485 signed_contrast = get_signed_contrast(trials_all) 1hjklbc

486 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1hjklbc

487 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1hjklbc

488 psych = compute_psychometric(trials_all, signed_contrast=signed_contrast) 1hjklbc

489 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1hjklbc

490 

491 return perf_easy, n_trials, psych, rt 1hjklbc

492 

493 

494def compute_bias_info(trials, trials_all): 

495 """ 

496 Compute all relevant performance metrics for when subject is on biasedChoiceWorld 

497 

498 :param trials: dict containing trials objects from three consecutive training sessions, 

499 keys are session dates 

500 :type trials: Bunch 

501 :param trials_all: trials object with data concatenated over three training sessions 

502 :type trials_all: Bunch 

503 :returns: 

504 - perf_easy - performance of easy trials for each session 

505 - n_trials - number of trials in each session 

506 - psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions 

507 - psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions 

508 - rt - median reaction time for zero contrast stimuli over all sessions 

509 """ 

510 

511 signed_contrast = get_signed_contrast(trials_all) 1eifdabc

512 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1eifdabc

513 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1eifdabc

514 psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2) 1eifdabc

515 psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8) 1eifdabc

516 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1eifdabc

517 

518 return perf_easy, n_trials, psych_20, psych_80, rt 1eifdabc

519 

520 

521def get_signed_contrast(trials): 

522 """ 

523 Compute signed contrast from trials object 

524 

525 :param trials: trials object that must contain contrastLeft and contrastRight keys 

526 :type trials: dict 

527 returns: array of signed contrasts in percent, where -ve values are on the left 

528 """ 

529 # Replace NaNs with zeros, stack and take the difference 

530 contrast = np.nan_to_num(np.c_[trials['contrastLeft'], trials['contrastRight']]) 1mhoeifjkldabc

531 return np.diff(contrast).flatten() * 100 1mhoeifjkldabc

532 

533 

534def compute_performance_easy(trials): 

535 """ 

536 Compute performance on easy trials (stimulus >= 50 %) from trials object 

537 

538 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType 

539 keys 

540 :type trials: dict 

541 returns: float containing performance on easy contrast trials 

542 """ 

543 signed_contrast = get_signed_contrast(trials) 1mheifjkldabc

544 easy_trials = np.where(np.abs(signed_contrast) >= 50)[0] 1mheifjkldabc

545 return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0] 1mheifjkldabc

546 

547 

548def compute_performance(trials, signed_contrast=None, block=None, prob_right=False): 

549 """ 

550 Compute performance on all trials at each contrast level from trials object 

551 

552 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType 

553 keys 

554 :type trials: dict 

555 returns: float containing performance on easy contrast trials 

556 """ 

557 if signed_contrast is None: 1mheifjkldabc

558 signed_contrast = get_signed_contrast(trials) 1dabc

559 

560 if block is None: 1mheifjkldabc

561 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1mhjkldabc

562 else: 

563 block_idx = trials.probabilityLeft == block 1eifdabc

564 

565 if not np.any(block_idx): 1mheifjkldabc

566 return np.nan * np.zeros(3) 1db

567 

568 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 1mheifjkldabc

569 

570 if not prob_right: 1mheifjkldabc

571 correct = trials.feedbackType == 1 

572 performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts) 

573 else: 

574 rightward = trials.choice == -1 1mheifjkldabc

575 # Calculate the proportion rightward for each contrast type 

576 performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts) 1mheifjkldabc

577 

578 return performance, contrasts, n_contrasts 1mheifjkldabc

579 

580 

581def compute_n_trials(trials): 

582 """ 

583 Compute number of trials in trials object 

584 

585 :param trials: trials object 

586 :type trials: dict 

587 returns: int containing number of trials in session 

588 """ 

589 return trials['choice'].shape[0] 1mheifjkldabc

590 

591 

592def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False, compute_ci=False, alpha=.032): 

593 """ 

594 Compute psychometric fit parameters for trials object. 

595 

596 Parameters 

597 ---------- 

598 trials : one.alf.io.AlfBunch 

599 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

600 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

601 signed_contrast : numpy.array 

602 An array of signed contrasts in percent the length of trials, where left contrasts are -ve. 

603 If None, these are computed from the trials object. 

604 block : float 

605 The block type to compute. If None, all trials are included, otherwise only trials where 

606 probabilityLeft matches this value are included. For biasedChoiceWorld, the 

607 probabilityLeft set is {0.5, 0.2, 0.8}. 

608 plotting : bool 

609 Which set of psychofit model parameters to use (see notes). 

610 compute_ci : bool 

611 If true, computes and returns the confidence intervals for response at each contrast. 

612 alpha : float, default=0.032 

613 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, 

614 this value is ignored. 

615 

616 Returns 

617 ------- 

618 numpy.array 

619 Array of psychometric fit parameters - bias, threshold, lapse high, lapse low. 

620 (tuple of numpy.array) 

621 If `compute_ci` is true, a tuple of 

622 

623 See Also 

624 -------- 

625 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence 

626 interval. 

627 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. 

628 

629 Notes 

630 ----- 

631 The psychofit starting parameters and model constraints used for the fit when computing the 

632 training status (e.g. trained_1a, etc.) are sub-optimal and can produce a poor fit. To keep 

633 the precise criteria the same for all subjects, these parameters have not changed. To produce a 

634 better fit for plotting purposes, or to calculate the training status in a manner inconsistent 

635 with the IBL training pipeline, use plotting=True. 

636 """ 

637 

638 if signed_contrast is None: 1mhoeifjkldabc

639 signed_contrast = get_signed_contrast(trials) 1modabc

640 

641 if block is None: 1mhoeifjkldabc

642 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1mhjklbc

643 else: 

644 block_idx = trials.probabilityLeft == block 1oeifdabc

645 

646 if not np.any(block_idx): 1mhoeifjkldabc

647 return np.nan * np.zeros(4) 1odbc

648 

649 prob_choose_right, contrasts, n_contrasts = compute_performance( 1mheifjkldabc

650 trials, signed_contrast=signed_contrast, block=block, prob_right=True) 

651 

652 if plotting: 1mheifjkldabc

653 # These starting parameters and constraints tend to produce a better fit, and are therefore 

654 # used for plotting. 

655 psych, _ = psy.mle_fit_psycho( 1dab

656 np.vstack([contrasts, n_contrasts, prob_choose_right]), 

657 P_model='erf_psycho_2gammas', 

658 parstart=np.array([0., 40., 0.1, 0.1]), 

659 parmin=np.array([-50., 10., 0., 0.]), 

660 parmax=np.array([50., 50., 0.2, 0.2]), 

661 nfits=10) 

662 else: 

663 # These starting parameters and constraints are not ideal but are still used for computing 

664 # the training status for consistency. 

665 psych, _ = psy.mle_fit_psycho( 1mheifjkldabc

666 np.vstack([contrasts, n_contrasts, prob_choose_right]), 

667 P_model='erf_psycho_2gammas', 

668 parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]), 

669 parmin=np.array([np.min(contrasts), 0., 0., 0.]), 

670 parmax=np.array([np.max(contrasts), 100., 1, 1])) 

671 

672 if compute_ci: 1mheifjkldabc

673 import statsmodels.stats.proportion as smp # noqa 

674 # choice == -1 means contrast on right hand side 

675 n_right = np.vectorize(lambda x: np.sum(trials['choice'][(x == signed_contrast) & block_idx] == -1))(contrasts) 

676 ci = smp.proportion_confint(n_right, n_contrasts, alpha=alpha, method='normal') - prob_choose_right 

677 

678 return psych, ci 

679 else: 

680 return psych 1mheifjkldabc

681 

682 

683def compute_median_reaction_time(trials, stim_on_type='stimOn_times', contrast=None, signed_contrast=None): 

684 """ 

685 Compute median response time on zero contrast trials from trials object 

686 

687 Parameters 

688 ---------- 

689 trials : one.alf.io.AlfBunch 

690 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

691 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

692 stim_on_type : str, default='stimOn_times' 

693 The trials key to use when calculating the response times. The difference between this and 

694 'feedback_times' is used (see notes). 

695 contrast : float 

696 If None, the median response time is calculated for all trials, regardless of contrast, 

697 otherwise only trials where the matching signed percent contrast was presented are used. 

698 signed_contrast : numpy.array 

699 An array of signed contrasts in percent the length of trials, where left contrasts are -ve. 

700 If None, these are computed from the trials object. 

701 

702 Returns 

703 ------- 

704 float 

705 The median response time for trials with `contrast` (returns NaN if no trials matching 

706 `contrast` in trials object). 

707 

708 Notes 

709 ----- 

710 - The `stim_on_type` is 'stimOn_times' by default, however for IBL rig data, the photodiode is 

711 sometimes not calibrated properly which can lead to inaccurate (or absent, i.e. NaN) stim on 

712 times. Therefore, it is sometimes more accurate to use the 'stimOnTrigger_times' (the time of 

713 the stimulus onset command), if available, or the 'goCue_times' (the time of the soundcard 

714 output TTL when the audio go cue is played) or the 'goCueTrigger_times' (the time of the 

715 audio go cue command). 

716 - The response/reaction time here is defined as the time between stim on and feedback, i.e. the 

717 entire open-loop trial duration. 

718 """ 

719 if signed_contrast is None: 1mheifjkldabc

720 signed_contrast = get_signed_contrast(trials) 1mdabc

721 

722 if contrast is None: 1mheifjkldabc

723 contrast_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1dabc

724 else: 

725 contrast_idx = signed_contrast == contrast 1mheifjkldabc

726 

727 if np.any(contrast_idx): 1mheifjkldabc

728 reaction_time = np.nanmedian((trials.response_times - trials[stim_on_type]) 1mheifjkldabc

729 [contrast_idx]) 

730 else: 

731 reaction_time = np.nan 1bc

732 

733 return reaction_time 1mheifjkldabc

734 

735 

736def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='response_times', signed_contrast=None, block=None, 

737 compute_ci=False, alpha=0.32): 

738 """ 

739 Compute median response time for all contrasts. 

740 

741 Parameters 

742 ---------- 

743 trials : one.alf.io.AlfBunch 

744 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

745 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

746 stim_on_type : str, default='stimOn_times' 

747 The trials key to use when calculating the response times. The difference between this and 

748 `stim_off_type` is used (see notes). 

749 stim_off_type : str, default='response_times' 

750 The trials key to use when calculating the response times. The difference between this and 

751 `stim_on_type` is used (see notes). 

752 signed_contrast : numpy.array 

753 An array of signed contrasts in percent the length of trials, where left contrasts are -ve. 

754 If None, these are computed from the trials object. 

755 block : float 

756 The block type to compute. If None, all trials are included, otherwise only trials where 

757 probabilityLeft matches this value are included. For biasedChoiceWorld, the 

758 probabilityLeft set is {0.5, 0.2, 0.8}. 

759 compute_ci : bool 

760 If true, computes and returns the confidence intervals for response time at each contrast. 

761 alpha : float, default=0.32 

762 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, 

763 this value is ignored. 

764 

765 Returns 

766 ------- 

767 numpy.array 

768 The median response times for each unique signed contrast. 

769 numpy.array 

770 The set of unique signed contrasts. 

771 numpy.array 

772 The number of trials for each unique signed contrast. 

773 (numpy.array) 

774 If `compute_ci` is true, an array of confidence intervals is return in the shape (n_trials, 

775 2). 

776 

777 Notes 

778 ----- 

779 - The response/reaction time by default is the time between stim on and response, i.e. the 

780 entire open-loop trial duration. One could use 'stimOn_times' and 'firstMovement_times' to 

781 get the true reaction time, or 'firstMovement_times' and 'response_times' to get the true 

782 response times, or calculate the last movement onset times and calculate the true movement 

783 times. See module examples for how to calculate this. 

784 

785 See Also 

786 -------- 

787 scipy.stats.bootstrap - the function used to compute the confidence interval. 

788 """ 

789 

790 if signed_contrast is None: 1a

791 signed_contrast = get_signed_contrast(trials) 

792 

793 if block is None: 1a

794 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 

795 else: 

796 block_idx = trials.probabilityLeft == block 1a

797 

798 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 1a

799 reaction_time = np.vectorize(lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type]) 1a

800 [(x == signed_contrast) & block_idx]))(contrasts) 

801 if compute_ci: 1a

802 ci = np.full((contrasts.size, 2), np.nan) 

803 for i, x in enumerate(contrasts): 

804 data = (trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx] 

805 bt = bootstrap((data,), np.nanmedian, confidence_level=1 - alpha) 

806 ci[i, 0] = bt.confidence_interval.low 

807 ci[i, 1] = bt.confidence_interval.high 

808 

809 return reaction_time, contrasts, n_contrasts, ci 

810 else: 

811 return reaction_time, contrasts, n_contrasts, 1a

812 

813 

814def criterion_1a(psych, n_trials, perf_easy): 

815 """ 

816 Returns bool indicating whether criteria for status 'trained_1a' are met. 

817 

818 Criteria 

819 -------- 

820 - Bias is less than 16 

821 - Threshold is less than 19 

822 - Lapse rate on both sides is less than 0.2 

823 - The total number of trials is greater than 200 for each session 

824 - Performance on easy contrasts > 80% for all sessions 

825 

826 Parameters 

827 ---------- 

828 psych : numpy.array 

829 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, 

830 lapse high, lapse low. 

831 n_trials : numpy.array of int 

832 The number for trials for each session. 

833 perf_easy : numpy.array of float 

834 The proportion of correct high contrast trials for each session. 

835 

836 Returns 

837 ------- 

838 bool 

839 True if the criteria are met for 'trained_1a'. 

840 

841 Notes 

842 ----- 

843 The parameter thresholds chosen here were originally determined by averaging the parameter fits 

844 for a number of sessions determined to be of 'good' performance by an experimenter. 

845 """ 

846 

847 criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and 1hjc

848 np.all(n_trials > 200) and np.all(perf_easy > 0.8)) 

849 return criterion 1hjc

850 

851 

852def criterion_1b(psych, n_trials, perf_easy, rt): 

853 """ 

854 Returns bool indicating whether criteria for trained_1b are met. 

855 

856 Criteria 

857 -------- 

858 - Bias is less than 10 

859 - Threshold is less than 20 (see notes) 

860 - Lapse rate on both sides is less than 0.1 

861 - The total number of trials is greater than 400 for each session 

862 - Performance on easy contrasts > 90% for all sessions 

863 - The median response time across all zero contrast trials is less than 2 seconds 

864 

865 Parameters 

866 ---------- 

867 psych : numpy.array 

868 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, 

869 lapse high, lapse low. 

870 n_trials : numpy.array of int 

871 The number for trials for each session. 

872 perf_easy : numpy.array of float 

873 The proportion of correct high contrast trials for each session. 

874 rt : float 

875 The median response time for zero contrast trials. 

876 

877 Returns 

878 ------- 

879 bool 

880 True if the criteria are met for 'trained_1b'. 

881 

882 Notes 

883 ----- 

884 The parameter thresholds chosen here were originally chosen to be slightly stricter than 1a, 

885 however it was decided to use round numbers so that readers would not assume a level of 

886 precision that isn't there (remember, these parameters were not chosen with any rigor). This 

887 regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the 

888 slope of the psychometric curve may be slightly less steep than 1a. 

889 """ 

890 criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and 1hjkc

891 np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2) 

892 return criterion 1hjkc

893 

894 

895def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): 

896 """ 

897 Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met. 

898 

899 NB: The difference between these two is whether the sessions were acquired ot a recording rig 

900 with a delay before the first trial. Neither of these two things are tested here. 

901 

902 Criteria 

903 -------- 

904 - Lapse on both sides < 0.1 for both bias blocks 

905 - Bias shift between blocks > 5 

906 - Total number of trials > 400 for all sessions 

907 - Performance on easy contrasts > 90% for all sessions 

908 - Median response time for zero contrast stimuli < 2 seconds 

909 

910 Parameters 

911 ---------- 

912 psych_20 : numpy.array 

913 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. 

914 Parameters are bias, threshold, lapse high, lapse low. 

915 psych_80 : numpy.array 

916 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. 

917 Parameters are bias, threshold, lapse high, lapse low. 

918 n_trials : numpy.array 

919 The number of trials for each session (typically three consecutive sessions). 

920 perf_easy : numpy.array 

921 The proportion of correct high contrast trials for each session (typically three 

922 consecutive sessions). 

923 rt : float 

924 The median response time for zero contrast trials. 

925 

926 Returns 

927 ------- 

928 bool 

929 True if subject passes the ready4ephysrig or ready4recording criteria. 

930 """ 

931 

932 criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse 1ifdbc

933 psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials 

934 np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times 

935 return criterion 1ifdbc

936 

937 

938def criterion_delay(n_trials, perf_easy): 

939 """ 

940 Returns bool indicating whether criteria for 'ready4delay' is met. 

941 

942 Criteria 

943 -------- 

944 - Total number of trials for any of the sessions is greater than 400 

945 - Performance on easy contrasts is greater than 90% for any of the sessions 

946 

947 Parameters 

948 ---------- 

949 n_trials : numpy.array of int 

950 The number of trials for each session (typically three consecutive sessions). 

951 perf_easy : numpy.array 

952 The proportion of correct high contrast trials for each session (typically three 

953 consecutive sessions). 

954 

955 Returns 

956 ------- 

957 bool 

958 True if subject passes the 'ready4delay' criteria. 

959 """ 

960 criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9) 1eac

961 return criterion 1eac

962 

963 

964def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs): 

965 """ 

966 Function to plot psychometric curve plots a la datajoint webpage. 

967 

968 Parameters 

969 ---------- 

970 trials : one.alf.io.AlfBunch 

971 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

972 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

973 ax : matplotlib.pyplot.Axes 

974 An axis object to plot onto. 

975 title : str 

976 An optional plot title. 

977 plot_ci : bool 

978 If true, computes and plots the confidence intervals for response at each contrast. 

979 ci_alpha : float, default=0.032 

980 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, 

981 this value is ignored. 

982 **kwargs 

983 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. 

984 

985 Returns 

986 ------- 

987 matplotlib.pyplot.Figure 

988 The figure handle containing the plot. 

989 matplotlib.pyplot.Axes 

990 The plotted axes. 

991 

992 See Also 

993 -------- 

994 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence 

995 interval. 

996 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. 

997 psychofit.erf_psycho_2gammas - The function used to transform contrast to response probability 

998 using the fit parameters. 

999 """ 

1000 

1001 signed_contrast = get_signed_contrast(trials) 1dab

1002 contrasts_fit = np.arange(-100, 100) 1dab

1003 

1004 prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True) 1dab

1005 out_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True, 1dab

1006 compute_ci=plot_ci, alpha=ci_alpha) 

1007 pars_50 = out_50[0] if plot_ci else out_50 1dab

1008 prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit) 1dab

1009 

1010 prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True) 1dab

1011 out_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True, 1dab

1012 compute_ci=plot_ci, alpha=ci_alpha) 

1013 pars_20 = out_20[0] if plot_ci else out_20 1dab

1014 prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit) 1dab

1015 

1016 prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True) 1dab

1017 out_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True, 1dab

1018 compute_ci=plot_ci, alpha=ci_alpha) 

1019 pars_80 = out_80[0] if plot_ci else out_80 1dab

1020 prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit) 1dab

1021 

1022 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 1dab

1023 

1024 if not ax: 1dab

1025 fig, ax = plt.subplots(**kwargs) 1a

1026 else: 

1027 fig = plt.gcf() 1dab

1028 

1029 fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1]) 1dab

1030 data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1]) 1dab

1031 fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0]) 1dab

1032 data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0]) 1dab

1033 fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2]) 1dab

1034 data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2]) 1dab

1035 

1036 if plot_ci: 1dab

1037 errbar_50 = np.c_[np.abs(out_50[1][0]), np.abs(out_50[1][1])].T 

1038 errbar_20 = np.c_[np.abs(out_20[1][0]), np.abs(out_20[1][1])].T 

1039 errbar_80 = np.c_[np.abs(out_80[1][0]), np.abs(out_80[1][1])].T 

1040 

1041 ax.errorbar(contrasts_50, prob_right_50, yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4) 

1042 ax.errorbar(contrasts_20, prob_right_20, yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4) 

1043 ax.errorbar(contrasts_80, prob_right_80, yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4) 

1044 

1045 ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80], 1dab

1046 ['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'], 

1047 loc='upper left') 

1048 ax.set_ylim(-0.05, 1.05) 1dab

1049 ax.set_ylabel('Probability choosing right') 1dab

1050 ax.set_xlabel('Contrasts') 1dab

1051 if title: 1dab

1052 ax.set_title(title) 1dab

1053 

1054 return fig, ax 1dab

1055 

1056 

1057def plot_reaction_time(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.32, **kwargs): 

1058 """ 

1059 Function to plot reaction time against contrast a la datajoint webpage. 

1060 

1061 The reaction times are plotted individually for the following three blocks: {0.5, 0.2, 0.8}. 

1062 

1063 Parameters 

1064 ---------- 

1065 trials : one.alf.io.AlfBunch 

1066 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

1067 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

1068 ax : matplotlib.pyplot.Axes 

1069 An axis object to plot onto. 

1070 title : str 

1071 An optional plot title. 

1072 plot_ci : bool 

1073 If true, computes and plots the confidence intervals for response at each contrast. 

1074 ci_alpha : float, default=0.32 

1075 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, 

1076 this value is ignored. 

1077 **kwargs 

1078 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. 

1079 

1080 Returns 

1081 ------- 

1082 matplotlib.pyplot.Figure 

1083 The figure handle containing the plot. 

1084 matplotlib.pyplot.Axes 

1085 The plotted axes. 

1086 

1087 See Also 

1088 -------- 

1089 scipy.stats.bootstrap - the function used to compute the confidence interval. 

1090 """ 

1091 

1092 signed_contrast = get_signed_contrast(trials) 1a

1093 out_50 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5, compute_ci=plot_ci, alpha=ci_alpha) 1a

1094 out_20 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2, compute_ci=plot_ci, alpha=ci_alpha) 1a

1095 out_80 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8, compute_ci=plot_ci, alpha=ci_alpha) 1a

1096 

1097 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 1a

1098 

1099 if not ax: 1a

1100 fig, ax = plt.subplots(**kwargs) 1a

1101 else: 

1102 fig = plt.gcf() 

1103 

1104 data_50 = ax.plot(out_50[1], out_50[0], '-o', color=cmap[1]) 1a

1105 data_20 = ax.plot(out_20[1], out_20[0], '-o', color=cmap[0]) 1a

1106 data_80 = ax.plot(out_80[1], out_80[0], '-o', color=cmap[2]) 1a

1107 

1108 if plot_ci: 1a

1109 errbar_50 = np.c_[out_50[0] - out_50[3][:, 0], out_50[3][:, 1] - out_50[0]].T 

1110 errbar_20 = np.c_[out_20[0] - out_20[3][:, 0], out_20[3][:, 1] - out_20[0]].T 

1111 errbar_80 = np.c_[out_80[0] - out_80[3][:, 0], out_80[3][:, 1] - out_80[0]].T 

1112 

1113 ax.errorbar(out_50[1], out_50[0], yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4) 

1114 ax.errorbar(out_20[1], out_20[0], yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4) 

1115 ax.errorbar(out_80[1], out_80[0], yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4) 

1116 

1117 ax.legend([data_50[0], data_20[0], data_80[0]], 1a

1118 ['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'], 

1119 loc='upper left') 

1120 ax.set_ylabel('Reaction time (s)') 1a

1121 ax.set_xlabel('Contrasts') 1a

1122 

1123 if title: 1a

1124 ax.set_title(title) 1a

1125 

1126 return fig, ax 1a

1127 

1128 

1129def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs): 

1130 """ 

1131 Function to plot reaction time with trial number a la datajoint webpage. 

1132 

1133 Parameters 

1134 ---------- 

1135 trials : one.alf.io.AlfBunch 

1136 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 

1137 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. 

1138 stim_on_type : str, default='stimOn_times' 

1139 The trials key to use when calculating the response times. The difference between this and 

1140 'feedback_times' is used (see notes for `compute_median_reaction_time`). 

1141 ax : matplotlib.pyplot.Axes 

1142 An axis object to plot onto. 

1143 title : str 

1144 An optional plot title. 

1145 **kwargs 

1146 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. 

1147 

1148 Returns 

1149 ------- 

1150 matplotlib.pyplot.Figure 

1151 The figure handle containing the plot. 

1152 matplotlib.pyplot.Axes 

1153 The plotted axes. 

1154 """ 

1155 

1156 reaction_time = pd.DataFrame() 1a

1157 reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type] 1a

1158 reaction_time.index = reaction_time.index + 1 1a

1159 reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median() 1a

1160 reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None) 1a

1161 reaction_time = reaction_time.where((pd.notnull(reaction_time)), None) 1a

1162 

1163 if not ax: 1a

1164 fig, ax = plt.subplots(**kwargs) 1a

1165 else: 

1166 fig = plt.gcf() 

1167 

1168 ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray') 1a

1169 ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2) 1a

1170 ax.set_yscale('log') 1a

1171 ax.set_ylim(0.1, 100) 1a

1172 ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) 1a

1173 ax.set_ylabel('Reaction time (s)') 1a

1174 ax.set_xlabel('Trial number') 1a

1175 if title: 1a

1176 ax.set_title(title) 1a

1177 

1178 return fig, ax 1a

1179 

1180 

1181def query_criterion(subject, status, from_status=None, one=None, validate=True): 

1182 """Get the session for which a given training criterion was met. 

1183 

1184 Parameters 

1185 ---------- 

1186 subject : str 

1187 The subject name. 

1188 status : str 

1189 The training status to query for. 

1190 from_status : str, optional 

1191 Count number of sessions and days from reaching `from_status` to `status`. 

1192 one : one.api.OneAlyx, optional 

1193 An instance of ONE. 

1194 validate : bool 

1195 If true, check if status in TrainingStatus enumeration. Set to false for non-standard 

1196 training pipelines. 

1197 

1198 Returns 

1199 ------- 

1200 str 

1201 The eID of the first session where this training status was reached. 

1202 int 

1203 The number of sessions it took to reach `status` (optionally from reaching `from_status`). 

1204 int 

1205 The number of days it tool to reach `status` (optionally from reaching `from_status`). 

1206 """ 

1207 if validate: 1n

1208 status = status.lower().replace(' ', '_') 1n

1209 try: 1n

1210 status = TrainingStatus[status.upper().replace(' ', '_')].name.lower() 1n

1211 except KeyError as ex: 1n

1212 raise ValueError( 1n

1213 f'Unknown status "{status}". For non-standard training protocols set validate=False' 

1214 ) from ex 

1215 one = one or ONE() 1n

1216 subject_json = one.alyx.rest('subjects', 'read', id=subject)['json'] 1n

1217 if not (criteria := subject_json.get('trained_criteria')) or status not in criteria: 1n

1218 return None, None, None 1n

1219 to_date, eid = criteria[status] 1n

1220 from_date, _ = criteria.get(from_status, (None, None)) 1n

1221 eids, det = one.search(subject=subject, date_range=[from_date, to_date], details=True) 1n

1222 if len(eids) == 0: 1n

1223 return eid, None, None 1n

1224 delta_date = det[0]['date'] - det[-1]['date'] 1n

1225 return eid, len(eids), delta_date.days 1n