Coverage for ibllib/pipes/training_status.py: 91%

584 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1import one.alf.io as alfio 

2from one.alf.exceptions import ALFObjectNotFound 

3 

4from ibllib.io.raw_data_loaders import load_bpod 

5from ibllib.oneibl.registration import _get_session_times 

6from ibllib.io.extractors.base import get_pipeline, get_session_extractor_type 

7from ibllib.io.session_params import read_params 

8import ibllib.pipes.dynamic_pipeline as dyn 

9 

10from iblutil.util import setup_logger 

11from ibllib.plots.snapshot import ReportSnapshot 

12from iblutil.numerical import ismember 

13from brainbox.behavior import training 

14 

15import numpy as np 

16import pandas as pd 

17from pathlib import Path 

18import matplotlib.pyplot as plt 

19import matplotlib.dates as mdates 

20from matplotlib.lines import Line2D 

21from datetime import datetime 

22import seaborn as sns 

23import boto3 

24from botocore.exceptions import ProfileNotFound, ClientError 

25 

26logger = setup_logger(__name__) 

27 

28 

29TRAINING_STATUS = {'untrainable': (-4, (0, 0, 0, 0)), 

30 'unbiasable': (-3, (0, 0, 0, 0)), 

31 'not_computed': (-2, (0, 0, 0, 0)), 

32 'habituation': (-1, (0, 0, 0, 0)), 

33 'in training': (0, (0, 0, 0, 0)), 

34 'trained 1a': (1, (195, 90, 80, 255)), 

35 'trained 1b': (2, (255, 153, 20, 255)), 

36 'ready4ephysrig': (3, (28, 20, 255, 255)), 

37 'ready4delay': (4, (117, 117, 117, 255)), 

38 'ready4recording': (5, (20, 255, 91, 255))} 

39 

40 

41def get_training_table_from_aws(lab, subject): 

42 """ 

43 If aws credentials exist on the local server download the latest training table from aws s3 private bucket 

44 :param lab: 

45 :param subject: 

46 :return: 

47 """ 

48 try: 1cba

49 session = boto3.Session(profile_name='ibl_training') 1cba

50 except ProfileNotFound: 

51 return 

52 

53 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1cba

54 dst_bucket_name = 'ibl-brain-wide-map-private' 1cba

55 try: 1cba

56 s3 = session.resource('s3') 1cba

57 bucket = s3.Bucket(name=dst_bucket_name) 1cba

58 bucket.download_file(f'resources/training/{lab}/{subject}/training.csv', 1cba

59 local_file_path) 

60 df = pd.read_csv(local_file_path) 

61 except ClientError: 1cba

62 return 1cba

63 

64 return df 

65 

66 

67def upload_training_table_to_aws(lab, subject): 

68 """ 

69 If aws credentials exist on the local server upload the training table to aws s3 private bucket 

70 :param lab: 

71 :param subject: 

72 :return: 

73 """ 

74 try: 1cba

75 session = boto3.Session(profile_name='ibl_training') 1cba

76 except ProfileNotFound: 

77 return 

78 

79 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1cba

80 dst_bucket_name = 'ibl-brain-wide-map-private' 1cba

81 try: 1cba

82 s3 = session.resource('s3') 1cba

83 bucket = s3.Bucket(name=dst_bucket_name) 1cba

84 bucket.upload_file(local_file_path, 1cba

85 f'resources/training/{lab}/{subject}/training.csv') 

86 except (ClientError, FileNotFoundError): 1cba

87 return 1cba

88 

89 

90def get_trials_task(session_path, one): 

91 # If experiment description file then process this 

92 experiment_description_file = read_params(session_path) 1a

93 if experiment_description_file is not None: 1a

94 tasks = [] 1a

95 pipeline = dyn.make_pipeline(session_path) 1a

96 trials_tasks = [t for t in pipeline.tasks if 'Trials' in t] 1a

97 for task in trials_tasks: 1a

98 t = pipeline.tasks.get(task) 1a

99 t.__init__(session_path, **t.kwargs) 1a

100 tasks.append(t) 1a

101 else: 

102 # Otherwise default to old way of doing things 

103 pipeline = get_pipeline(session_path) 1a

104 if pipeline == 'training': 1a

105 from ibllib.pipes.training_preprocessing import TrainingTrials 1a

106 tasks = [TrainingTrials(session_path)] 1a

107 elif pipeline == 'ephys': 

108 from ibllib.pipes.ephys_preprocessing import EphysTrials 

109 tasks = [EphysTrials(session_path)] 

110 else: 

111 try: 

112 # try and look if there is a custom extractor in the personal projects extraction class 

113 import projects.base 

114 task_type = get_session_extractor_type(session_path) 

115 PipelineClass = projects.base.get_pipeline(task_type) 

116 pipeline = PipelineClass(session_path, one) 

117 trials_task_name = next(task for task in pipeline.tasks if 'Trials' in task) 

118 task = pipeline.tasks.get(trials_task_name) 

119 task.__init__(session_path) 

120 tasks = [task] 

121 except Exception: 

122 tasks = [] 

123 

124 return tasks 1a

125 

126 

127def save_path(subj_path): 

128 return Path(subj_path).joinpath('training.csv') 1ecbafhd

129 

130 

131def save_dataframe(df, subj_path): 

132 """ 

133 Save training dataframe to disk 

134 :param df: dataframe to save 

135 :param subj_path: path to subject folder 

136 :return: 

137 """ 

