Coverage for ibllib/pipes/training_status.py: 91%
584 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1import one.alf.io as alfio
2from one.alf.exceptions import ALFObjectNotFound
4from ibllib.io.raw_data_loaders import load_bpod
5from ibllib.oneibl.registration import _get_session_times
6from ibllib.io.extractors.base import get_pipeline, get_session_extractor_type
7from ibllib.io.session_params import read_params
8import ibllib.pipes.dynamic_pipeline as dyn
10from iblutil.util import setup_logger
11from ibllib.plots.snapshot import ReportSnapshot
12from iblutil.numerical import ismember
13from brainbox.behavior import training
15import numpy as np
16import pandas as pd
17from pathlib import Path
18import matplotlib.pyplot as plt
19import matplotlib.dates as mdates
20from matplotlib.lines import Line2D
21from datetime import datetime
22import seaborn as sns
23import boto3
24from botocore.exceptions import ProfileNotFound, ClientError
26logger = setup_logger(__name__)
29TRAINING_STATUS = {'untrainable': (-4, (0, 0, 0, 0)),
30 'unbiasable': (-3, (0, 0, 0, 0)),
31 'not_computed': (-2, (0, 0, 0, 0)),
32 'habituation': (-1, (0, 0, 0, 0)),
33 'in training': (0, (0, 0, 0, 0)),
34 'trained 1a': (1, (195, 90, 80, 255)),
35 'trained 1b': (2, (255, 153, 20, 255)),
36 'ready4ephysrig': (3, (28, 20, 255, 255)),
37 'ready4delay': (4, (117, 117, 117, 255)),
38 'ready4recording': (5, (20, 255, 91, 255))}
41def get_training_table_from_aws(lab, subject):
42 """
43 If aws credentials exist on the local server download the latest training table from aws s3 private bucket
44 :param lab:
45 :param subject:
46 :return:
47 """
48 try: 1cba
49 session = boto3.Session(profile_name='ibl_training') 1cba
50 except ProfileNotFound:
51 return
53 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1cba
54 dst_bucket_name = 'ibl-brain-wide-map-private' 1cba
55 try: 1cba
56 s3 = session.resource('s3') 1cba
57 bucket = s3.Bucket(name=dst_bucket_name) 1cba
58 bucket.download_file(f'resources/training/{lab}/{subject}/training.csv', 1cba
59 local_file_path)
60 df = pd.read_csv(local_file_path)
61 except ClientError: 1cba
62 return 1cba
64 return df
67def upload_training_table_to_aws(lab, subject):
68 """
69 If aws credentials exist on the local server upload the training table to aws s3 private bucket
70 :param lab:
71 :param subject:
72 :return:
73 """
74 try: 1cba
75 session = boto3.Session(profile_name='ibl_training') 1cba
76 except ProfileNotFound:
77 return
79 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1cba
80 dst_bucket_name = 'ibl-brain-wide-map-private' 1cba
81 try: 1cba
82 s3 = session.resource('s3') 1cba
83 bucket = s3.Bucket(name=dst_bucket_name) 1cba
84 bucket.upload_file(local_file_path, 1cba
85 f'resources/training/{lab}/{subject}/training.csv')
86 except (ClientError, FileNotFoundError): 1cba
87 return 1cba
90def get_trials_task(session_path, one):
91 # If experiment description file then process this
92 experiment_description_file = read_params(session_path) 1a
93 if experiment_description_file is not None: 1a
94 tasks = [] 1a
95 pipeline = dyn.make_pipeline(session_path) 1a
96 trials_tasks = [t for t in pipeline.tasks if 'Trials' in t] 1a
97 for task in trials_tasks: 1a
98 t = pipeline.tasks.get(task) 1a
99 t.__init__(session_path, **t.kwargs) 1a
100 tasks.append(t) 1a
101 else:
102 # Otherwise default to old way of doing things
103 pipeline = get_pipeline(session_path) 1a
104 if pipeline == 'training': 1a
105 from ibllib.pipes.training_preprocessing import TrainingTrials 1a
106 tasks = [TrainingTrials(session_path)] 1a
107 elif pipeline == 'ephys':
108 from ibllib.pipes.ephys_preprocessing import EphysTrials
109 tasks = [EphysTrials(session_path)]
110 else:
111 try:
112 # try and look if there is a custom extractor in the personal projects extraction class
113 import projects.base
114 task_type = get_session_extractor_type(session_path)
115 PipelineClass = projects.base.get_pipeline(task_type)
116 pipeline = PipelineClass(session_path, one)
117 trials_task_name = next(task for task in pipeline.tasks if 'Trials' in task)
118 task = pipeline.tasks.get(trials_task_name)
119 task.__init__(session_path)
120 tasks = [task]
121 except Exception:
122 tasks = []
124 return tasks 1a
127def save_path(subj_path):
128 return Path(subj_path).joinpath('training.csv') 1ecbafhd
131def save_dataframe(df, subj_path):
132 """
133 Save training dataframe to disk
134 :param df: dataframe to save
135 :param subj_path: path to subject folder
136 :return:
137 """
138 df.to_csv(save_path(subj_path), index=False) 1cbad
141def load_existing_dataframe(subj_path):
142 """
143 Load training dataframe from disk, if dataframe doesn't exist returns None
144 :param subj_path: path to subject folder
145 :return:
146 """
147 df_location = save_path(subj_path) 1cbafhd
148 if df_location.exists(): 1cbafhd
149 return pd.read_csv(df_location) 1afhd
150 else:
151 df_location.parent.mkdir(exist_ok=True, parents=True) 1cbafhd
152 return None 1cbafhd
155def load_trials(sess_path, one, collections=None, force=True, mode='raise'):
156 """
157 Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE,
158 if this also fails, will then attempt to re-extraxt locally
159 :param sess_path: session path
160 :param one: ONE instance
161 :param force: when True and if the session trials can't be found, will attempt to re-extract from the disk
162 :param mode: 'raise' or 'warn', if 'raise', will error when forcing re-extraction of past sessions
163 :return:
164 """
165 try: 1cbagd
166 # try and load all trials that are found locally in the session path locally
167 if collections is None: 1cbagd
168 trial_locations = list(sess_path.rglob('_ibl_trials.goCueTrigger_times.npy')) 1cbagd
169 else:
170 trial_locations = [Path(sess_path).joinpath(c, '_ibl_trials.goCueTrigger_times.npy') for c in collections] 1cad
172 if len(trial_locations) > 1: 1cbagd
173 trial_dict = {} 1a
174 for i, loc in enumerate(trial_locations): 1a
175 trial_dict[i] = alfio.load_object(loc.parent, 'trials', short_keys=True) 1a
176 trials = training.concatenate_trials(trial_dict)
177 elif len(trial_locations) == 1: 1cbagd
178 trials = alfio.load_object(trial_locations[0].parent, 'trials', short_keys=True) 1cbagd
179 else:
180 raise ALFObjectNotFound
182 if 'probabilityLeft' not in trials.keys(): 1cbagd
183 raise ALFObjectNotFound
184 except ALFObjectNotFound: 1a
185 # Next try and load all trials data through ONE
186 try: 1a
187 if not force: 1a
188 return None
189 eid = one.path2eid(sess_path) 1a
190 if collections is None: 1a
191 trial_collections = one.list_datasets(eid, '_ibl_trials.goCueTrigger_times.npy')
192 if len(trial_collections) > 0:
193 trial_collections = ['/'.join(c.split('/')[:-1]) for c in trial_collections]
194 else:
195 trial_collections = collections 1a
197 if len(trial_collections) > 1: 1a
198 trial_dict = {} 1a
199 for i, collection in enumerate(trial_collections): 1a
200 trial_dict[i] = one.load_object(eid, 'trials', collection=collection) 1a
201 trials = training.concatenate_trials(trial_dict)
202 elif len(trial_collections) == 1: 1a
203 trials = one.load_object(eid, 'trials', collection=trial_collections[0]) 1a
204 else:
205 raise ALFObjectNotFound
207 if 'probabilityLeft' not in trials.keys():
208 raise ALFObjectNotFound
209 except Exception: 1a
210 # Finally try to rextract the trials data locally
211 try: 1a
212 # Get the tasks that need to be run
213 tasks = get_trials_task(sess_path, one) 1a
214 if len(tasks) > 0: 1a
215 for task in tasks: 1a
216 status = task.run() 1a
217 if status == 0: 1a
218 return load_trials(sess_path, collections=collections, one=one, force=False)
219 else:
220 return 1a
221 else:
222 trials = None
223 except Exception as e:
224 if mode == 'raise':
225 raise Exception(f'Exhausted all possibilities for loading trials for {sess_path}') from e
226 else:
227 logger.warning(f'Exhausted all possibilities for loading trials for {sess_path}')
228 return
230 return trials 1cbagd
233def load_combined_trials(sess_paths, one, force=True):
234 """
235 Load and concatenate trials for multiple sessions. Used when we want to concatenate trials for two sessions on the same day
236 :param sess_paths: list of paths to sessions
237 :param one: ONE instance
238 :return:
239 """
240 trials_dict = {} 1cbagd
241 for sess_path in sess_paths: 1cbagd
242 trials = load_trials(Path(sess_path), one, force=force) 1cbagd
243 if trials is not None: 1cbagd
244 trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force) 1cbagd
246 return training.concatenate_trials(trials_dict) 1cbagd
249def get_latest_training_information(sess_path, one):
250 """
251 Extracts the latest training status.
253 Parameters
254 ----------
255 sess_path : pathlib.Path
256 The session path from which to load the data.
257 one : one.api.One
258 An ONE instance.
260 Returns
261 -------
262 pandas.DataFrame
263 A table of training information.
264 """
266 subj_path = sess_path.parent.parent 1cbad
267 sub = subj_path.parts[-1] 1cbad
268 if one.mode != 'local': 1cbad
269 lab = one.alyx.rest('subjects', 'list', nickname=sub)[0]['lab'] 1cba
270 df = get_training_table_from_aws(lab, sub) 1cba
271 else:
272 df = None 1d
274 if df is None: 1cbad
275 df = load_existing_dataframe(subj_path) 1cbad
277 # Find the dates and associated session paths where we don't have data stored in our dataframe
278 missing_dates = check_up_to_date(subj_path, df) 1cbad
280 # Iterate through the dates to fill up our training dataframe
281 for _, grp in missing_dates.groupby('date'): 1cbad
282 sess_dicts = get_training_info_for_session(grp.session_path.values, one) 1cbad
283 if len(sess_dicts) == 0: 1cbad
284 continue
286 for sess_dict in sess_dicts: 1cbad
287 if df is None: 1cbad
288 df = pd.DataFrame.from_dict(sess_dict) 1cbad
289 else:
290 df = pd.concat([df, pd.DataFrame.from_dict(sess_dict)]) 1ad
292 # Sort values by date and reset the index
293 df = df.sort_values('date') 1cbad
294 df = df.reset_index(drop=True) 1cbad
295 # Save our dataframe
296 save_dataframe(df, subj_path) 1cbad
298 # Now go through the backlog and compute the training status for sessions. If for example one was missing as it is cumulative
299 # we need to go through and compute all the back log
300 # Find the earliest date in missing dates that we need to recompute the training status for
301 missing_status = find_earliest_recompute_date(df.drop_duplicates('date').reset_index(drop=True)) 1cbad
302 for date in missing_status: 1cbad
303 df = compute_training_status(df, date, one) 1cbad
305 df_lim = df.drop_duplicates(subset='session_path', keep='first') 1cbad
307 # Detect untrainable
308 if 'untrainable' not in df_lim.training_status.values: 1cbad
309 un_df = df_lim[df_lim['training_status'] == 'in training'].sort_values('date') 1cbad
310 if len(un_df) >= 40: 1cbad
311 sess = un_df.iloc[39].session_path
312 df.loc[df['session_path'] == sess, 'training_status'] = 'untrainable'
314 # Detect unbiasable
315 if 'unbiasable' not in df_lim.training_status.values: 1cbad
316 un_df = df_lim[df_lim['task_protocol'] == 'biased'].sort_values('date') 1cbad
317 if len(un_df) >= 40: 1cbad
318 tr_st = un_df[0:40].training_status.unique()
319 if 'ready4ephysrig' not in tr_st:
320 sess = un_df.iloc[39].session_path
321 df.loc[df['session_path'] == sess, 'training_status'] = 'unbiasable'
323 save_dataframe(df, subj_path) 1cbad
325 if one.mode != 'local': 1cbad
326 upload_training_table_to_aws(lab, sub) 1cba
328 return df 1cbad
331def find_earliest_recompute_date(df):
332 """
333 Find the earliest date that we need to compute the training status from. Training status depends on previous sessions
334 so if a session was missing and now has been added we need to recompute everything from that date onwards
335 :param df:
336 :return:
337 """
338 missing_df = df[df['training_status'] == 'not_computed'] 1cbahd
339 if len(missing_df) == 0: 1cbahd
340 return [] 1a
341 missing_df = missing_df.sort_values('date') 1cbahd
342 first_index = missing_df.index[0] 1cbahd
344 return df[first_index:].date.values 1cbahd
347def compute_training_status(df, compute_date, one, force=True, task_collection='raw_behavior_data'):
348 """
349 Compute the training status for compute date based on training from that session and two previous days
350 :param df: training dataframe
351 :param compute_date: date to compute training on
352 :param one: ONE instance
353 :return:
354 """
356 # compute_date = str(one.path2ref(session_path)['date'])
357 df_temp = df[df['date'] <= compute_date] 1cbad
358 df_temp = df_temp.drop_duplicates(subset=['session_path', 'task_protocol']) 1cbad
359 df_temp.sort_values('date') 1cbad
361 dates = df_temp.date.values 1cbad
363 n_sess_for_date = len(np.where(dates == compute_date)[0]) 1cbad
364 n_dates = np.min([2 + n_sess_for_date, len(dates)]).astype(int) 1cbad
365 compute_dates = dates[(-1 * n_dates):] 1cbad
366 if n_sess_for_date > 1: 1cbad
367 compute_dates = compute_dates[:(-1 * (n_sess_for_date - 1))] 1d
369 assert compute_dates[-1] == compute_date 1cbad
371 df_temp_group = df_temp.groupby('date') 1cbad
373 trials = {} 1cbad
374 n_delay = 0 1cbad
375 ephys_sessions = [] 1cbad
376 protocol = [] 1cbad
377 status = [] 1cbad
378 for date in compute_dates: 1cbad
380 df_date = df_temp_group.get_group(date) 1cbad
382 # If habituation skip
383 if df_date.iloc[-1]['task_protocol'] == 'habituation': 1cbad
384 continue 1a
385 # Here we should split by protocol in an ideal world but that world isn't today. This is only really relevant for
386 # chained protocols
387 trials[df_date.iloc[-1]['date']] = load_combined_trials(df_date.session_path.values, one, force=force) 1cbad
388 protocol.append(df_date.iloc[-1]['task_protocol']) 1cbad
389 status.append(df_date.iloc[-1]['training_status']) 1cbad
390 if df_date.iloc[-1]['combined_n_delay'] >= 900: # delay of 15 mins 1cbad
391 n_delay += 1 1d
392 if df_date.iloc[-1]['location'] == 'ephys_rig': 1cbad
393 ephys_sessions.append(df_date.iloc[-1]['date']) 1bad
395 n_status = np.max([-2, -1 * len(status)]) 1cbad
396 training_status, _ = training.get_training_status(trials, protocol, ephys_sessions, n_delay) 1cbad
397 training_status = pass_through_training_hierachy(training_status, status[n_status]) 1cbad
398 df.loc[df['date'] == compute_date, 'training_status'] = training_status 1cbad
400 return df 1cbad
403def pass_through_training_hierachy(status_new, status_old):
404 """
405 Makes sure that the new training status is not less than the one from the previous day. e.g Subject cannot regress in
406 performance
407 :param status_new: latest training status
408 :param status_old: previous training status
409 :return:
410 """
412 if TRAINING_STATUS[status_old][0] > TRAINING_STATUS[status_new][0]: 1cbadi
413 return status_old 1ad
414 else:
415 return status_new 1cbadi
418def compute_session_duration_delay_location(sess_path, collections=None, **kwargs):
419 """
420 Get meta information about task. Extracts session duration, delay before session start and location of session
422 Parameters
423 ----------
424 sess_path : pathlib.Path, str
425 The session path with the pattern subject/yyyy-mm-dd/nnn.
426 collections : list
427 The location within the session path directory of task settings and data.
429 Returns
430 -------
431 int
432 The session duration in minutes, rounded to the nearest minute.
433 int
434 The delay between session start time and the first trial in seconds.
435 str {'ephys_rig', 'training_rig'}
436 The location of the session.
437 """
438 if collections is None: 1cbad
439 collections, _ = get_data_collection(sess_path) 1b
441 session_duration = 0 1cbad
442 session_delay = 0 1cbad
443 session_location = 'training_rig' 1cbad
444 for collection in collections: 1cbad
445 md, sess_data = load_bpod(sess_path, task_collection=collection) 1cbad
446 if md is None: 1cbad
447 continue 1b
448 try: 1cbad
449 start_time, end_time = _get_session_times(sess_path, md, sess_data) 1cbad
450 session_duration = session_duration + int((end_time - start_time).total_seconds() / 60) 1cbad
451 session_delay = session_delay + md.get('SESSION_START_DELAY_SEC', 0) 1cbad
452 except Exception:
453 session_duration = session_duration + 0
454 session_delay = session_delay + 0
456 if 'ephys' in md.get('PYBPOD_BOARD', None): 1cbad
457 session_location = 'ephys_rig' 1bad
458 else:
459 session_location = 'training_rig' 1cad
461 return session_duration, session_delay, session_location 1cbad
464def get_data_collection(session_path):
465 """
466 Returns the location of the raw behavioral data and extracted trials data for the session path. If
467 multiple locations in one session (e.g for dynamic) returns all of these
468 :param session_path: path of session
469 :return:
470 """
471 experiment_description_file = read_params(session_path) 1cbad
472 if experiment_description_file is not None: 1cbad
473 pipeline = dyn.make_pipeline(session_path) 1cbad
474 trials_tasks = [t for t in pipeline.tasks if 'Trials' in t] 1cbad
475 collections = [pipeline.tasks.get(task).kwargs['collection'] for task in trials_tasks] 1cbad
476 if len(collections) == 1 and collections[0] == 'raw_behavior_data': 1cbad
477 alf_collections = ['alf'] 1ca
478 elif all(['raw_task_data' in c for c in collections]): 1bad
479 alf_collections = [f'alf/task_{c[-2:]}' for c in collections] 1ad
480 else:
481 alf_collections = None 1b
482 else:
483 collections = ['raw_behavior_data'] 1ad
484 alf_collections = ['alf'] 1ad
486 return collections, alf_collections 1cbad
489def get_sess_dict(session_path, one, protocol, alf_collections=None, raw_collections=None, force=True):
491 sess_dict = {} 1cbad
492 sess_dict['date'] = str(one.path2ref(session_path)['date']) 1cbad
493 sess_dict['session_path'] = str(session_path) 1cbad
494 sess_dict['task_protocol'] = protocol 1cbad
496 if sess_dict['task_protocol'] == 'habituation': 1cbad
497 nan_array = np.array([np.nan]) 1a
498 sess_dict['performance'], sess_dict['contrasts'], _ = (nan_array, nan_array, np.nan) 1a
499 sess_dict['performance_easy'] = np.nan 1a
500 sess_dict['reaction_time'] = np.nan 1a
501 sess_dict['n_trials'] = np.nan 1a
502 sess_dict['sess_duration'] = np.nan 1a
503 sess_dict['n_delay'] = np.nan 1a
504 sess_dict['location'] = np.nan 1a
505 sess_dict['training_status'] = 'habituation' 1a
506 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1a
507 (np.nan, np.nan, np.nan, np.nan)
508 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1a
509 (np.nan, np.nan, np.nan, np.nan)
510 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1a
511 (np.nan, np.nan, np.nan, np.nan)
513 else:
514 # if we can't compute trials then we need to pass
515 trials = load_trials(session_path, one, collections=alf_collections, force=force, mode='warn') 1cbad
516 if trials is None: 1cbad
517 return 1a
519 sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True) 1cbad
520 if sess_dict['task_protocol'] == 'training': 1cbad
521 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1ad
522 training.compute_psychometric(trials)
523 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1ad
524 (np.nan, np.nan, np.nan, np.nan)
525 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1ad
526 (np.nan, np.nan, np.nan, np.nan)
527 else:
528 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1cbad
529 training.compute_psychometric(trials, block=0.5)
530 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1cbad
531 training.compute_psychometric(trials, block=0.2)
532 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1cbad
533 training.compute_psychometric(trials, block=0.8)
535 sess_dict['performance_easy'] = training.compute_performance_easy(trials) 1cbad
536 sess_dict['reaction_time'] = training.compute_median_reaction_time(trials) 1cbad
537 sess_dict['n_trials'] = training.compute_n_trials(trials) 1cbad
538 sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \ 1cbad
539 compute_session_duration_delay_location(session_path, collections=raw_collections)
540 sess_dict['training_status'] = 'not_computed' 1cbad
542 return sess_dict 1cbad
545def get_training_info_for_session(session_paths, one, force=True):
546 """
547 Extract the training information needed for plots for each session
548 :param session_paths: list of session paths on same date
549 :param one: ONE instance
550 :return:
551 """
553 # return list of dicts to add
554 sess_dicts = [] 1cbad
555 for session_path in session_paths: 1cbad
556 collections, alf_collections = get_data_collection(session_path) 1cbad
557 session_path = Path(session_path) 1cbad
558 protocols = [] 1cbad
559 for c in collections: 1cbad
560 protocols.append(get_session_extractor_type(session_path, task_collection=c)) 1cbad
562 un_protocols = np.unique(protocols) 1cbad
563 # Example, training, training, biased - training would be combined, biased not
564 if len(un_protocols) != 1: 1cbad
565 print(f'Different protocols in same session {session_path} : {protocols}') 1b
566 for prot in un_protocols: 1b
567 if prot is False: 1b
568 continue
569 try: 1b
570 alf = alf_collections[np.where(protocols == prot)[0]] 1b
571 raw = collections[np.where(protocols == prot)[0]]
572 except TypeError: 1b
573 alf = None 1b
574 raw = None 1b
575 sess_dict = get_sess_dict(session_path, one, prot, alf_collections=alf, raw_collections=raw, force=force) 1b
576 else:
577 prot = un_protocols[0] 1cad
578 sess_dict = get_sess_dict(session_path, one, prot, alf_collections=alf_collections, raw_collections=collections, 1cad
579 force=force)
581 if sess_dict is not None: 1cbad
582 sess_dicts.append(sess_dict) 1cbad
584 protocols = [s['task_protocol'] for s in sess_dicts] 1cbad
586 if len(protocols) > 0 and len(set(protocols)) != 1: 1cbad
587 print(f'Different protocols on same date {sess_dicts[0]["date"]} : {protocols}')
589 # Only if all protocols are the same and are not habituation
590 if len(sess_dicts) > 1 and len(set(protocols)) == 1 and protocols[0] != 'habituation': # Only if all protocols are the same 1cbad
591 print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}') 1d
592 combined_trials = load_combined_trials(session_paths, one, force=force) 1d
593 performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True) 1d
594 psychs = {} 1d
595 psychs['50'] = training.compute_psychometric(combined_trials, block=0.5) 1d
596 psychs['20'] = training.compute_psychometric(combined_trials, block=0.2) 1d
597 psychs['80'] = training.compute_psychometric(combined_trials, block=0.8) 1d
599 performance_easy = training.compute_performance_easy(combined_trials) 1d
600 reaction_time = training.compute_median_reaction_time(combined_trials) 1d
601 n_trials = training.compute_n_trials(combined_trials) 1d
603 sess_duration = np.nansum([s['sess_duration'] for s in sess_dicts]) 1d
604 n_delay = np.nanmax([s['n_delay'] for s in sess_dicts]) 1d
606 for sess_dict in sess_dicts: 1d
607 sess_dict['combined_performance'] = performance 1d
608 sess_dict['combined_contrasts'] = contrasts 1d
609 sess_dict['combined_performance_easy'] = performance_easy 1d
610 sess_dict['combined_reaction_time'] = reaction_time 1d
611 sess_dict['combined_n_trials'] = n_trials 1d
612 sess_dict['combined_sess_duration'] = sess_duration 1d
613 sess_dict['combined_n_delay'] = n_delay 1d
615 for bias in [50, 20, 80]: 1d
616 sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0] 1d
617 sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1] 1d
618 sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][2] 1d
619 sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][3] 1d
621 # Case where two sessions on same day with different number of contrasts! Oh boy
622 if sess_dict['combined_performance'].size != sess_dict['performance'].size: 1d
623 sess_dict['performance'] = \
624 np.r_[sess_dict['performance'],
625 np.full(sess_dict['combined_performance'].size - sess_dict['performance'].size, np.nan)]
626 sess_dict['contrasts'] = \
627 np.r_[sess_dict['contrasts'],
628 np.full(sess_dict['combined_contrasts'].size - sess_dict['contrasts'].size, np.nan)]
630 else:
631 for sess_dict in sess_dicts: 1cbad
632 sess_dict['combined_performance'] = sess_dict['performance'] 1cbad
633 sess_dict['combined_contrasts'] = sess_dict['contrasts'] 1cbad
634 sess_dict['combined_performance_easy'] = sess_dict['performance_easy'] 1cbad
635 sess_dict['combined_reaction_time'] = sess_dict['reaction_time'] 1cbad
636 sess_dict['combined_n_trials'] = sess_dict['n_trials'] 1cbad
637 sess_dict['combined_sess_duration'] = sess_dict['sess_duration'] 1cbad
638 sess_dict['combined_n_delay'] = sess_dict['n_delay'] 1cbad
640 for bias in [50, 20, 80]: 1cbad
641 sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}'] 1cbad
642 sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}'] 1cbad
643 sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}'] 1cbad
644 sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}'] 1cbad
646 return sess_dicts 1cbad
649def check_up_to_date(subj_path, df):
650 """
651 Check which sessions on local file system are missing from the computed training table.
653 Parameters
654 ----------
655 subj_path : pathlib.Path
656 The path to the subject's dated session folders.
657 df : pandas.DataFrame
658 The computed training table.
660 Returns
661 -------
662 pandas.DataFrame
663 A table of dates and session paths that are missing from the computed training table.
664 """
665 df_session = pd.DataFrame() 1cbafd
667 for session in alfio.iter_sessions(subj_path): 1cbafd
668 s_df = pd.DataFrame({'date': session.parts[-2], 'session_path': str(session)}, index=[0]) 1cbafd
669 df_session = pd.concat([df_session, s_df], ignore_index=True) 1cbafd
671 if df is None or 'combined_thres_50' not in df.columns: 1cbafd
672 return df_session 1cbafd
673 else:
674 # recorded_session_paths = df['session_path'].values
675 isin, _ = ismember(df_session.date.unique(), df.date.unique()) 1af
676 missing_dates = df_session.date.unique()[~isin] 1af
677 return df_session[df_session['date'].isin(missing_dates)].sort_values('date') 1af
680def plot_trial_count_and_session_duration(df, subject):
682 df = df.drop_duplicates('date').reset_index(drop=True) 1cba
684 y1 = {'column': 'combined_n_trials', 1cba
685 'title': 'Trial counts',
686 'lim': None,
687 'color': 'k',
688 'join': True}
690 y2 = {'column': 'combined_sess_duration', 1cba
691 'title': 'Session duration (mins)',
692 'lim': None,
693 'color': 'r',
694 'log': False,
695 'join': True}
697 ax = plot_over_days(df, subject, y1, y2) 1cba
699 return ax 1cba
702def plot_performance_easy_median_reaction_time(df, subject):
703 df = df.drop_duplicates('date').reset_index(drop=True) 1cba
705 y1 = {'column': 'combined_performance_easy', 1cba
706 'title': 'Performance on easy trials',
707 'lim': [0, 1.05],
708 'color': 'k',
709 'join': True}
711 y2 = {'column': 'combined_reaction_time', 1cba
712 'title': 'Median reaction time (s)',
713 'lim': [0.1, np.nanmax([10, np.nanmax(df.combined_reaction_time.values)])],
714 'color': 'r',
715 'log': True,
716 'join': True}
717 ax = plot_over_days(df, subject, y1, y2) 1cba
719 return ax 1cba
722def plot_fit_params(df, subject):
723 fig, axs = plt.subplots(2, 2, figsize=(12, 6)) 1cba
724 axs = axs.ravel() 1cba
726 df = df.drop_duplicates('date').reset_index(drop=True) 1cba
728 cmap = sns.diverging_palette(20, 220, n=3, center="dark") 1cba
730 y50 = {'column': 'combined_bias_50', 1cba
731 'title': 'Bias',
732 'lim': [-100, 100],
733 'color': cmap[1],
734 'join': False}
736 y80 = {'column': 'combined_bias_80', 1cba
737 'title': 'Bias',
738 'lim': [-100, 100],
739 'color': cmap[2],
740 'join': False}
742 y20 = {'column': 'combined_bias_20', 1cba
743 'title': 'Bias',
744 'lim': [-100, 100],
745 'color': cmap[0],
746 'join': False}
748 plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False) 1cba
749 plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False) 1cba
750 plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False) 1cba
751 axs[0].axhline(16, linewidth=2, linestyle='--', color='k') 1cba
752 axs[0].axhline(-16, linewidth=2, linestyle='--', color='k') 1cba
754 y50['column'] = 'combined_thres_50' 1cba
755 y50['title'] = 'Threshold' 1cba
756 y50['lim'] = [0, 100] 1cba
757 y80['column'] = 'combined_thres_20' 1cba
758 y80['title'] = 'Threshold' 1cba
759 y20['lim'] = [0, 100] 1cba
760 y20['column'] = 'combined_thres_80' 1cba
761 y20['title'] = 'Threshold' 1cba
762 y80['lim'] = [0, 100] 1cba
764 plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False) 1cba
765 plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False) 1cba
766 plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False) 1cba
767 axs[1].axhline(19, linewidth=2, linestyle='--', color='k') 1cba
769 y50['column'] = 'combined_lapselow_50' 1cba
770 y50['title'] = 'Lapse Low' 1cba
771 y50['lim'] = [0, 1] 1cba
772 y80['column'] = 'combined_lapselow_20' 1cba
773 y80['title'] = 'Lapse Low' 1cba
774 y80['lim'] = [0, 1] 1cba
775 y20['column'] = 'combined_lapselow_80' 1cba
776 y20['title'] = 'Lapse Low' 1cba
777 y20['lim'] = [0, 1] 1cba
779 plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False) 1cba
780 plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False) 1cba
781 plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False) 1cba
782 axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k') 1cba
784 y50['column'] = 'combined_lapsehigh_50' 1cba
785 y50['title'] = 'Lapse High' 1cba
786 y50['lim'] = [0, 1] 1cba
787 y80['column'] = 'combined_lapsehigh_20' 1cba
788 y80['title'] = 'Lapse High' 1cba
789 y80['lim'] = [0, 1] 1cba
790 y20['column'] = 'combined_lapsehigh_80' 1cba
791 y20['title'] = 'Lapse High' 1cba
792 y20['lim'] = [0, 1] 1cba
794 plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True) 1cba
795 plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False) 1cba
796 plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False) 1cba
797 axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k') 1cba
799 fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1cba
800 lines, labels = axs[3].get_legend_handles_labels() 1cba
801 fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5) 1cba
803 legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8), 1cba
804 Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8),
805 Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)]
806 legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True) 1cba
807 fig.add_artist(legend2) 1cba
809 return axs 1cba
812def plot_psychometric_curve(df, subject, one):
813 df = df.drop_duplicates('date').reset_index(drop=True) 1cba
814 sess_path = Path(df.iloc[-1]["session_path"]) 1cba
815 trials = load_trials(sess_path, one) 1cba
817 fig, ax1 = plt.subplots(figsize=(8, 6)) 1cba
819 training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1cba
821 return ax1 1cba
824def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True):
826 if ax is None: 1cba
827 fig, ax1 = plt.subplots(figsize=(12, 6)) 1cba
828 else:
829 ax1 = ax 1cba
831 dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']] 1cba
832 if y1['join']: 1cba
833 ax1.plot(dates, df[y1['column']], color=y1['color']) 1cba
834 ax1.scatter(dates, df[y1['column']], color=y1['color']) 1cba
835 ax1.set_ylabel(y1['title']) 1cba
836 ax1.set_ylim(y1['lim']) 1cba
838 if y2 is not None: 1cba
839 ax2 = ax1.twinx() 1cba
840 if y2['join']: 1cba
841 ax2.plot(dates, df[y2['column']], color=y2['color']) 1cba
842 ax2.scatter(dates, df[y2['column']], color=y2['color']) 1cba
843 ax2.set_ylabel(y2['title']) 1cba
844 ax2.yaxis.label.set_color(y2['color']) 1cba
845 ax2.tick_params(axis='y', colors=y2['color']) 1cba
846 ax2.set_ylim(y2['lim']) 1cba
847 if y2['log']: 1cba
848 ax2.set_yscale('log') 1cba
850 ax2.spines['right'].set_visible(False) 1cba
851 ax2.spines['top'].set_visible(False) 1cba
852 ax2.spines['left'].set_visible(False) 1cba
854 month_format = mdates.DateFormatter('%b %Y') 1cba
855 month_locator = mdates.MonthLocator() 1cba
856 ax1.xaxis.set_major_locator(month_locator) 1cba
857 ax1.xaxis.set_major_formatter(month_format) 1cba
858 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1cba
859 ax1.xaxis.set_minor_locator(week_locator) 1cba
860 ax1.grid(True, which='minor', axis='x', linestyle='--') 1cba
862 ax1.spines['left'].set_visible(False) 1cba
863 ax1.spines['right'].set_visible(False) 1cba
864 ax1.spines['top'].set_visible(False) 1cba
866 if training_lines: 1cba
867 ax1 = add_training_lines(df, ax1) 1cba
869 if title: 1cba
870 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1cba
872 # Put a legend below current axis
873 box = ax1.get_position() 1cba
874 ax1.set_position([box.x0, box.y0 + box.height * 0.1, 1cba
875 box.width, box.height * 0.9])
876 if legend: 1cba
877 ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), 1cba
878 fancybox=True, shadow=True, ncol=5)
880 return ax1 1cba
883def add_training_lines(df, ax):
885 status = df.drop_duplicates(subset='training_status', keep='first') 1cba
886 for _, st in status.iterrows(): 1cba
888 if st['training_status'] in ['untrainable', 'unbiasable']: 1cba
889 continue
891 if TRAINING_STATUS[st['training_status']][0] <= 0: 1cba
892 continue 1a
894 ax.axvline(datetime.strptime(st['date'], '%Y-%m-%d'), linewidth=2, 1cba
895 color=np.array(TRAINING_STATUS[st['training_status']][1]) / 255, label=st['training_status'])
897 return ax 1cba
900def plot_heatmap_performance_over_days(df, subject):
902 df = df.drop_duplicates(subset=['date', 'combined_contrasts']) 1cba
903 df_perf = df.pivot(index=['date'], columns=['combined_contrasts'], values=['combined_performance']).sort_values( 1cba
904 by='combined_contrasts', axis=1, ascending=False)
905 df_perf.index = pd.to_datetime(df_perf.index) 1cba
906 full_date_range = pd.date_range(start=df_perf.index.min(), end=df_perf.index.max(), freq="D") 1cba
907 df_perf = df_perf.reindex(full_date_range, fill_value=np.nan) 1cba
909 n_contrasts = len(df.combined_contrasts.unique()) 1cba
911 dates = df_perf.index.to_pydatetime() 1cba
912 dnum = mdates.date2num(dates) 1cba
913 if len(dnum) > 1: 1cba
914 start = dnum[0] - (dnum[1] - dnum[0]) / 2. 1a
915 stop = dnum[-1] + (dnum[1] - dnum[0]) / 2. 1a
916 else:
917 start = dnum[0] + 0.5 1cba
918 stop = dnum[0] + 1.5 1cba
920 extent = [start, stop, 0, n_contrasts] 1cba
922 fig, ax1 = plt.subplots(figsize=(12, 6)) 1cba
923 im = ax1.imshow(df_perf.T.values, extent=extent, aspect="auto", cmap='PuOr') 1cba
925 month_format = mdates.DateFormatter('%b %Y') 1cba
926 month_locator = mdates.MonthLocator() 1cba
927 ax1.xaxis.set_major_locator(month_locator) 1cba
928 ax1.xaxis.set_major_formatter(month_format) 1cba
929 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1cba
930 ax1.xaxis.set_minor_locator(week_locator) 1cba
931 ax1.grid(True, which='minor', axis='x', linestyle='--') 1cba
932 ax1.set_yticks(np.arange(0.5, n_contrasts + 0.5, 1)) 1cba
933 ax1.set_yticklabels(np.sort(df.combined_contrasts.unique())) 1cba
934 ax1.set_ylabel('Contrast (%)') 1cba
935 ax1.set_xlabel('Date') 1cba
936 cbar = fig.colorbar(im) 1cba
937 cbar.set_label('Rightward choice (%') 1cba
938 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1cba
940 ax1.spines['left'].set_visible(False) 1cba
941 ax1.spines['right'].set_visible(False) 1cba
942 ax1.spines['top'].set_visible(False) 1cba
944 return ax1 1cba
947def make_plots(session_path, one, df=None, save=False, upload=False, task_collection='raw_behavior_data'):
948 subject = one.path2ref(session_path)['subject'] 1cba
949 subj_path = session_path.parent.parent 1cba
951 df = load_existing_dataframe(subj_path) if df is None else df 1cba
953 df = df[df['task_protocol'] != 'habituation'] 1cba
955 if len(df) == 0: 1cba
956 return
958 ax1 = plot_trial_count_and_session_duration(df, subject) 1cba
959 ax2 = plot_performance_easy_median_reaction_time(df, subject) 1cba
960 ax3 = plot_heatmap_performance_over_days(df, subject) 1cba
961 ax4 = plot_fit_params(df, subject) 1cba
962 ax5 = plot_psychometric_curve(df, subject, one) 1cba
964 outputs = [] 1cba
965 if save: 1cba
966 save_path = Path(subj_path) 1cba
967 save_name = save_path.joinpath('subj_trial_count_session_duration.png') 1cba
968 outputs.append(save_name) 1cba
969 ax1.get_figure().savefig(save_name, bbox_inches='tight') 1cba
971 save_name = save_path.joinpath('subj_performance_easy_reaction_time.png') 1cba
972 outputs.append(save_name) 1cba
973 ax2.get_figure().savefig(save_name, bbox_inches='tight') 1cba
975 save_name = save_path.joinpath('subj_performance_heatmap.png') 1cba
976 outputs.append(save_name) 1cba
977 ax3.get_figure().savefig(save_name, bbox_inches='tight') 1cba
979 save_name = save_path.joinpath('subj_psychometric_fit_params.png') 1cba
980 outputs.append(save_name) 1cba
981 ax4[0].get_figure().savefig(save_name, bbox_inches='tight') 1cba
983 save_name = save_path.joinpath('subj_psychometric_curve.png') 1cba
984 outputs.append(save_name) 1cba
985 ax5.get_figure().savefig(save_name, bbox_inches='tight') 1cba
987 if upload: 1cba
988 subj = one.alyx.rest('subjects', 'list', nickname=subject)[0] 1cba
989 snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one) 1cba
990 snp.outputs = outputs 1cba
991 snp.register_images(widths=['orig']) 1cba