Coverage for brainbox/behavior/training.py: 64%
403 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""Computing and testing IBL training status criteria.
3For an in-depth description of each training status, see `Appendix 2`_ of the IBL Protocol For Mice
4Training.
6.. _Appendix 2: https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_\
7measure_decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729
9Examples
10--------
11Plot the psychometric curve for a given session.
13>>> trials = ONE().load_object(eid, 'trials')
14>>> fix, ax = plot_psychometric(trials)
16Compute 'response times', defined as the duration of open-loop for each contrast.
18>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(trials)
20Compute 'reaction times', defined as the time between go cue and first detected movement.
21NB: These may be negative!
23>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
24... trials, stim_on_type='goCue_times', stim_off_type='firstMovement_times')
26Compute 'response times', defined as the time between first detected movement and response.
28>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
29... trials, stim_on_type='firstMovement_times', stim_off_type='response_times')
31Compute 'movement times', defined as the time between last detected movement and response threshold.
33>>> import brainbox.behavior.wheel as wh
34>>> wheel_moves = ONE().load_object(eid, 'wheeMoves')
35>>> trials['lastMovement_times'] = wh.get_movement_onset(wheel_moves.intervals, trial_data.response_times)
36>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
37... trials, stim_on_type='lastMovement_times', stim_off_type='response_times')
39"""
40import logging
41import datetime
42import re
43from enum import IntFlag, auto, unique
45import numpy as np
46import matplotlib
47import matplotlib.pyplot as plt
48import seaborn as sns
49import pandas as pd
50from scipy.stats import bootstrap
51from iblutil.util import Bunch
52from one.api import ONE
53from one.alf.io import AlfBunch
54from one.alf.exceptions import ALFObjectNotFound
55import psychofit as psy
57_logger = logging.getLogger('ibllib')
59TRIALS_KEYS = ['contrastLeft',
60 'contrastRight',
61 'feedbackType',
62 'probabilityLeft',
63 'choice',
64 'response_times',
65 'stimOn_times']
66"""list of str: The required keys in the trials object for computing training status."""
69@unique
70class TrainingStatus(IntFlag):
71 """Standard IBL training criteria.
73 Enumeration allows for comparisons between training status.
75 Examples
76 --------
77 >>> status = 'ready4delay'
78 ... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY
79 ... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training'
80 ... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained'
81 ... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained'
82 ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained'
83 ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY
85 Get the next training status
87 >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()])
88 <TrainingStatus.READY4RECORDING: 128>
90 Notes
91 -----
92 - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b.
93 - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it
94 is possible to have skipped the TRAINED_1A session status.
95 """
96 UNTRAINABLE = auto()
97 UNBIASABLE = auto()
98 IN_TRAINING = auto()
99 TRAINED_1A = auto()
100 TRAINED_1B = auto()
101 READY4EPHYSRIG = auto()
102 READY4DELAY = auto()
103 READY4RECORDING = auto()
104 # Compound training statuses for convenience
105 FAILED = UNTRAINABLE | UNBIASABLE
106 READY = READY4EPHYSRIG | READY4DELAY | READY4RECORDING
107 TRAINED = TRAINED_1A | TRAINED_1B
110def get_lab_training_status(lab, date=None, details=True, one=None):
111 """
112 Computes the training status of all alive and water restricted subjects in a specified lab.
114 The response are printed to std out.
116 Parameters
117 ----------
118 lab : str
119 Lab name (must match the name registered on Alyx).
120 date : str
121 The ISO date from which to compute training status. If not specified will compute from the
122 latest date with available data. Format should be 'YYYY-MM-DD'.
123 details : bool
124 Whether to display all information about training status computation e.g. performance,
125 number of trials, psychometric fit parameters.
126 one : one.api.OneAlyx
127 An instance of ONE.
129 """
130 one = one or ONE()
131 subj_lab = one.alyx.rest('subjects', 'list', lab=lab, alive=True, water_restricted=True)
132 subjects = [subj['nickname'] for subj in subj_lab]
133 for subj in subjects:
134 get_subject_training_status(subj, date=date, details=details, one=one)
137def get_subject_training_status(subj, date=None, details=True, one=None):
138 """
139 Computes the training status of specified subject and prints results to std out.
141 Parameters
142 ----------
143 subj : str
144 Subject nickname (must match the name registered on Alyx).
145 date : str
146 The ISO date from which to compute training status. If not specified will compute from the
147 latest date with available data. Format should be 'YYYY-MM-DD'.
148 details : bool
149 Whether to display all information about training status computation e.g. performance,
150 number of trials, psychometric fit parameters.
151 one : one.api.OneAlyx
152 An instance of ONE.
153 """
154 one = one or ONE()
156 trials, task_protocol, ephys_sess, n_delay = get_sessions(subj, date=date, one=one)
157 if not trials:
158 return
159 sess_dates = list(trials.keys())
160 status, info, _ = get_training_status(trials, task_protocol, ephys_sess, n_delay)
162 if details:
163 if np.any(info.get('psych')):
164 display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
165 n_trials=info.n_trials, psych=info.psych, rt=info.rt)
166 elif np.any(info.get('psych_20')):
167 display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
168 n_trials=info.n_trials, psych_20=info.psych_20, psych_80=info.psych_80,
169 rt=info.rt)
170 else:
171 display_status(subj, sess_dates, status)
174def get_sessions(subj, date=None, one=None):
175 """
176 Download and load in training data for a specified subject. If a date is given it will load
177 data from the three (or as many as are available) previous sessions up to the specified date.
178 If not it will load data from the last three training sessions that have data available.
180 Parameters
181 ----------
182 subj : str
183 Subject nickname (must match the name registered on Alyx).
184 date : str
185 The ISO date from which to compute training status. If not specified will compute from the
186 latest date with available data. Format should be 'YYYY-MM-DD'.
187 one : one.api.OneAlyx
188 An instance of ONE.
190 Returns
191 -------
192 iblutil.util.Bunch
193 Dictionary of trials objects where each key is the ISO session date string.
194 list of str
195 List of the task protocol used for each of the sessions.
196 list of str
197 List of ISO date strings where training was conducted on ephys rig. Empty list if all
198 sessions on training rig.
199 n_delay : int
200 Number of sessions on ephys rig that had delay prior to starting session > 15min.
201 Returns 0 if no sessions detected.
202 """
203 one = one or ONE()
205 if date is None:
206 # compute from yesterday
207 specified_date = (datetime.date.today() - datetime.timedelta(days=1))
208 latest_sess = specified_date.strftime("%Y-%m-%d")
209 latest_minus_week = (datetime.date.today() -
210 datetime.timedelta(days=8)).strftime("%Y-%m-%d")
211 else:
212 # compute from the date specified
213 specified_date = datetime.datetime.strptime(date, '%Y-%m-%d')
214 latest_minus_week = (specified_date - datetime.timedelta(days=7)).strftime("%Y-%m-%d")
215 latest_sess = date
217 sessions = one.alyx.rest('sessions', 'list', subject=subj, date_range=[latest_minus_week,
218 latest_sess], dataset_types='trials.goCueTrigger_times')
220 # If not enough sessions in the last week, then just fetch them all
221 if len(sessions) < 3:
222 specified_date_plus = (specified_date + datetime.timedelta(days=1)).strftime("%Y-%m-%d")
223 django_query = 'start_time__lte,' + specified_date_plus
224 sessions = one.alyx.rest('sessions', 'list', subject=subj,
225 dataset_types='trials.goCueTrigger_times', django=django_query)
227 # If still 0 sessions then return with warning
228 if len(sessions) == 0:
229 _logger.warning(f"No training sessions detected for {subj}")
230 return [None] * 4
232 trials = Bunch()
233 task_protocol = []
234 sess_dates = []
235 if len(sessions) < 3:
236 for n, _ in enumerate(sessions):
237 try:
238 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
239 except ALFObjectNotFound:
240 trials_ = None
242 if trials_:
243 task_protocol.append(re.search('tasks_(.*)Choice',
244 sessions[n]['task_protocol']).group(1))
245 sess_dates.append(sessions[n]['start_time'][:10])
246 trials[sessions[n]['start_time'][:10]] = trials_
248 else:
249 n = 0
250 while len(trials) < 3:
251 print(sessions[n]['url'].split('/')[-1])
252 try:
253 trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
254 except ALFObjectNotFound:
255 trials_ = None
257 if trials_:
258 task_protocol.append(re.search('tasks_(.*)Choice',
259 sessions[n]['task_protocol']).group(1))
260 sess_dates.append(sessions[n]['start_time'][:10])
261 trials[sessions[n]['start_time'][:10]] = trials_
263 n += 1
265 if not np.any(np.array(task_protocol) == 'training'):
266 ephys_sess = one.alyx.rest('sessions', 'list', subject=subj,
267 date_range=[sess_dates[-1], sess_dates[0]],
268 django='location__name__icontains,ephys')
269 if len(ephys_sess) > 0:
270 ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess]
272 n_delay = len(one.alyx.rest('sessions', 'list', subject=subj,
273 date_range=[sess_dates[-1], sess_dates[0]],
274 django='json__SESSION_DELAY_START__gte,900'))
275 else:
276 ephys_sess_dates = []
277 n_delay = 0
278 else:
279 ephys_sess_dates = []
280 n_delay = 0
282 return trials, task_protocol, ephys_sess_dates, n_delay
285def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay):
286 """
287 Compute training status of a subject from consecutive training datasets.
289 For IBL, training status is calculated using trials from the last three consecutive sessions.
291 Parameters
292 ----------
293 trials : dict of str
294 Dictionary of trials objects where each key is the ISO session date string.
295 task_protocol : list of str
296 Task protocol used for each training session in `trials`, can be 'training', 'biased' or
297 'ephys'.
298 ephys_sess_dates : list of str
299 List of ISO date strings where training was conducted on ephys rig. Empty list if all
300 sessions on training rig.
301 n_delay : int
302 Number of sessions on ephys rig that had delay prior to starting session > 15min.
303 Returns 0 if no sessions detected.
305 Returns
306 -------
307 str
308 Training status of the subject.
309 iblutil.util.Bunch
310 Bunch containing performance metrics that decide training status i.e. performance on easy
311 trials, number of trials, psychometric fit parameters, reaction time.
312 """
314 info = Bunch() 1edcgfhjba
315 trials_all = concatenate_trials(trials) 1edcgfhjba
316 info.session_dates = list(trials.keys()) 1edcgfhjba
317 info.protocols = [p for p in task_protocol] 1edcgfhjba
319 # Case when all sessions are trainingChoiceWorld
320 if np.all(np.array(task_protocol) == 'training'): 1edcgfhjba
321 signed_contrast = np.unique(get_signed_contrast(trials_all)) 1efha
322 (info.perf_easy, info.n_trials, 1efha
323 info.psych, info.rt) = compute_training_info(trials, trials_all)
325 pass_criteria, criteria = criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt, 1efha
326 signed_contrast)
327 if pass_criteria: 1efha
328 failed_criteria = Bunch() 1ha
329 failed_criteria['NBiased'] = {'val': info.protocols, 'pass': False} 1ha
330 failed_criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} 1ha
331 status = 'trained 1b' 1ha
332 else:
333 failed_criteria = criteria 1efa
334 pass_criteria, criteria = criterion_1a(info.psych, info.n_trials, info.perf_easy, signed_contrast) 1efa
335 if pass_criteria: 1efa
336 status = 'trained 1a' 1fa
337 else:
338 failed_criteria = criteria 1ea
339 status = 'in training' 1ea
341 return status, info, failed_criteria 1efha
343 # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion
344 if ~np.all(np.array(task_protocol) == 'training') and \ 1dcgjba
345 np.any(np.array(task_protocol) == 'training'):
346 status = 'trained 1b' 1ja
347 (info.perf_easy, info.n_trials, 1ja
348 info.psych, info.rt) = compute_training_info(trials, trials_all)
350 criteria = Bunch() 1ja
351 criteria['NBiased'] = {'val': info.protocols, 'pass': False} 1ja
352 criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} 1ja
354 return status, info, criteria 1ja
356 # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions
357 if not np.any(np.array(task_protocol) == 'training'): 1dcgba
359 (info.perf_easy, info.n_trials, 1dcgba
360 info.psych_20, info.psych_80,
361 info.rt) = compute_bias_info(trials, trials_all)
363 n_ephys = len(ephys_sess_dates) 1dcgba
364 info.n_ephys = n_ephys 1dcgba
365 info.n_delay = n_delay 1dcgba
367 # Criterion recording
368 pass_criteria, criteria = criteria_recording(n_ephys, n_delay, info.psych_20, info.psych_80, info.n_trials, 1dcgba
369 info.perf_easy, info.rt)
370 if pass_criteria: 1dcgba
371 # Here the criteria doesn't actually fail but we have no other criteria to meet so we return this
372 failed_criteria = criteria 1ga
373 status = 'ready4recording' 1ga
374 else:
375 failed_criteria = criteria 1dcba
376 assert all(date in trials for date in ephys_sess_dates) 1dcba
377 perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in 1dcba
378 ephys_sess_dates])
379 n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) 1dcba
381 pass_criteria, criteria = criterion_delay(n_ephys_trials, perf_ephys_easy, n_ephys=n_ephys) 1dcba
383 if pass_criteria: 1dcba
384 status = 'ready4delay' 1da
385 else:
386 failed_criteria = criteria 1cba
387 pass_criteria, criteria = criterion_ephys(info.psych_20, info.psych_80, info.n_trials, 1cba
388 info.perf_easy, info.rt)
389 if pass_criteria: 1cba
390 status = 'ready4ephysrig' 1ca
391 else:
392 failed_criteria = criteria 1ba
393 status = 'trained 1b' 1ba
395 return status, info, failed_criteria 1dcgba
398def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None,
399 psych_20=None, psych_80=None, rt=None):
400 """
401 Display training status of subject to terminal.
403 Parameters
404 ----------
405 subj : str
406 Subject nickname (must match the name registered on Alyx).
407 sess_dates : list of str
408 ISO date strings of training sessions used to determine training status.
409 status : str
410 Training status of subject.
411 perf_easy : numpy.array
412 Proportion of correct high contrast trials for each training session.
413 n_trials : numpy.array
414 Total number of trials for each training session.
415 psych : numpy.array
416 Psychometric parameters fit to data from all training sessions - bias, threshold, lapse
417 high, lapse low.
418 psych_20 : numpy.array
419 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
420 psych_80 : numpy.array
421 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
422 rt : float
423 The median response time for zero contrast trials across all training sessions. NaN
424 indicates no zero contrast stimuli in training sessions.
426 """
428 if perf_easy is None:
429 print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, "
430 f"{sess_dates[2]}]")
431 elif psych_20 is None:
432 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
433 f"Perf easy={[np.around(pe, 2) for pe in perf_easy]}, "
434 f"N trials={[nt for nt in n_trials]} "
435 f"\nPsych fit over last 3 sessions: "
436 f"bias={np.around(psych[0], 2)}, thres={np.around(psych[1], 2)}, "
437 f"lapse_low={np.around(psych[2], 2)}, lapse_high={np.around(psych[3], 2)} "
438 f"\nMedian reaction time at 0 contrast over last 3 sessions = "
439 f"{np.around(rt, 2)}")
441 else:
442 print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
443 f"Perf easy={[np.around(pe, 2) for pe in perf_easy]}, "
444 f"N trials={[nt for nt in n_trials]} "
445 f"\nPsych fit over last 3 sessions (20): "
446 f"bias={np.around(psych_20[0], 2)}, thres={np.around(psych_20[1], 2)}, "
447 f"lapse_low={np.around(psych_20[2], 2)}, lapse_high={np.around(psych_20[3], 2)} "
448 f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0], 2)}, "
449 f"thres={np.around(psych_80[1], 2)}, lapse_low={np.around(psych_80[2], 2)}, "
450 f"lapse_high={np.around(psych_80[3], 2)} "
451 f"\nMedian reaction time at 0 contrast over last 3 sessions = "
452 f"{np.around(rt, 2)}")
455def concatenate_trials(trials):
456 """
457 Concatenate trials from different training sessions.
459 Parameters
460 ----------
461 trials : dict of str
462 Dictionary of trials objects where each key is the ISO session date string.
464 Returns
465 -------
466 one.alf.io.AlfBunch
467 Trials object with data concatenated over three training sessions.
468 """
469 trials_all = AlfBunch() 1kemdcgfhjbna
470 for k in TRIALS_KEYS: 1kemdcgfhjbna
471 trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys())) 1kemdcgfhjbna
473 return trials_all 1kemdcgfhjbna
476def compute_training_info(trials, trials_all):
477 """
478 Compute all relevant performance metrics for when subject is on trainingChoiceWorld.
480 Parameters
481 ----------
482 trials : dict of str
483 Dictionary of trials objects where each key is the ISO session date string.
484 trials_all : one.alf.io.AlfBunch
485 Trials object with data concatenated over three training sessions.
487 Returns
488 -------
489 numpy.array
490 Proportion of correct high contrast trials for each session.
491 numpy.array
492 Total number of trials for each training session.
493 numpy.array
494 Array of psychometric parameters fit to `all_trials` - bias, threshold, lapse high,
495 lapse low.
496 float
497 The median response time for all zero-contrast trials across all sessions. Returns NaN if
498 no trials zero-contrast trials).
499 """
501 signed_contrast = get_signed_contrast(trials_all) 1efhja
502 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1efhja
503 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1efhja
504 psych = compute_psychometric(trials_all, signed_contrast=signed_contrast) 1efhja
505 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1efhja
507 return perf_easy, n_trials, psych, rt 1efhja
510def compute_bias_info(trials, trials_all):
511 """
512 Compute all relevant performance metrics for when subject is on biasedChoiceWorld
514 :param trials: dict containing trials objects from three consecutive training sessions,
515 keys are session dates
516 :type trials: Bunch
517 :param trials_all: trials object with data concatenated over three training sessions
518 :type trials_all: Bunch
519 :returns:
520 - perf_easy - performance of easy trials for each session
521 - n_trials - number of trials in each session
522 - psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions
523 - psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions
524 - rt - median reaction time for zero contrast stimuli over all sessions
525 """
527 signed_contrast = get_signed_contrast(trials_all) 1dcgba
528 perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) 1dcgba
529 n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) 1dcgba
530 psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2) 1dcgba
531 psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8) 1dcgba
532 rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) 1dcgba
534 return perf_easy, n_trials, psych_20, psych_80, rt 1dcgba
537def get_signed_contrast(trials):
538 """
539 Compute signed contrast from trials object
541 :param trials: trials object that must contain contrastLeft and contrastRight keys
542 :type trials: dict
543 returns: array of signed contrasts in percent, where -ve values are on the left
544 """
545 # Replace NaNs with zeros, stack and take the difference
546 contrast = np.nan_to_num(np.c_[trials['contrastLeft'], trials['contrastRight']]) 1kemdcgfhjba
547 return np.diff(contrast).flatten() * 100 1kemdcgfhjba
550def compute_performance_easy(trials):
551 """
552 Compute performance on easy trials (stimulus >= 50 %) from trials object
554 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
555 keys
556 :type trials: dict
557 returns: float containing performance on easy contrast trials
558 """
559 signed_contrast = get_signed_contrast(trials) 1kedcgfhjba
560 easy_trials = np.where(np.abs(signed_contrast) >= 50)[0] 1kedcgfhjba
561 return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0] 1kedcgfhjba
564def compute_performance(trials, signed_contrast=None, block=None, prob_right=False):
565 """
566 Compute performance on all trials at each contrast level from trials object
568 :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
569 keys
570 :type trials: dict
571 returns: float containing performance on easy contrast trials
572 """
573 if signed_contrast is None: 1kedcgfhjba
574 signed_contrast = get_signed_contrast(trials) 1ba
576 if block is None: 1kedcgfhjba
577 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1kefhjba
578 else:
579 block_idx = trials.probabilityLeft == block 1dcgba
581 if not np.any(block_idx): 1kedcgfhjba
582 return np.nan * np.zeros(3) 1b
584 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) 1kedcgfhjba
586 if not prob_right: 1kedcgfhjba
587 correct = trials.feedbackType == 1
588 performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts)
589 else:
590 rightward = trials.choice == -1 1kedcgfhjba
591 # Calculate the proportion rightward for each contrast type
592 performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts) 1kedcgfhjba
594 return performance, contrasts, n_contrasts 1kedcgfhjba
597def compute_n_trials(trials):
598 """
599 Compute number of trials in trials object
601 :param trials: trials object
602 :type trials: dict
603 returns: int containing number of trials in session
604 """
605 return trials['choice'].shape[0] 1kedcgfhjba
608def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False, compute_ci=False, alpha=.032):
609 """
610 Compute psychometric fit parameters for trials object.
612 Parameters
613 ----------
614 trials : one.alf.io.AlfBunch
615 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
616 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
617 signed_contrast : numpy.array
618 An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
619 If None, these are computed from the trials object.
620 block : float
621 The block type to compute. If None, all trials are included, otherwise only trials where
622 probabilityLeft matches this value are included. For biasedChoiceWorld, the
623 probabilityLeft set is {0.5, 0.2, 0.8}.
624 plotting : bool
625 Which set of psychofit model parameters to use (see notes).
626 compute_ci : bool
627 If true, computes and returns the confidence intervals for response at each contrast.
628 alpha : float, default=0.032
629 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false,
630 this value is ignored.
632 Returns
633 -------
634 numpy.array
635 Array of psychometric fit parameters - bias, threshold, lapse high, lapse low.
636 (tuple of numpy.array)
637 If `compute_ci` is true, a tuple of
639 See Also
640 --------
641 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence
642 interval.
643 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters.
645 Notes
646 -----
647 The psychofit starting parameters and model constraints used for the fit when computing the
648 training status (e.g. trained_1a, etc.) are sub-optimal and can produce a poor fit. To keep
649 the precise criteria the same for all subjects, these parameters have not changed. To produce a
650 better fit for plotting purposes, or to calculate the training status in a manner inconsistent
651 with the IBL training pipeline, use plotting=True.
652 """
654 if signed_contrast is None: 1kemdcgfhjba
655 signed_contrast = get_signed_contrast(trials) 1kmba
657 if block is None: 1kemdcgfhjba
658 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1kefhja
659 else:
660 block_idx = trials.probabilityLeft == block 1mdcgba
662 if not np.any(block_idx): 1kemdcgfhjba
663 return np.nan * np.zeros(4) 1mba
665 prob_choose_right, contrasts, n_contrasts = compute_performance( 1kedcgfhjba
666 trials, signed_contrast=signed_contrast, block=block, prob_right=True)
668 if plotting: 1kedcgfhjba
669 # These starting parameters and constraints tend to produce a better fit, and are therefore
670 # used for plotting.
671 psych, _ = psy.mle_fit_psycho( 1b
672 np.vstack([contrasts, n_contrasts, prob_choose_right]),
673 P_model='erf_psycho_2gammas',
674 parstart=np.array([0., 40., 0.1, 0.1]),
675 parmin=np.array([-50., 10., 0., 0.]),
676 parmax=np.array([50., 50., 0.2, 0.2]),
677 nfits=10)
678 else:
679 # These starting parameters and constraints are not ideal but are still used for computing
680 # the training status for consistency.
681 psych, _ = psy.mle_fit_psycho( 1kedcgfhjba
682 np.vstack([contrasts, n_contrasts, prob_choose_right]),
683 P_model='erf_psycho_2gammas',
684 parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]),
685 parmin=np.array([np.min(contrasts), 0., 0., 0.]),
686 parmax=np.array([np.max(contrasts), 100., 1, 1]))
688 if compute_ci: 1kedcgfhjba
689 import statsmodels.stats.proportion as smp # noqa
690 # choice == -1 means contrast on right hand side
691 n_right = np.vectorize(lambda x: np.sum(trials['choice'][(x == signed_contrast) & block_idx] == -1))(contrasts)
692 ci = smp.proportion_confint(n_right, n_contrasts, alpha=alpha, method='normal') - prob_choose_right
694 return psych, ci
695 else:
696 return psych 1kedcgfhjba
699def compute_median_reaction_time(trials, stim_on_type='stimOn_times', contrast=None, signed_contrast=None):
700 """
701 Compute median response time on zero contrast trials from trials object
703 Parameters
704 ----------
705 trials : one.alf.io.AlfBunch
706 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
707 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
708 stim_on_type : str, default='stimOn_times'
709 The trials key to use when calculating the response times. The difference between this and
710 'feedback_times' is used (see notes).
711 contrast : float
712 If None, the median response time is calculated for all trials, regardless of contrast,
713 otherwise only trials where the matching signed percent contrast was presented are used.
714 signed_contrast : numpy.array
715 An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
716 If None, these are computed from the trials object.
718 Returns
719 -------
720 float
721 The median response time for trials with `contrast` (returns NaN if no trials matching
722 `contrast` in trials object).
724 Notes
725 -----
726 - The `stim_on_type` is 'stimOn_times' by default, however for IBL rig data, the photodiode is
727 sometimes not calibrated properly which can lead to inaccurate (or absent, i.e. NaN) stim on
728 times. Therefore, it is sometimes more accurate to use the 'stimOnTrigger_times' (the time of
729 the stimulus onset command), if available, or the 'goCue_times' (the time of the soundcard
730 output TTL when the audio go cue is played) or the 'goCueTrigger_times' (the time of the
731 audio go cue command).
732 - The response/reaction time here is defined as the time between stim on and feedback, i.e. the
733 entire open-loop trial duration.
734 """
735 if signed_contrast is None: 1kedcgfhjba
736 signed_contrast = get_signed_contrast(trials) 1kba
738 if contrast is None: 1kedcgfhjba
739 contrast_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) 1ba
740 else:
741 contrast_idx = signed_contrast == contrast 1kedcgfhjba
743 if np.any(contrast_idx): 1kedcgfhjba
744 reaction_time = np.nanmedian((trials.response_times - trials[stim_on_type]) 1kedcgfhjba
745 [contrast_idx])
746 else:
747 reaction_time = np.nan 1a
749 return reaction_time 1kedcgfhjba
752def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='response_times', signed_contrast=None, block=None,
753 compute_ci=False, alpha=0.32):
754 """
755 Compute median response time for all contrasts.
757 Parameters
758 ----------
759 trials : one.alf.io.AlfBunch
760 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
761 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
762 stim_on_type : str, default='stimOn_times'
763 The trials key to use when calculating the response times. The difference between this and
764 `stim_off_type` is used (see notes).
765 stim_off_type : str, default='response_times'
766 The trials key to use when calculating the response times. The difference between this and
767 `stim_on_type` is used (see notes).
768 signed_contrast : numpy.array
769 An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
770 If None, these are computed from the trials object.
771 block : float
772 The block type to compute. If None, all trials are included, otherwise only trials where
773 probabilityLeft matches this value are included. For biasedChoiceWorld, the
774 probabilityLeft set is {0.5, 0.2, 0.8}.
775 compute_ci : bool
776 If true, computes and returns the confidence intervals for response time at each contrast.
777 alpha : float, default=0.32
778 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false,
779 this value is ignored.
781 Returns
782 -------
783 numpy.array
784 The median response times for each unique signed contrast.
785 numpy.array
786 The set of unique signed contrasts.
787 numpy.array
788 The number of trials for each unique signed contrast.
789 (numpy.array)
790 If `compute_ci` is true, an array of confidence intervals is return in the shape (n_trials,
791 2).
793 Notes
794 -----
795 - The response/reaction time by default is the time between stim on and response, i.e. the
796 entire open-loop trial duration. One could use 'stimOn_times' and 'firstMovement_times' to
797 get the true reaction time, or 'firstMovement_times' and 'response_times' to get the true
798 response times, or calculate the last movement onset times and calculate the true movement
799 times. See module examples for how to calculate this.
801 See Also
802 --------
803 scipy.stats.bootstrap - the function used to compute the confidence interval.
804 """
806 if signed_contrast is None:
807 signed_contrast = get_signed_contrast(trials)
809 if block is None:
810 block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
811 else:
812 block_idx = trials.probabilityLeft == block
814 contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
815 reaction_time = np.vectorize(
816 lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]),
817 otypes=[float]
818 )(contrasts)
820 if compute_ci:
821 ci = np.full((contrasts.size, 2), np.nan)
822 for i, x in enumerate(contrasts):
823 data = (trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]
824 bt = bootstrap((data,), np.nanmedian, confidence_level=1 - alpha)
825 ci[i, 0] = bt.confidence_interval.low
826 ci[i, 1] = bt.confidence_interval.high
828 return reaction_time, contrasts, n_contrasts, ci
829 else:
830 return reaction_time, contrasts, n_contrasts,
833def criterion_1a(psych, n_trials, perf_easy, signed_contrast):
834 """
835 Returns bool indicating whether criteria for status 'trained_1a' are met.
837 Criteria
838 --------
839 - Bias is less than 16
840 - Threshold is less than 19
841 - Lapse rate on both sides is less than 0.2
842 - The total number of trials is greater than 200 for each session
843 - Performance on easy contrasts > 80% for all sessions
844 - Zero contrast trials must be present
846 Parameters
847 ----------
848 psych : numpy.array
849 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold,
850 lapse high, lapse low.
851 n_trials : numpy.array of int
852 The number for trials for each session.
853 perf_easy : numpy.array of float
854 The proportion of correct high contrast trials for each session.
855 signed_contrast: numpy.array
856 Unique list of contrasts displayed
858 Returns
859 -------
860 bool
861 True if the criteria are met for 'trained_1a'.
862 Bunch
863 Bunch containing breakdown of the passing/ failing critieria
865 Notes
866 -----
867 The parameter thresholds chosen here were originally determined by averaging the parameter fits
868 for a number of sessions determined to be of 'good' performance by an experimenter.
869 """
871 criteria = Bunch() 1efa
872 criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} 1efa
873 criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2} 1efa
874 criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.2} 1efa
875 criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 16} 1efa
876 criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 19} 1efa
877 criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 200)} 1efa
878 criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.8)} 1efa
880 passing = np.all([v['pass'] for k, v in criteria.items()]) 1efa
882 criteria['Criteria'] = {'val': 'trained_1a', 'pass': passing} 1efa
884 return passing, criteria 1efa
887def criterion_1b(psych, n_trials, perf_easy, rt, signed_contrast):
888 """
889 Returns bool indicating whether criteria for trained_1b are met.
891 Criteria
892 --------
893 - Bias is less than 10
894 - Threshold is less than 20 (see notes)
895 - Lapse rate on both sides is less than 0.1
896 - The total number of trials is greater than 400 for each session
897 - Performance on easy contrasts > 90% for all sessions
898 - The median response time across all zero contrast trials is less than 2 seconds
899 - Zero contrast trials must be present
901 Parameters
902 ----------
903 psych : numpy.array
904 The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold,
905 lapse high, lapse low.
906 n_trials : numpy.array of int
907 The number for trials for each session.
908 perf_easy : numpy.array of float
909 The proportion of correct high contrast trials for each session.
910 rt : float
911 The median response time for zero contrast trials.
912 signed_contrast: numpy.array
913 Unique list of contrasts displayed
915 Returns
916 -------
917 bool
918 True if the criteria are met for 'trained_1b'.
919 Bunch
920 Bunch containing breakdown of the passing/ failing critieria
922 Notes
923 -----
924 The parameter thresholds chosen here were originally chosen to be slightly stricter than 1a,
925 however it was decided to use round numbers so that readers would not assume a level of
926 precision that isn't there (remember, these parameters were not chosen with any rigor). This
927 regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the
928 slope of the psychometric curve may be slightly less steep than 1a.
929 """
931 criteria = Bunch() 1efha
932 criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} 1efha
933 criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.1} 1efha
934 criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.1} 1efha
935 criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 10} 1efha
936 criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 20} 1efha
937 criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} 1efha
938 criteria['Perf_tasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} 1efha
939 criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} 1efha
941 passing = np.all([v['pass'] for k, v in criteria.items()]) 1efha
943 criteria['Criteria'] = {'val': 'trained_1b', 'pass': passing} 1efha
945 return passing, criteria 1efha
948def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt):
949 """
950 Returns bool indicating whether criteria for ready4ephysrig are met.
952 Criteria
953 --------
954 - Lapse on both sides < 0.1 for both bias blocks
955 - Bias shift between blocks > 5
956 - Total number of trials > 400 for all sessions
957 - Performance on easy contrasts > 90% for all sessions
958 - Median response time for zero contrast stimuli < 2 seconds
960 Parameters
961 ----------
962 psych_20 : numpy.array
963 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
964 Parameters are bias, threshold, lapse high, lapse low.
965 psych_80 : numpy.array
966 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
967 Parameters are bias, threshold, lapse high, lapse low.
968 n_trials : numpy.array
969 The number of trials for each session (typically three consecutive sessions).
970 perf_easy : numpy.array
971 The proportion of correct high contrast trials for each session (typically three
972 consecutive sessions).
973 rt : float
974 The median response time for zero contrast trials.
976 Returns
977 -------
978 bool
979 True if subject passes the ready4ephysrig criteria.
980 Bunch
981 Bunch containing breakdown of the passing/ failing critieria
982 """
983 criteria = Bunch() 1dcgba
984 criteria['LapseLow_80'] = {'val': psych_80[2], 'pass': psych_80[2] < 0.1} 1dcgba
985 criteria['LapseHigh_80'] = {'val': psych_80[3], 'pass': psych_80[3] < 0.1} 1dcgba
986 criteria['LapseLow_20'] = {'val': psych_20[2], 'pass': psych_20[2] < 0.1} 1dcgba
987 criteria['LapseHigh_20'] = {'val': psych_20[3], 'pass': psych_20[3] < 0.1} 1dcgba
988 criteria['Bias_shift'] = {'val': psych_80[0] - psych_20[0], 'pass': psych_80[0] - psych_20[0] > 5} 1dcgba
989 criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} 1dcgba
990 criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} 1dcgba
991 criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} 1dcgba
993 passing = np.all([v['pass'] for k, v in criteria.items()]) 1dcgba
995 criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': passing} 1dcgba
997 return passing, criteria 1dcgba
1000def criterion_delay(n_trials, perf_easy, n_ephys=1):
1001 """
1002 Returns bool indicating whether criteria for 'ready4delay' is met.
1004 Criteria
1005 --------
1006 - At least one session on an ephys rig
1007 - Total number of trials for any of the sessions is greater than 400
1008 - Performance on easy contrasts is greater than 90% for any of the sessions
1010 Parameters
1011 ----------
1012 n_trials : numpy.array of int
1013 The number of trials for each session (typically three consecutive sessions).
1014 perf_easy : numpy.array
1015 The proportion of correct high contrast trials for each session (typically three
1016 consecutive sessions).
1018 Returns
1019 -------
1020 bool
1021 True if subject passes the 'ready4delay' criteria.
1022 Bunch
1023 Bunch containing breakdown of the passing/ failing critieria
1024 """
1026 criteria = Bunch() 1dcba
1027 criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys > 0} 1dcba
1028 criteria['N_trials'] = {'val': n_trials, 'pass': np.any(n_trials > 400)} 1dcba
1029 criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.any(perf_easy > 0.9)} 1dcba
1031 passing = np.all([v['pass'] for k, v in criteria.items()]) 1dcba
1033 criteria['Criteria'] = {'val': 'ready4delay', 'pass': passing} 1dcba
1035 return passing, criteria 1dcba
1038def criteria_recording(n_ephys, delay, psych_20, psych_80, n_trials, perf_easy, rt):
1039 """
1040 Returns bool indicating whether criteria for ready4recording are met.
1042 Criteria
1043 --------
1044 - At least 3 ephys sessions
1045 - Delay on any session > 0
1046 - Lapse on both sides < 0.1 for both bias blocks
1047 - Bias shift between blocks > 5
1048 - Total number of trials > 400 for all sessions
1049 - Performance on easy contrasts > 90% for all sessions
1050 - Median response time for zero contrast stimuli < 2 seconds
1052 Parameters
1053 ----------
1054 psych_20 : numpy.array
1055 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
1056 Parameters are bias, threshold, lapse high, lapse low.
1057 psych_80 : numpy.array
1058 The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
1059 Parameters are bias, threshold, lapse high, lapse low.
1060 n_trials : numpy.array
1061 The number of trials for each session (typically three consecutive sessions).
1062 perf_easy : numpy.array
1063 The proportion of correct high contrast trials for each session (typically three
1064 consecutive sessions).
1065 rt : float
1066 The median response time for zero contrast trials.
1068 Returns
1069 -------
1070 bool
1071 True if subject passes the ready4recording criteria.
1072 Bunch
1073 Bunch containing breakdown of the passing/ failing critieria
1074 """
1076 _, criteria = criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt) 1dcgba
1077 criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys >= 3} 1dcgba
1078 criteria['N_delay'] = {'val': delay, 'pass': delay > 0} 1dcgba
1080 passing = np.all([v['pass'] for k, v in criteria.items()]) 1dcgba
1082 criteria['Criteria'] = {'val': 'ready4recording', 'pass': passing} 1dcgba
1084 return passing, criteria 1dcgba
1087def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs):
1088 """
1089 Function to plot psychometric curve plots a la datajoint webpage.
1091 Parameters
1092 ----------
1093 trials : one.alf.io.AlfBunch
1094 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
1095 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
1096 ax : matplotlib.pyplot.Axes
1097 An axis object to plot onto.
1098 title : str
1099 An optional plot title.
1100 plot_ci : bool
1101 If true, computes and plots the confidence intervals for response at each contrast.
1102 ci_alpha : float, default=0.032
1103 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false,
1104 this value is ignored.
1105 **kwargs
1106 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
1108 Returns
1109 -------
1110 matplotlib.pyplot.Figure
1111 The figure handle containing the plot.
1112 matplotlib.pyplot.Axes
1113 The plotted axes.
1115 See Also
1116 --------
1117 statsmodels.stats.proportion.proportion_confint - The function used to compute confidence
1118 interval.
1119 psychofit.mle_fit_psycho - The function used to fit the psychometric parameters.
1120 psychofit.erf_psycho_2gammas - The function used to transform contrast to response probability
1121 using the fit parameters.
1122 """
1124 signed_contrast = get_signed_contrast(trials) 1b
1125 contrasts_fit = np.arange(-100, 100) 1b
1127 prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True) 1b
1128 out_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True, 1b
1129 compute_ci=plot_ci, alpha=ci_alpha)
1130 pars_50 = out_50[0] if plot_ci else out_50 1b
1131 prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit) 1b
1133 prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True) 1b
1134 out_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True, 1b
1135 compute_ci=plot_ci, alpha=ci_alpha)
1136 pars_20 = out_20[0] if plot_ci else out_20 1b
1137 prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit) 1b
1139 prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True) 1b
1140 out_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True, 1b
1141 compute_ci=plot_ci, alpha=ci_alpha)
1142 pars_80 = out_80[0] if plot_ci else out_80 1b
1143 prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit) 1b
1145 cmap = sns.diverging_palette(20, 220, n=3, center='dark') 1b
1147 if not ax: 1b
1148 fig, ax = plt.subplots(**kwargs)
1149 else:
1150 fig = plt.gcf() 1b
1152 fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1]) 1b
1153 data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1]) 1b
1154 fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0]) 1b
1155 data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0]) 1b
1156 fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2]) 1b
1157 data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2]) 1b
1159 if plot_ci: 1b
1160 errbar_50 = np.c_[np.abs(out_50[1][0]), np.abs(out_50[1][1])].T
1161 errbar_20 = np.c_[np.abs(out_20[1][0]), np.abs(out_20[1][1])].T
1162 errbar_80 = np.c_[np.abs(out_80[1][0]), np.abs(out_80[1][1])].T
1164 ax.errorbar(contrasts_50, prob_right_50, yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
1165 ax.errorbar(contrasts_20, prob_right_20, yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
1166 ax.errorbar(contrasts_80, prob_right_80, yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
1168 ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80], 1b
1169 ['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'],
1170 loc='upper left')
1171 ax.set_ylim(-0.05, 1.05) 1b
1172 ax.set_ylabel('Probability choosing right') 1b
1173 ax.set_xlabel('Contrasts') 1b
1174 if title: 1b
1175 ax.set_title(title) 1b
1177 return fig, ax 1b
1180def plot_reaction_time(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.32, **kwargs):
1181 """
1182 Function to plot reaction time against contrast a la datajoint webpage.
1184 The reaction times are plotted individually for the following three blocks: {0.5, 0.2, 0.8}.
1186 Parameters
1187 ----------
1188 trials : one.alf.io.AlfBunch
1189 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
1190 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
1191 ax : matplotlib.pyplot.Axes
1192 An axis object to plot onto.
1193 title : str
1194 An optional plot title.
1195 plot_ci : bool
1196 If true, computes and plots the confidence intervals for response at each contrast.
1197 ci_alpha : float, default=0.32
1198 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false,
1199 this value is ignored.
1200 **kwargs
1201 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
1203 Returns
1204 -------
1205 matplotlib.pyplot.Figure
1206 The figure handle containing the plot.
1207 matplotlib.pyplot.Axes
1208 The plotted axes.
1210 See Also
1211 --------
1212 scipy.stats.bootstrap - the function used to compute the confidence interval.
1213 """
1215 signed_contrast = get_signed_contrast(trials)
1216 out_50 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5, compute_ci=plot_ci, alpha=ci_alpha)
1217 out_20 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2, compute_ci=plot_ci, alpha=ci_alpha)
1218 out_80 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8, compute_ci=plot_ci, alpha=ci_alpha)
1220 cmap = sns.diverging_palette(20, 220, n=3, center='dark')
1222 if not ax:
1223 fig, ax = plt.subplots(**kwargs)
1224 else:
1225 fig = plt.gcf()
1227 data_50 = ax.plot(out_50[1], out_50[0], '-o', color=cmap[1])
1228 data_20 = ax.plot(out_20[1], out_20[0], '-o', color=cmap[0])
1229 data_80 = ax.plot(out_80[1], out_80[0], '-o', color=cmap[2])
1231 if plot_ci:
1232 errbar_50 = np.c_[out_50[0] - out_50[3][:, 0], out_50[3][:, 1] - out_50[0]].T
1233 errbar_20 = np.c_[out_20[0] - out_20[3][:, 0], out_20[3][:, 1] - out_20[0]].T
1234 errbar_80 = np.c_[out_80[0] - out_80[3][:, 0], out_80[3][:, 1] - out_80[0]].T
1236 ax.errorbar(out_50[1], out_50[0], yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
1237 ax.errorbar(out_20[1], out_20[0], yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
1238 ax.errorbar(out_80[1], out_80[0], yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
1240 ax.legend([data_50[0], data_20[0], data_80[0]],
1241 ['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'],
1242 loc='upper left')
1243 ax.set_ylabel('Reaction time (s)')
1244 ax.set_xlabel('Contrasts')
1246 if title:
1247 ax.set_title(title)
1249 return fig, ax
1252def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs):
1253 """
1254 Function to plot reaction time with trial number a la datajoint webpage.
1256 Parameters
1257 ----------
1258 trials : one.alf.io.AlfBunch
1259 An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
1260 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
1261 stim_on_type : str, default='stimOn_times'
1262 The trials key to use when calculating the response times. The difference between this and
1263 'feedback_times' is used (see notes for `compute_median_reaction_time`).
1264 ax : matplotlib.pyplot.Axes
1265 An axis object to plot onto.
1266 title : str
1267 An optional plot title.
1268 **kwargs
1269 If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
1271 Returns
1272 -------
1273 matplotlib.pyplot.Figure
1274 The figure handle containing the plot.
1275 matplotlib.pyplot.Axes
1276 The plotted axes.
1277 """
1279 reaction_time = pd.DataFrame()
1280 reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type]
1281 reaction_time.index = reaction_time.index + 1
1282 reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median()
1283 reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None)
1284 reaction_time = reaction_time.where((pd.notnull(reaction_time)), None)
1286 if not ax:
1287 fig, ax = plt.subplots(**kwargs)
1288 else:
1289 fig = plt.gcf()
1291 ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray')
1292 ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2)
1293 ax.set_yscale('log')
1294 ax.set_ylim(0.1, 100)
1295 ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
1296 ax.set_ylabel('Reaction time (s)')
1297 ax.set_xlabel('Trial number')
1298 if title:
1299 ax.set_title(title)
1301 return fig, ax
1304def query_criterion(subject, status, from_status=None, one=None, validate=True):
1305 """Get the session for which a given training criterion was met.
1307 Parameters
1308 ----------
1309 subject : str
1310 The subject name.
1311 status : str
1312 The training status to query for.
1313 from_status : str, optional
1314 Count number of sessions and days from reaching `from_status` to `status`.
1315 one : one.api.OneAlyx, optional
1316 An instance of ONE.
1317 validate : bool
1318 If true, check if status in TrainingStatus enumeration. Set to false for non-standard
1319 training pipelines.
1321 Returns
1322 -------
1323 str
1324 The eID of the first session where this training status was reached.
1325 int
1326 The number of sessions it took to reach `status` (optionally from reaching `from_status`).
1327 int
1328 The number of days it tool to reach `status` (optionally from reaching `from_status`).
1329 """
1330 if validate: 1l
1331 status = status.lower().replace(' ', '_') 1l
1332 try: 1l
1333 status = TrainingStatus[status.upper().replace(' ', '_')].name.lower() 1l
1334 except KeyError as ex: 1l
1335 raise ValueError( 1l
1336 f'Unknown status "{status}". For non-standard training protocols set validate=False'
1337 ) from ex
1338 one = one or ONE() 1l
1339 subject_json = one.alyx.rest('subjects', 'read', id=subject)['json'] 1l
1340 if not (criteria := subject_json.get('trained_criteria')) or status not in criteria: 1l
1341 return None, None, None 1l
1342 to_date, eid = criteria[status] 1l
1343 from_date, _ = criteria.get(from_status, (None, None)) 1l
1344 eids, det = one.search(subject=subject, date_range=[from_date, to_date], details=True) 1l
1345 if len(eids) == 0: 1l
1346 return eid, None, None 1l
1347 delta_date = det[0]['date'] - det[-1]['date'] 1l
1348 return eid, len(eids), delta_date.days 1l