138 df.to_csv(save_path(subj_path), index=False) 1cbad

139 

140 

141def load_existing_dataframe(subj_path): 

142 """ 

143 Load training dataframe from disk, if dataframe doesn't exist returns None 

144 :param subj_path: path to subject folder 

145 :return: 

146 """ 

147 df_location = save_path(subj_path) 1cbafhd

148 if df_location.exists(): 1cbafhd

149 return pd.read_csv(df_location) 1afhd

150 else: 

151 df_location.parent.mkdir(exist_ok=True, parents=True) 1cbafhd

152 return None 1cbafhd

153 

154 

155def load_trials(sess_path, one, collections=None, force=True, mode='raise'): 

156 """ 

157 Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE, 

158 if this also fails, will then attempt to re-extraxt locally 

159 :param sess_path: session path 

160 :param one: ONE instance 

161 :param force: when True and if the session trials can't be found, will attempt to re-extract from the disk 

162 :param mode: 'raise' or 'warn', if 'raise', will error when forcing re-extraction of past sessions 

163 :return: 

164 """ 

165 try: 1cbagd

166 # try and load all trials that are found locally in the session path locally 

167 if collections is None: 1cbagd

168 trial_locations = list(sess_path.rglob('_ibl_trials.goCueTrigger_times.npy')) 1cbagd

169 else: 

170 trial_locations = [Path(sess_path).joinpath(c, '_ibl_trials.goCueTrigger_times.npy') for c in collections] 1cad

171 

172 if len(trial_locations) > 1: 1cbagd

173 trial_dict = {} 1a

174 for i, loc in enumerate(trial_locations): 1a

175 trial_dict[i] = alfio.load_object(loc.parent, 'trials', short_keys=True) 1a

176 trials = training.concatenate_trials(trial_dict) 

177 elif len(trial_locations) == 1: 1cbagd

178 trials = alfio.load_object(trial_locations[0].parent, 'trials', short_keys=True) 1cbagd

179 else: 

180 raise ALFObjectNotFound 

181 

182 if 'probabilityLeft' not in trials.keys(): 1cbagd

183 raise ALFObjectNotFound 

184 except ALFObjectNotFound: 1a

185 # Next try and load all trials data through ONE 

186 try: 1a

187 if not force: 1a

188 return None 

189 eid = one.path2eid(sess_path) 1a

190 if collections is None: 1a

191 trial_collections = one.list_datasets(eid, '_ibl_trials.goCueTrigger_times.npy') 

192 if len(trial_collections) > 0: 

193 trial_collections = ['/'.join(c.split('/')[:-1]) for c in trial_collections] 

194 else: 

195 trial_collections = collections 1a

196 

197 if len(trial_collections) > 1: 1a

198 trial_dict = {} 1a

199 for i, collection in enumerate(trial_collections): 1a

200 trial_dict[i] = one.load_object(eid, 'trials', collection=collection) 1a

201 trials = training.concatenate_trials(trial_dict) 

202 elif len(trial_collections) == 1: 1a

203 trials = one.load_object(eid, 'trials', collection=trial_collections[0]) 1a

204 else: 

205 raise ALFObjectNotFound 

206 

207 if 'probabilityLeft' not in trials.keys(): 

208 raise ALFObjectNotFound 

209 except Exception: 1a

210 # Finally try to rextract the trials data locally 

211 try: 1a

212 # Get the tasks that need to be run 

213 tasks = get_trials_task(sess_path, one) 1a

214 if len(tasks) > 0: 1a

215 for task in tasks: 1a

216 status = task.run() 1a

217 if status == 0: 1a

218 return load_trials(sess_path, collections=collections, one=one, force=False) 

219 else: 

220 return 1a

221 else: 

222 trials = None 

223 except Exception as e: 

224 if mode == 'raise': 

225 raise Exception(f'Exhausted all possibilities for loading trials for {sess_path}') from e 

226 else: 

227 logger.warning(f'Exhausted all possibilities for loading trials for {sess_path}') 

228 return 

229 

230 return trials 1cbagd

231 

232 

233def load_combined_trials(sess_paths, one, force=True): 

234 """ 

235 Load and concatenate trials for multiple sessions. Used when we want to concatenate trials for two sessions on the same day 

236 :param sess_paths: list of paths to sessions 

237 :param one: ONE instance 

238 :return: 

239 """ 

240 trials_dict = {} 1cbagd

241 for sess_path in sess_paths: 1cbagd

242 trials = load_trials(Path(sess_path), one, force=force) 1cbagd

243 if trials is not None: 1cbagd

244 trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force) 1cbagd

245 

246 return training.concatenate_trials(trials_dict) 1cbagd

247 

248 

249def get_latest_training_information(sess_path, one): 

250 """ 

251 Extracts the latest training status. 

252 

253 Parameters 

254 ---------- 

255 sess_path : pathlib.Path 

256 The session path from which to load the data. 

257 one : one.api.One 

258 An ONE instance. 

259 

260 Returns 

261 ------- 

262 pandas.DataFrame 

263 A table of training information. 

264 """ 

265 

266 subj_path = sess_path.parent.parent 1cbad

267 sub = subj_path.parts[-1] 1cbad

268 if one.mode != 'local': 1cbad

269 lab = one.alyx.rest('subjects', 'list', nickname=sub)[0]['lab'] 1cba

