Coverage for ibllib/pipes/training_status.py: 83%
613 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1import logging
2from pathlib import Path
3from datetime import datetime
4from itertools import chain
6import numpy as np
7import pandas as pd
8from iblutil.numerical import ismember
9import one.alf.io as alfio
10from one.alf.exceptions import ALFObjectNotFound
11import one.alf.path as alfiles
12import matplotlib.pyplot as plt
13import matplotlib.dates as mdates
14from matplotlib.lines import Line2D
15import seaborn as sns
16import boto3
17from botocore.exceptions import ProfileNotFound, ClientError
19from ibllib.io.raw_data_loaders import load_bpod
20from ibllib.oneibl.registration import _get_session_times
21from ibllib.io.extractors.base import get_bpod_extractor_class
22from ibllib.io.session_params import read_params
23from ibllib.io.extractors.bpod_trials import get_bpod_extractor
24from ibllib.plots.snapshot import ReportSnapshot
25from brainbox.behavior import training
27logger = logging.getLogger(__name__)
30TRAINING_STATUS = {'untrainable': (-4, (0, 0, 0, 0)),
31 'unbiasable': (-3, (0, 0, 0, 0)),
32 'not_computed': (-2, (0, 0, 0, 0)),
33 'habituation': (-1, (0, 0, 0, 0)),
34 'in training': (0, (0, 0, 0, 0)),
35 'trained 1a': (1, (195, 90, 80, 255)),
36 'trained 1b': (2, (255, 153, 20, 255)),
37 'ready4ephysrig': (3, (28, 20, 255, 255)),
38 'ready4delay': (4, (117, 117, 117, 255)),
39 'ready4recording': (5, (20, 255, 91, 255))}
42def get_training_table_from_aws(lab, subject):
43 """
44 If aws credentials exist on the local server download the latest training table from aws s3 private bucket
45 :param lab:
46 :param subject:
47 :return:
48 """
49 try: 1a
50 session = boto3.Session(profile_name='ibl_training') 1a
51 except ProfileNotFound:
52 return
54 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1a
55 dst_bucket_name = 'ibl-brain-wide-map-private' 1a
56 try: 1a
57 s3 = session.resource('s3') 1a
58 bucket = s3.Bucket(name=dst_bucket_name) 1a
59 bucket.download_file(f'resources/training/{lab}/{subject}/training.csv', 1a
60 local_file_path)
61 df = pd.read_csv(local_file_path)
62 except ClientError: 1a
63 return 1a
65 return df
68def upload_training_table_to_aws(lab, subject):
69 """
70 If aws credentials exist on the local server upload the training table to aws s3 private bucket
71 :param lab:
72 :param subject:
73 :return:
74 """
75 try: 1a
76 session = boto3.Session(profile_name='ibl_training') 1a
77 except ProfileNotFound:
78 return
80 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1a
81 dst_bucket_name = 'ibl-brain-wide-map-private' 1a
82 try: 1a
83 s3 = session.resource('s3') 1a
84 bucket = s3.Bucket(name=dst_bucket_name) 1a
85 bucket.upload_file(local_file_path, 1a
86 f'resources/training/{lab}/{subject}/training.csv')
87 except (ClientError, FileNotFoundError): 1a
88 return 1a
91def save_path(subj_path):
92 return Path(subj_path).joinpath('training.csv') 1cadfb
95def save_dataframe(df, subj_path):
96 """Save training dataframe to disk.
98 :param df: dataframe to save
99 :param subj_path: path to subject folder
100 :return:
101 """
102 df.to_csv(save_path(subj_path), index=False) 1ab
105def load_existing_dataframe(subj_path):
106 """Load training dataframe from disk, if dataframe doesn't exist returns None.
108 :param subj_path: path to subject folder
109 :return:
110 """
111 df_location = save_path(subj_path) 1adfb
112 if df_location.exists(): 1adfb
113 return pd.read_csv(df_location) 1dfb
114 else:
115 df_location.parent.mkdir(exist_ok=True, parents=True) 1adfb
116 return None 1adfb
119def load_trials(sess_path, one, collections=None, force=True, mode='raise'):
120 """
121 Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE,
122 if this also fails, will then attempt to re-extract locally
123 :param sess_path: session path
124 :param one: ONE instance
125 :param force: when True and if the session trials can't be found, will attempt to re-extract from the disk
126 :param mode: 'raise' or 'warn', if 'raise', will error when forcing re-extraction of past sessions
127 :return:
128 """
129 try: 1aeb
130 # try and load all trials that are found locally in the session path locally
131 if collections is None: 1aeb
132 trial_locations = list(sess_path.rglob('_ibl_trials.goCueTrigger_times.*npy')) 1aeb
133 else:
134 trial_locations = [Path(sess_path).joinpath(c, '_ibl_trials.goCueTrigger_times.*npy') for c in collections] 1ab
136 if len(trial_locations) > 1: 1aeb
137 trial_dict = {}
138 for i, loc in enumerate(trial_locations):
139 trial_dict[i] = alfio.load_object(loc.parent, 'trials', short_keys=True)
140 trials = training.concatenate_trials(trial_dict)
141 elif len(trial_locations) == 1: 1aeb
142 trials = alfio.load_object(trial_locations[0].parent, 'trials', short_keys=True) 1aeb
143 else:
144 raise ALFObjectNotFound
146 if 'probabilityLeft' not in trials.keys(): 1aeb
147 raise ALFObjectNotFound
148 except ALFObjectNotFound:
149 # Next try and load all trials data through ONE
150 try:
151 if not force:
152 return None
153 eid = one.path2eid(sess_path)
154 if collections is None:
155 trial_collections = one.list_datasets(eid, '_ibl_trials.goCueTrigger_times.npy')
156 if len(trial_collections) > 0:
157 trial_collections = ['/'.join(c.split('/')[:-1]) for c in trial_collections]
158 else:
159 trial_collections = collections
161 if len(trial_collections) > 1:
162 trial_dict = {}
163 for i, collection in enumerate(trial_collections):
164 trial_dict[i] = one.load_object(eid, 'trials', collection=collection)
165 trials = training.concatenate_trials(trial_dict)
166 elif len(trial_collections) == 1:
167 trials = one.load_object(eid, 'trials', collection=trial_collections[0])
168 else:
169 raise ALFObjectNotFound
171 if 'probabilityLeft' not in trials.keys():
172 raise ALFObjectNotFound
173 except Exception:
174 # Finally try to re-extract the trials data locally
175 try:
176 raw_collections, _ = get_data_collection(sess_path)
178 if len(raw_collections) == 0:
179 return None
181 trials_dict = {}
182 for i, collection in enumerate(raw_collections):
183 extractor = get_bpod_extractor(sess_path, task_collection=collection)
184 trials_data, _ = extractor.extract(task_collection=collection, save=False)
185 trials_dict[i] = alfio.AlfBunch.from_df(trials_data['table'])
187 if len(trials_dict) > 1:
188 trials = training.concatenate_trials(trials_dict)
189 else:
190 trials = trials_dict[0]
192 except Exception as e:
193 if mode == 'raise':
194 raise Exception(f'Exhausted all possibilities for loading trials for {sess_path}') from e
195 else:
196 logger.warning(f'Exhausted all possibilities for loading trials for {sess_path}')
197 return
199 return trials 1aeb
202def load_combined_trials(sess_paths, one, force=True):
203 """
204 Load and concatenate trials for multiple sessions. Used when we want to concatenate trials for two sessions on the same day
205 :param sess_paths: list of paths to sessions
206 :param one: ONE instance
207 :return:
208 """
209 trials_dict = {} 1aeb
210 for sess_path in sess_paths: 1aeb
211 trials = load_trials(Path(sess_path), one, force=force, mode='warn') 1aeb
212 if trials is not None: 1aeb
213 trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force, mode='warn' 1aeb
215 )
217 return training.concatenate_trials(trials_dict) 1aeb
220def get_latest_training_information(sess_path, one, save=True):
221 """
222 Extracts the latest training status.
224 Parameters
225 ----------
226 sess_path : pathlib.Path
227 The session path from which to load the data.
228 one : one.api.One
229 An ONE instance.
231 Returns
232 -------
233 pandas.DataFrame
234 A table of training information.
235 """
237 subj_path = sess_path.parent.parent 1ab
238 sub = subj_path.parts[-1] 1ab
239 if one.mode != 'local': 1ab
240 lab = one.alyx.rest('subjects', 'list', nickname=sub)[0]['lab'] 1a
241 df = get_training_table_from_aws(lab, sub) 1a
242 else:
243 df = None 1b
245 if df is None: 1ab
246 df = load_existing_dataframe(subj_path) 1ab
248 # Find the dates and associated session paths where we don't have data stored in our dataframe
249 missing_dates = check_up_to_date(subj_path, df) 1ab
251 # Iterate through the dates to fill up our training dataframe
252 for _, grp in missing_dates.groupby('date'): 1ab
253 sess_dicts = get_training_info_for_session(grp.session_path.values, one) 1ab
254 if len(sess_dicts) == 0: 1ab
255 continue
257 for sess_dict in sess_dicts: 1ab
258 if df is None: 1ab
259 df = pd.DataFrame.from_dict(sess_dict) 1ab
260 else:
261 df = pd.concat([df, pd.DataFrame.from_dict(sess_dict)]) 1b
263 # Sort values by date and reset the index
264 df = df.sort_values('date') 1ab
265 df = df.reset_index(drop=True) 1ab
266 # Save our dataframe
267 if save: 1ab
268 save_dataframe(df, subj_path) 1ab
270 # Now go through the backlog and compute the training status for sessions. If for example one was missing as it is cumulative
271 # we need to go through and compute all the backlog
272 # Find the earliest date in missing dates that we need to recompute the training status for
273 missing_status = find_earliest_recompute_date(df.drop_duplicates('date').reset_index(drop=True)) 1ab
274 for date in missing_status: 1ab
275 df, _, _, _ = compute_training_status(df, date, one) 1ab
277 df_lim = df.drop_duplicates(subset='session_path', keep='first') 1ab
279 # Detect untrainable
280 if 'untrainable' not in df_lim.training_status.values: 1ab
281 un_df = df_lim[df_lim['training_status'] == 'in training'].sort_values('date') 1ab
282 if len(un_df) >= 40: 1ab
283 sess = un_df.iloc[39].session_path
284 df.loc[df['session_path'] == sess, 'training_status'] = 'untrainable'
286 # Detect unbiasable
287 if 'unbiasable' not in df_lim.training_status.values: 1ab
288 un_df = df_lim[df_lim['task_protocol'] == 'biased'].sort_values('date') 1ab
289 if len(un_df) >= 40: 1ab
290 tr_st = un_df[0:40].training_status.unique()
291 if 'ready4ephysrig' not in tr_st:
292 sess = un_df.iloc[39].session_path
293 df.loc[df['session_path'] == sess, 'training_status'] = 'unbiasable'
294 if save: 1ab
295 save_dataframe(df, subj_path) 1ab
297 if one.mode != 'local' and save: 1ab
298 upload_training_table_to_aws(lab, sub) 1a
300 return df 1ab
303def find_earliest_recompute_date(df):
304 """
305 Find the earliest date that we need to compute the training status from. Training status depends on previous sessions
306 so if a session was missing and now has been added we need to recompute everything from that date onwards
307 :param df:
308 :return:
309 """
310 missing_df = df[df['training_status'] == 'not_computed'] 1afb
311 if len(missing_df) == 0: 1afb
312 return []
313 missing_df = missing_df.sort_values('date') 1afb
314 first_index = missing_df.index[0] 1afb
316 return df[first_index:].date.values 1afb
319def compute_training_status(df, compute_date, one, force=True, populate=True):
320 """
321 Compute the training status for compute date based on training from that session and two previous days.
323 When true and if the session trials can't be found, will attempt to re-extract from disk.
324 :return:
326 Parameters
327 ----------
328 df : pandas.DataFrame
329 A training data frame, e.g. one generated from :func:`get_training_info_for_session`.
330 compute_date : str, datetime.datetime, pandas.Timestamp
331 The date to compute training on.
332 one : one.api.One
333 An instance of ONE for loading trials data.
334 force : bool
335 When true and if the session trials can't be found, will attempt to re-extract from disk.
336 populate : bool
337 Whether to update the training data frame with the new training status value
339 Returns
340 -------
341 pandas.DataFrame
342 The input data frame with a 'training_status' column populated for `compute_date` if populate=True
343 Bunch
344 Bunch containing information fit parameters information for the combined sessions
345 Bunch
346 Bunch cotaining the training status criteria information
347 str
348 The training status
349 """
351 # compute_date = str(alfiles.session_path_parts(session_path, as_dict=True)['date'])
352 df_temp = df[df['date'] <= compute_date] 1ab
353 df_temp = df_temp.drop_duplicates(subset=['session_path', 'task_protocol']) 1ab
354 df_temp.sort_values('date') 1ab
356 dates = df_temp.date.values 1ab
358 n_sess_for_date = len(np.where(dates == compute_date)[0]) 1ab
359 n_dates = np.min([2 + n_sess_for_date, len(dates)]).astype(int) 1ab
360 compute_dates = dates[(-1 * n_dates):] 1ab
361 if n_sess_for_date > 1: 1ab
362 compute_dates = compute_dates[:(-1 * (n_sess_for_date - 1))] 1b
364 assert compute_dates[-1] == compute_date 1ab
366 df_temp_group = df_temp.groupby('date') 1ab
368 trials = {} 1ab
369 n_delay = 0 1ab
370 ephys_sessions = [] 1ab
371 protocol = [] 1ab
372 status = [] 1ab
373 for date in compute_dates: 1ab
375 df_date = df_temp_group.get_group(date) 1ab
377 # If habituation skip
378 if df_date.iloc[-1]['task_protocol'] == 'habituation': 1ab
379 continue
380 # Here we should split by protocol in an ideal world but that world isn't today. This is only really relevant for
381 # chained protocols
382 trials[df_date.iloc[-1]['date']] = load_combined_trials(df_date.session_path.values, one, force=force) 1ab
383 protocol.append(df_date.iloc[-1]['task_protocol']) 1ab
384 status.append(df_date.iloc[-1]['training_status']) 1ab
385 if df_date.iloc[-1]['combined_n_delay'] >= 900: # delay of 15 mins 1ab
386 n_delay += 1 1b
387 if df_date.iloc[-1]['location'] == 'ephys_rig': 1ab
388 ephys_sessions.append(df_date.iloc[-1]['date']) 1b
390 n_status = np.max([-2, -1 * len(status)]) 1ab
391 training_status, info, criteria = training.get_training_status(trials, protocol, ephys_sessions, n_delay) 1ab
392 training_status = pass_through_training_hierachy(training_status, status[n_status]) 1ab
393 if populate: 1ab
394 df.loc[df['date'] == compute_date, 'training_status'] = training_status 1ab
396 return df, info, criteria, training_status 1ab
399def pass_through_training_hierachy(status_new, status_old):
400 """
401 Makes sure that the new training status is not less than the one from the previous day. e.g Subject cannot regress in
402 performance
403 :param status_new: latest training status
404 :param status_old: previous training status
405 :return:
406 """
408 if TRAINING_STATUS[status_old][0] > TRAINING_STATUS[status_new][0]: 1abg
409 return status_old 1b
410 else:
411 return status_new 1abg
414def compute_session_duration_delay_location(sess_path, collections=None, **kwargs):
415 """
416 Get meta information about task. Extracts session duration, delay before session start and location of session
418 Parameters
419 ----------
420 sess_path : pathlib.Path, str
421 The session path with the pattern subject/yyyy-mm-dd/nnn.
422 collections : list
423 The location within the session path directory of task settings and data.
425 Returns
426 -------
427 int
428 The session duration in minutes, rounded to the nearest minute.
429 int
430 The delay between session start time and the first trial in seconds.
431 str {'ephys_rig', 'training_rig'}
432 The location of the session.
433 """
434 if collections is None: 1ab
435 collections, _ = get_data_collection(sess_path)
437 session_duration = 0 1ab
438 session_delay = 0 1ab
439 session_location = 'training_rig' 1ab
440 for collection in collections: 1ab
441 md, sess_data = load_bpod(sess_path, task_collection=collection) 1ab
442 if md is None: 1ab
443 continue
444 try: 1ab
445 start_time, end_time = _get_session_times(sess_path, md, sess_data) 1ab
446 session_duration = session_duration + int((end_time - start_time).total_seconds() / 60) 1ab
447 session_delay = session_delay + md.get('SESSION_DELAY_START', 1ab
448 md.get('SESSION_START_DELAY_SEC', 0))
449 except Exception:
450 session_duration = session_duration + 0
451 session_delay = session_delay + 0
453 if 'ephys' in md.get('RIG_NAME', md.get('PYBPOD_BOARD', None)): 1ab
454 session_location = 'ephys_rig' 1b
455 else:
456 session_location = 'training_rig' 1ab
458 return session_duration, session_delay, session_location 1ab
461def get_data_collection(session_path):
462 """Return the location of the raw behavioral data and extracted trials data for a given session.
464 For multiple locations in one session (e.g. chained protocols), returns all collections.
465 Passive protocols are excluded.
467 Parameters
468 ----------
469 session_path : pathlib.Path
470 A session path in the form subject/date/number.
472 Returns
473 -------
474 list of str
475 A list of sub-directory names that contain raw behaviour data.
476 list of str
477 A list of sub-directory names that contain ALF trials data.
479 Examples
480 --------
481 An iblrig v7 session
483 >>> get_data_collection(Path(r'C:/data/subject/2023-01-01/001'))
484 ['raw_behavior_data'], ['alf']
486 An iblrig v8 session where two protocols were run
488 >>> get_data_collection(Path(r'C:/data/subject/2023-01-01/001'))
489 ['raw_task_data_00', 'raw_task_data_01], ['alf/task_00', 'alf/task_01']
490 """
491 experiment_description = read_params(session_path) 1ab
492 collections = [] 1ab
493 if experiment_description is not None: 1ab
494 task_protocols = experiment_description.get('tasks', []) 1ab
495 for i, (protocol, task_info) in enumerate(chain(*map(dict.items, task_protocols))): 1ab
496 if 'passiveChoiceWorld' in protocol: 1ab
497 continue 1b
498 collection = task_info.get('collection', f'raw_task_data_{i:02}') 1ab
499 if collection == 'raw_passive_data': 1ab
500 continue
501 collections.append(collection) 1ab
502 else:
503 settings = Path(session_path).rglob('_iblrig_taskSettings.raw*.json') 1b
504 for setting in settings: 1b
505 if setting.parent.name != 'raw_passive_data': 1b
506 collections.append(setting.parent.name) 1b
508 if len(collections) == 1 and collections[0] == 'raw_behavior_data': 1ab
509 alf_collections = ['alf'] 1ab
510 elif all(['raw_task_data' in c for c in collections]): 1b
511 alf_collections = [f'alf/task_{c[-2:]}' for c in collections] 1b
512 else:
513 alf_collections = None
515 return collections, alf_collections 1ab
518def get_sess_dict(session_path, one, protocol, alf_collections=None, raw_collections=None, force=True):
520 sess_dict = {} 1ab
521 sess_dict['date'] = str(alfiles.session_path_parts(session_path, as_dict=True)['date']) 1ab
522 sess_dict['session_path'] = str(session_path) 1ab
523 sess_dict['task_protocol'] = protocol 1ab
525 if sess_dict['task_protocol'] == 'habituation': 1ab
526 nan_array = np.array([np.nan])
527 sess_dict['performance'], sess_dict['contrasts'], _ = (nan_array, nan_array, np.nan)
528 sess_dict['performance_easy'] = np.nan
529 sess_dict['reaction_time'] = np.nan
530 sess_dict['n_trials'] = np.nan
531 sess_dict['sess_duration'] = np.nan
532 sess_dict['n_delay'] = np.nan
533 sess_dict['location'] = np.nan
534 sess_dict['training_status'] = 'habituation'
535 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapselow_50'], sess_dict['lapsehigh_50'] = \
536 (np.nan, np.nan, np.nan, np.nan)
537 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapselow_20'], sess_dict['lapsehigh_20'] = \
538 (np.nan, np.nan, np.nan, np.nan)
539 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapselow_80'], sess_dict['lapsehigh_80'] = \
540 (np.nan, np.nan, np.nan, np.nan)
542 else:
543 # if we can't compute trials then we need to pass
544 trials = load_trials(session_path, one, collections=alf_collections, force=force, mode='warn') 1ab
545 if trials is None: 1ab
546 return
548 sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True) 1ab
549 if sess_dict['task_protocol'] == 'training': 1ab
550 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapselow_50'], sess_dict['lapsehigh_50'] = \ 1b
551 training.compute_psychometric(trials)
552 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapselow_20'], sess_dict['lapsehigh_20'] = \ 1b
553 (np.nan, np.nan, np.nan, np.nan)
554 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapselow_80'], sess_dict['lapsehigh_80'] = \ 1b
555 (np.nan, np.nan, np.nan, np.nan)
556 else:
557 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapselow_50'], sess_dict['lapsehigh_50'] = \ 1ab
558 training.compute_psychometric(trials, block=0.5)
559 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapselow_20'], sess_dict['lapsehigh_20'] = \ 1ab
560 training.compute_psychometric(trials, block=0.2)
561 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapselow_80'], sess_dict['lapsehigh_80'] = \ 1ab
562 training.compute_psychometric(trials, block=0.8)
564 sess_dict['performance_easy'] = training.compute_performance_easy(trials) 1ab
565 sess_dict['reaction_time'] = training.compute_median_reaction_time(trials) 1ab
566 sess_dict['n_trials'] = training.compute_n_trials(trials) 1ab
567 sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \ 1ab
568 compute_session_duration_delay_location(session_path, collections=raw_collections)
569 sess_dict['training_status'] = 'not_computed' 1ab
571 return sess_dict 1ab
574def get_training_info_for_session(session_paths, one, force=True):
575 """
576 Extract the training information needed for plots for each session.
578 Parameters
579 ----------
580 session_paths : list of pathlib.Path
581 List of session paths on same date.
582 one : one.api.One
583 An ONE instance.
584 force : bool
585 When true and if the session trials can't be found, will attempt to re-extract from disk.
587 Returns
588 -------
589 list of dict
590 A list of dictionaries the length of `session_paths` containing individual and aggregate
591 performance information.
592 """
594 # return list of dicts to add
595 sess_dicts = [] 1ab
596 for session_path in session_paths: 1ab
597 collections, alf_collections = get_data_collection(session_path) 1ab
598 session_path = Path(session_path) 1ab
599 protocols = [] 1ab
600 for c in collections: 1ab
601 try: 1ab
602 prot = get_bpod_extractor_class(session_path, task_collection=c) 1ab
603 prot = prot[:-6].lower() 1ab
604 protocols.append(prot) 1ab
605 except ValueError:
606 continue
608 un_protocols = np.unique(protocols) 1ab
609 # Example, training, training, biased - training would be combined, biased not
610 sess_dict = None 1ab
611 if len(un_protocols) != 1: 1ab
612 print(f'Different protocols in same session {session_path} : {protocols}')
613 for prot in un_protocols:
614 if prot is False:
615 continue
616 try:
617 alf = alf_collections[np.where(protocols == prot)[0]]
618 raw = collections[np.where(protocols == prot)[0]]
619 except TypeError:
620 alf = None
621 raw = None
622 sess_dict = get_sess_dict(session_path, one, prot, alf_collections=alf, raw_collections=raw, force=force)
623 else:
624 prot = un_protocols[0] 1ab
625 sess_dict = get_sess_dict( 1ab
626 session_path, one, prot, alf_collections=alf_collections, raw_collections=collections, force=force)
628 if sess_dict is not None: 1ab
629 sess_dicts.append(sess_dict) 1ab
631 protocols = [s['task_protocol'] for s in sess_dicts] 1ab
633 if len(protocols) > 0 and len(set(protocols)) != 1: 1ab
634 print(f'Different protocols on same date {sess_dicts[0]["date"]} : {protocols}')
636 # Only if all protocols are the same and are not habituation
637 if len(sess_dicts) > 1 and len(set(protocols)) == 1 and protocols[0] != 'habituation': # Only if all protocols are the same 1ab
638 print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}') 1b
639 combined_trials = load_combined_trials(session_paths, one, force=force) 1b
640 performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True) 1b
641 psychs = {} 1b
642 psychs['50'] = training.compute_psychometric(combined_trials, block=0.5) 1b
643 psychs['20'] = training.compute_psychometric(combined_trials, block=0.2) 1b
644 psychs['80'] = training.compute_psychometric(combined_trials, block=0.8) 1b
646 performance_easy = training.compute_performance_easy(combined_trials) 1b
647 reaction_time = training.compute_median_reaction_time(combined_trials) 1b
648 n_trials = training.compute_n_trials(combined_trials) 1b
650 sess_duration = np.nansum([s['sess_duration'] for s in sess_dicts]) 1b
651 n_delay = np.nanmax([s['n_delay'] for s in sess_dicts]) 1b
653 for sess_dict in sess_dicts: 1b
654 sess_dict['combined_performance'] = performance 1b
655 sess_dict['combined_contrasts'] = contrasts 1b
656 sess_dict['combined_performance_easy'] = performance_easy 1b
657 sess_dict['combined_reaction_time'] = reaction_time 1b
658 sess_dict['combined_n_trials'] = n_trials 1b
659 sess_dict['combined_sess_duration'] = sess_duration 1b
660 sess_dict['combined_n_delay'] = n_delay 1b
662 for bias in [50, 20, 80]: 1b
663 sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0] 1b
664 sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1] 1b
665 sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][2] 1b
666 sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][3] 1b
668 # Case where two sessions on same day with different number of contrasts! Oh boy
669 if sess_dict['combined_performance'].size != sess_dict['performance'].size: 1b
670 sess_dict['performance'] = \
671 np.r_[sess_dict['performance'],
672 np.full(sess_dict['combined_performance'].size - sess_dict['performance'].size, np.nan)]
673 sess_dict['contrasts'] = \
674 np.r_[sess_dict['contrasts'],
675 np.full(sess_dict['combined_contrasts'].size - sess_dict['contrasts'].size, np.nan)]
677 else:
678 for sess_dict in sess_dicts: 1ab
679 sess_dict['combined_performance'] = sess_dict['performance'] 1ab
680 sess_dict['combined_contrasts'] = sess_dict['contrasts'] 1ab
681 sess_dict['combined_performance_easy'] = sess_dict['performance_easy'] 1ab
682 sess_dict['combined_reaction_time'] = sess_dict['reaction_time'] 1ab
683 sess_dict['combined_n_trials'] = sess_dict['n_trials'] 1ab
684 sess_dict['combined_sess_duration'] = sess_dict['sess_duration'] 1ab
685 sess_dict['combined_n_delay'] = sess_dict['n_delay'] 1ab
687 for bias in [50, 20, 80]: 1ab
688 sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}'] 1ab
689 sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}'] 1ab
690 sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}'] 1ab
691 sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}'] 1ab
693 return sess_dicts 1ab
696def check_up_to_date(subj_path, df):
697 """
698 Check which sessions on local file system are missing from the computed training table.
700 Parameters
701 ----------
702 subj_path : pathlib.Path
703 The path to the subject's dated session folders.
704 df : pandas.DataFrame
705 The computed training table.
707 Returns
708 -------
709 pandas.DataFrame
710 A table of dates and session paths that are missing from the computed training table.
711 """
712 df_session = pd.DataFrame(columns=['date', 'session_path']) 1adb
714 for session in alfio.iter_sessions(subj_path, pattern='????-??-??/*'): 1adb
715 s_df = pd.DataFrame({'date': session.parts[-2], 'session_path': str(session)}, index=[0]) 1adb
716 df_session = pd.concat([df_session, s_df], ignore_index=True) 1adb
718 if df is None or 'combined_thres_50' not in df.columns: 1adb
719 return df_session 1adb
720 else:
721 # recorded_session_paths = df['session_path'].values
722 isin, _ = ismember(df_session.date.unique(), df.date.unique()) 1d
723 missing_dates = df_session.date.unique()[~isin] 1d
724 return df_session[df_session['date'].isin(missing_dates)].sort_values('date') 1d
727def plot_trial_count_and_session_duration(df, subject):
729 df = df.drop_duplicates('date').reset_index(drop=True) 1a
731 y1 = {'column': 'combined_n_trials', 1a
732 'title': 'Trial counts',
733 'lim': None,
734 'color': 'k',
735 'join': True}
737 y2 = {'column': 'combined_sess_duration', 1a
738 'title': 'Session duration (mins)',
739 'lim': None,
740 'color': 'r',
741 'log': False,
742 'join': True}
744 ax = plot_over_days(df, subject, y1, y2) 1a
746 return ax 1a
749def plot_performance_easy_median_reaction_time(df, subject):
750 df = df.drop_duplicates('date').reset_index(drop=True) 1a
752 y1 = {'column': 'combined_performance_easy', 1a
753 'title': 'Performance on easy trials',
754 'lim': [0, 1.05],
755 'color': 'k',
756 'join': True}
758 y2 = {'column': 'combined_reaction_time', 1a
759 'title': 'Median reaction time (s)',
760 'lim': [0.1, np.nanmax([10, np.nanmax(df.combined_reaction_time.values)])],
761 'color': 'r',
762 'log': True,
763 'join': True}
764 ax = plot_over_days(df, subject, y1, y2) 1a
766 return ax 1a
769def display_info(df, axs):
770 compute_date = df['date'].values[-1] 1a
771 _, info, criteria, _ = compute_training_status(df, compute_date, None, force=False, populate=False) 1a
773 def _array_to_string(vals): 1a
774 if isinstance(vals, (str, bool, int, float)): 1a
775 if isinstance(vals, float): 1a
776 vals = np.round(vals, 3) 1a
777 return f'{vals}' 1a
779 str_vals = '' 1a
780 for v in vals: 1a
781 if isinstance(v, float): 1a
782 v = np.round(v, 3) 1a
783 str_vals += f'{v}, ' 1a
784 return str_vals[:-2] 1a
786 pos = np.arange(len(info))[::-1] * 0.1 1a
787 for i, (k, v) in enumerate(info.items()): 1a
788 str_v = _array_to_string(v) 1a
789 text = axs[0].text(0, pos[i], k.capitalize(), color='k', weight='bold', fontsize=8, transform=axs[0].transAxes) 1a
790 axs[0].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", 1a
791 color='k', fontsize=7)
793 pos = np.arange(len(criteria))[::-1] * 0.1 1a
794 crit_val = criteria.pop('Criteria') 1a
795 c = 'g' if crit_val['pass'] else 'r' 1a
796 str_v = _array_to_string(crit_val['val']) 1a
797 text = axs[1].text(0, pos[0], 'Criteria', color='k', weight='bold', fontsize=8, transform=axs[1].transAxes) 1a
798 axs[1].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", 1a
799 color=c, fontsize=7)
800 pos = pos[1:] 1a
802 for i, (k, v) in enumerate(criteria.items()): 1a
803 c = 'g' if v['pass'] else 'r' 1a
804 str_v = _array_to_string(v['val']) 1a
805 text = axs[1].text(0, pos[i], k.capitalize(), color='k', weight='bold', fontsize=8, transform=axs[1].transAxes) 1a
806 axs[1].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", 1a
807 color=c, fontsize=7)
809 axs[0].set_axis_off() 1a
810 axs[1].set_axis_off() 1a
813def plot_fit_params(df, subject):
814 fig, axs = plt.subplots(2, 3, figsize=(12, 6), gridspec_kw={'width_ratios': [2, 2, 1]}) 1a
816 try: 1a
817 display_info(df, axs=[axs[0, 2], axs[1, 2]]) 1a
818 except ValueError:
819 print('Could not evaluate detailed training status information')
821 df = df.drop_duplicates('date').reset_index(drop=True) 1a
823 cmap = sns.diverging_palette(20, 220, n=3, center="dark") 1a
825 y50 = {'column': 'combined_bias_50', 1a
826 'title': 'Bias',
827 'lim': [-100, 100],
828 'color': cmap[1],
829 'join': False}
831 y80 = {'column': 'combined_bias_80', 1a
832 'title': 'Bias',
833 'lim': [-100, 100],
834 'color': cmap[2],
835 'join': False}
837 y20 = {'column': 'combined_bias_20', 1a
838 'title': 'Bias',
839 'lim': [-100, 100],
840 'color': cmap[0],
841 'join': False}
843 plot_over_days(df, subject, y50, ax=axs[0, 0], legend=False, title=False) 1a
844 plot_over_days(df, subject, y80, ax=axs[0, 0], legend=False, title=False) 1a
845 plot_over_days(df, subject, y20, ax=axs[0, 0], legend=False, title=False) 1a
846 axs[0, 0].axhline(16, linewidth=2, linestyle='--', color='k') 1a
847 axs[0, 0].axhline(-16, linewidth=2, linestyle='--', color='k') 1a
849 y50['column'] = 'combined_thres_50' 1a
850 y50['title'] = 'Threshold' 1a
851 y50['lim'] = [0, 100] 1a
852 y80['column'] = 'combined_thres_20' 1a
853 y80['title'] = 'Threshold' 1a
854 y20['lim'] = [0, 100] 1a
855 y20['column'] = 'combined_thres_80' 1a
856 y20['title'] = 'Threshold' 1a
857 y80['lim'] = [0, 100] 1a
859 plot_over_days(df, subject, y50, ax=axs[0, 1], legend=False, title=False) 1a
860 plot_over_days(df, subject, y80, ax=axs[0, 1], legend=False, title=False) 1a
861 plot_over_days(df, subject, y20, ax=axs[0, 1], legend=False, title=False) 1a
862 axs[0, 1].axhline(19, linewidth=2, linestyle='--', color='k') 1a
864 y50['column'] = 'combined_lapselow_50' 1a
865 y50['title'] = 'Lapse Low' 1a
866 y50['lim'] = [0, 1] 1a
867 y80['column'] = 'combined_lapselow_20' 1a
868 y80['title'] = 'Lapse Low' 1a
869 y80['lim'] = [0, 1] 1a
870 y20['column'] = 'combined_lapselow_80' 1a
871 y20['title'] = 'Lapse Low' 1a
872 y20['lim'] = [0, 1] 1a
874 plot_over_days(df, subject, y50, ax=axs[1, 0], legend=False, title=False) 1a
875 plot_over_days(df, subject, y80, ax=axs[1, 0], legend=False, title=False) 1a
876 plot_over_days(df, subject, y20, ax=axs[1, 0], legend=False, title=False) 1a
877 axs[1, 0].axhline(0.2, linewidth=2, linestyle='--', color='k') 1a
879 y50['column'] = 'combined_lapsehigh_50' 1a
880 y50['title'] = 'Lapse High' 1a
881 y50['lim'] = [0, 1] 1a
882 y80['column'] = 'combined_lapsehigh_20' 1a
883 y80['title'] = 'Lapse High' 1a
884 y80['lim'] = [0, 1] 1a
885 y20['column'] = 'combined_lapsehigh_80' 1a
886 y20['title'] = 'Lapse High' 1a
887 y20['lim'] = [0, 1] 1a
889 plot_over_days(df, subject, y50, ax=axs[1, 1], legend=False, title=False, training_lines=True) 1a
890 plot_over_days(df, subject, y80, ax=axs[1, 1], legend=False, title=False, training_lines=False) 1a
891 plot_over_days(df, subject, y20, ax=axs[1, 1], legend=False, title=False, training_lines=False) 1a
892 axs[1, 1].axhline(0.2, linewidth=2, linestyle='--', color='k') 1a
894 fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1a
895 lines, labels = axs[1, 1].get_legend_handles_labels() 1a
896 fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), facecolor='w', fancybox=True, shadow=True, 1a
897 ncol=5)
899 legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8), 1a
900 Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8),
901 Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)]
902 legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, 1a
903 shadow=True, facecolor='w')
904 fig.add_artist(legend2) 1a
906 return axs 1a
909def plot_psychometric_curve(df, subject, one):
910 df = df.drop_duplicates('date').reset_index(drop=True) 1a
911 sess_path = Path(df.iloc[-1]["session_path"]) 1a
912 trials = load_trials(sess_path, one, mode='warn') 1a
914 fig, ax1 = plt.subplots(figsize=(8, 6)) 1a
916 training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1a
918 return ax1 1a
921def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True):
923 if ax is None: 1a
924 fig, ax1 = plt.subplots(figsize=(12, 6)) 1a
925 else:
926 ax1 = ax 1a
928 dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']] 1a
929 if y1['join']: 1a
930 ax1.plot(dates, df[y1['column']], color=y1['color']) 1a
931 ax1.scatter(dates, df[y1['column']], color=y1['color']) 1a
932 ax1.set_ylabel(y1['title']) 1a
933 ax1.set_ylim(y1['lim']) 1a
935 if y2 is not None: 1a
936 ax2 = ax1.twinx() 1a
937 if y2['join']: 1a
938 ax2.plot(dates, df[y2['column']], color=y2['color']) 1a
939 ax2.scatter(dates, df[y2['column']], color=y2['color']) 1a
940 ax2.set_ylabel(y2['title']) 1a
941 ax2.yaxis.label.set_color(y2['color']) 1a
942 ax2.tick_params(axis='y', colors=y2['color']) 1a
943 ax2.set_ylim(y2['lim']) 1a
944 if y2['log']: 1a
945 ax2.set_yscale('log') 1a
947 ax2.spines['right'].set_visible(False) 1a
948 ax2.spines['top'].set_visible(False) 1a
949 ax2.spines['left'].set_visible(False) 1a
951 month_format = mdates.DateFormatter('%b %Y') 1a
952 month_locator = mdates.MonthLocator() 1a
953 ax1.xaxis.set_major_locator(month_locator) 1a
954 ax1.xaxis.set_major_formatter(month_format) 1a
955 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1a
956 ax1.xaxis.set_minor_locator(week_locator) 1a
957 ax1.grid(True, which='minor', axis='x', linestyle='--') 1a
959 ax1.spines['left'].set_visible(False) 1a
960 ax1.spines['right'].set_visible(False) 1a
961 ax1.spines['top'].set_visible(False) 1a
963 if training_lines: 1a
964 ax1 = add_training_lines(df, ax1) 1a
966 if title: 1a
967 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1a
969 # Put a legend below current axis
970 box = ax1.get_position() 1a
971 ax1.set_position([box.x0, box.y0 + box.height * 0.1, 1a
972 box.width, box.height * 0.9])
973 if legend: 1a
974 ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), 1a
975 fancybox=True, shadow=True, ncol=5, facecolor='white')
977 return ax1 1a
980def add_training_lines(df, ax):
982 status = df.drop_duplicates(subset='training_status', keep='first') 1a
983 for _, st in status.iterrows(): 1a
985 if st['training_status'] in ['untrainable', 'unbiasable']: 1a
986 continue
988 if TRAINING_STATUS[st['training_status']][0] <= 0: 1a
989 continue
991 ax.axvline(datetime.strptime(st['date'], '%Y-%m-%d'), linewidth=2, 1a
992 color=np.array(TRAINING_STATUS[st['training_status']][1]) / 255, label=st['training_status'])
994 return ax 1a
997def plot_heatmap_performance_over_days(df, subject):
999 df = df.drop_duplicates(subset=['date', 'combined_contrasts']) 1a
1000 df_perf = df.pivot(index=['date'], columns=['combined_contrasts'], values=['combined_performance']).sort_values( 1a
1001 by='combined_contrasts', axis=1, ascending=False)
1002 df_perf.index = pd.to_datetime(df_perf.index) 1a
1003 full_date_range = pd.date_range(start=df_perf.index.min(), end=df_perf.index.max(), freq="D") 1a
1004 df_perf = df_perf.reindex(full_date_range, fill_value=np.nan) 1a
1006 n_contrasts = len(df.combined_contrasts.unique()) 1a
1008 dates = df_perf.index.to_pydatetime() 1a
1009 dnum = mdates.date2num(dates) 1a
1010 if len(dnum) > 1: 1a
1011 start = dnum[0] - (dnum[1] - dnum[0]) / 2.
1012 stop = dnum[-1] + (dnum[1] - dnum[0]) / 2.
1013 else:
1014 start = dnum[0] + 0.5 1a
1015 stop = dnum[0] + 1.5 1a
1017 extent = [start, stop, 0, n_contrasts] 1a
1019 fig, ax1 = plt.subplots(figsize=(12, 6)) 1a
1020 im = ax1.imshow(df_perf.T.values, extent=extent, aspect="auto", cmap='PuOr') 1a
1022 month_format = mdates.DateFormatter('%b %Y') 1a
1023 month_locator = mdates.MonthLocator() 1a
1024 ax1.xaxis.set_major_locator(month_locator) 1a
1025 ax1.xaxis.set_major_formatter(month_format) 1a
1026 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1a
1027 ax1.xaxis.set_minor_locator(week_locator) 1a
1028 ax1.grid(True, which='minor', axis='x', linestyle='--') 1a
1029 ax1.set_yticks(np.arange(0.5, n_contrasts + 0.5, 1)) 1a
1030 ax1.set_yticklabels(np.sort(df.combined_contrasts.unique())) 1a
1031 ax1.set_ylabel('Contrast (%)') 1a
1032 ax1.set_xlabel('Date') 1a
1033 cbar = fig.colorbar(im) 1a
1034 cbar.set_label('Rightward choice (%') 1a
1035 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1a
1037 ax1.spines['left'].set_visible(False) 1a
1038 ax1.spines['right'].set_visible(False) 1a
1039 ax1.spines['top'].set_visible(False) 1a
1041 return ax1 1a
1044def make_plots(session_path, one, df=None, save=False, upload=False, task_collection='raw_behavior_data'):
1045 subject = one.path2ref(session_path)['subject'] 1a
1046 subj_path = session_path.parent.parent 1a
1048 df = load_existing_dataframe(subj_path) if df is None else df 1a
1050 df = df[df['task_protocol'] != 'habituation'] 1a
1052 if len(df) == 0: 1a
1053 return
1055 ax1 = plot_trial_count_and_session_duration(df, subject) 1a
1056 ax2 = plot_performance_easy_median_reaction_time(df, subject) 1a
1057 ax3 = plot_heatmap_performance_over_days(df, subject) 1a
1058 ax4 = plot_fit_params(df, subject) 1a
1059 ax5 = plot_psychometric_curve(df, subject, one) 1a
1061 outputs = [] 1a
1062 if save: 1a
1063 save_path = Path(subj_path) 1a
1064 save_name = save_path.joinpath('subj_trial_count_session_duration.png') 1a
1065 outputs.append(save_name) 1a
1066 ax1.get_figure().savefig(save_name, bbox_inches='tight') 1a
1068 save_name = save_path.joinpath('subj_performance_easy_reaction_time.png') 1a
1069 outputs.append(save_name) 1a
1070 ax2.get_figure().savefig(save_name, bbox_inches='tight') 1a
1072 save_name = save_path.joinpath('subj_performance_heatmap.png') 1a
1073 outputs.append(save_name) 1a
1074 ax3.get_figure().savefig(save_name, bbox_inches='tight') 1a
1076 save_name = save_path.joinpath('subj_psychometric_fit_params.png') 1a
1077 outputs.append(save_name) 1a
1078 ax4[0, 0].get_figure().savefig(save_name, bbox_inches='tight') 1a
1080 save_name = save_path.joinpath('subj_psychometric_curve.png') 1a
1081 outputs.append(save_name) 1a
1082 ax5.get_figure().savefig(save_name, bbox_inches='tight') 1a
1084 if upload: 1a
1085 subj = one.alyx.rest('subjects', 'list', nickname=subject)[0] 1a
1086 snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one) 1a
1087 snp.outputs = outputs 1a
1088 snp.register_images(widths=['orig']) 1a