Coverage for brainbox/behavior/training.py: 70%
350 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1"""Computing and testing IBL training status criteria.
3For an in-depth description of each training status, see `Appendix 2`_ of the IBL Protocol For Mice
4Training.
6.. _Appendix 2: https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_\
7measure_decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729
9Examples
10--------
11Plot the psychometric curve for a given session.
13>>> trials = ONE().load_object(eid, 'trials')
14>>> fix, ax = plot_psychometric(trials)
16Compute 'response times', defined as the duration of open-loop for each contrast.
18>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(trials)
20Compute 'reaction times', defined as the time between go cue and first detected movement.
21NB: These may be negative!
23>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
24... trials, stim_on_type='goCue_times', stim_off_type='firstMovement_times')
26Compute 'response times', defined as the time between first detected movement and response.
28>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
29... trials, stim_on_type='firstMovement_times', stim_off_type='response_times')
31Compute 'movement times', defined as the time between last detected movement and response threshold.
33>>> import brainbox.behavior.wheel as wh
34>>> wheel_moves = ONE().load_object(eid, 'wheeMoves')
35>>> trials['lastMovement_times'] = wh.get_movement_onset(wheel_moves.intervals, trial_data.response_times)
36>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
37... trials, stim_on_type='lastMovement_times', stim_off_type='response_times')
39"""
40import logging
41import datetime
42import re
43from enum import IntFlag, auto, unique
45import numpy as np
46import matplotlib
47import matplotlib.pyplot as plt
48import seaborn as sns
49import pandas as pd
50from scipy.stats import bootstrap
51from iblutil.util import Bunch
52from one.api import ONE
53from one.alf.io import AlfBunch
54from one.alf.exceptions import ALFObjectNotFound
56import psychofit as psy
58_logger = logging.getLogger('ibllib')
60TRIALS_KEYS = ['contrastLeft',
61 'contrastRight',
62 'feedbackType',
63 'probabilityLeft',
64 'choice',
65 'response_times',
66 'stimOn_times']
67"""list of str: The required keys in the trials object for computing training status."""
70@unique
71class TrainingStatus(IntFlag):
72 """Standard IBL training criteria.
74 Enumeration allows for comparisons between training status.
76 Examples
77 --------
78 >>> status = 'ready4delay'
79 ... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY
80 ... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training'
81 ... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained'
82 ... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained'
83 ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained'
84 ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY
86 # Get the next training status
87 >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()])
88 <TrainingStatus.READY4RECORDING: 128>
90 Notes
91 -----
92 - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b.
93 - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it
94 is possible to have skipped the TRAINED_1A session status.
95 """
96 UNTRAINABLE = auto()
97 UNBIASABLE = auto()
98 IN_TRAINING = auto()
99 TRAINED_1A = auto()
100 TRAINED_1B = auto()
101 READY4EPHYSRIG = auto()
102 READY4DELAY = auto()
103 READY4RECORDING = auto()
104 # Compound training statuses for convenience
105 FAILED = UNTRAINABLE | UNBIASABLE
106 READY = READY4EPHYSRIG | READY4DELAY | READY4RECORDING
107 TRAINED = TRAINED_1A | TRAINED_1B
110def get_lab_training_status(lab, date=None, details=True, one=None):
111 """
112 Computes the training status of all alive and water restricted subjects in a specified lab.
114 The response are printed to std out.
116 Parameters
117 ----------
118 lab : str
119 Lab name (must match the name registered on Alyx).
120 date : str
121 The ISO date from which to compute training status. If not specified will compute from the
122 latest date with available data. Format should be 'YYYY-MM-DD'.
123 details : bool
124 Whether to display all information about training status computation e.g. performance,
125 number of trials, psychometric fit parameters.
126 one : one.api.OneAlyx
127 An instance of ONE.
129 """
130 one = one or ONE()
131 subj_lab = one.alyx.rest('subjects', 'list', lab=lab, alive=True, water_restricted=True)
132 subjects = [subj['nickname'] for subj in subj_lab]
133 for subj in subjects:
134 get_subject_training_status(subj, date=date, details=details, one=one)
137def get_subject_training_status(subj, date=None, details=True, one=None):
138 """
139 Computes the training status of specified subject and prints results to std out.
141 Parameters
142 ----------
143 subj : str
144 Subject nickname (must match the name registered on Alyx).
145 date : str
146 The ISO date from which to compute training status. If not specified will compute from the
147 latest date with available data. Format should be 'YYYY-MM-DD'.
148 details : bool
149 Whether to display all information about training status computation e.g. performance,
150 number of trials, psychometric fit parameters.
151 one : one.api.OneAlyx
152 An instance of ONE.
153 """
154 one = one or ONE()
156 trials, task_protocol, ephys_sess, n_delay = get_sessions(subj, date=date, one=one)
157 if not trials:
158 return
159 sess_dates = list(trials.keys())
160 status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay)
162 if details:
163 if np.any(info.get('psych')):
164 display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
165 n_trials=info.n_trials, psych=info.psych, rt=info.rt)
166 elif np.any(info.get('psych_20')):
167 display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
168 n_trials=info.n_trials, psych_20=info.psych_20, psych_80=info.psych_80,
169 rt=info.rt)
170 else:
171 display_status(subj, sess_dates, status)
174def get_sessions(subj, date=None, one=None):
175 """
176 Download and load in training data for a specified subject. If a date is given it will load
177 data from the three (or as many as are available) previous sessions up to the specified date.
178 If not it will load data from the last three training sessions that have data available.
180 Parameters
181 ----------
182 subj : str
183 Subject nickname (must match the name registered on Alyx).
184 date : str
185 The ISO date from which to compute training status. If not specified will compute from the
186 latest date with available data. Format should be 'YYYY-MM-DD'.
187 one : one.api.OneAlyx
188 An instance of ONE.
190 Returns
191 -------
192 iblutil.util.Bunch
193 Dictionary of trials objects where each key is the ISO session date string.
194 list of str
195 List of the task protocol used for each of the sessions.
196 list of str
197 List of ISO date strings where training was conducted on ephys rig. Empty list if all
198 sessions on training rig.
199 n_delay : int
200 Number of sessions on ephys rig that had delay prior to starting session > 15min.
201 Returns 0 if no sessions detected.
202 """
203 one = one or ONE()
205 if date is None:
206 # compute from yesterday
207 specified_date = (datetime.date.today() - datetime.timedelta(days=1))
208 latest_sess = specified_date.strftime("%Y-%m-%d")
209 latest_minus_week = (datetime.date.today() -
210 datetime.timedelta(days=8)).strftime("%Y-%m-%d")
211 else:
212 # compute from the date specified
213 specified_date = datetime.datetime.strptime(date, '%Y-%m-%d')
214 latest_minus_week = (specified_date - datetime.timedelta(days=7)).strftime("%Y-%m-%d")
215 latest_sess = date
217 sessions = one.alyx.rest('sessions', 'list', subject=subj, date_range=[latest_minus_week,
218 latest_sess], dataset_types='trials.goCueTrigger_times')
220 # If not enough sessions in the last week, then just fetch them all
221 if len(sessions) < 3:
222 specified_date_plus = (specified_date + datetime.timedelta(days=1)).strftime("%Y-%m-%d")
223 django_query = 'start_time__lte,' + specified_date_plus
224 sessions = one.alyx.rest('sessions', 'list', subject=subj,
225 dataset_types='trials.goCueTrigger_times', django=django_query)
227 # If still 0 sessions then return with warning
228 if len(sessions) == 0:
229 _logger.warning(f"No training sessions detected for {subj}")
230 return [None] * 4
232 trials = Bunch()
233 task_protocol = []
234 sess_dates = []
235 if len(sessions) < 3:
236 for n, _ in enumerate(sessions):
237 try:
238 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
239 except ALFObjectNotFound:
240 trials_ = None
242 if trials_:
243 task_protocol.append(re.search('tasks_(.*)Choice',
244 sessions[n]['task_protocol']).group(1))
245 sess_dates.append(sessions[n]['start_time'][:10])
246 trials[sessions[n]['start_time'][:10]] = trials_
248 else:
249 n = 0
250 while len(trials) < 3:
251 print(sessions[n]['url'].split('/')[-1])
252 try:
253 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
254 except ALFObjectNotFound:
255 trials_ = None
257 if trials_:
258 task_protocol.append(re.search('tasks_(.*)Choice',
259 sessions[n]['task_protocol']).group(1))
260 sess_dates.append(sessions[n]['start_time'][:10])
261 trials[sessions[n]['start_time'][:10]] = trials_
263 n += 1
265 if not np.any(np.array(task_protocol) == 'training'):
266 ephys_sess = one.alyx.rest('sessions', 'list', subject=subj,
267 date_range=[sess_dates[-1], sess_dates[0]],
268 django='json__PYBPOD_BOARD__icontains,ephys')
269 if len(ephys_sess) > 0:
270 ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess]
272 n_delay = len(one.alyx.rest('sessions', 'list', subject=subj,
273 date_range=[sess_dates[-1], sess_dates[0]],
274 django='json__SESSION_START_DELAY_SEC__gte,900'))
275 else:
276 ephys_sess_dates = []
277 n_delay = 0
278 else:
279 ephys_sess_dates = []
280 n_delay = 0
282 return trials, task_protocol, ephys_sess_dates, n_delay
285def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay):
286 """
287 Compute training status of a subject from consecutive training datasets.
289 For IBL, training status is calculated using trials from the last three consecutive sessions.
291 Parameters
292 ----------
293 trials : dict of str
294 Dictionary of trials objects where each key is the ISO session date string.
295 task_protocol : list of str
296 Task protocol used for each training session in `trials`, can be 'training', 'biased' or
297 'ephys'.
298 ephys_sess_dates : list of str
299 List of ISO date strings where training was conducted on ephys rig. Empty list if all
300 sessions on training rig.
301 n_delay : int
302 Number of sessions on ephys rig that had delay prior to starting session > 15min.
303 Returns 0 if no sessions detected.
305 Returns
306 -------
307 str
308 Training status of the subject.
309 iblutil.util.Bunch
310 Bunch containing performance metrics that decide training status i.e. performance on easy
311 trials, number of trials, psychometric fit parameters, reaction time.
312 """
314 info = Bunch() 1heifjkldabc
315 trials_all = concatenate_trials(trials) 1heifjkldabc
317 # Case when all sessions are trainingChoiceWorld
318 if np.all(np.array(task_protocol) == 'training'): 1heifjkldabc
319 signed_contrast = get_signed_contrast(trials_all) 1hjkbc
320 (info.perf_easy, info.n_trials, 1hjkbc
321 info.psych, info.rt) = compute_training_info(trials, trials_all)
322 if not np.any(signed_contrast == 0): 1hjkbc
323 status = 'in training' 1bc
324 else:
325 if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt): 1hjkc
326 status = 'trained 1b' 1kc
327 elif criterion_1a(info.psych, info.n_trials, info.perf_easy): 1hjc
328 status = 'trained 1a' 1jc
329 else:
330 status = 'in training' 1h
332 return status, info 1hjkbc
334 # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion
335 if ~np.all(np.array(task_protocol) == 'training') and \ 1eifldabc
336 np.any(np.array(task_protocol) == 'training'):
337 status = 'trained 1b' 1lbc
338 (info.perf_easy, info.n_trials, 1lbc
339 info.psych, info.rt) = compute_training_info(trials, trials_all)
341 return status, info 1lbc
343 # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions
344 if not np.any(np.array(task_protocol) == 'training'): 1eifdabc
346 (info.perf_easy, info.n_trials, 1eifdabc
347 info.psych_20, info.psych_80,
348 info.rt) = compute_bias_info(trials, trials_all)
349 # We are still on training rig and so all sessions should be biased
350 if len(ephys_sess_dates) == 0: 1eifdabc
351 assert np.all(np.array(task_protocol) == 'biased') 1idbc
352 if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, 1idbc
353 info.rt):
354 status = 'ready4ephysrig' 1ibc
355 else:
356 status = 'trained 1b' 1dbc
358 elif len(ephys_sess_dates) < 3: 1efac
359 assert all(date in trials for date in ephys_sess_dates) 1eac
360 perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in 1eac
361 ephys_sess_dates])
362 n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) 1eac
364 if criterion_delay(n_ephys_trials, perf_ephys_easy): 1eac
365 status = 'ready4delay' 1eac
366 else:
367 status = 'ready4ephysrig'
369 elif len(ephys_sess_dates) >= 3: 1fc
370 if n_delay > 0 and \ 1fc
371 criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
372 info.rt):
373 status = 'ready4recording' 1fc
374 elif criterion_delay(info.n_trials, info.perf_easy): 1c
375 status = 'ready4delay' 1c
376 else:
377 status = 'ready4ephysrig'
379 return status, info 1eifdabc
382def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None,
383 psych_20=None, psych_80=None, rt=None):
384 """
385 Display training status of subject to terminal.
387 Parameters
388 ----------
389 subj : str
390 Subject nickname (must match the name registered on Alyx).
391 sess_dates : list of str
392 ISO date strings of training sessions used to determine training status.
393 status : str
394 Training status of subject.
395 perf_easy : numpy.array
396 Proportion of correct high contrast trials for each training session.
397 n_trials : numpy.array
398 Total number of trials for each training session.
399 psych : numpy.array
400 Psychometric parameters fit to data from all training sessions - bias, threshold, lapse
401 high, lapse low.
402 psych_20 : numpy.array
403 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
404 psych_80 : numpy.array
405 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
406 rt : float
407 The median response time for zero contrast trials across all training sessions. NaN
408 indicates no zero contrast stimuli in training sessions.
410 """
412 if perf_easy is None:
413 print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, "
414 f"{sess_dates[2]}]")
415 elif psych_20 is None:
416 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
417 f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
418 f"N trials={[nt for nt in n_trials]} "
419 f"\nPsych fit over last 3 sessions: "
420 f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, "
421 f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} "
422 f"\nMedian reaction time at 0 contrast over last 3 sessions = "
423 f"{np.around(rt,2)}")
425 else:
426 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
427 f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
428 f"N trials={[nt for nt in n_trials]} "
429 f"\nPsych fit over last 3 sessions (20): "
430 f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, "
431 f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} "
432 f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, "
433 f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, "
434 f"lapse_high={np.around(psych_80[3],2)} "
435 f"\nMedian reaction time at 0 contrast over last 3 sessions = "
436 f"{np.around(rt, 2)}")
439def concatenate_trials(trials):
440 """
441 Concatenate trials from different training sessions.
443 Parameters
444 ----------
445 trials : dict of str
446 Dictionary of trials objects where each key is the ISO session date string.
448 Returns
449 -------
450 one.alf.io.AlfBunch
451 Trials object with data concatenated over three training sessions.
452 """
453 trials_all = AlfBunch() 1mhoeifjkldabpc
454 for k in TRIALS_KEYS: 1mhoeifjkldabpc
455 trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys())) 1mhoeifjkldabpc
457 return trials_all 1mhoeifjkldabpc
460def compute_training_info(trials, trials_all):
461 """
462 Compute all relevant performance metrics for when subject is on trainingChoiceWorld.
464 Parameters
465 ----------
466 trials : dict of str
467 Dictionary of trials objects where each key is the ISO session date string.
468 trials_all : one.alf.io.AlfBunch
469 Trials object with data concatenated over three training sessions.
471 Returns
472 -------
473 numpy.array
474 Proportion of correct high contrast trials for each session.
475 numpy.array
476 Total number of trials for each training session.
477 numpy.array
478 Array of psychometric parameters fit to `all_trials` - bias, threshold, lapse high,
479 lapse low.
480 float
481 The median response time for all zero-contrast trials across all sessions. Returns NaN if
482 no trials zero-contrast trials).
483 """
485 signed_contrast = get_signed_contrast(trials_all) 1hjklbc
486 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1hjklbc
487 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1hjklbc
488 psych = compute_psychometric(trials_all, signed_contrast=signed_contrast) 1hjklbc
489 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1hjklbc
491 return perf_easy, n_trials, psych, rt 1hjklbc
494def compute_bias_info(trials, trials_all):
495 """
496 Compute all relevant performance metrics for when subject is on biasedChoiceWorld
498 :param trials: dict containing trials objects from three consecutive training sessions,
499 keys are session dates
500 :type trials: Bunch
501 :param trials_all: trials object with data concatenated over three training sessions
502 :type trials_all: Bunch
503 :returns:
504 - perf_easy - performance of easy trials for each session
505 - n_trials - number of trials in each session
506 - psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions
507 - psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions
508 - rt - median reaction time for zero contrast stimuli over all sessions
509 """
511 signed_contrast = get_signed_contrast(trials_all) 1eifdabc
512 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1eifdabc
513 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1eifdabc
514 psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2) 1eifdabc
515 psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8) 1eifdabc
516 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1eifdabc
518 return perf_easy, n_trials, psych_20, psych_80, rt 1eifdabc
521def get_signed_contrast(trials):
522 """
523 Compute signed contrast from trials object
525 :param trials: trials object that must contain contrastLeft and contrastRight keys
526 :type trials: dict
527 returns: array of signed contrasts in percent, where -ve values are on the left
528 """
529 # Replace NaNs with zeros, stack and take the difference
530 contrast = np.nan_to_num(np.c_[trials['contrastLeft'], trials['contrastRight']]) 1mhoeifjkldabc
531 return np.diff(contrast).flatten() * 100 1mhoeifjkldabc
534def compute_performance_easy(trials):
535 """
536 Compute performance on easy trials (stimulus >= 50 %) from trials object
538 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
539 keys
540 :type trials: dict
541 returns: float containing performance on easy contrast trials
542 """
543 signed_contrast = get_signed_contrast(trials) 1mheifjkldabc
544 easy_trials = np.where(np.abs(signed_contrast) >= 50)[0] 1mheifjkldabc
545 return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0] 1mheifjkldabc
548def compute_performance(trials, signed_contrast=None, block=None, prob_right=False):
549 """
550 Compute performance on all trials at each contrast level from trials object
552 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
553 keys
554 :type trials: dict
555 returns: float containing performance on easy contrast trials
556 """
557 if signed_contrast is None: 1mheifjkldabc
558 signed_contrast = get_signed_contrast(trials) 1dabc
560 if block is None: 1mheifjkldabc
561 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1mhjkldabc
562 else:
563 block_idx = trials.probabilityLeft == block 1eifdabc
565 if not np.any(block_idx): 1mheifjkldabc
566 return np.nan * np.zeros(3) 1db
568 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 1mheifjkldabc
570 if not prob_right: 1mheifjkldabc
571 correct = trials.feedbackType == 1
572 performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts)
573 else:
574 rightward = trials.choice == -1 1mheifjkldabc
575 # Calculate the proportion rightward for each contrast type
576 performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts) 1mheifjkldabc
578 return performance, contrasts, n_contrasts 1mheifjkldabc
581def compute_n_trials(trials):
582 """
583 Compute number of trials in trials object
585 :param trials: trials object
586 :type trials: dict
587 returns: int containing number of trials in session
588 """
589 return trials['choice'].shape[0] 1mheifjkldabc
592def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False, compute_ci=False, alpha=.032):
593 """
594 Compute psychometric fit parameters for trials object.
596 Parameters
597 ----------
598 trials : one.alf.io.AlfBunch
599 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
600 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
601 signed_contrast : numpy.array
602 An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
603 If None, these are computed from the trials object.
604 block : float
605 The block type to compute. If None, all trials are included, otherwise only trials where
606 probabilityLeft matches this value are included. For biasedChoiceWorld, the
607 probabilityLeft set is {0.5, 0.2, 0.8}.
608 plotting : bool
609 Which set of psychofit model parameters to use (see notes).
610 compute_ci : bool
611 If true, computes and returns the confidence intervals for response at each contrast.
612 alpha : float, default=0.032
613 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false,
614 this value is ignored.
616 Returns
617 -------
618 numpy.array
619 Array of psychometric fit parameters - bias, threshold, lapse high, lapse low.
620 (tuple of numpy.array)
621 If `compute_ci` is true, a tuple of
623 See Also
624 --------
625 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence
626 interval.
627 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters.
629 Notes
630 -----
631 The psychofit starting parameters and model constraints used for the fit when computing the
632 training status (e.g. trained_1a, etc.) are sub-optimal and can produce a poor fit. To keep
633 the precise criteria the same for all subjects, these parameters have not changed. To produce a
634 better fit for plotting purposes, or to calculate the training status in a manner inconsistent
635 with the IBL training pipeline, use plotting=True.
636 """
638 if signed_contrast is None: 1mhoeifjkldabc
639 signed_contrast = get_signed_contrast(trials) 1modabc
641 if block is None: 1mhoeifjkldabc
642 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1mhjklbc
643 else:
644 block_idx = trials.probabilityLeft == block 1oeifdabc
646 if not np.any(block_idx): 1mhoeifjkldabc
647 return np.nan * np.zeros(4) 1odbc
649 prob_choose_right, contrasts, n_contrasts = compute_performance( 1mheifjkldabc
650 trials, signed_contrast=signed_contrast, block=block, prob_right=True)
652 if plotting: 1mheifjkldabc
653 # These starting parameters and constraints tend to produce a better fit, and are therefore
654 # used for plotting.
655 psych, _ = psy.mle_fit_psycho( 1dab
656 np.vstack([contrasts, n_contrasts, prob_choose_right]),
657 P_model='erf_psycho_2gammas',
658 parstart=np.array([0., 40., 0.1, 0.1]),
659 parmin=np.array([-50., 10., 0., 0.]),
660 parmax=np.array([50., 50., 0.2, 0.2]),
661 nfits=10)
662 else:
663 # These starting parameters and constraints are not ideal but are still used for computing
664 # the training status for consistency.
665 psych, _ = psy.mle_fit_psycho( 1mheifjkldabc
666 np.vstack([contrasts, n_contrasts, prob_choose_right]),
667 P_model='erf_psycho_2gammas',
668 parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]),
669 parmin=np.array([np.min(contrasts), 0., 0., 0.]),
670 parmax=np.array([np.max(contrasts), 100., 1, 1]))
672 if compute_ci: 1mheifjkldabc
673 import statsmodels.stats.proportion as smp # noqa
674 # choice == -1 means contrast on right hand side
675 n_right = np.vectorize(lambda x: np.sum(trials['choice'][(x == signed_contrast) & block_idx] == -1))(contrasts)
676 ci = smp.proportion_confint(n_right, n_contrasts, alpha=alpha, method='normal') - prob_choose_right
678 return psych, ci
679 else:
680 return psych 1mheifjkldabc
683def compute_median_reaction_time(trials, stim_on_type='stimOn_times', contrast=None, signed_contrast=None):
684 """
685 Compute median response time on zero contrast trials from trials object
687 Parameters
688 ----------
689 trials : one.alf.io.AlfBunch
690 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
691 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
692 stim_on_type : str, default='stimOn_times'
693 The trials key to use when calculating the response times. The difference between this and
694 'feedback_times' is used (see notes).
695 contrast : float
696 If None, the median response time is calculated for all trials, regardless of contrast,
697 otherwise only trials where the matching signed percent contrast was presented are used.
698 signed_contrast : numpy.array
699 An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
700 If None, these are computed from the trials object.
702 Returns
703 -------
704 float
705 The median response time for trials with `contrast` (returns NaN if no trials matching
706 `contrast` in trials object).
708 Notes
709 -----
710 - The `stim_on_type` is 'stimOn_times' by default, however for IBL rig data, the photodiode is
711 sometimes not calibrated properly which can lead to inaccurate (or absent, i.e. NaN) stim on
712 times. Therefore, it is sometimes more accurate to use the 'stimOnTrigger_times' (the time of
713 the stimulus onset command), if available, or the 'goCue_times' (the time of the soundcard
714 output TTL when the audio go cue is played) or the 'goCueTrigger_times' (the time of the
715 audio go cue command).
716 - The response/reaction time here is defined as the time between stim on and feedback, i.e. the
717 entire open-loop trial duration.
718 """
719 if signed_contrast is None: 1mheifjkldabc
720 signed_contrast = get_signed_contrast(trials) 1mdabc
722 if contrast is None: 1mheifjkldabc
723 contrast_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1dabc
724 else:
725 contrast_idx = signed_contrast == contrast 1mheifjkldabc
727 if np.any(contrast_idx): 1mheifjkldabc
728 reaction_time = np.nanmedian((trials.response_times - trials[stim_on_type]) 1mheifjkldabc
729 [contrast_idx])
730 else:
731 reaction_time = np.nan 1bc
733 return reaction_time 1mheifjkldabc
736def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='response_times', signed_contrast=None, block=None,
737 compute_ci=False, alpha=0.32):
738 """
739 Compute median response time for all contrasts.
741 Parameters
742 ----------
743 trials : one.alf.io.AlfBunch
744 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
745 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
746 stim_on_type : str, default='stimOn_times'
747 The trials key to use when calculating the response times. The difference between this and
748 `stim_off_type` is used (see notes).
749 stim_off_type : str, default='response_times'
750 The trials key to use when calculating the response times. The difference between this and
751 `stim_on_type` is used (see notes).
752 signed_contrast : numpy.array
753 An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
754 If None, these are computed from the trials object.
755 block : float
756 The block type to compute. If None, all trials are included, otherwise only trials where
757 probabilityLeft matches this value are included. For biasedChoiceWorld, the
758 probabilityLeft set is {0.5, 0.2, 0.8}.
759 compute_ci : bool
760 If true, computes and returns the confidence intervals for response time at each contrast.
761 alpha : float, default=0.32
762 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false,
763 this value is ignored.
765 Returns
766 -------
767 numpy.array
768 The median response times for each unique signed contrast.
769 numpy.array
770 The set of unique signed contrasts.
771 numpy.array
772 The number of trials for each unique signed contrast.
773 (numpy.array)
774 If `compute_ci` is true, an array of confidence intervals is return in the shape (n_trials,
775 2).
777 Notes
778 -----
779 - The response/reaction time by default is the time between stim on and response, i.e. the
780 entire open-loop trial duration. One could use 'stimOn_times' and 'firstMovement_times' to
781 get the true reaction time, or 'firstMovement_times' and 'response_times' to get the true
782 response times, or calculate the last movement onset times and calculate the true movement
783 times. See module examples for how to calculate this.
785 See Also
786 --------
787 scipy.stats.bootstrap - the function used to compute the confidence interval.
788 """
790 if signed_contrast is None: 1a
791 signed_contrast = get_signed_contrast(trials)
793 if block is None: 1a
794 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
795 else:
796 block_idx = trials.probabilityLeft == block 1a
798 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 1a
799 reaction_time = np.vectorize(lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type]) 1a
800 [(x == signed_contrast) & block_idx]))(contrasts)
801 if compute_ci: 1a
802 ci = np.full((contrasts.size, 2), np.nan)
803 for i, x in enumerate(contrasts):
804 data = (trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]
805 bt = bootstrap((data,), np.nanmedian, confidence_level=1 - alpha)
806 ci[i, 0] = bt.confidence_interval.low
807 ci[i, 1] = bt.confidence_interval.high
809 return reaction_time, contrasts, n_contrasts, ci
810 else:
811 return reaction_time, contrasts, n_contrasts, 1a
814def criterion_1a(psych, n_trials, perf_easy):
815 """
816 Returns bool indicating whether criteria for status 'trained_1a' are met.
818 Criteria
819 --------
820 - Bias is less than 16
821 - Threshold is less than 19
822 - Lapse rate on both sides is less than 0.2
823 - The total number of trials is greater than 200 for each session
824 - Performance on easy contrasts > 80% for all sessions
826 Parameters
827 ----------
828 psych : numpy.array
829 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold,
830 lapse high, lapse low.
831 n_trials : numpy.array of int
832 The number for trials for each session.
833 perf_easy : numpy.array of float
834 The proportion of correct high contrast trials for each session.
836 Returns
837 -------
838 bool
839 True if the criteria are met for 'trained_1a'.
841 Notes
842 -----
843 The parameter thresholds chosen here were originally determined by averaging the parameter fits
844 for a number of sessions determined to be of 'good' performance by an experimenter.
845 """
847 criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and 1hjc
848 np.all(n_trials > 200) and np.all(perf_easy > 0.8))
849 return criterion 1hjc
852def criterion_1b(psych, n_trials, perf_easy, rt):
853 """
854 Returns bool indicating whether criteria for trained_1b are met.
856 Criteria
857 --------
858 - Bias is less than 10
859 - Threshold is less than 20 (see notes)
860 - Lapse rate on both sides is less than 0.1
861 - The total number of trials is greater than 400 for each session
862 - Performance on easy contrasts > 90% for all sessions
863 - The median response time across all zero contrast trials is less than 2 seconds
865 Parameters
866 ----------
867 psych : numpy.array
868 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold,
869 lapse high, lapse low.
870 n_trials : numpy.array of int
871 The number for trials for each session.
872 perf_easy : numpy.array of float
873 The proportion of correct high contrast trials for each session.
874 rt : float
875 The median response time for zero contrast trials.
877 Returns
878 -------
879 bool
880 True if the criteria are met for 'trained_1b'.
882 Notes
883 -----
884 The parameter thresholds chosen here were originally chosen to be slightly stricter than 1a,
885 however it was decided to use round numbers so that readers would not assume a level of
886 precision that isn't there (remember, these parameters were not chosen with any rigor). This
887 regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the
888 slope of the psychometric curve may be slightly less steep than 1a.
889 """
890 criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and 1hjkc
891 np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2)
892 return criterion 1hjkc
895def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt):
896 """
897 Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met.
899 NB: The difference between these two is whether the sessions were acquired ot a recording rig
900 with a delay before the first trial. Neither of these two things are tested here.
902 Criteria
903 --------
904 - Lapse on both sides < 0.1 for both bias blocks
905 - Bias shift between blocks > 5
906 - Total number of trials > 400 for all sessions
907 - Performance on easy contrasts > 90% for all sessions
908 - Median response time for zero contrast stimuli < 2 seconds
910 Parameters
911 ----------
912 psych_20 : numpy.array
913 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
914 Parameters are bias, threshold, lapse high, lapse low.
915 psych_80 : numpy.array
916 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
917 Parameters are bias, threshold, lapse high, lapse low.
918 n_trials : numpy.array
919 The number of trials for each session (typically three consecutive sessions).
920 perf_easy : numpy.array
921 The proportion of correct high contrast trials for each session (typically three
922 consecutive sessions).
923 rt : float
924 The median response time for zero contrast trials.
926 Returns
927 -------
928 bool
929 True if subject passes the ready4ephysrig or ready4recording criteria.
930 """
932 criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse 1ifdbc
933 psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials
934 np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times
935 return criterion 1ifdbc
938def criterion_delay(n_trials, perf_easy):
939 """
940 Returns bool indicating whether criteria for 'ready4delay' is met.
942 Criteria
943 --------
944 - Total number of trials for any of the sessions is greater than 400
945 - Performance on easy contrasts is greater than 90% for any of the sessions
947 Parameters
948 ----------
949 n_trials : numpy.array of int
950 The number of trials for each session (typically three consecutive sessions).
951 perf_easy : numpy.array
952 The proportion of correct high contrast trials for each session (typically three
953 consecutive sessions).
955 Returns
956 -------
957 bool
958 True if subject passes the 'ready4delay' criteria.
959 """
960 criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9) 1eac
961 return criterion 1eac
964def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs):
965 """
966 Function to plot psychometric curve plots a la datajoint webpage.
968 Parameters
969 ----------
970 trials : one.alf.io.AlfBunch
971 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
972 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
973 ax : matplotlib.pyplot.Axes
974 An axis object to plot onto.
975 title : str
976 An optional plot title.
977 plot_ci : bool
978 If true, computes and plots the confidence intervals for response at each contrast.
979 ci_alpha : float, default=0.032
980 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false,
981 this value is ignored.
982 **kwargs
983 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
985 Returns
986 -------
987 matplotlib.pyplot.Figure
988 The figure handle containing the plot.
989 matplotlib.pyplot.Axes
990 The plotted axes.
992 See Also
993 --------
994 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence
995 interval.
996 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters.
997 psychofit.erf_psycho_2gammas - The function used to transform contrast to response probability
998 using the fit parameters.
999 """
1001 signed_contrast = get_signed_contrast(trials) 1dab
1002 contrasts_fit = np.arange(-100, 100) 1dab
1004 prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True) 1dab
1005 out_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True, 1dab
1006 compute_ci=plot_ci, alpha=ci_alpha)
1007 pars_50 = out_50[0] if plot_ci else out_50 1dab
1008 prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit) 1dab
1010 prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True) 1dab
1011 out_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True, 1dab
1012 compute_ci=plot_ci, alpha=ci_alpha)
1013 pars_20 = out_20[0] if plot_ci else out_20 1dab
1014 prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit) 1dab
1016 prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True) 1dab
1017 out_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True, 1dab
1018 compute_ci=plot_ci, alpha=ci_alpha)
1019 pars_80 = out_80[0] if plot_ci else out_80 1dab
1020 prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit) 1dab
1022 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 1dab
1024 if not ax: 1dab
1025 fig, ax = plt.subplots(**kwargs) 1a
1026 else:
1027 fig = plt.gcf() 1dab
1029 fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1]) 1dab
1030 data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1]) 1dab
1031 fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0]) 1dab
1032 data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0]) 1dab
1033 fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2]) 1dab
1034 data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2]) 1dab
1036 if plot_ci: 1dab
1037 errbar_50 = np.c_[np.abs(out_50[1][0]), np.abs(out_50[1][1])].T
1038 errbar_20 = np.c_[np.abs(out_20[1][0]), np.abs(out_20[1][1])].T
1039 errbar_80 = np.c_[np.abs(out_80[1][0]), np.abs(out_80[1][1])].T
1041 ax.errorbar(contrasts_50, prob_right_50, yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
1042 ax.errorbar(contrasts_20, prob_right_20, yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
1043 ax.errorbar(contrasts_80, prob_right_80, yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
1045 ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80], 1dab
1046 ['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'],
1047 loc='upper left')
1048 ax.set_ylim(-0.05, 1.05) 1dab
1049 ax.set_ylabel('Probability choosing right') 1dab
1050 ax.set_xlabel('Contrasts') 1dab
1051 if title: 1dab
1052 ax.set_title(title) 1dab
1054 return fig, ax 1dab
1057def plot_reaction_time(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.32, **kwargs):
1058 """
1059 Function to plot reaction time against contrast a la datajoint webpage.
1061 The reaction times are plotted individually for the following three blocks: {0.5, 0.2, 0.8}.
1063 Parameters
1064 ----------
1065 trials : one.alf.io.AlfBunch
1066 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
1067 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
1068 ax : matplotlib.pyplot.Axes
1069 An axis object to plot onto.
1070 title : str
1071 An optional plot title.
1072 plot_ci : bool
1073 If true, computes and plots the confidence intervals for response at each contrast.
1074 ci_alpha : float, default=0.32
1075 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false,
1076 this value is ignored.
1077 **kwargs
1078 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
1080 Returns
1081 -------
1082 matplotlib.pyplot.Figure
1083 The figure handle containing the plot.
1084 matplotlib.pyplot.Axes
1085 The plotted axes.
1087 See Also
1088 --------
1089 scipy.stats.bootstrap - the function used to compute the confidence interval.
1090 """
1092 signed_contrast = get_signed_contrast(trials) 1a
1093 out_50 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5, compute_ci=plot_ci, alpha=ci_alpha) 1a
1094 out_20 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2, compute_ci=plot_ci, alpha=ci_alpha) 1a
1095 out_80 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8, compute_ci=plot_ci, alpha=ci_alpha) 1a
1097 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 1a
1099 if not ax: 1a
1100 fig, ax = plt.subplots(**kwargs) 1a
1101 else:
1102 fig = plt.gcf()
1104 data_50 = ax.plot(out_50[1], out_50[0], '-o', color=cmap[1]) 1a
1105 data_20 = ax.plot(out_20[1], out_20[0], '-o', color=cmap[0]) 1a
1106 data_80 = ax.plot(out_80[1], out_80[0], '-o', color=cmap[2]) 1a
1108 if plot_ci: 1a
1109 errbar_50 = np.c_[out_50[0] - out_50[3][:, 0], out_50[3][:, 1] - out_50[0]].T
1110 errbar_20 = np.c_[out_20[0] - out_20[3][:, 0], out_20[3][:, 1] - out_20[0]].T
1111 errbar_80 = np.c_[out_80[0] - out_80[3][:, 0], out_80[3][:, 1] - out_80[0]].T
1113 ax.errorbar(out_50[1], out_50[0], yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
1114 ax.errorbar(out_20[1], out_20[0], yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
1115 ax.errorbar(out_80[1], out_80[0], yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
1117 ax.legend([data_50[0], data_20[0], data_80[0]], 1a
1118 ['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'],
1119 loc='upper left')
1120 ax.set_ylabel('Reaction time (s)') 1a
1121 ax.set_xlabel('Contrasts') 1a
1123 if title: 1a
1124 ax.set_title(title) 1a
1126 return fig, ax 1a
1129def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs):
1130 """
1131 Function to plot reaction time with trial number a la datajoint webpage.
1133 Parameters
1134 ----------
1135 trials : one.alf.io.AlfBunch
1136 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
1137 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
1138 stim_on_type : str, default='stimOn_times'
1139 The trials key to use when calculating the response times. The difference between this and
1140 'feedback_times' is used (see notes for `compute_median_reaction_time`).
1141 ax : matplotlib.pyplot.Axes
1142 An axis object to plot onto.
1143 title : str
1144 An optional plot title.
1145 **kwargs
1146 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
1148 Returns
1149 -------
1150 matplotlib.pyplot.Figure
1151 The figure handle containing the plot.
1152 matplotlib.pyplot.Axes
1153 The plotted axes.
1154 """
1156 reaction_time = pd.DataFrame() 1a
1157 reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type] 1a
1158 reaction_time.index = reaction_time.index + 1 1a
1159 reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median() 1a
1160 reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None) 1a
1161 reaction_time = reaction_time.where((pd.notnull(reaction_time)), None) 1a
1163 if not ax: 1a
1164 fig, ax = plt.subplots(**kwargs) 1a
1165 else:
1166 fig = plt.gcf()
1168 ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray') 1a
1169 ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2) 1a
1170 ax.set_yscale('log') 1a
1171 ax.set_ylim(0.1, 100) 1a
1172 ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) 1a
1173 ax.set_ylabel('Reaction time (s)') 1a
1174 ax.set_xlabel('Trial number') 1a
1175 if title: 1a
1176 ax.set_title(title) 1a
1178 return fig, ax 1a
1181def query_criterion(subject, status, from_status=None, one=None, validate=True):
1182 """Get the session for which a given training criterion was met.
1184 Parameters
1185 ----------
1186 subject : str
1187 The subject name.
1188 status : str
1189 The training status to query for.
1190 from_status : str, optional
1191 Count number of sessions and days from reaching `from_status` to `status`.
1192 one : one.api.OneAlyx, optional
1193 An instance of ONE.
1194 validate : bool
1195 If true, check if status in TrainingStatus enumeration. Set to false for non-standard
1196 training pipelines.
1198 Returns
1199 -------
1200 str
1201 The eID of the first session where this training status was reached.
1202 int
1203 The number of sessions it took to reach `status` (optionally from reaching `from_status`).
1204 int
1205 The number of days it tool to reach `status` (optionally from reaching `from_status`).
1206 """
1207 if validate: 1n
1208 status = status.lower().replace(' ', '_') 1n
1209 try: 1n
1210 status = TrainingStatus[status.upper().replace(' ', '_')].name.lower() 1n
1211 except KeyError as ex: 1n
1212 raise ValueError( 1n
1213 f'Unknown status "{status}". For non-standard training protocols set validate=False'
1214 ) from ex
1215 one = one or ONE() 1n
1216 subject_json = one.alyx.rest('subjects', 'read', id=subject)['json'] 1n
1217 if not (criteria := subject_json.get('trained_criteria')) or status not in criteria: 1n
1218 return None, None, None 1n
1219 to_date, eid = criteria[status] 1n
1220 from_date, _ = criteria.get(from_status, (None, None)) 1n
1221 eids, det = one.search(subject=subject, date_range=[from_date, to_date], details=True) 1n
1222 if len(eids) == 0: 1n
1223 return eid, None, None 1n
1224 delta_date = det[0]['date'] - det[-1]['date'] 1n
1225 return eid, len(eids), delta_date.days 1n