270 df = get_training_table_from_aws(lab, sub) 1cba

271 else: 

272 df = None 1d

273 

274 if df is None: 1cbad

275 df = load_existing_dataframe(subj_path) 1cbad

276 

277 # Find the dates and associated session paths where we don't have data stored in our dataframe 

278 missing_dates = check_up_to_date(subj_path, df) 1cbad

279 

280 # Iterate through the dates to fill up our training dataframe 

281 for _, grp in missing_dates.groupby('date'): 1cbad

282 sess_dicts = get_training_info_for_session(grp.session_path.values, one) 1cbad

283 if len(sess_dicts) == 0: 1cbad

284 continue 

285 

286 for sess_dict in sess_dicts: 1cbad

287 if df is None: 1cbad

288 df = pd.DataFrame.from_dict(sess_dict) 1cbad

289 else: 

290 df = pd.concat([df, pd.DataFrame.from_dict(sess_dict)]) 1ad

291 

292 # Sort values by date and reset the index 

293 df = df.sort_values('date') 1cbad

294 df = df.reset_index(drop=True) 1cbad

295 # Save our dataframe 

296 save_dataframe(df, subj_path) 1cbad

297 

298 # Now go through the backlog and compute the training status for sessions. If for example one was missing as it is cumulative 

299 # we need to go through and compute all the back log 

300 # Find the earliest date in missing dates that we need to recompute the training status for 

301 missing_status = find_earliest_recompute_date(df.drop_duplicates('date').reset_index(drop=True)) 1cbad

302 for date in missing_status: 1cbad

303 df = compute_training_status(df, date, one) 1cbad

304 

305 df_lim = df.drop_duplicates(subset='session_path', keep='first') 1cbad

306 

307 # Detect untrainable 

308 if 'untrainable' not in df_lim.training_status.values: 1cbad

309 un_df = df_lim[df_lim['training_status'] == 'in training'].sort_values('date') 1cbad

310 if len(un_df) >= 40: 1cbad

311 sess = un_df.iloc[39].session_path 

312 df.loc[df['session_path'] == sess, 'training_status'] = 'untrainable' 

313 

314 # Detect unbiasable 

315 if 'unbiasable' not in df_lim.training_status.values: 1cbad

316 un_df = df_lim[df_lim['task_protocol'] == 'biased'].sort_values('date') 1cbad

317 if len(un_df) >= 40: 1cbad

318 tr_st = un_df[0:40].training_status.unique() 

319 if 'ready4ephysrig' not in tr_st: 

320 sess = un_df.iloc[39].session_path 

321 df.loc[df['session_path'] == sess, 'training_status'] = 'unbiasable' 

322 

323 save_dataframe(df, subj_path) 1cbad

324 

325 if one.mode != 'local': 1cbad

326 upload_training_table_to_aws(lab, sub) 1cba

327 

328 return df 1cbad

329 

330 

331def find_earliest_recompute_date(df): 

332 """ 

333 Find the earliest date that we need to compute the training status from. Training status depends on previous sessions 

334 so if a session was missing and now has been added we need to recompute everything from that date onwards 

335 :param df: 

336 :return: 

337 """ 

338 missing_df = df[df['training_status'] == 'not_computed'] 1cbahd

339 if len(missing_df) == 0: 1cbahd

340 return [] 1a

341 missing_df = missing_df.sort_values('date') 1cbahd

342 first_index = missing_df.index[0] 1cbahd

343 

344 return df[first_index:].date.values 1cbahd

345 

346 

347def compute_training_status(df, compute_date, one, force=True, task_collection='raw_behavior_data'): 

348 """ 

349 Compute the training status for compute date based on training from that session and two previous days 

350 :param df: training dataframe 

351 :param compute_date: date to compute training on 

352 :param one: ONE instance 

353 :return: 

354 """ 

355 

356 # compute_date = str(one.path2ref(session_path)['date']) 

357 df_temp = df[df['date'] <= compute_date] 1cbad

358 df_temp = df_temp.drop_duplicates(subset=['session_path', 'task_protocol']) 1cbad

359 df_temp.sort_values('date') 1cbad

360 

361 dates = df_temp.date.values 1cbad

362 

363 n_sess_for_date = len(np.where(dates == compute_date)[0]) 1cbad

364 n_dates = np.min([2 + n_sess_for_date, len(dates)]).astype(int) 1cbad

365 compute_dates = dates[(-1 * n_dates):] 1cbad

366 if n_sess_for_date > 1: 1cbad

367 compute_dates = compute_dates[:(-1 * (n_sess_for_date - 1))] 1d

368 

369 assert compute_dates[-1] == compute_date 1cbad

370 

371 df_temp_group = df_temp.groupby('date') 1cbad

372 

373 trials = {} 1cbad

374 n_delay = 0 1cbad

375 ephys_sessions = [] 1cbad

376 protocol = [] 1cbad

377 status = [] 1cbad

378 for date in compute_dates: 1cbad

379 

380 df_date = df_temp_group.get_group(date) 1cbad

381 

382 # If habituation skip 

383 if df_date.iloc[-1]['task_protocol'] == 'habituation': 1cbad

384 continue 1a

385 # Here we should split by protocol in an ideal world but that world isn't today. This is only really relevant for 

386 # chained protocols 

387 trials[df_date.iloc[-1]['date']] = load_combined_trials(df_date.session_path.values, one, force=force) 1cbad

