Coverage for ibllib/pipes/training_status.py: 90%
573 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1import logging
2from pathlib import Path
3from datetime import datetime
4from itertools import chain
6import numpy as np
7import pandas as pd
8from iblutil.numerical import ismember
9import one.alf.io as alfio
10from one.alf.exceptions import ALFObjectNotFound
11import matplotlib.pyplot as plt
12import matplotlib.dates as mdates
13from matplotlib.lines import Line2D
14import seaborn as sns
15import boto3
16from botocore.exceptions import ProfileNotFound, ClientError
18from ibllib.io.raw_data_loaders import load_bpod
19from ibllib.oneibl.registration import _get_session_times
20from ibllib.io.extractors.base import get_session_extractor_type, get_bpod_extractor_class
21from ibllib.io.session_params import read_params
22from ibllib.io.extractors.bpod_trials import get_bpod_extractor
23from ibllib.plots.snapshot import ReportSnapshot
24from brainbox.behavior import training
26logger = logging.getLogger(__name__)
29TRAINING_STATUS = {'untrainable': (-4, (0, 0, 0, 0)),
30 'unbiasable': (-3, (0, 0, 0, 0)),
31 'not_computed': (-2, (0, 0, 0, 0)),
32 'habituation': (-1, (0, 0, 0, 0)),
33 'in training': (0, (0, 0, 0, 0)),
34 'trained 1a': (1, (195, 90, 80, 255)),
35 'trained 1b': (2, (255, 153, 20, 255)),
36 'ready4ephysrig': (3, (28, 20, 255, 255)),
37 'ready4delay': (4, (117, 117, 117, 255)),
38 'ready4recording': (5, (20, 255, 91, 255))}
41def get_training_table_from_aws(lab, subject):
42 """
43 If aws credentials exist on the local server download the latest training table from aws s3 private bucket
44 :param lab:
45 :param subject:
46 :return:
47 """
48 try: 1ba
49 session = boto3.Session(profile_name='ibl_training') 1ba
50 except ProfileNotFound:
51 return
53 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1ba
54 dst_bucket_name = 'ibl-brain-wide-map-private' 1ba
55 try: 1ba
56 s3 = session.resource('s3') 1ba
57 bucket = s3.Bucket(name=dst_bucket_name) 1ba
58 bucket.download_file(f'resources/training/{lab}/{subject}/training.csv', 1ba
59 local_file_path)
60 df = pd.read_csv(local_file_path)
61 except ClientError: 1ba
62 return 1ba
64 return df
67def upload_training_table_to_aws(lab, subject):
68 """
69 If aws credentials exist on the local server upload the training table to aws s3 private bucket
70 :param lab:
71 :param subject:
72 :return:
73 """
74 try: 1ba
75 session = boto3.Session(profile_name='ibl_training') 1ba
76 except ProfileNotFound:
77 return
79 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1ba
80 dst_bucket_name = 'ibl-brain-wide-map-private' 1ba
81 try: 1ba
82 s3 = session.resource('s3') 1ba
83 bucket = s3.Bucket(name=dst_bucket_name) 1ba
84 bucket.upload_file(local_file_path, 1ba
85 f'resources/training/{lab}/{subject}/training.csv')
86 except (ClientError, FileNotFoundError): 1ba
87 return 1ba
90def save_path(subj_path):
91 return Path(subj_path).joinpath('training.csv') 1dbaegc
94def save_dataframe(df, subj_path):
95 """Save training dataframe to disk.
97 :param df: dataframe to save
98 :param subj_path: path to subject folder
99 :return:
100 """
101 df.to_csv(save_path(subj_path), index=False) 1bac
104def load_existing_dataframe(subj_path):
105 """Load training dataframe from disk, if dataframe doesn't exist returns None.
107 :param subj_path: path to subject folder
108 :return:
109 """
110 df_location = save_path(subj_path) 1baegc
111 if df_location.exists(): 1baegc
112 return pd.read_csv(df_location) 1aegc
113 else:
114 df_location.parent.mkdir(exist_ok=True, parents=True) 1baegc
115 return None 1baegc
118def load_trials(sess_path, one, collections=None, force=True, mode='raise'):
119 """
120 Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE,
121 if this also fails, will then attempt to re-extract locally
122 :param sess_path: session path
123 :param one: ONE instance
124 :param force: when True and if the session trials can't be found, will attempt to re-extract from the disk
125 :param mode: 'raise' or 'warn', if 'raise', will error when forcing re-extraction of past sessions
126 :return:
127 """
128 try: 1bafc
129 # try and load all trials that are found locally in the session path locally
130 if collections is None: 1bafc
131 trial_locations = list(sess_path.rglob('_ibl_trials.goCueTrigger_times.*npy')) 1bafc
132 else:
133 trial_locations = [Path(sess_path).joinpath(c, '_ibl_trials.goCueTrigger_times.*npy') for c in collections] 1bac
135 if len(trial_locations) > 1: 1bafc
136 trial_dict = {} 1a
137 for i, loc in enumerate(trial_locations): 1a
138 trial_dict[i] = alfio.load_object(loc.parent, 'trials', short_keys=True) 1a
139 trials = training.concatenate_trials(trial_dict)
140 elif len(trial_locations) == 1: 1bafc
141 trials = alfio.load_object(trial_locations[0].parent, 'trials', short_keys=True) 1bafc
142 else:
143 raise ALFObjectNotFound 1a
145 if 'probabilityLeft' not in trials.keys(): 1bafc
146 raise ALFObjectNotFound
147 except ALFObjectNotFound: 1a
148 # Next try and load all trials data through ONE
149 try: 1a
150 if not force: 1a
151 return None
152 eid = one.path2eid(sess_path) 1a
153 if collections is None: 1a
154 trial_collections = one.list_datasets(eid, '_ibl_trials.goCueTrigger_times.npy') 1a
155 if len(trial_collections) > 0: 1a
156 trial_collections = ['/'.join(c.split('/')[:-1]) for c in trial_collections] 1a
157 else:
158 trial_collections = collections 1a
160 if len(trial_collections) > 1: 1a
161 trial_dict = {} 1a
162 for i, collection in enumerate(trial_collections): 1a
163 trial_dict[i] = one.load_object(eid, 'trials', collection=collection) 1a
164 trials = training.concatenate_trials(trial_dict)
165 elif len(trial_collections) == 1: 1a
166 trials = one.load_object(eid, 'trials', collection=trial_collections[0]) 1a
167 else:
168 raise ALFObjectNotFound
170 if 'probabilityLeft' not in trials.keys():
171 raise ALFObjectNotFound
172 except Exception: 1a
173 # Finally try to re-extract the trials data locally
174 try: 1a
175 raw_collections, _ = get_data_collection(sess_path) 1a
177 if len(raw_collections) == 0: 1a
178 return None
180 trials_dict = {} 1a
181 for i, collection in enumerate(raw_collections): 1a
182 extractor = get_bpod_extractor(sess_path, task_collection=collection) 1a
183 trials_data, _ = extractor.extract(task_collection=collection, save=False) 1a
184 trials_dict[i] = alfio.AlfBunch.from_df(trials_data['table']) 1a
186 if len(trials_dict) > 1: 1a
187 trials = training.concatenate_trials(trials_dict)
188 else:
189 trials = trials_dict[0] 1a
191 except Exception as e: 1a
192 if mode == 'raise': 1a
193 raise Exception(f'Exhausted all possibilities for loading trials for {sess_path}') from e
194 else:
195 logger.warning(f'Exhausted all possibilities for loading trials for {sess_path}') 1a
196 return 1a
198 return trials 1bafc
201def load_combined_trials(sess_paths, one, force=True):
202 """
203 Load and concatenate trials for multiple sessions. Used when we want to concatenate trials for two sessions on the same day
204 :param sess_paths: list of paths to sessions
205 :param one: ONE instance
206 :return:
207 """
208 trials_dict = {} 1bafc
209 for sess_path in sess_paths: 1bafc
210 trials = load_trials(Path(sess_path), one, force=force) 1bafc
211 if trials is not None: 1bafc
212 trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force) 1bafc
214 return training.concatenate_trials(trials_dict) 1bafc
217def get_latest_training_information(sess_path, one):
218 """
219 Extracts the latest training status.
221 Parameters
222 ----------
223 sess_path : pathlib.Path
224 The session path from which to load the data.
225 one : one.api.One
226 An ONE instance.
228 Returns
229 -------
230 pandas.DataFrame
231 A table of training information.
232 """
234 subj_path = sess_path.parent.parent 1bac
235 sub = subj_path.parts[-1] 1bac
236 if one.mode != 'local': 1bac
237 lab = one.alyx.rest('subjects', 'list', nickname=sub)[0]['lab'] 1ba
238 df = get_training_table_from_aws(lab, sub) 1ba
239 else:
240 df = None 1c
242 if df is None: 1bac
243 df = load_existing_dataframe(subj_path) 1bac
245 # Find the dates and associated session paths where we don't have data stored in our dataframe
246 missing_dates = check_up_to_date(subj_path, df) 1bac
248 # Iterate through the dates to fill up our training dataframe
249 for _, grp in missing_dates.groupby('date'): 1bac
250 sess_dicts = get_training_info_for_session(grp.session_path.values, one) 1bac
251 if len(sess_dicts) == 0: 1bac
252 continue
254 for sess_dict in sess_dicts: 1bac
255 if df is None: 1bac
256 df = pd.DataFrame.from_dict(sess_dict) 1bac
257 else:
258 df = pd.concat([df, pd.DataFrame.from_dict(sess_dict)]) 1ac
260 # Sort values by date and reset the index
261 df = df.sort_values('date') 1bac
262 df = df.reset_index(drop=True) 1bac
263 # Save our dataframe
264 save_dataframe(df, subj_path) 1bac
266 # Now go through the backlog and compute the training status for sessions. If for example one was missing as it is cumulative
267 # we need to go through and compute all the backlog
268 # Find the earliest date in missing dates that we need to recompute the training status for
269 missing_status = find_earliest_recompute_date(df.drop_duplicates('date').reset_index(drop=True)) 1bac
270 for date in missing_status: 1bac
271 df = compute_training_status(df, date, one) 1bac
273 df_lim = df.drop_duplicates(subset='session_path', keep='first') 1bac
275 # Detect untrainable
276 if 'untrainable' not in df_lim.training_status.values: 1bac
277 un_df = df_lim[df_lim['training_status'] == 'in training'].sort_values('date') 1bac
278 if len(un_df) >= 40: 1bac
279 sess = un_df.iloc[39].session_path
280 df.loc[df['session_path'] == sess, 'training_status'] = 'untrainable'
282 # Detect unbiasable
283 if 'unbiasable' not in df_lim.training_status.values: 1bac
284 un_df = df_lim[df_lim['task_protocol'] == 'biased'].sort_values('date') 1bac
285 if len(un_df) >= 40: 1bac
286 tr_st = un_df[0:40].training_status.unique()
287 if 'ready4ephysrig' not in tr_st:
288 sess = un_df.iloc[39].session_path
289 df.loc[df['session_path'] == sess, 'training_status'] = 'unbiasable'
291 save_dataframe(df, subj_path) 1bac
293 if one.mode != 'local': 1bac
294 upload_training_table_to_aws(lab, sub) 1ba
296 return df 1bac
299def find_earliest_recompute_date(df):
300 """
301 Find the earliest date that we need to compute the training status from. Training status depends on previous sessions
302 so if a session was missing and now has been added we need to recompute everything from that date onwards
303 :param df:
304 :return:
305 """
306 missing_df = df[df['training_status'] == 'not_computed'] 1bagc
307 if len(missing_df) == 0: 1bagc
308 return [] 1a
309 missing_df = missing_df.sort_values('date') 1bagc
310 first_index = missing_df.index[0] 1bagc
312 return df[first_index:].date.values 1bagc
315def compute_training_status(df, compute_date, one, force=True):
316 """
317 Compute the training status for compute date based on training from that session and two previous days.
319 When true and if the session trials can't be found, will attempt to re-extract from disk.
320 :return:
322 Parameters
323 ----------
324 df : pandas.DataFrame
325 A training data frame, e.g. one generated from :func:`get_training_info_for_session`.
326 compute_date : str, datetime.datetime, pandas.Timestamp
327 The date to compute training on.
328 one : one.api.One
329 An instance of ONE for loading trials data.
330 force : bool
331 When true and if the session trials can't be found, will attempt to re-extract from disk.
333 Returns
334 -------
335 pandas.DataFrame
336 The input data frame with a 'training_status' column populated for `compute_date`.
337 """
339 # compute_date = str(one.path2ref(session_path)['date'])
340 df_temp = df[df['date'] <= compute_date] 1bac
341 df_temp = df_temp.drop_duplicates(subset=['session_path', 'task_protocol']) 1bac
342 df_temp.sort_values('date') 1bac
344 dates = df_temp.date.values 1bac
346 n_sess_for_date = len(np.where(dates == compute_date)[0]) 1bac
347 n_dates = np.min([2 + n_sess_for_date, len(dates)]).astype(int) 1bac
348 compute_dates = dates[(-1 * n_dates):] 1bac
349 if n_sess_for_date > 1: 1bac
350 compute_dates = compute_dates[:(-1 * (n_sess_for_date - 1))] 1ac
352 assert compute_dates[-1] == compute_date 1bac
354 df_temp_group = df_temp.groupby('date') 1bac
356 trials = {} 1bac
357 n_delay = 0 1bac
358 ephys_sessions = [] 1bac
359 protocol = [] 1bac
360 status = [] 1bac
361 for date in compute_dates: 1bac
363 df_date = df_temp_group.get_group(date) 1bac
365 # If habituation skip
366 if df_date.iloc[-1]['task_protocol'] == 'habituation': 1bac
367 continue
368 # Here we should split by protocol in an ideal world but that world isn't today. This is only really relevant for
369 # chained protocols
370 trials[df_date.iloc[-1]['date']] = load_combined_trials(df_date.session_path.values, one, force=force) 1bac
371 protocol.append(df_date.iloc[-1]['task_protocol']) 1bac
372 status.append(df_date.iloc[-1]['training_status']) 1bac
373 if df_date.iloc[-1]['combined_n_delay'] >= 900: # delay of 15 mins 1bac
374 n_delay += 1 1c
375 if df_date.iloc[-1]['location'] == 'ephys_rig': 1bac
376 ephys_sessions.append(df_date.iloc[-1]['date']) 1ac
378 n_status = np.max([-2, -1 * len(status)]) 1bac
379 training_status, _ = training.get_training_status(trials, protocol, ephys_sessions, n_delay) 1bac
380 training_status = pass_through_training_hierachy(training_status, status[n_status]) 1bac
381 df.loc[df['date'] == compute_date, 'training_status'] = training_status 1bac
383 return df 1bac
386def pass_through_training_hierachy(status_new, status_old):
387 """
388 Makes sure that the new training status is not less than the one from the previous day. e.g Subject cannot regress in
389 performance
390 :param status_new: latest training status
391 :param status_old: previous training status
392 :return:
393 """
395 if TRAINING_STATUS[status_old][0] > TRAINING_STATUS[status_new][0]: 1bach
396 return status_old 1c
397 else:
398 return status_new 1bach
401def compute_session_duration_delay_location(sess_path, collections=None, **kwargs):
402 """
403 Get meta information about task. Extracts session duration, delay before session start and location of session
405 Parameters
406 ----------
407 sess_path : pathlib.Path, str
408 The session path with the pattern subject/yyyy-mm-dd/nnn.
409 collections : list
410 The location within the session path directory of task settings and data.
412 Returns
413 -------
414 int
415 The session duration in minutes, rounded to the nearest minute.
416 int
417 The delay between session start time and the first trial in seconds.
418 str {'ephys_rig', 'training_rig'}
419 The location of the session.
420 """
421 if collections is None: 1bac
422 collections, _ = get_data_collection(sess_path)
424 session_duration = 0 1bac
425 session_delay = 0 1bac
426 session_location = 'training_rig' 1bac
427 for collection in collections: 1bac
428 md, sess_data = load_bpod(sess_path, task_collection=collection) 1bac
429 if md is None: 1bac
430 continue
431 try: 1bac
432 start_time, end_time = _get_session_times(sess_path, md, sess_data) 1bac
433 session_duration = session_duration + int((end_time - start_time).total_seconds() / 60) 1bac
434 session_delay = session_delay + md.get('SESSION_START_DELAY_SEC', 0) 1bac
435 except Exception:
436 session_duration = session_duration + 0
437 session_delay = session_delay + 0
439 if 'ephys' in md.get('PYBPOD_BOARD', None): 1bac
440 session_location = 'ephys_rig' 1ac
441 else:
442 session_location = 'training_rig' 1bac
444 return session_duration, session_delay, session_location 1bac
447def get_data_collection(session_path):
448 """Return the location of the raw behavioral data and extracted trials data for a given session.
450 For multiple locations in one session (e.g. chained protocols), returns all collections.
451 Passive protocols are excluded.
453 Parameters
454 ----------
455 session_path : pathlib.Path
456 A session path in the form subject/date/number.
458 Returns
459 -------
460 list of str
461 A list of sub-directory names that contain raw behaviour data.
462 list of str
463 A list of sub-directory names that contain ALF trials data.
465 Examples
466 --------
467 An iblrig v7 session
469 >>> get_data_collection(Path(r'C:/data/subject/2023-01-01/001'))
470 ['raw_behavior_data'], ['alf']
472 An iblrig v8 session where two protocols were run
474 >>> get_data_collection(Path(r'C:/data/subject/2023-01-01/001'))
475 ['raw_task_data_00', 'raw_task_data_01], ['alf/task_00', 'alf/task_01']
476 """
477 experiment_description = read_params(session_path) 1bac
478 collections = [] 1bac
479 if experiment_description is not None: 1bac
480 task_protocols = experiment_description.get('tasks', []) 1bac
481 for i, (protocol, task_info) in enumerate(chain(*map(dict.items, task_protocols))): 1bac
482 if 'passiveChoiceWorld' in protocol: 1bac
483 continue 1c
484 collection = task_info.get('collection', f'raw_task_data_{i:02}') 1bac
485 if collection == 'raw_passive_data': 1bac
486 continue
487 collections.append(collection) 1bac
488 else:
489 settings = Path(session_path).rglob('_iblrig_taskSettings.raw.json') 1ac
490 for setting in settings: 1ac
491 if setting.parent.name != 'raw_passive_data': 1ac
492 collections.append(setting.parent.name) 1ac
494 if len(collections) == 1 and collections[0] == 'raw_behavior_data': 1bac
495 alf_collections = ['alf'] 1bac
496 elif all(['raw_task_data' in c for c in collections]): 1ac
497 alf_collections = [f'alf/task_{c[-2:]}' for c in collections] 1ac
498 else:
499 alf_collections = None
501 return collections, alf_collections 1bac
504def get_sess_dict(session_path, one, protocol, alf_collections=None, raw_collections=None, force=True):
506 sess_dict = {} 1bac
507 sess_dict['date'] = str(one.path2ref(session_path)['date']) 1bac
508 sess_dict['session_path'] = str(session_path) 1bac
509 sess_dict['task_protocol'] = protocol 1bac
511 if sess_dict['task_protocol'] == 'habituation': 1bac
512 nan_array = np.array([np.nan])
513 sess_dict['performance'], sess_dict['contrasts'], _ = (nan_array, nan_array, np.nan)
514 sess_dict['performance_easy'] = np.nan
515 sess_dict['reaction_time'] = np.nan
516 sess_dict['n_trials'] = np.nan
517 sess_dict['sess_duration'] = np.nan
518 sess_dict['n_delay'] = np.nan
519 sess_dict['location'] = np.nan
520 sess_dict['training_status'] = 'habituation'
521 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
522 (np.nan, np.nan, np.nan, np.nan)
523 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
524 (np.nan, np.nan, np.nan, np.nan)
525 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
526 (np.nan, np.nan, np.nan, np.nan)
528 else:
529 # if we can't compute trials then we need to pass
530 trials = load_trials(session_path, one, collections=alf_collections, force=force, mode='warn') 1bac
531 if trials is None: 1bac
532 return 1a
534 sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True) 1bac
535 if sess_dict['task_protocol'] == 'training': 1bac
536 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1ac
537 training.compute_psychometric(trials)
538 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1ac
539 (np.nan, np.nan, np.nan, np.nan)
540 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1ac
541 (np.nan, np.nan, np.nan, np.nan)
542 else:
543 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1bac
544 training.compute_psychometric(trials, block=0.5)
545 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1bac
546 training.compute_psychometric(trials, block=0.2)
547 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1bac
548 training.compute_psychometric(trials, block=0.8)
550 sess_dict['performance_easy'] = training.compute_performance_easy(trials) 1bac
551 sess_dict['reaction_time'] = training.compute_median_reaction_time(trials) 1bac
552 sess_dict['n_trials'] = training.compute_n_trials(trials) 1bac
553 sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \ 1bac
554 compute_session_duration_delay_location(session_path, collections=raw_collections)
555 sess_dict['training_status'] = 'not_computed' 1bac
557 return sess_dict 1bac
560def get_training_info_for_session(session_paths, one, force=True):
561 """
562 Extract the training information needed for plots for each session.
564 Parameters
565 ----------
566 session_paths : list of pathlib.Path
567 List of session paths on same date.
568 one : one.api.One
569 An ONE instance.
570 force : bool
571 When true and if the session trials can't be found, will attempt to re-extract from disk.
573 Returns
574 -------
575 list of dict
576 A list of dictionaries the length of `session_paths` containing individual and aggregate
577 performance information.
578 """
580 # return list of dicts to add
581 sess_dicts = [] 1bac
582 for session_path in session_paths: 1bac
583 collections, alf_collections = get_data_collection(session_path) 1bac
584 session_path = Path(session_path) 1bac
585 protocols = [] 1bac
586 for c in collections: 1bac
587 try: 1bac
588 prot = get_bpod_extractor_class(session_path, task_collection=c) 1bac
589 prot = prot[:-6].lower() 1bac
590 except Exception:
591 prot = get_session_extractor_type(session_path, task_collection=c)
592 protocols.append(prot) 1bac
594 un_protocols = np.unique(protocols) 1bac
595 # Example, training, training, biased - training would be combined, biased not
596 sess_dict = None 1bac
597 if len(un_protocols) != 1: 1bac
598 print(f'Different protocols in same session {session_path} : {protocols}')
599 for prot in un_protocols:
600 if prot is False:
601 continue
602 try:
603 alf = alf_collections[np.where(protocols == prot)[0]]
604 raw = collections[np.where(protocols == prot)[0]]
605 except TypeError:
606 alf = None
607 raw = None
608 sess_dict = get_sess_dict(session_path, one, prot, alf_collections=alf, raw_collections=raw, force=force)
609 else:
610 prot = un_protocols[0] 1bac
611 sess_dict = get_sess_dict( 1bac
612 session_path, one, prot, alf_collections=alf_collections, raw_collections=collections, force=force)
614 if sess_dict is not None: 1bac
615 sess_dicts.append(sess_dict) 1bac
617 protocols = [s['task_protocol'] for s in sess_dicts] 1bac
619 if len(protocols) > 0 and len(set(protocols)) != 1: 1bac
620 print(f'Different protocols on same date {sess_dicts[0]["date"]} : {protocols}') 1a
622 # Only if all protocols are the same and are not habituation
623 if len(sess_dicts) > 1 and len(set(protocols)) == 1 and protocols[0] != 'habituation': # Only if all protocols are the same 1bac
624 print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}') 1c
625 combined_trials = load_combined_trials(session_paths, one, force=force) 1c
626 performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True) 1c
627 psychs = {} 1c
628 psychs['50'] = training.compute_psychometric(combined_trials, block=0.5) 1c
629 psychs['20'] = training.compute_psychometric(combined_trials, block=0.2) 1c
630 psychs['80'] = training.compute_psychometric(combined_trials, block=0.8) 1c
632 performance_easy = training.compute_performance_easy(combined_trials) 1c
633 reaction_time = training.compute_median_reaction_time(combined_trials) 1c
634 n_trials = training.compute_n_trials(combined_trials) 1c
636 sess_duration = np.nansum([s['sess_duration'] for s in sess_dicts]) 1c
637 n_delay = np.nanmax([s['n_delay'] for s in sess_dicts]) 1c
639 for sess_dict in sess_dicts: 1c
640 sess_dict['combined_performance'] = performance 1c
641 sess_dict['combined_contrasts'] = contrasts 1c
642 sess_dict['combined_performance_easy'] = performance_easy 1c
643 sess_dict['combined_reaction_time'] = reaction_time 1c
644 sess_dict['combined_n_trials'] = n_trials 1c
645 sess_dict['combined_sess_duration'] = sess_duration 1c
646 sess_dict['combined_n_delay'] = n_delay 1c
648 for bias in [50, 20, 80]: 1c
649 sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0] 1c
650 sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1] 1c
651 sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][2] 1c
652 sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][3] 1c
654 # Case where two sessions on same day with different number of contrasts! Oh boy
655 if sess_dict['combined_performance'].size != sess_dict['performance'].size: 1c
656 sess_dict['performance'] = \
657 np.r_[sess_dict['performance'],
658 np.full(sess_dict['combined_performance'].size - sess_dict['performance'].size, np.nan)]
659 sess_dict['contrasts'] = \
660 np.r_[sess_dict['contrasts'],
661 np.full(sess_dict['combined_contrasts'].size - sess_dict['contrasts'].size, np.nan)]
663 else:
664 for sess_dict in sess_dicts: 1bac
665 sess_dict['combined_performance'] = sess_dict['performance'] 1bac
666 sess_dict['combined_contrasts'] = sess_dict['contrasts'] 1bac
667 sess_dict['combined_performance_easy'] = sess_dict['performance_easy'] 1bac
668 sess_dict['combined_reaction_time'] = sess_dict['reaction_time'] 1bac
669 sess_dict['combined_n_trials'] = sess_dict['n_trials'] 1bac
670 sess_dict['combined_sess_duration'] = sess_dict['sess_duration'] 1bac
671 sess_dict['combined_n_delay'] = sess_dict['n_delay'] 1bac
673 for bias in [50, 20, 80]: 1bac
674 sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}'] 1bac
675 sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}'] 1bac
676 sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}'] 1bac
677 sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}'] 1bac
679 return sess_dicts 1bac
682def check_up_to_date(subj_path, df):
683 """
684 Check which sessions on local file system are missing from the computed training table.
686 Parameters
687 ----------
688 subj_path : pathlib.Path
689 The path to the subject's dated session folders.
690 df : pandas.DataFrame
691 The computed training table.
693 Returns
694 -------
695 pandas.DataFrame
696 A table of dates and session paths that are missing from the computed training table.
697 """
698 df_session = pd.DataFrame(columns=['date', 'session_path']) 1baec
700 for session in alfio.iter_sessions(subj_path, pattern='????-??-??/*'): 1baec
701 s_df = pd.DataFrame({'date': session.parts[-2], 'session_path': str(session)}, index=[0]) 1baec
702 df_session = pd.concat([df_session, s_df], ignore_index=True) 1baec
704 if df is None or 'combined_thres_50' not in df.columns: 1baec
705 return df_session 1baec
706 else:
707 # recorded_session_paths = df['session_path'].values
708 isin, _ = ismember(df_session.date.unique(), df.date.unique()) 1ae
709 missing_dates = df_session.date.unique()[~isin] 1ae
710 return df_session[df_session['date'].isin(missing_dates)].sort_values('date') 1ae
713def plot_trial_count_and_session_duration(df, subject):
715 df = df.drop_duplicates('date').reset_index(drop=True) 1ba
717 y1 = {'column': 'combined_n_trials', 1ba
718 'title': 'Trial counts',
719 'lim': None,
720 'color': 'k',
721 'join': True}
723 y2 = {'column': 'combined_sess_duration', 1ba
724 'title': 'Session duration (mins)',
725 'lim': None,
726 'color': 'r',
727 'log': False,
728 'join': True}
730 ax = plot_over_days(df, subject, y1, y2) 1ba
732 return ax 1ba
735def plot_performance_easy_median_reaction_time(df, subject):
736 df = df.drop_duplicates('date').reset_index(drop=True) 1ba
738 y1 = {'column': 'combined_performance_easy', 1ba
739 'title': 'Performance on easy trials',
740 'lim': [0, 1.05],
741 'color': 'k',
742 'join': True}
744 y2 = {'column': 'combined_reaction_time', 1ba
745 'title': 'Median reaction time (s)',
746 'lim': [0.1, np.nanmax([10, np.nanmax(df.combined_reaction_time.values)])],
747 'color': 'r',
748 'log': True,
749 'join': True}
750 ax = plot_over_days(df, subject, y1, y2) 1ba
752 return ax 1ba
755def plot_fit_params(df, subject):
756 fig, axs = plt.subplots(2, 2, figsize=(12, 6)) 1ba
757 axs = axs.ravel() 1ba
759 df = df.drop_duplicates('date').reset_index(drop=True) 1ba
761 cmap = sns.diverging_palette(20, 220, n=3, center="dark") 1ba
763 y50 = {'column': 'combined_bias_50', 1ba
764 'title': 'Bias',
765 'lim': [-100, 100],
766 'color': cmap[1],
767 'join': False}
769 y80 = {'column': 'combined_bias_80', 1ba
770 'title': 'Bias',
771 'lim': [-100, 100],
772 'color': cmap[2],
773 'join': False}
775 y20 = {'column': 'combined_bias_20', 1ba
776 'title': 'Bias',
777 'lim': [-100, 100],
778 'color': cmap[0],
779 'join': False}
781 plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False) 1ba
782 plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False) 1ba
783 plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False) 1ba
784 axs[0].axhline(16, linewidth=2, linestyle='--', color='k') 1ba
785 axs[0].axhline(-16, linewidth=2, linestyle='--', color='k') 1ba
787 y50['column'] = 'combined_thres_50' 1ba
788 y50['title'] = 'Threshold' 1ba
789 y50['lim'] = [0, 100] 1ba
790 y80['column'] = 'combined_thres_20' 1ba
791 y80['title'] = 'Threshold' 1ba
792 y20['lim'] = [0, 100] 1ba
793 y20['column'] = 'combined_thres_80' 1ba
794 y20['title'] = 'Threshold' 1ba
795 y80['lim'] = [0, 100] 1ba
797 plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False) 1ba
798 plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False) 1ba
799 plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False) 1ba
800 axs[1].axhline(19, linewidth=2, linestyle='--', color='k') 1ba
802 y50['column'] = 'combined_lapselow_50' 1ba
803 y50['title'] = 'Lapse Low' 1ba
804 y50['lim'] = [0, 1] 1ba
805 y80['column'] = 'combined_lapselow_20' 1ba
806 y80['title'] = 'Lapse Low' 1ba
807 y80['lim'] = [0, 1] 1ba
808 y20['column'] = 'combined_lapselow_80' 1ba
809 y20['title'] = 'Lapse Low' 1ba
810 y20['lim'] = [0, 1] 1ba
812 plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False) 1ba
813 plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False) 1ba
814 plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False) 1ba
815 axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k') 1ba
817 y50['column'] = 'combined_lapsehigh_50' 1ba
818 y50['title'] = 'Lapse High' 1ba
819 y50['lim'] = [0, 1] 1ba
820 y80['column'] = 'combined_lapsehigh_20' 1ba
821 y80['title'] = 'Lapse High' 1ba
822 y80['lim'] = [0, 1] 1ba
823 y20['column'] = 'combined_lapsehigh_80' 1ba
824 y20['title'] = 'Lapse High' 1ba
825 y20['lim'] = [0, 1] 1ba
827 plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True) 1ba
828 plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False) 1ba
829 plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False) 1ba
830 axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k') 1ba
832 fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1ba
833 lines, labels = axs[3].get_legend_handles_labels() 1ba
834 fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5) 1ba
836 legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8), 1ba
837 Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8),
838 Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)]
839 legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True) 1ba
840 fig.add_artist(legend2) 1ba
842 return axs 1ba
845def plot_psychometric_curve(df, subject, one):
846 df = df.drop_duplicates('date').reset_index(drop=True) 1ba
847 sess_path = Path(df.iloc[-1]["session_path"]) 1ba
848 trials = load_trials(sess_path, one) 1ba
850 fig, ax1 = plt.subplots(figsize=(8, 6)) 1ba
852 training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1ba
854 return ax1 1ba
857def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True):
859 if ax is None: 1ba
860 fig, ax1 = plt.subplots(figsize=(12, 6)) 1ba
861 else:
862 ax1 = ax 1ba
864 dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']] 1ba
865 if y1['join']: 1ba
866 ax1.plot(dates, df[y1['column']], color=y1['color']) 1ba
867 ax1.scatter(dates, df[y1['column']], color=y1['color']) 1ba
868 ax1.set_ylabel(y1['title']) 1ba
869 ax1.set_ylim(y1['lim']) 1ba
871 if y2 is not None: 1ba
872 ax2 = ax1.twinx() 1ba
873 if y2['join']: 1ba
874 ax2.plot(dates, df[y2['column']], color=y2['color']) 1ba
875 ax2.scatter(dates, df[y2['column']], color=y2['color']) 1ba
876 ax2.set_ylabel(y2['title']) 1ba
877 ax2.yaxis.label.set_color(y2['color']) 1ba
878 ax2.tick_params(axis='y', colors=y2['color']) 1ba
879 ax2.set_ylim(y2['lim']) 1ba
880 if y2['log']: 1ba
881 ax2.set_yscale('log') 1ba
883 ax2.spines['right'].set_visible(False) 1ba
884 ax2.spines['top'].set_visible(False) 1ba
885 ax2.spines['left'].set_visible(False) 1ba
887 month_format = mdates.DateFormatter('%b %Y') 1ba
888 month_locator = mdates.MonthLocator() 1ba
889 ax1.xaxis.set_major_locator(month_locator) 1ba
890 ax1.xaxis.set_major_formatter(month_format) 1ba
891 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1ba
892 ax1.xaxis.set_minor_locator(week_locator) 1ba
893 ax1.grid(True, which='minor', axis='x', linestyle='--') 1ba
895 ax1.spines['left'].set_visible(False) 1ba
896 ax1.spines['right'].set_visible(False) 1ba
897 ax1.spines['top'].set_visible(False) 1ba
899 if training_lines: 1ba
900 ax1 = add_training_lines(df, ax1) 1ba
902 if title: 1ba
903 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1ba
905 # Put a legend below current axis
906 box = ax1.get_position() 1ba
907 ax1.set_position([box.x0, box.y0 + box.height * 0.1, 1ba
908 box.width, box.height * 0.9])
909 if legend: 1ba
910 ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), 1ba
911 fancybox=True, shadow=True, ncol=5)
913 return ax1 1ba
916def add_training_lines(df, ax):
918 status = df.drop_duplicates(subset='training_status', keep='first') 1ba
919 for _, st in status.iterrows(): 1ba
921 if st['training_status'] in ['untrainable', 'unbiasable']: 1ba
922 continue
924 if TRAINING_STATUS[st['training_status']][0] <= 0: 1ba
925 continue 1a
927 ax.axvline(datetime.strptime(st['date'], '%Y-%m-%d'), linewidth=2, 1ba
928 color=np.array(TRAINING_STATUS[st['training_status']][1]) / 255, label=st['training_status'])
930 return ax 1ba
933def plot_heatmap_performance_over_days(df, subject):
935 df = df.drop_duplicates(subset=['date', 'combined_contrasts']) 1ba
936 df_perf = df.pivot(index=['date'], columns=['combined_contrasts'], values=['combined_performance']).sort_values( 1ba
937 by='combined_contrasts', axis=1, ascending=False)
938 df_perf.index = pd.to_datetime(df_perf.index) 1ba
939 full_date_range = pd.date_range(start=df_perf.index.min(), end=df_perf.index.max(), freq="D") 1ba
940 df_perf = df_perf.reindex(full_date_range, fill_value=np.nan) 1ba
942 n_contrasts = len(df.combined_contrasts.unique()) 1ba
944 dates = df_perf.index.to_pydatetime() 1ba
945 dnum = mdates.date2num(dates) 1ba
946 if len(dnum) > 1: 1ba
947 start = dnum[0] - (dnum[1] - dnum[0]) / 2. 1a
948 stop = dnum[-1] + (dnum[1] - dnum[0]) / 2. 1a
949 else:
950 start = dnum[0] + 0.5 1ba
951 stop = dnum[0] + 1.5 1ba
953 extent = [start, stop, 0, n_contrasts] 1ba
955 fig, ax1 = plt.subplots(figsize=(12, 6)) 1ba
956 im = ax1.imshow(df_perf.T.values, extent=extent, aspect="auto", cmap='PuOr') 1ba
958 month_format = mdates.DateFormatter('%b %Y') 1ba
959 month_locator = mdates.MonthLocator() 1ba
960 ax1.xaxis.set_major_locator(month_locator) 1ba
961 ax1.xaxis.set_major_formatter(month_format) 1ba
962 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1ba
963 ax1.xaxis.set_minor_locator(week_locator) 1ba
964 ax1.grid(True, which='minor', axis='x', linestyle='--') 1ba
965 ax1.set_yticks(np.arange(0.5, n_contrasts + 0.5, 1)) 1ba
966 ax1.set_yticklabels(np.sort(df.combined_contrasts.unique())) 1ba
967 ax1.set_ylabel('Contrast (%)') 1ba
968 ax1.set_xlabel('Date') 1ba
969 cbar = fig.colorbar(im) 1ba
970 cbar.set_label('Rightward choice (%') 1ba
971 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1ba
973 ax1.spines['left'].set_visible(False) 1ba
974 ax1.spines['right'].set_visible(False) 1ba
975 ax1.spines['top'].set_visible(False) 1ba
977 return ax1 1ba
980def make_plots(session_path, one, df=None, save=False, upload=False, task_collection='raw_behavior_data'):
981 subject = one.path2ref(session_path)['subject'] 1ba
982 subj_path = session_path.parent.parent 1ba
984 df = load_existing_dataframe(subj_path) if df is None else df 1ba
986 df = df[df['task_protocol'] != 'habituation'] 1ba
988 if len(df) == 0: 1ba
989 return
991 ax1 = plot_trial_count_and_session_duration(df, subject) 1ba
992 ax2 = plot_performance_easy_median_reaction_time(df, subject) 1ba
993 ax3 = plot_heatmap_performance_over_days(df, subject) 1ba
994 ax4 = plot_fit_params(df, subject) 1ba
995 ax5 = plot_psychometric_curve(df, subject, one) 1ba
997 outputs = [] 1ba
998 if save: 1ba
999 save_path = Path(subj_path) 1ba
1000 save_name = save_path.joinpath('subj_trial_count_session_duration.png') 1ba
1001 outputs.append(save_name) 1ba
1002 ax1.get_figure().savefig(save_name, bbox_inches='tight') 1ba
1004 save_name = save_path.joinpath('subj_performance_easy_reaction_time.png') 1ba
1005 outputs.append(save_name) 1ba
1006 ax2.get_figure().savefig(save_name, bbox_inches='tight') 1ba
1008 save_name = save_path.joinpath('subj_performance_heatmap.png') 1ba
1009 outputs.append(save_name) 1ba
1010 ax3.get_figure().savefig(save_name, bbox_inches='tight') 1ba
1012 save_name = save_path.joinpath('subj_psychometric_fit_params.png') 1ba
1013 outputs.append(save_name) 1ba
1014 ax4[0].get_figure().savefig(save_name, bbox_inches='tight') 1ba
1016 save_name = save_path.joinpath('subj_psychometric_curve.png') 1ba
1017 outputs.append(save_name) 1ba
1018 ax5.get_figure().savefig(save_name, bbox_inches='tight') 1ba
1020 if upload: 1ba
1021 subj = one.alyx.rest('subjects', 'list', nickname=subject)[0] 1ba
1022 snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one) 1ba
1023 snp.outputs = outputs 1ba
1024 snp.register_images(widths=['orig']) 1ba