Coverage for brainbox/behavior/training.py: 58%
350 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""Computing and testing IBL training status criteria.
3For an in-depth description of each training status, see `Appendix 2`_ of the IBL Protocol For Mice
4Training.
6.. _Appendix 2: https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_\
7measure_decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729
9Examples
10--------
11Plot the psychometric curve for a given session.
13>>> trials = ONE().load_object(eid, 'trials')
14>>> fix, ax = plot_psychometric(trials)
16Compute 'response times', defined as the duration of open-loop for each contrast.
18>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(trials)
20Compute 'reaction times', defined as the time between go cue and first detected movement.
21NB: These may be negative!
23>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
24... trials, stim_on_type='goCue_times', stim_off_type='firstMovement_times')
26Compute 'response times', defined as the time between first detected movement and response.
28>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
29... trials, stim_on_type='firstMovement_times', stim_off_type='response_times')
31Compute 'movement times', defined as the time between last detected movement and response threshold.
33>>> import brainbox.behavior.wheel as wh
34>>> wheel_moves = ONE().load_object(eid, 'wheeMoves')
35>>> trials['lastMovement_times'] = wh.get_movement_onset(wheel_moves.intervals, trial_data.response_times)
36>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
37... trials, stim_on_type='lastMovement_times', stim_off_type='response_times')
39"""
40import logging
41import datetime
42import re
43from enum import IntFlag, auto, unique
45import numpy as np
46import matplotlib
47import matplotlib.pyplot as plt
48import seaborn as sns
49import pandas as pd
50from scipy.stats import bootstrap
51from iblutil.util import Bunch
52from one.api import ONE
53from one.alf.io import AlfBunch
54from one.alf.exceptions import ALFObjectNotFound
55import psychofit as psy
57_logger = logging.getLogger('ibllib')
59TRIALS_KEYS = ['contrastLeft',
60 'contrastRight',
61 'feedbackType',
62 'probabilityLeft',
63 'choice',
64 'response_times',
65 'stimOn_times']
66"""list of str: The required keys in the trials object for computing training status."""
69@unique
70class TrainingStatus(IntFlag):
71 """Standard IBL training criteria.
73 Enumeration allows for comparisons between training status.
75 Examples
76 --------
77 >>> status = 'ready4delay'
78 ... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY
79 ... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training'
80 ... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained'
81 ... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained'
82 ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained'
83 ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY
85 Get the next training status
87 >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()])
88 <TrainingStatus.READY4RECORDING: 128>
90 Notes
91 -----
92 - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b.
93 - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it
94 is possible to have skipped the TRAINED_1A session status.
95 """
96 UNTRAINABLE = auto()
97 UNBIASABLE = auto()
98 IN_TRAINING = auto()
99 TRAINED_1A = auto()
100 TRAINED_1B = auto()
101 READY4EPHYSRIG = auto()
102 READY4DELAY = auto()
103 READY4RECORDING = auto()
104 # Compound training statuses for convenience
105 FAILED = UNTRAINABLE | UNBIASABLE
106 READY = READY4EPHYSRIG | READY4DELAY | READY4RECORDING
107 TRAINED = TRAINED_1A | TRAINED_1B
110def get_lab_training_status(lab, date=None, details=True, one=None):
111 """
112 Computes the training status of all alive and water restricted subjects in a specified lab.
114 The response are printed to std out.
116 Parameters
117 ----------
118 lab : str
119 Lab name (must match the name registered on Alyx).
120 date : str
121 The ISO date from which to compute training status. If not specified will compute from the
122 latest date with available data. Format should be 'YYYY-MM-DD'.
123 details : bool
124 Whether to display all information about training status computation e.g. performance,
125 number of trials, psychometric fit parameters.
126 one : one.api.OneAlyx
127 An instance of ONE.
129 """
130 one = one or ONE()
131 subj_lab = one.alyx.rest('subjects', 'list', lab=lab, alive=True, water_restricted=True)
132 subjects = [subj['nickname'] for subj in subj_lab]
133 for subj in subjects:
134 get_subject_training_status(subj, date=date, details=details, one=one)
137def get_subject_training_status(subj, date=None, details=True, one=None):
138 """
139 Computes the training status of specified subject and prints results to std out.
141 Parameters
142 ----------
143 subj : str
144 Subject nickname (must match the name registered on Alyx).
145 date : str
146 The ISO date from which to compute training status. If not specified will compute from the
147 latest date with available data. Format should be 'YYYY-MM-DD'.
148 details : bool
149 Whether to display all information about training status computation e.g. performance,
150 number of trials, psychometric fit parameters.
151 one : one.api.OneAlyx
152 An instance of ONE.
153 """
154 one = one or ONE()
156 trials, task_protocol, ephys_sess, n_delay = get_sessions(subj, date=date, one=one)
157 if not trials:
158 return
159 sess_dates = list(trials.keys())
160 status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay)
162 if details:
163 if np.any(info.get('psych')):
164 display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
165 n_trials=info.n_trials, psych=info.psych, rt=info.rt)
166 elif np.any(info.get('psych_20')):
167 display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
168 n_trials=info.n_trials, psych_20=info.psych_20, psych_80=info.psych_80,
169 rt=info.rt)
170 else:
171 display_status(subj, sess_dates, status)
174def get_sessions(subj, date=None, one=None):
175 """
176 Download and load in training data for a specified subject. If a date is given it will load
177 data from the three (or as many as are available) previous sessions up to the specified date.
178 If not it will load data from the last three training sessions that have data available.
180 Parameters
181 ----------
182 subj : str
183 Subject nickname (must match the name registered on Alyx).
184 date : str
185 The ISO date from which to compute training status. If not specified will compute from the
186 latest date with available data. Format should be 'YYYY-MM-DD'.
187 one : one.api.OneAlyx
188 An instance of ONE.
190 Returns
191 -------
192 iblutil.util.Bunch
193 Dictionary of trials objects where each key is the ISO session date string.
194 list of str
195 List of the task protocol used for each of the sessions.
196 list of str
197 List of ISO date strings where training was conducted on ephys rig. Empty list if all
198 sessions on training rig.
199 n_delay : int
200 Number of sessions on ephys rig that had delay prior to starting session > 15min.
201 Returns 0 if no sessions detected.
202 """
203 one = one or ONE()
205 if date is None:
206 # compute from yesterday
207 specified_date = (datetime.date.today() - datetime.timedelta(days=1))
208 latest_sess = specified_date.strftime("%Y-%m-%d")
209 latest_minus_week = (datetime.date.today() -
210 datetime.timedelta(days=8)).strftime("%Y-%m-%d")
211 else:
212 # compute from the date specified
213 specified_date = datetime.datetime.strptime(date, '%Y-%m-%d')
214 latest_minus_week = (specified_date - datetime.timedelta(days=7)).strftime("%Y-%m-%d")
215 latest_sess = date
217 sessions = one.alyx.rest('sessions', 'list', subject=subj, date_range=[latest_minus_week,
218 latest_sess], dataset_types='trials.goCueTrigger_times')
220 # If not enough sessions in the last week, then just fetch them all
221 if len(sessions) < 3:
222 specified_date_plus = (specified_date + datetime.timedelta(days=1)).strftime("%Y-%m-%d")
223 django_query = 'start_time__lte,' + specified_date_plus
224 sessions = one.alyx.rest('sessions', 'list', subject=subj,
225 dataset_types='trials.goCueTrigger_times', django=django_query)
227 # If still 0 sessions then return with warning
228 if len(sessions) == 0:
229 _logger.warning(f"No training sessions detected for {subj}")
230 return [None] * 4
232 trials = Bunch()
233 task_protocol = []
234 sess_dates = []
235 if len(sessions) < 3:
236 for n, _ in enumerate(sessions):
237 try:
238 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
239 except ALFObjectNotFound:
240 trials_ = None
242 if trials_:
243 task_protocol.append(re.search('tasks_(.*)Choice',
244 sessions[n]['task_protocol']).group(1))
245 sess_dates.append(sessions[n]['start_time'][:10])
246 trials[sessions[n]['start_time'][:10]] = trials_
248 else:
249 n = 0
250 while len(trials) < 3:
251 print(sessions[n]['url'].split('/')[-1])
252 try:
253 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
254 except ALFObjectNotFound:
255 trials_ = None
257 if trials_:
258 task_protocol.append(re.search('tasks_(.*)Choice',
259 sessions[n]['task_protocol']).group(1))
260 sess_dates.append(sessions[n]['start_time'][:10])
261 trials[sessions[n]['start_time'][:10]] = trials_
263 n += 1
265 if not np.any(np.array(task_protocol) == 'training'):
266 ephys_sess = one.alyx.rest('sessions', 'list', subject=subj,
267 date_range=[sess_dates[-1], sess_dates[0]],
268 django='json__PYBPOD_BOARD__icontains,ephys')
269 if len(ephys_sess) > 0:
270 ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess]
272 n_delay = len(one.alyx.rest('sessions', 'list', subject=subj,
273 date_range=[sess_dates[-1], sess_dates[0]],
274 django='json__SESSION_START_DELAY_SEC__gte,900'))
275 else:
276 ephys_sess_dates = []
277 n_delay = 0
278 else:
279 ephys_sess_dates = []
280 n_delay = 0
282 return trials, task_protocol, ephys_sess_dates, n_delay
285def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay):
286 """
287 Compute training status of a subject from consecutive training datasets.
289 For IBL, training status is calculated using trials from the last three consecutive sessions.
291 Parameters
292 ----------
293 trials : dict of str
294 Dictionary of trials objects where each key is the ISO session date string.
295 task_protocol : list of str
296 Task protocol used for each training session in `trials`, can be 'training', 'biased' or
297 'ephys'.
298 ephys_sess_dates : list of str
299 List of ISO date strings where training was conducted on ephys rig. Empty list if all
300 sessions on training rig.
301 n_delay : int
302 Number of sessions on ephys rig that had delay prior to starting session > 15min.
303 Returns 0 if no sessions detected.
305 Returns
306 -------
307 str
308 Training status of the subject.
309 iblutil.util.Bunch
310 Bunch containing performance metrics that decide training status i.e. performance on easy
311 trials, number of trials, psychometric fit parameters, reaction time.
312 """
314 info = Bunch() 1gdheijkcab
315 trials_all = concatenate_trials(trials) 1gdheijkcab
317 # Case when all sessions are trainingChoiceWorld
318 if np.all(np.array(task_protocol) == 'training'): 1gdheijkcab
319 signed_contrast = get_signed_contrast(trials_all) 1gijab
320 (info.perf_easy, info.n_trials, 1gijab
321 info.psych, info.rt) = compute_training_info(trials, trials_all)
322 if not np.any(signed_contrast == 0): 1gijab
323 status = 'in training' 1ab
324 else:
325 if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt): 1gijb
326 status = 'trained 1b' 1jb
327 elif criterion_1a(info.psych, info.n_trials, info.perf_easy): 1gib
328 status = 'trained 1a' 1ib
329 else:
330 status = 'in training' 1g
332 return status, info 1gijab
334 # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion
335 if ~np.all(np.array(task_protocol) == 'training') and \ 1dhekcab
336 np.any(np.array(task_protocol) == 'training'):
337 status = 'trained 1b' 1kb
338 (info.perf_easy, info.n_trials, 1kb
339 info.psych, info.rt) = compute_training_info(trials, trials_all)
341 return status, info 1kb
343 # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions
344 if not np.any(np.array(task_protocol) == 'training'): 1dhecab
346 (info.perf_easy, info.n_trials, 1dhecab
347 info.psych_20, info.psych_80,
348 info.rt) = compute_bias_info(trials, trials_all)
349 # We are still on training rig and so all sessions should be biased
350 if len(ephys_sess_dates) == 0: 1dhecab
351 assert np.all(np.array(task_protocol) == 'biased') 1hcab
352 if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, 1hcab
353 info.rt):
354 status = 'ready4ephysrig' 1hb
355 else:
356 status = 'trained 1b' 1cab
358 elif len(ephys_sess_dates) < 3: 1deab
359 assert all(date in trials for date in ephys_sess_dates) 1dab
360 perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in 1dab
361 ephys_sess_dates])
362 n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) 1dab
364 if criterion_delay(n_ephys_trials, perf_ephys_easy): 1dab
365 status = 'ready4delay' 1db
366 else:
367 status = 'ready4ephysrig' 1a
369 elif len(ephys_sess_dates) >= 3: 1eb
370 if n_delay > 0 and \ 1eb
371 criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
372 info.rt):
373 status = 'ready4recording' 1eb
374 elif criterion_delay(info.n_trials, info.perf_easy): 1b
375 status = 'ready4delay' 1b
376 else:
377 status = 'ready4ephysrig'
379 return status, info 1dhecab
382def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None,
383 psych_20=None, psych_80=None, rt=None):
384 """
385 Display training status of subject to terminal.
387 Parameters
388 ----------
389 subj : str
390 Subject nickname (must match the name registered on Alyx).
391 sess_dates : list of str
392 ISO date strings of training sessions used to determine training status.
393 status : str
394 Training status of subject.
395 perf_easy : numpy.array
396 Proportion of correct high contrast trials for each training session.
397 n_trials : numpy.array
398 Total number of trials for each training session.
399 psych : numpy.array
400 Psychometric parameters fit to data from all training sessions - bias, threshold, lapse
401 high, lapse low.
402 psych_20 : numpy.array
403 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
404 psych_80 : numpy.array
405 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
406 rt : float
407 The median response time for zero contrast trials across all training sessions. NaN
408 indicates no zero contrast stimuli in training sessions.
410 """
412 if perf_easy is None:
413 print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, "
414 f"{sess_dates[2]}]")
415 elif psych_20 is None:
416 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
417 f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
418 f"N trials={[nt for nt in n_trials]} "
419 f"\nPsych fit over last 3 sessions: "
420 f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, "
421 f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} "
422 f"\nMedian reaction time at 0 contrast over last 3 sessions = "
423 f"{np.around(rt,2)}")
425 else:
426 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
427 f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
428 f"N trials={[nt for nt in n_trials]} "
429 f"\nPsych fit over last 3 sessions (20): "
430 f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, "
431 f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} "
432 f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, "
433 f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, "
434 f"lapse_high={np.around(psych_80[3],2)} "
435 f"\nMedian reaction time at 0 contrast over last 3 sessions = "
436 f"{np.around(rt, 2)}")
439def concatenate_trials(trials):
440 """
441 Concatenate trials from different training sessions.
443 Parameters
444 ----------
445 trials : dict of str
446 Dictionary of trials objects where each key is the ISO session date string.
448 Returns
449 -------
450 one.alf.io.AlfBunch
451 Trials object with data concatenated over three training sessions.
452 """
453 trials_all = AlfBunch() 1lgndheijkcaob
454 for k in TRIALS_KEYS: 1lgndheijkcaob
455 trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys())) 1lgndheijkcaob
457 return trials_all 1lgndheijkcaob
460def compute_training_info(trials, trials_all):
461 """
462 Compute all relevant performance metrics for when subject is on trainingChoiceWorld.
464 Parameters
465 ----------
466 trials : dict of str
467 Dictionary of trials objects where each key is the ISO session date string.
468 trials_all : one.alf.io.AlfBunch
469 Trials object with data concatenated over three training sessions.
471 Returns
472 -------
473 numpy.array
474 Proportion of correct high contrast trials for each session.
475 numpy.array
476 Total number of trials for each training session.
477 numpy.array
478 Array of psychometric parameters fit to `all_trials` - bias, threshold, lapse high,
479 lapse low.
480 float
481 The median response time for all zero-contrast trials across all sessions. Returns NaN if
482 no trials zero-contrast trials).
483 """
485 signed_contrast = get_signed_contrast(trials_all) 1gijkab
486 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1gijkab
487 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1gijkab
488 psych = compute_psychometric(trials_all, signed_contrast=signed_contrast) 1gijkab
489 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1gijkab
491 return perf_easy, n_trials, psych, rt 1gijkab
494def compute_bias_info(trials, trials_all):
495 """
496 Compute all relevant performance metrics for when subject is on biasedChoiceWorld
498 :param trials: dict containing trials objects from three consecutive training sessions,
499 keys are session dates
500 :type trials: Bunch
501 :param trials_all: trials object with data concatenated over three training sessions
502 :type trials_all: Bunch
503 :returns:
504 - perf_easy - performance of easy trials for each session
505 - n_trials - number of trials in each session
506 - psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions
507 - psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions
508 - rt - median reaction time for zero contrast stimuli over all sessions
509 """
511 signed_contrast = get_signed_contrast(trials_all) 1dhecab
512 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1dhecab
513 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1dhecab
514 psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2) 1dhecab
515 psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8) 1dhecab
516 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1dhecab
518 return perf_easy, n_trials, psych_20, psych_80, rt 1dhecab
521def get_signed_contrast(trials):
522 """
523 Compute signed contrast from trials object
525 :param trials: trials object that must contain contrastLeft and contrastRight keys
526 :type trials: dict
527 returns: array of signed contrasts in percent, where -ve values are on the left
528 """
529 # Replace NaNs with zeros, stack and take the difference
530 contrast = np.nan_to_num(np.c_[trials['contrastLeft'], trials['contrastRight']]) 1lgndheijkcab
531 return np.diff(contrast).flatten() * 100 1lgndheijkcab
534def compute_performance_easy(trials):
535 """
536 Compute performance on easy trials (stimulus >= 50 %) from trials object
538 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
539 keys
540 :type trials: dict
541 returns: float containing performance on easy contrast trials
542 """
543 signed_contrast = get_signed_contrast(trials) 1lgdheijkcab
544 easy_trials = np.where(np.abs(signed_contrast) >= 50)[0] 1lgdheijkcab
545 return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0] 1lgdheijkcab
548def compute_performance(trials, signed_contrast=None, block=None, prob_right=False):
549 """
550 Compute performance on all trials at each contrast level from trials object
552 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
553 keys
554 :type trials: dict
555 returns: float containing performance on easy contrast trials
556 """
557 if signed_contrast is None: 1lgdheijkcab
558 signed_contrast = get_signed_contrast(trials) 1cab
560 if block is None: 1lgdheijkcab
561 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1lgijkcab
562 else:
563 block_idx = trials.probabilityLeft == block 1dhecab
565 if not np.any(block_idx): 1lgdheijkcab
566 return np.nan * np.zeros(3) 1ca
568 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 1lgdheijkcab
570 if not prob_right: 1lgdheijkcab
571 correct = trials.feedbackType == 1
572 performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts)
573 else:
574 rightward = trials.choice == -1 1lgdheijkcab
575 # Calculate the proportion rightward for each contrast type
576 performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts) 1lgdheijkcab
578 return performance, contrasts, n_contrasts 1lgdheijkcab
581def compute_n_trials(trials):
582 """
583 Compute number of trials in trials object
585 :param trials: trials object
586 :type trials: dict
587 returns: int containing number of trials in session
588 """
589 return trials['choice'].shape[0] 1lgdheijkcab
592def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False, compute_ci=False, alpha=.032):
593 """
594 Compute psychometric fit parameters for trials object.
596 Parameters
597 ----------
598 trials : one.alf.io.AlfBunch
599 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
600 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
601 signed_contrast : numpy.array
602 An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
603 If None, these are computed from the trials object.
604 block : float
605 The block type to compute. If None, all trials are included, otherwise only trials where
606 probabilityLeft matches this value are included. For biasedChoiceWorld, the
607 probabilityLeft set is {0.5, 0.2, 0.8}.
608 plotting : bool
609 Which set of psychofit model parameters to use (see notes).
610 compute_ci : bool
611 If true, computes and returns the confidence intervals for response at each contrast.
612 alpha : float, default=0.032
613 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false,
614 this value is ignored.
616 Returns
617 -------
618 numpy.array
619 Array of psychometric fit parameters - bias, threshold, lapse high, lapse low.
620 (tuple of numpy.array)
621 If `compute_ci` is true, a tuple of
623 See Also
624 --------
625 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence
626 interval.
627 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters.
629 Notes
630 -----
631 The psychofit starting parameters and model constraints used for the fit when computing the
632 training status (e.g. trained_1a, etc.) are sub-optimal and can produce a poor fit. To keep
633 the precise criteria the same for all subjects, these parameters have not changed. To produce a
634 better fit for plotting purposes, or to calculate the training status in a manner inconsistent
635 with the IBL training pipeline, use plotting=True.
636 """
638 if signed_contrast is None: 1lgndheijkcab
639 signed_contrast = get_signed_contrast(trials) 1lncab
641 if block is None: 1lgndheijkcab
642 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1lgijkab
643 else:
644 block_idx = trials.probabilityLeft == block 1ndhecab
646 if not np.any(block_idx): 1lgndheijkcab
647 return np.nan * np.zeros(4) 1ncab
649 prob_choose_right, contrasts, n_contrasts = compute_performance( 1lgdheijkcab
650 trials, signed_contrast=signed_contrast, block=block, prob_right=True)
652 if plotting: 1lgdheijkcab
653 # These starting parameters and constraints tend to produce a better fit, and are therefore
654 # used for plotting.
655 psych, _ = psy.mle_fit_psycho( 1ca
656 np.vstack([contrasts, n_contrasts, prob_choose_right]),
657 P_model='erf_psycho_2gammas',
658 parstart=np.array([0., 40., 0.1, 0.1]),
659 parmin=np.array([-50., 10., 0., 0.]),
660 parmax=np.array([50., 50., 0.2, 0.2]),
661 nfits=10)
662 else:
663 # These starting parameters and constraints are not ideal but are still used for computing
664 # the training status for consistency.
665 psych, _ = psy.mle_fit_psycho( 1lgdheijkcab
666 np.vstack([contrasts, n_contrasts, prob_choose_right]),
667 P_model='erf_psycho_2gammas',
668 parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]),
669 parmin=np.array([np.min(contrasts), 0., 0., 0.]),
670 parmax=np.array([np.max(contrasts), 100., 1, 1]))
672 if compute_ci: 1lgdheijkcab
673 import statsmodels.stats.proportion as smp # noqa
674 # choice == -1 means contrast on right hand side
675 n_right = np.vectorize(lambda x: np.sum(trials['choice'][(x == signed_contrast) & block_idx] == -1))(contrasts)
676 ci = smp.proportion_confint(n_right, n_contrasts, alpha=alpha, method='normal') - prob_choose_right
678 return psych, ci
679 else:
680 return psych 1lgdheijkcab
683def compute_median_reaction_time(trials, stim_on_type='stimOn_times', contrast=None, signed_contrast=None):
684 """
685 Compute median response time on zero contrast trials from trials object
687 Parameters
688 ----------
689 trials : one.alf.io.AlfBunch
690 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
691 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
692 stim_on_type : str, default='stimOn_times'
693 The trials key to use when calculating the response times. The difference between this and
694 'feedback_times' is used (see notes).
695 contrast : float
696 If None, the median response time is calculated for all trials, regardless of contrast,
697 otherwise only trials where the matching signed percent contrast was presented are used.
698 signed_contrast : numpy.array
699 An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
700 If None, these are computed from the trials object.
702 Returns
703 -------
704 float
705 The median response time for trials with `contrast` (returns NaN if no trials matching
706 `contrast` in trials object).
708 Notes
709 -----
710 - The `stim_on_type` is 'stimOn_times' by default, however for IBL rig data, the photodiode is
711 sometimes not calibrated properly which can lead to inaccurate (or absent, i.e. NaN) stim on
712 times. Therefore, it is sometimes more accurate to use the 'stimOnTrigger_times' (the time of
713 the stimulus onset command), if available, or the 'goCue_times' (the time of the soundcard
714 output TTL when the audio go cue is played) or the 'goCueTrigger_times' (the time of the
715 audio go cue command).
716 - The response/reaction time here is defined as the time between stim on and feedback, i.e. the
717 entire open-loop trial duration.
718 """
719 if signed_contrast is None: 1lgdheijkcab
720 signed_contrast = get_signed_contrast(trials) 1lcab
722 if contrast is None: 1lgdheijkcab
723 contrast_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1cab
724 else:
725 contrast_idx = signed_contrast == contrast 1lgdheijkcab
727 if np.any(contrast_idx): 1lgdheijkcab
728 reaction_time = np.nanmedian((trials.response_times - trials[stim_on_type]) 1lgdheijkcab
729 [contrast_idx])
730 else:
731 reaction_time = np.nan 1ab
733 return reaction_time 1lgdheijkcab
736def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='response_times', signed_contrast=None, block=None,
737 compute_ci=False, alpha=0.32):
738 """
739 Compute median response time for all contrasts.
741 Parameters
742 ----------
743 trials : one.alf.io.AlfBunch
744 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
745 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
746 stim_on_type : str, default='stimOn_times'
747 The trials key to use when calculating the response times. The difference between this and
748 `stim_off_type` is used (see notes).
749 stim_off_type : str, default='response_times'
750 The trials key to use when calculating the response times. The difference between this and
751 `stim_on_type` is used (see notes).
752 signed_contrast : numpy.array
753 An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
754 If None, these are computed from the trials object.
755 block : float
756 The block type to compute. If None, all trials are included, otherwise only trials where
757 probabilityLeft matches this value are included. For biasedChoiceWorld, the
758 probabilityLeft set is {0.5, 0.2, 0.8}.
759 compute_ci : bool
760 If true, computes and returns the confidence intervals for response time at each contrast.
761 alpha : float, default=0.32
762 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false,
763 this value is ignored.
765 Returns
766 -------
767 numpy.array
768 The median response times for each unique signed contrast.
769 numpy.array
770 The set of unique signed contrasts.
771 numpy.array
772 The number of trials for each unique signed contrast.
773 (numpy.array)
774 If `compute_ci` is true, an array of confidence intervals is return in the shape (n_trials,
775 2).
777 Notes
778 -----
779 - The response/reaction time by default is the time between stim on and response, i.e. the
780 entire open-loop trial duration. One could use 'stimOn_times' and 'firstMovement_times' to
781 get the true reaction time, or 'firstMovement_times' and 'response_times' to get the true
782 response times, or calculate the last movement onset times and calculate the true movement
783 times. See module examples for how to calculate this.
785 See Also
786 --------
787 scipy.stats.bootstrap - the function used to compute the confidence interval.
788 """
790 if signed_contrast is None:
791 signed_contrast = get_signed_contrast(trials)
793 if block is None:
794 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
795 else:
796 block_idx = trials.probabilityLeft == block
798 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
799 reaction_time = np.vectorize(
800 lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]),
801 otypes=[float]
802 )(contrasts)
804 if compute_ci:
805 ci = np.full((contrasts.size, 2), np.nan)
806 for i, x in enumerate(contrasts):
807 data = (trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]
808 bt = bootstrap((data,), np.nanmedian, confidence_level=1 - alpha)
809 ci[i, 0] = bt.confidence_interval.low
810 ci[i, 1] = bt.confidence_interval.high
812 return reaction_time, contrasts, n_contrasts, ci
813 else:
814 return reaction_time, contrasts, n_contrasts,
817def criterion_1a(psych, n_trials, perf_easy):
818 """
819 Returns bool indicating whether criteria for status 'trained_1a' are met.
821 Criteria
822 --------
823 - Bias is less than 16
824 - Threshold is less than 19
825 - Lapse rate on both sides is less than 0.2
826 - The total number of trials is greater than 200 for each session
827 - Performance on easy contrasts > 80% for all sessions
829 Parameters
830 ----------
831 psych : numpy.array
832 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold,
833 lapse high, lapse low.
834 n_trials : numpy.array of int
835 The number for trials for each session.
836 perf_easy : numpy.array of float
837 The proportion of correct high contrast trials for each session.
839 Returns
840 -------
841 bool
842 True if the criteria are met for 'trained_1a'.
844 Notes
845 -----
846 The parameter thresholds chosen here were originally determined by averaging the parameter fits
847 for a number of sessions determined to be of 'good' performance by an experimenter.
848 """
850 criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and 1gib
851 np.all(n_trials > 200) and np.all(perf_easy > 0.8))
852 return criterion 1gib
855def criterion_1b(psych, n_trials, perf_easy, rt):
856 """
857 Returns bool indicating whether criteria for trained_1b are met.
859 Criteria
860 --------
861 - Bias is less than 10
862 - Threshold is less than 20 (see notes)
863 - Lapse rate on both sides is less than 0.1
864 - The total number of trials is greater than 400 for each session
865 - Performance on easy contrasts > 90% for all sessions
866 - The median response time across all zero contrast trials is less than 2 seconds
868 Parameters
869 ----------
870 psych : numpy.array
871 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold,
872 lapse high, lapse low.
873 n_trials : numpy.array of int
874 The number for trials for each session.
875 perf_easy : numpy.array of float
876 The proportion of correct high contrast trials for each session.
877 rt : float
878 The median response time for zero contrast trials.
880 Returns
881 -------
882 bool
883 True if the criteria are met for 'trained_1b'.
885 Notes
886 -----
887 The parameter thresholds chosen here were originally chosen to be slightly stricter than 1a,
888 however it was decided to use round numbers so that readers would not assume a level of
889 precision that isn't there (remember, these parameters were not chosen with any rigor). This
890 regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the
891 slope of the psychometric curve may be slightly less steep than 1a.
892 """
893 criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and 1gijb
894 np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2)
895 return criterion 1gijb
898def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt):
899 """
900 Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met.
902 NB: The difference between these two is whether the sessions were acquired ot a recording rig
903 with a delay before the first trial. Neither of these two things are tested here.
905 Criteria
906 --------
907 - Lapse on both sides < 0.1 for both bias blocks
908 - Bias shift between blocks > 5
909 - Total number of trials > 400 for all sessions
910 - Performance on easy contrasts > 90% for all sessions
911 - Median response time for zero contrast stimuli < 2 seconds
913 Parameters
914 ----------
915 psych_20 : numpy.array
916 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
917 Parameters are bias, threshold, lapse high, lapse low.
918 psych_80 : numpy.array
919 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
920 Parameters are bias, threshold, lapse high, lapse low.
921 n_trials : numpy.array
922 The number of trials for each session (typically three consecutive sessions).
923 perf_easy : numpy.array
924 The proportion of correct high contrast trials for each session (typically three
925 consecutive sessions).
926 rt : float
927 The median response time for zero contrast trials.
929 Returns
930 -------
931 bool
932 True if subject passes the ready4ephysrig or ready4recording criteria.
933 """
935 criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse 1hecab
936 psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials
937 np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times
938 return criterion 1hecab
941def criterion_delay(n_trials, perf_easy):
942 """
943 Returns bool indicating whether criteria for 'ready4delay' is met.
945 Criteria
946 --------
947 - Total number of trials for any of the sessions is greater than 400
948 - Performance on easy contrasts is greater than 90% for any of the sessions
950 Parameters
951 ----------
952 n_trials : numpy.array of int
953 The number of trials for each session (typically three consecutive sessions).
954 perf_easy : numpy.array
955 The proportion of correct high contrast trials for each session (typically three
956 consecutive sessions).
958 Returns
959 -------
960 bool
961 True if subject passes the 'ready4delay' criteria.
962 """
963 criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9) 1dab
964 return criterion 1dab
967def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs):
968 """
969 Function to plot psychometric curve plots a la datajoint webpage.
971 Parameters
972 ----------
973 trials : one.alf.io.AlfBunch
974 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
975 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
976 ax : matplotlib.pyplot.Axes
977 An axis object to plot onto.
978 title : str
979 An optional plot title.
980 plot_ci : bool
981 If true, computes and plots the confidence intervals for response at each contrast.
982 ci_alpha : float, default=0.032
983 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false,
984 this value is ignored.
985 **kwargs
986 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
988 Returns
989 -------
990 matplotlib.pyplot.Figure
991 The figure handle containing the plot.
992 matplotlib.pyplot.Axes
993 The plotted axes.
995 See Also
996 --------
997 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence
998 interval.
999 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters.
1000 psychofit.erf_psycho_2gammas - The function used to transform contrast to response probability
1001 using the fit parameters.
1002 """
1004 signed_contrast = get_signed_contrast(trials) 1ca
1005 contrasts_fit = np.arange(-100, 100) 1ca
1007 prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True) 1ca
1008 out_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True, 1ca
1009 compute_ci=plot_ci, alpha=ci_alpha)
1010 pars_50 = out_50[0] if plot_ci else out_50 1ca
1011 prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit) 1ca
1013 prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True) 1ca
1014 out_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True, 1ca
1015 compute_ci=plot_ci, alpha=ci_alpha)
1016 pars_20 = out_20[0] if plot_ci else out_20 1ca
1017 prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit) 1ca
1019 prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True) 1ca
1020 out_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True, 1ca
1021 compute_ci=plot_ci, alpha=ci_alpha)
1022 pars_80 = out_80[0] if plot_ci else out_80 1ca
1023 prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit) 1ca
1025 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 1ca
1027 if not ax: 1ca
1028 fig, ax = plt.subplots(**kwargs)
1029 else:
1030 fig = plt.gcf() 1ca
1032 fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1]) 1ca
1033 data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1]) 1ca
1034 fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0]) 1ca
1035 data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0]) 1ca
1036 fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2]) 1ca
1037 data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2]) 1ca
1039 if plot_ci: 1ca
1040 errbar_50 = np.c_[np.abs(out_50[1][0]), np.abs(out_50[1][1])].T
1041 errbar_20 = np.c_[np.abs(out_20[1][0]), np.abs(out_20[1][1])].T
1042 errbar_80 = np.c_[np.abs(out_80[1][0]), np.abs(out_80[1][1])].T
1044 ax.errorbar(contrasts_50, prob_right_50, yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
1045 ax.errorbar(contrasts_20, prob_right_20, yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
1046 ax.errorbar(contrasts_80, prob_right_80, yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
1048 ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80], 1ca
1049 ['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'],
1050 loc='upper left')
1051 ax.set_ylim(-0.05, 1.05) 1ca
1052 ax.set_ylabel('Probability choosing right') 1ca
1053 ax.set_xlabel('Contrasts') 1ca
1054 if title: 1ca
1055 ax.set_title(title) 1ca
1057 return fig, ax 1ca
1060def plot_reaction_time(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.32, **kwargs):
1061 """
1062 Function to plot reaction time against contrast a la datajoint webpage.
1064 The reaction times are plotted individually for the following three blocks: {0.5, 0.2, 0.8}.
1066 Parameters
1067 ----------
1068 trials : one.alf.io.AlfBunch
1069 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
1070 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
1071 ax : matplotlib.pyplot.Axes
1072 An axis object to plot onto.
1073 title : str
1074 An optional plot title.
1075 plot_ci : bool
1076 If true, computes and plots the confidence intervals for response at each contrast.
1077 ci_alpha : float, default=0.32
1078 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false,
1079 this value is ignored.
1080 **kwargs
1081 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
1083 Returns
1084 -------
1085 matplotlib.pyplot.Figure
1086 The figure handle containing the plot.
1087 matplotlib.pyplot.Axes
1088 The plotted axes.
1090 See Also
1091 --------
1092 scipy.stats.bootstrap - the function used to compute the confidence interval.
1093 """
1095 signed_contrast = get_signed_contrast(trials)
1096 out_50 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5, compute_ci=plot_ci, alpha=ci_alpha)
1097 out_20 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2, compute_ci=plot_ci, alpha=ci_alpha)
1098 out_80 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8, compute_ci=plot_ci, alpha=ci_alpha)
1100 cmap = sns.diverging_palette(20, 220, n=3, center='dark')
1102 if not ax:
1103 fig, ax = plt.subplots(**kwargs)
1104 else:
1105 fig = plt.gcf()
1107 data_50 = ax.plot(out_50[1], out_50[0], '-o', color=cmap[1])
1108 data_20 = ax.plot(out_20[1], out_20[0], '-o', color=cmap[0])
1109 data_80 = ax.plot(out_80[1], out_80[0], '-o', color=cmap[2])
1111 if plot_ci:
1112 errbar_50 = np.c_[out_50[0] - out_50[3][:, 0], out_50[3][:, 1] - out_50[0]].T
1113 errbar_20 = np.c_[out_20[0] - out_20[3][:, 0], out_20[3][:, 1] - out_20[0]].T
1114 errbar_80 = np.c_[out_80[0] - out_80[3][:, 0], out_80[3][:, 1] - out_80[0]].T
1116 ax.errorbar(out_50[1], out_50[0], yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
1117 ax.errorbar(out_20[1], out_20[0], yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
1118 ax.errorbar(out_80[1], out_80[0], yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
1120 ax.legend([data_50[0], data_20[0], data_80[0]],
1121 ['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'],
1122 loc='upper left')
1123 ax.set_ylabel('Reaction time (s)')
1124 ax.set_xlabel('Contrasts')
1126 if title:
1127 ax.set_title(title)
1129 return fig, ax
1132def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs):
1133 """
1134 Function to plot reaction time with trial number a la datajoint webpage.
1136 Parameters
1137 ----------
1138 trials : one.alf.io.AlfBunch
1139 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
1140 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
1141 stim_on_type : str, default='stimOn_times'
1142 The trials key to use when calculating the response times. The difference between this and
1143 'feedback_times' is used (see notes for `compute_median_reaction_time`).
1144 ax : matplotlib.pyplot.Axes
1145 An axis object to plot onto.
1146 title : str
1147 An optional plot title.
1148 **kwargs
1149 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
1151 Returns
1152 -------
1153 matplotlib.pyplot.Figure
1154 The figure handle containing the plot.
1155 matplotlib.pyplot.Axes
1156 The plotted axes.
1157 """
1159 reaction_time = pd.DataFrame()
1160 reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type]
1161 reaction_time.index = reaction_time.index + 1
1162 reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median()
1163 reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None)
1164 reaction_time = reaction_time.where((pd.notnull(reaction_time)), None)
1166 if not ax:
1167 fig, ax = plt.subplots(**kwargs)
1168 else:
1169 fig = plt.gcf()
1171 ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray')
1172 ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2)
1173 ax.set_yscale('log')
1174 ax.set_ylim(0.1, 100)
1175 ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
1176 ax.set_ylabel('Reaction time (s)')
1177 ax.set_xlabel('Trial number')
1178 if title:
1179 ax.set_title(title)
1181 return fig, ax
1184def query_criterion(subject, status, from_status=None, one=None, validate=True):
1185 """Get the session for which a given training criterion was met.
1187 Parameters
1188 ----------
1189 subject : str
1190 The subject name.
1191 status : str
1192 The training status to query for.
1193 from_status : str, optional
1194 Count number of sessions and days from reaching `from_status` to `status`.
1195 one : one.api.OneAlyx, optional
1196 An instance of ONE.
1197 validate : bool
1198 If true, check if status in TrainingStatus enumeration. Set to false for non-standard
1199 training pipelines.
1201 Returns
1202 -------
1203 str
1204 The eID of the first session where this training status was reached.
1205 int
1206 The number of sessions it took to reach `status` (optionally from reaching `from_status`).
1207 int
1208 The number of days it tool to reach `status` (optionally from reaching `from_status`).
1209 """
1210 if validate: 1m
1211 status = status.lower().replace(' ', '_') 1m
1212 try: 1m
1213 status = TrainingStatus[status.upper().replace(' ', '_')].name.lower() 1m
1214 except KeyError as ex: 1m
1215 raise ValueError( 1m
1216 f'Unknown status "{status}". For non-standard training protocols set validate=False'
1217 ) from ex
1218 one = one or ONE() 1m
1219 subject_json = one.alyx.rest('subjects', 'read', id=subject)['json'] 1m
1220 if not (criteria := subject_json.get('trained_criteria')) or status not in criteria: 1m
1221 return None, None, None 1m
1222 to_date, eid = criteria[status] 1m
1223 from_date, _ = criteria.get(from_status, (None, None)) 1m
1224 eids, det = one.search(subject=subject, date_range=[from_date, to_date], details=True) 1m
1225 if len(eids) == 0: 1m
1226 return eid, None, None 1m
1227 delta_date = det[0]['date'] - det[-1]['date'] 1m
1228 return eid, len(eids), delta_date.days 1m