388 protocol.append(df_date.iloc[-1]['task_protocol']) 1cbad

389 status.append(df_date.iloc[-1]['training_status']) 1cbad

390 if df_date.iloc[-1]['combined_n_delay'] >= 900: # delay of 15 mins 1cbad

391 n_delay += 1 1d

392 if df_date.iloc[-1]['location'] == 'ephys_rig': 1cbad

393 ephys_sessions.append(df_date.iloc[-1]['date']) 1bad

394 

395 n_status = np.max([-2, -1 * len(status)]) 1cbad

396 training_status, _ = training.get_training_status(trials, protocol, ephys_sessions, n_delay) 1cbad

397 training_status = pass_through_training_hierachy(training_status, status[n_status]) 1cbad

398 df.loc[df['date'] == compute_date, 'training_status'] = training_status 1cbad

399 

400 return df 1cbad

401 

402 

403def pass_through_training_hierachy(status_new, status_old): 

404 """ 

405 Makes sure that the new training status is not less than the one from the previous day. e.g Subject cannot regress in 

406 performance 

407 :param status_new: latest training status 

408 :param status_old: previous training status 

409 :return: 

410 """ 

411 

412 if TRAINING_STATUS[status_old][0] > TRAINING_STATUS[status_new][0]: 1cbadi

413 return status_old 1ad

414 else: 

415 return status_new 1cbadi

416 

417 

418def compute_session_duration_delay_location(sess_path, collections=None, **kwargs): 

419 """ 

420 Get meta information about task. Extracts session duration, delay before session start and location of session 

421 

422 Parameters 

423 ---------- 

424 sess_path : pathlib.Path, str 

425 The session path with the pattern subject/yyyy-mm-dd/nnn. 

426 collections : list 

427 The location within the session path directory of task settings and data. 

428 

429 Returns 

430 ------- 

431 int 

432 The session duration in minutes, rounded to the nearest minute. 

433 int 

434 The delay between session start time and the first trial in seconds. 

435 str {'ephys_rig', 'training_rig'} 

436 The location of the session. 

437 """ 

438 if collections is None: 1cbad

439 collections, _ = get_data_collection(sess_path) 1b

440 

441 session_duration = 0 1cbad

442 session_delay = 0 1cbad

443 session_location = 'training_rig' 1cbad

444 for collection in collections: 1cbad

445 md, sess_data = load_bpod(sess_path, task_collection=collection) 1cbad

446 if md is None: 1cbad

447 continue 1b

448 try: 1cbad

449 start_time, end_time = _get_session_times(sess_path, md, sess_data) 1cbad

450 session_duration = session_duration + int((end_time - start_time).total_seconds() / 60) 1cbad

451 session_delay = session_delay + md.get('SESSION_START_DELAY_SEC', 0) 1cbad

452 except Exception: 

453 session_duration = session_duration + 0 

454 session_delay = session_delay + 0 

455 

456 if 'ephys' in md.get('PYBPOD_BOARD', None): 1cbad

457 session_location = 'ephys_rig' 1bad

458 else: 

459 session_location = 'training_rig' 1cad

460 

461 return session_duration, session_delay, session_location 1cbad

462 

463 

464def get_data_collection(session_path): 

465 """ 

466 Returns the location of the raw behavioral data and extracted trials data for the session path. If 

467 multiple locations in one session (e.g for dynamic) returns all of these 

468 :param session_path: path of session 

469 :return: 

470 """ 

471 experiment_description_file = read_params(session_path) 1cbad

472 if experiment_description_file is not None: 1cbad

473 pipeline = dyn.make_pipeline(session_path) 1cbad

474 trials_tasks = [t for t in pipeline.tasks if 'Trials' in t] 1cbad

475 collections = [pipeline.tasks.get(task).kwargs['collection'] for task in trials_tasks] 1cbad

476 if len(collections) == 1 and collections[0] == 'raw_behavior_data': 1cbad

477 alf_collections = ['alf'] 1ca

478 elif all(['raw_task_data' in c for c in collections]): 1bad

479 alf_collections = [f'alf/task_{c[-2:]}' for c in collections] 1ad

480 else: 

481 alf_collections = None 1b

482 else: 

483 collections = ['raw_behavior_data'] 1ad

484 alf_collections = ['alf'] 1ad

485 

486 return collections, alf_collections 1cbad

487 

488 

489def get_sess_dict(session_path, one, protocol, alf_collections=None, raw_collections=None, force=True): 

490 

491 sess_dict = {} 1cbad

492 sess_dict['date'] = str(one.path2ref(session_path)['date']) 1cbad

493 sess_dict['session_path'] = str(session_path) 1cbad

494 sess_dict['task_protocol'] = protocol 1cbad

495 

496 if sess_dict['task_protocol'] == 'habituation': 1cbad

497 nan_array = np.array([np.nan]) 1a

498 sess_dict['performance'], sess_dict['contrasts'], _ = (nan_array, nan_array, np.nan) 1a

499 sess_dict['performance_easy'] = np.nan 1a

500 sess_dict['reaction_time'] = np.nan 1a

501 sess_dict['n_trials'] = np.nan 1a

502 sess_dict['sess_duration'] = np.nan 1a

503 sess_dict['n_delay'] = np.nan 1a

504 sess_dict['location'] = np.nan 1a

505 sess_dict['training_status'] = 'habituation' 1a

506 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1a

507 (np.nan, np.nan, np.nan, np.nan) 

508 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1a

509 (np.nan, np.nan, np.nan, np.nan) 

510 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1a

511 (np.nan, np.nan, np.nan, np.nan) 

512 

513 else: 

514 # if we can't compute trials then we need to pass 

515 trials = load_trials(session_path, one, collections=alf_collections, force=force, mode='warn') 1cbad

516 if trials is None: 1cbad

517 return 1a

518 

519 sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True) 1cbad

520 if sess_dict['task_protocol'] == 'training': 1cbad

521 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1ad

522 training.compute_psychometric(trials) 

523 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1ad

524 (np.nan, np.nan, np.nan, np.nan) 

525 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1ad

526 (np.nan, np.nan, np.nan, np.nan) 

527 else: 

528 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1cbad

529 training.compute_psychometric(trials, block=0.5) 

530 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1cbad

531 training.compute_psychometric(trials, block=0.2) 

532 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1cbad

533 training.compute_psychometric(trials, block=0.8) 

534 

535 sess_dict['performance_easy'] = training.compute_performance_easy(trials) 1cbad

536 sess_dict['reaction_time'] = training.compute_median_reaction_time(trials) 1cbad

537 sess_dict['n_trials'] = training.compute_n_trials(trials) 1cbad

538 sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \ 1cbad

539 compute_session_duration_delay_location(session_path, collections=raw_collections) 

540 sess_dict['training_status'] = 'not_computed' 1cbad

541 

542 return sess_dict 1cbad

543 

544 

545def get_training_info_for_session(session_paths, one, force=True): 

546 """ 

547 Extract the training information needed for plots for each session 

548 :param session_paths: list of session paths on same date 

549 :param one: ONE instance 

550 :return: 

551 """ 

552 

553 # return list of dicts to add 

554 sess_dicts = [] 1cbad

555 for session_path in session_paths: 1cbad

556 collections, alf_collections = get_data_collection(session_path) 1cbad

557 session_path = Path(session_path) 1cbad

558 protocols = [] 1cbad

559 for c in collections: 1cbad

560 protocols.append(get_session_extractor_type(session_path, task_collection=c)) 1cbad

561 

562 un_protocols = np.unique(protocols) 1cbad

563 # Example, training, training, biased - training would be combined, biased not 

564 if len(un_protocols) != 1: 1cbad

565 print(f'Different protocols in same session {session_path} : {protocols}') 1b

566 for prot in un_protocols: 1b

567 if prot is False: 1b

568 continue 

569 try: 1b

570 alf = alf_collections[np.where(protocols == prot)[0]] 1b

571 raw = collections[np.where(protocols == prot)[0]] 

572 except TypeError: 1b

573 alf = None 1b

574 raw = None 1b

575 sess_dict = get_sess_dict(session_path, one, prot, alf_collections=alf, raw_collections=raw, force=force) 1b

576 else: 

577 prot = un_protocols[0] 1cad

578 sess_dict = get_sess_dict(session_path, one, prot, alf_collections=alf_collections, raw_collections=collections, 1cad

579 force=force) 

580 

581 if sess_dict is not None: 1cbad

582 sess_dicts.append(sess_dict) 1cbad

583 

584 protocols = [s['task_protocol'] for s in sess_dicts] 1cbad

585 

586 if len(protocols) > 0 and len(set(protocols)) != 1: 1cbad

587 print(f'Different protocols on same date {sess_dicts[0]["date"]} : {protocols}') 

588 

589 # Only if all protocols are the same and are not habituation 

590 if len(sess_dicts) > 1 and len(set(protocols)) == 1 and protocols[0] != 'habituation': # Only if all protocols are the same 1cbad

591 print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}') 1d

592 combined_trials = load_combined_trials(session_paths, one, force=force) 1d

593 performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True) 1d

594 psychs = {} 1d

595 psychs['50'] = training.compute_psychometric(combined_trials, block=0.5) 1d

596 psychs['20'] = training.compute_psychometric(combined_trials, block=0.2) 1d

597 psychs['80'] = training.compute_psychometric(combined_trials, block=0.8) 1d

598 

599 performance_easy = training.compute_performance_easy(combined_trials) 1d

600 reaction_time = training.compute_median_reaction_time(combined_trials) 1d

601 n_trials = training.compute_n_trials(combined_trials) 1d

602 

603 sess_duration = np.nansum([s['sess_duration'] for s in sess_dicts]) 1d

604 n_delay = np.nanmax([s['n_delay'] for s in sess_dicts]) 1d

605 

606 for sess_dict in sess_dicts: 1d

607 sess_dict['combined_performance'] = performance 1d

608 sess_dict['combined_contrasts'] = contrasts 1d

609 sess_dict['combined_performance_easy'] = performance_easy 1d

610 sess_dict['combined_reaction_time'] = reaction_time 1d

611 sess_dict['combined_n_trials'] = n_trials 1d

612 sess_dict['combined_sess_duration'] = sess_duration 1d

613 sess_dict['combined_n_delay'] = n_delay 1d

614 

615 for bias in [50, 20, 80]: 1d

616 sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0] 1d

617 sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1] 1d

618 sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][2] 1d

619 sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][3] 1d

620 

621 # Case where two sessions on same day with different number of contrasts! Oh boy 

622 if sess_dict['combined_performance'].size != sess_dict['performance'].size: 1d

623 sess_dict['performance'] = \ 

624 np.r_[sess_dict['performance'], 

625 np.full(sess_dict['combined_performance'].size - sess_dict['performance'].size, np.nan)] 

626 sess_dict['contrasts'] = \ 

627 np.r_[sess_dict['contrasts'], 

628 np.full(sess_dict['combined_contrasts'].size - sess_dict['contrasts'].size, np.nan)] 

629 

630 else: 

631 for sess_dict in sess_dicts: 1cbad

632 sess_dict['combined_performance'] = sess_dict['performance'] 1cbad

633 sess_dict['combined_contrasts'] = sess_dict['contrasts'] 1cbad

634 sess_dict['combined_performance_easy'] = sess_dict['performance_easy'] 1cbad

635 sess_dict['combined_reaction_time'] = sess_dict['reaction_time'] 1cbad

636 sess_dict['combined_n_trials'] = sess_dict['n_trials'] 1cbad

637 sess_dict['combined_sess_duration'] = sess_dict['sess_duration'] 1cbad

638 sess_dict['combined_n_delay'] = sess_dict['n_delay'] 1cbad

639 

640 for bias in [50, 20, 80]: 1cbad

641 sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}'] 1cbad

642 sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}'] 1cbad

643 sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}'] 1cbad

644 sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}'] 1cbad

645 

646 return sess_dicts 1cbad

647 

648 

649def check_up_to_date(subj_path, df): 

650 """ 

651 Check which sessions on local file system are missing from the computed training table. 

652 

653 Parameters 

654 ---------- 

655 subj_path : pathlib.Path 

656 The path to the subject's dated session folders. 

657 df : pandas.DataFrame 

658 The computed training table. 

659 

660 Returns 

661 ------- 

662 pandas.DataFrame 

663 A table of dates and session paths that are missing from the computed training table. 

664 """ 

665 df_session = pd.DataFrame() 1cbafd

666 

667 for session in alfio.iter_sessions(subj_path): 1cbafd

668 s_df = pd.DataFrame({'date': session.parts[-2], 'session_path': str(session)}, index=[0]) 1cbafd

669 df_session = pd.concat([df_session, s_df], ignore_index=True) 1cbafd

670 

671 if df is None or 'combined_thres_50' not in df.columns: 1cbafd

672 return df_session 1cbafd

673 else: 

674 # recorded_session_paths = df['session_path'].values 

675 isin, _ = ismember(df_session.date.unique(), df.date.unique()) 1af

676 missing_dates = df_session.date.unique()[~isin] 1af

677 return df_session[df_session['date'].isin(missing_dates)].sort_values('date') 1af

678 

679 

680def plot_trial_count_and_session_duration(df, subject): 

681 

682 df = df.drop_duplicates('date').reset_index(drop=True) 1cba

683 

684 y1 = {'column': 'combined_n_trials', 1cba

685 'title': 'Trial counts', 

686 'lim': None, 

687 'color': 'k', 

688 'join': True} 

689 

690 y2 = {'column': 'combined_sess_duration', 1cba

691 'title': 'Session duration (mins)', 

692 'lim': None, 

693 'color': 'r', 

694 'log': False, 

695 'join': True} 

696 

697 ax = plot_over_days(df, subject, y1, y2) 1cba

698 

699 return ax 1cba

700 

701 

702def plot_performance_easy_median_reaction_time(df, subject): 

703 df = df.drop_duplicates('date').reset_index(drop=True) 1cba

704 

705 y1 = {'column': 'combined_performance_easy', 1cba

706 'title': 'Performance on easy trials', 

707 'lim': [0, 1.05], 

708 'color': 'k', 

709 'join': True} 

710 

711 y2 = {'column': 'combined_reaction_time', 1cba

712 'title': 'Median reaction time (s)', 

713 'lim': [0.1, np.nanmax([10, np.nanmax(df.combined_reaction_time.values)])], 

714 'color': 'r', 

715 'log': True, 

716 'join': True} 

717 ax = plot_over_days(df, subject, y1, y2) 1cba

718 

719 return ax 1cba

720 

721 

722def plot_fit_params(df, subject): 

723 fig, axs = plt.subplots(2, 2, figsize=(12, 6)) 1cba

724 axs = axs.ravel() 1cba

725 

726 df = df.drop_duplicates('date').reset_index(drop=True) 1cba

727 

728 cmap = sns.diverging_palette(20, 220, n=3, center="dark") 1cba

729 

730 y50 = {'column': 'combined_bias_50', 1cba

731 'title': 'Bias', 

732 'lim': [-100, 100], 

733 'color': cmap[1], 

734 'join': False} 

735 

736 y80 = {'column': 'combined_bias_80', 1cba

737 'title': 'Bias', 

738 'lim': [-100, 100], 

739 'color': cmap[2], 

740 'join': False} 

741 

742 y20 = {'column': 'combined_bias_20', 1cba

743 'title': 'Bias', 

744 'lim': [-100, 100], 

745 'color': cmap[0], 

746 'join': False} 

747 

748 plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False) 1cba

749 plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False) 1cba

750 plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False) 1cba

751 axs[0].axhline(16, linewidth=2, linestyle='--', color='k') 1cba

752 axs[0].axhline(-16, linewidth=2, linestyle='--', color='k') 1cba

753 

754 y50['column'] = 'combined_thres_50' 1cba

755 y50['title'] = 'Threshold' 1cba

756 y50['lim'] = [0, 100] 1cba

757 y80['column'] = 'combined_thres_20' 1cba

758 y80['title'] = 'Threshold' 1cba

759 y20['lim'] = [0, 100] 1cba

760 y20['column'] = 'combined_thres_80' 1cba

761 y20['title'] = 'Threshold' 1cba

762 y80['lim'] = [0, 100] 1cba

763 

764 plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False) 1cba

765 plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False) 1cba

766 plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False) 1cba

767 axs[1].axhline(19, linewidth=2, linestyle='--', color='k') 1cba

768 

769 y50['column'] = 'combined_lapselow_50' 1cba

770 y50['title'] = 'Lapse Low' 1cba

771 y50['lim'] = [0, 1] 1cba

772 y80['column'] = 'combined_lapselow_20' 1cba

773 y80['title'] = 'Lapse Low' 1cba

774 y80['lim'] = [0, 1] 1cba

775 y20['column'] = 'combined_lapselow_80' 1cba

776 y20['title'] = 'Lapse Low' 1cba

777 y20['lim'] = [0, 1] 1cba

778 

779 plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False) 1cba

780 plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False) 1cba

781 plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False) 1cba

782 axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k') 1cba

783 

784 y50['column'] = 'combined_lapsehigh_50' 1cba

785 y50['title'] = 'Lapse High' 1cba

786 y50['lim'] = [0, 1] 1cba

787 y80['column'] = 'combined_lapsehigh_20' 1cba

788 y80['title'] = 'Lapse High' 1cba

789 y80['lim'] = [0, 1] 1cba

790 y20['column'] = 'combined_lapsehigh_80' 1cba

791 y20['title'] = 'Lapse High' 1cba

792 y20['lim'] = [0, 1] 1cba

793 

794 plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True) 1cba

795 plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False) 1cba

796 plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False) 1cba

797 axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k') 1cba

798 

799 fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1cba

800 lines, labels = axs[3].get_legend_handles_labels() 1cba

801 fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5) 1cba

802 

803 legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8), 1cba

804 Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8), 

805 Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)] 

806 legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True) 1cba

807 fig.add_artist(legend2) 1cba

808 

809 return axs 1cba

810 

811 

812def plot_psychometric_curve(df, subject, one): 

813 df = df.drop_duplicates('date').reset_index(drop=True) 1cba

814 sess_path = Path(df.iloc[-1]["session_path"]) 1cba

815 trials = load_trials(sess_path, one) 1cba

816 

817 fig, ax1 = plt.subplots(figsize=(8, 6)) 1cba

818 

819 training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1cba

820 

821 return ax1 1cba

822 

823 

824def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True): 

825 

826 if ax is None: 1cba

827 fig, ax1 = plt.subplots(figsize=(12, 6)) 1cba

828 else: 

829 ax1 = ax 1cba

830 

831 dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']] 1cba

832 if y1['join']: 1cba

833 ax1.plot(dates, df[y1['column']], color=y1['color']) 1cba

834 ax1.scatter(dates, df[y1['column']], color=y1['color']) 1cba

835 ax1.set_ylabel(y1['title']) 1cba

836 ax1.set_ylim(y1['lim']) 1cba

837 

838 if y2 is not None: 1cba

839 ax2 = ax1.twinx() 1cba

840 if y2['join']: 1cba

841 ax2.plot(dates, df[y2['column']], color=y2['color']) 1cba

842 ax2.scatter(dates, df[y2['column']], color=y2['color']) 1cba

843 ax2.set_ylabel(y2['title']) 1cba

844 ax2.yaxis.label.set_color(y2['color']) 1cba

845 ax2.tick_params(axis='y', colors=y2['color']) 1cba

846 ax2.set_ylim(y2['lim']) 1cba

847 if y2['log']: 1cba

848 ax2.set_yscale('log') 1cba

849 

850 ax2.spines['right'].set_visible(False) 1cba

851 ax2.spines['top'].set_visible(False) 1cba

852 ax2.spines['left'].set_visible(False) 1cba

853 

854 month_format = mdates.DateFormatter('%b %Y') 1cba

855 month_locator = mdates.MonthLocator() 1cba

856 ax1.xaxis.set_major_locator(month_locator) 1cba

857 ax1.xaxis.set_major_formatter(month_format) 1cba

858 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1cba

859 ax1.xaxis.set_minor_locator(week_locator) 1cba

860 ax1.grid(True, which='minor', axis='x', linestyle='--') 1cba

861 

862 ax1.spines['left'].set_visible(False) 1cba

863 ax1.spines['right'].set_visible(False) 1cba

864 ax1.spines['top'].set_visible(False) 1cba

865 

866 if training_lines: 1cba

867 ax1 = add_training_lines(df, ax1) 1cba

868 

869 if title: 1cba

870 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1cba

871 

872 # Put a legend below current axis 

873 box = ax1.get_position() 1cba

874 ax1.set_position([box.x0, box.y0 + box.height * 0.1, 1cba

875 box.width, box.height * 0.9]) 

876 if legend: 1cba

877 ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), 1cba

878 fancybox=True, shadow=True, ncol=5) 

879 

880 return ax1 1cba

881 

882 

883def add_training_lines(df, ax): 

884 

885 status = df.drop_duplicates(subset='training_status', keep='first') 1cba

886 for _, st in status.iterrows(): 1cba

887 

888 if st['training_status'] in ['untrainable', 'unbiasable']: 1cba

889 continue 

890 

891 if TRAINING_STATUS[st['training_status']][0] <= 0: 1cba

892 continue 1a

893 

894 ax.axvline(datetime.strptime(st['date'], '%Y-%m-%d'), linewidth=2, 1cba

895 color=np.array(TRAINING_STATUS[st['training_status']][1]) / 255, label=st['training_status']) 

896 

897 return ax 1cba

898 

899 

900def plot_heatmap_performance_over_days(df, subject): 

901 

902 df = df.drop_duplicates(subset=['date', 'combined_contrasts']) 1cba

903 df_perf = df.pivot(index=['date'], columns=['combined_contrasts'], values=['combined_performance']).sort_values( 1cba

904 by='combined_contrasts', axis=1, ascending=False) 

905 df_perf.index = pd.to_datetime(df_perf.index) 1cba

906 full_date_range = pd.date_range(start=df_perf.index.min(), end=df_perf.index.max(), freq="D") 1cba

907 df_perf = df_perf.reindex(full_date_range, fill_value=np.nan) 1cba

908 

909 n_contrasts = len(df.combined_contrasts.unique()) 1cba

910 

911 dates = df_perf.index.to_pydatetime() 1cba

912 dnum = mdates.date2num(dates) 1cba

913 if len(dnum) > 1: 1cba

914 start = dnum[0] - (dnum[1] - dnum[0]) / 2. 1a

915 stop = dnum[-1] + (dnum[1] - dnum[0]) / 2. 1a

916 else: 

917 start = dnum[0] + 0.5 1cba

918 stop = dnum[0] + 1.5 1cba

919 

920 extent = [start, stop, 0, n_contrasts] 1cba

921 

922 fig, ax1 = plt.subplots(figsize=(12, 6)) 1cba

923 im = ax1.imshow(df_perf.T.values, extent=extent, aspect="auto", cmap='PuOr') 1cba

924 

925 month_format = mdates.DateFormatter('%b %Y') 1cba

926 month_locator = mdates.MonthLocator() 1cba

927 ax1.xaxis.set_major_locator(month_locator) 1cba

928 ax1.xaxis.set_major_formatter(month_format) 1cba

929 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1cba

930 ax1.xaxis.set_minor_locator(week_locator) 1cba

931 ax1.grid(True, which='minor', axis='x', linestyle='--') 1cba

932 ax1.set_yticks(np.arange(0.5, n_contrasts + 0.5, 1)) 1cba

933 ax1.set_yticklabels(np.sort(df.combined_contrasts.unique())) 1cba

934 ax1.set_ylabel('Contrast (%)') 1cba

935 ax1.set_xlabel('Date') 1cba

936 cbar = fig.colorbar(im) 1cba

937 cbar.set_label('Rightward choice (%') 1cba

938 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1cba

939 

940 ax1.spines['left'].set_visible(False) 1cba

941 ax1.spines['right'].set_visible(False) 1cba

942 ax1.spines['top'].set_visible(False) 1cba

943 

944 return ax1 1cba

945 

946 

947def make_plots(session_path, one, df=None, save=False, upload=False, task_collection='raw_behavior_data'): 

948 subject = one.path2ref(session_path)['subject'] 1cba

949 subj_path = session_path.parent.parent 1cba

950 

951 df = load_existing_dataframe(subj_path) if df is None else df 1cba

952 

953 df = df[df['task_protocol'] != 'habituation'] 1cba

954 

955 if len(df) == 0: 1cba

956 return 

957 

958 ax1 = plot_trial_count_and_session_duration(df, subject) 1cba

959 ax2 = plot_performance_easy_median_reaction_time(df, subject) 1cba

960 ax3 = plot_heatmap_performance_over_days(df, subject) 1cba

961 ax4 = plot_fit_params(df, subject) 1cba

962 ax5 = plot_psychometric_curve(df, subject, one) 1cba

963 

964 outputs = [] 1cba

965 if save: 1cba

966 save_path = Path(subj_path) 1cba

967 save_name = save_path.joinpath('subj_trial_count_session_duration.png') 1cba

968 outputs.append(save_name) 1cba

969 ax1.get_figure().savefig(save_name, bbox_inches='tight') 1cba

970 

971 save_name = save_path.joinpath('subj_performance_easy_reaction_time.png') 1cba

972 outputs.append(save_name) 1cba

973 ax2.get_figure().savefig(save_name, bbox_inches='tight') 1cba

974 

975 save_name = save_path.joinpath('subj_performance_heatmap.png') 1cba

976 outputs.append(save_name) 1cba

977 ax3.get_figure().savefig(save_name, bbox_inches='tight') 1cba

978 

979 save_name = save_path.joinpath('subj_psychometric_fit_params.png') 1cba

980 outputs.append(save_name) 1cba

981 ax4[0].get_figure().savefig(save_name, bbox_inches='tight') 1cba

982 

983 save_name = save_path.joinpath('subj_psychometric_curve.png') 1cba

984 outputs.append(save_name) 1cba

985 ax5.get_figure().savefig(save_name, bbox_inches='tight') 1cba

986 

987 if upload: 1cba

988 subj = one.alyx.rest('subjects', 'list', nickname=subject)[0] 1cba

989 snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one) 1cba

990 snp.outputs = outputs 1cba

991 snp.register_images(widths=['orig']) 1cba