Coverage for ibllib/pipes/training_status.py: 90%

573 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1import logging 

2from pathlib import Path 

3from datetime import datetime 

4from itertools import chain 

5 

6import numpy as np 

7import pandas as pd 

8from iblutil.numerical import ismember 

9import one.alf.io as alfio 

10from one.alf.exceptions import ALFObjectNotFound 

11import matplotlib.pyplot as plt 

12import matplotlib.dates as mdates 

13from matplotlib.lines import Line2D 

14import seaborn as sns 

15import boto3 

16from botocore.exceptions import ProfileNotFound, ClientError 

17 

18from ibllib.io.raw_data_loaders import load_bpod 

19from ibllib.oneibl.registration import _get_session_times 

20from ibllib.io.extractors.base import get_session_extractor_type, get_bpod_extractor_class 

21from ibllib.io.session_params import read_params 

22from ibllib.io.extractors.bpod_trials import get_bpod_extractor 

23from ibllib.plots.snapshot import ReportSnapshot 

24from brainbox.behavior import training 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29TRAINING_STATUS = {'untrainable': (-4, (0, 0, 0, 0)), 

30 'unbiasable': (-3, (0, 0, 0, 0)), 

31 'not_computed': (-2, (0, 0, 0, 0)), 

32 'habituation': (-1, (0, 0, 0, 0)), 

33 'in training': (0, (0, 0, 0, 0)), 

34 'trained 1a': (1, (195, 90, 80, 255)), 

35 'trained 1b': (2, (255, 153, 20, 255)), 

36 'ready4ephysrig': (3, (28, 20, 255, 255)), 

37 'ready4delay': (4, (117, 117, 117, 255)), 

38 'ready4recording': (5, (20, 255, 91, 255))} 

39 

40 

41def get_training_table_from_aws(lab, subject): 

42 """ 

43 If aws credentials exist on the local server download the latest training table from aws s3 private bucket 

44 :param lab: 

45 :param subject: 

46 :return: 

47 """ 

48 try: 1ba

49 session = boto3.Session(profile_name='ibl_training') 1ba

50 except ProfileNotFound: 

51 return 

52 

53 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1ba

54 dst_bucket_name = 'ibl-brain-wide-map-private' 1ba

55 try: 1ba

56 s3 = session.resource('s3') 1ba

57 bucket = s3.Bucket(name=dst_bucket_name) 1ba

58 bucket.download_file(f'resources/training/{lab}/{subject}/training.csv', 1ba

59 local_file_path) 

60 df = pd.read_csv(local_file_path) 

61 except ClientError: 1ba

62 return 1ba

63 

64 return df 

65 

66 

67def upload_training_table_to_aws(lab, subject): 

68 """ 

69 If aws credentials exist on the local server upload the training table to aws s3 private bucket 

70 :param lab: 

71 :param subject: 

72 :return: 

73 """ 

74 try: 1ba

75 session = boto3.Session(profile_name='ibl_training') 1ba

76 except ProfileNotFound: 

77 return 

78 

79 local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv' 1ba

80 dst_bucket_name = 'ibl-brain-wide-map-private' 1ba

81 try: 1ba

82 s3 = session.resource('s3') 1ba

83 bucket = s3.Bucket(name=dst_bucket_name) 1ba

84 bucket.upload_file(local_file_path, 1ba

85 f'resources/training/{lab}/{subject}/training.csv') 

86 except (ClientError, FileNotFoundError): 1ba

87 return 1ba

88 

89 

90def save_path(subj_path): 

91 return Path(subj_path).joinpath('training.csv') 1dbaegc

92 

93 

94def save_dataframe(df, subj_path): 

95 """Save training dataframe to disk. 

96 

97 :param df: dataframe to save 

98 :param subj_path: path to subject folder 

99 :return: 

100 """ 

101 df.to_csv(save_path(subj_path), index=False) 1bac

102 

103 

104def load_existing_dataframe(subj_path): 

105 """Load training dataframe from disk, if dataframe doesn't exist returns None. 

106 

107 :param subj_path: path to subject folder 

108 :return: 

109 """ 

110 df_location = save_path(subj_path) 1baegc

111 if df_location.exists(): 1baegc

112 return pd.read_csv(df_location) 1aegc

113 else: 

114 df_location.parent.mkdir(exist_ok=True, parents=True) 1baegc

115 return None 1baegc

116 

117 

118def load_trials(sess_path, one, collections=None, force=True, mode='raise'): 

119 """ 

120 Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE, 

121 if this also fails, will then attempt to re-extract locally 

122 :param sess_path: session path 

123 :param one: ONE instance 

124 :param force: when True and if the session trials can't be found, will attempt to re-extract from the disk 

125 :param mode: 'raise' or 'warn', if 'raise', will error when forcing re-extraction of past sessions 

126 :return: 

127 """ 

128 try: 1bafc

129 # try and load all trials that are found locally in the session path locally 

130 if collections is None: 1bafc

131 trial_locations = list(sess_path.rglob('_ibl_trials.goCueTrigger_times.*npy')) 1bafc

132 else: 

133 trial_locations = [Path(sess_path).joinpath(c, '_ibl_trials.goCueTrigger_times.*npy') for c in collections] 1bac

134 

135 if len(trial_locations) > 1: 1bafc

136 trial_dict = {} 1a

137 for i, loc in enumerate(trial_locations): 1a

138 trial_dict[i] = alfio.load_object(loc.parent, 'trials', short_keys=True) 1a

139 trials = training.concatenate_trials(trial_dict) 

140 elif len(trial_locations) == 1: 1bafc

141 trials = alfio.load_object(trial_locations[0].parent, 'trials', short_keys=True) 1bafc

142 else: 

143 raise ALFObjectNotFound 1a

144 

145 if 'probabilityLeft' not in trials.keys(): 1bafc

146 raise ALFObjectNotFound 

147 except ALFObjectNotFound: 1a

148 # Next try and load all trials data through ONE 

149 try: 1a

150 if not force: 1a

151 return None 

152 eid = one.path2eid(sess_path) 1a

153 if collections is None: 1a

154 trial_collections = one.list_datasets(eid, '_ibl_trials.goCueTrigger_times.npy') 1a

155 if len(trial_collections) > 0: 1a

156 trial_collections = ['/'.join(c.split('/')[:-1]) for c in trial_collections] 1a

157 else: 

158 trial_collections = collections 1a

159 

160 if len(trial_collections) > 1: 1a

161 trial_dict = {} 1a

162 for i, collection in enumerate(trial_collections): 1a

163 trial_dict[i] = one.load_object(eid, 'trials', collection=collection) 1a

164 trials = training.concatenate_trials(trial_dict) 

165 elif len(trial_collections) == 1: 1a

166 trials = one.load_object(eid, 'trials', collection=trial_collections[0]) 1a

167 else: 

168 raise ALFObjectNotFound 

169 

170 if 'probabilityLeft' not in trials.keys(): 

171 raise ALFObjectNotFound 

172 except Exception: 1a

173 # Finally try to re-extract the trials data locally 

174 try: 1a

175 raw_collections, _ = get_data_collection(sess_path) 1a

176 

177 if len(raw_collections) == 0: 1a

178 return None 

179 

180 trials_dict = {} 1a

181 for i, collection in enumerate(raw_collections): 1a

182 extractor = get_bpod_extractor(sess_path, task_collection=collection) 1a

183 trials_data, _ = extractor.extract(task_collection=collection, save=False) 1a

184 trials_dict[i] = alfio.AlfBunch.from_df(trials_data['table']) 1a

185 

186 if len(trials_dict) > 1: 1a

187 trials = training.concatenate_trials(trials_dict) 

188 else: 

189 trials = trials_dict[0] 1a

190 

191 except Exception as e: 1a

192 if mode == 'raise': 1a

193 raise Exception(f'Exhausted all possibilities for loading trials for {sess_path}') from e 

194 else: 

195 logger.warning(f'Exhausted all possibilities for loading trials for {sess_path}') 1a

196 return 1a

197 

198 return trials 1bafc

199 

200 

201def load_combined_trials(sess_paths, one, force=True): 

202 """ 

203 Load and concatenate trials for multiple sessions. Used when we want to concatenate trials for two sessions on the same day 

204 :param sess_paths: list of paths to sessions 

205 :param one: ONE instance 

206 :return: 

207 """ 

208 trials_dict = {} 1bafc

209 for sess_path in sess_paths: 1bafc

210 trials = load_trials(Path(sess_path), one, force=force) 1bafc

211 if trials is not None: 1bafc

212 trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force) 1bafc

213 

214 return training.concatenate_trials(trials_dict) 1bafc

215 

216 

217def get_latest_training_information(sess_path, one): 

218 """ 

219 Extracts the latest training status. 

220 

221 Parameters 

222 ---------- 

223 sess_path : pathlib.Path 

224 The session path from which to load the data. 

225 one : one.api.One 

226 An ONE instance. 

227 

228 Returns 

229 ------- 

230 pandas.DataFrame 

231 A table of training information. 

232 """ 

233 

234 subj_path = sess_path.parent.parent 1bac

235 sub = subj_path.parts[-1] 1bac

236 if one.mode != 'local': 1bac

237 lab = one.alyx.rest('subjects', 'list', nickname=sub)[0]['lab'] 1ba

238 df = get_training_table_from_aws(lab, sub) 1ba

239 else: 

240 df = None 1c

241 

242 if df is None: 1bac

243 df = load_existing_dataframe(subj_path) 1bac

244 

245 # Find the dates and associated session paths where we don't have data stored in our dataframe 

246 missing_dates = check_up_to_date(subj_path, df) 1bac

247 

248 # Iterate through the dates to fill up our training dataframe 

249 for _, grp in missing_dates.groupby('date'): 1bac

250 sess_dicts = get_training_info_for_session(grp.session_path.values, one) 1bac

251 if len(sess_dicts) == 0: 1bac

252 continue 

253 

254 for sess_dict in sess_dicts: 1bac

255 if df is None: 1bac

256 df = pd.DataFrame.from_dict(sess_dict) 1bac

257 else: 

258 df = pd.concat([df, pd.DataFrame.from_dict(sess_dict)]) 1ac

259 

260 # Sort values by date and reset the index 

261 df = df.sort_values('date') 1bac

262 df = df.reset_index(drop=True) 1bac

263 # Save our dataframe 

264 save_dataframe(df, subj_path) 1bac

265 

266 # Now go through the backlog and compute the training status for sessions. If for example one was missing as it is cumulative 

267 # we need to go through and compute all the backlog 

268 # Find the earliest date in missing dates that we need to recompute the training status for 

269 missing_status = find_earliest_recompute_date(df.drop_duplicates('date').reset_index(drop=True)) 1bac

270 for date in missing_status: 1bac

271 df = compute_training_status(df, date, one) 1bac

272 

273 df_lim = df.drop_duplicates(subset='session_path', keep='first') 1bac

274 

275 # Detect untrainable 

276 if 'untrainable' not in df_lim.training_status.values: 1bac

277 un_df = df_lim[df_lim['training_status'] == 'in training'].sort_values('date') 1bac

278 if len(un_df) >= 40: 1bac

279 sess = un_df.iloc[39].session_path 

280 df.loc[df['session_path'] == sess, 'training_status'] = 'untrainable' 

281 

282 # Detect unbiasable 

283 if 'unbiasable' not in df_lim.training_status.values: 1bac

284 un_df = df_lim[df_lim['task_protocol'] == 'biased'].sort_values('date') 1bac

285 if len(un_df) >= 40: 1bac

286 tr_st = un_df[0:40].training_status.unique() 

287 if 'ready4ephysrig' not in tr_st: 

288 sess = un_df.iloc[39].session_path 

289 df.loc[df['session_path'] == sess, 'training_status'] = 'unbiasable' 

290 

291 save_dataframe(df, subj_path) 1bac

292 

293 if one.mode != 'local': 1bac

294 upload_training_table_to_aws(lab, sub) 1ba

295 

296 return df 1bac

297 

298 

299def find_earliest_recompute_date(df): 

300 """ 

301 Find the earliest date that we need to compute the training status from. Training status depends on previous sessions 

302 so if a session was missing and now has been added we need to recompute everything from that date onwards 

303 :param df: 

304 :return: 

305 """ 

306 missing_df = df[df['training_status'] == 'not_computed'] 1bagc

307 if len(missing_df) == 0: 1bagc

308 return [] 1a

309 missing_df = missing_df.sort_values('date') 1bagc

310 first_index = missing_df.index[0] 1bagc

311 

312 return df[first_index:].date.values 1bagc

313 

314 

315def compute_training_status(df, compute_date, one, force=True): 

316 """ 

317 Compute the training status for compute date based on training from that session and two previous days. 

318 

319 When true and if the session trials can't be found, will attempt to re-extract from disk. 

320 :return: 

321 

322 Parameters 

323 ---------- 

324 df : pandas.DataFrame 

325 A training data frame, e.g. one generated from :func:`get_training_info_for_session`. 

326 compute_date : str, datetime.datetime, pandas.Timestamp 

327 The date to compute training on. 

328 one : one.api.One 

329 An instance of ONE for loading trials data. 

330 force : bool 

331 When true and if the session trials can't be found, will attempt to re-extract from disk. 

332 

333 Returns 

334 ------- 

335 pandas.DataFrame 

336 The input data frame with a 'training_status' column populated for `compute_date`. 

337 """ 

338 

339 # compute_date = str(one.path2ref(session_path)['date']) 

340 df_temp = df[df['date'] <= compute_date] 1bac

341 df_temp = df_temp.drop_duplicates(subset=['session_path', 'task_protocol']) 1bac

342 df_temp.sort_values('date') 1bac

343 

344 dates = df_temp.date.values 1bac

345 

346 n_sess_for_date = len(np.where(dates == compute_date)[0]) 1bac

347 n_dates = np.min([2 + n_sess_for_date, len(dates)]).astype(int) 1bac

348 compute_dates = dates[(-1 * n_dates):] 1bac

349 if n_sess_for_date > 1: 1bac

350 compute_dates = compute_dates[:(-1 * (n_sess_for_date - 1))] 1ac

351 

352 assert compute_dates[-1] == compute_date 1bac

353 

354 df_temp_group = df_temp.groupby('date') 1bac

355 

356 trials = {} 1bac

357 n_delay = 0 1bac

358 ephys_sessions = [] 1bac

359 protocol = [] 1bac

360 status = [] 1bac

361 for date in compute_dates: 1bac

362 

363 df_date = df_temp_group.get_group(date) 1bac

364 

365 # If habituation skip 

366 if df_date.iloc[-1]['task_protocol'] == 'habituation': 1bac

367 continue 

368 # Here we should split by protocol in an ideal world but that world isn't today. This is only really relevant for 

369 # chained protocols 

370 trials[df_date.iloc[-1]['date']] = load_combined_trials(df_date.session_path.values, one, force=force) 1bac

371 protocol.append(df_date.iloc[-1]['task_protocol']) 1bac

372 status.append(df_date.iloc[-1]['training_status']) 1bac

373 if df_date.iloc[-1]['combined_n_delay'] >= 900: # delay of 15 mins 1bac

374 n_delay += 1 1c

375 if df_date.iloc[-1]['location'] == 'ephys_rig': 1bac

376 ephys_sessions.append(df_date.iloc[-1]['date']) 1ac

377 

378 n_status = np.max([-2, -1 * len(status)]) 1bac

379 training_status, _ = training.get_training_status(trials, protocol, ephys_sessions, n_delay) 1bac

380 training_status = pass_through_training_hierachy(training_status, status[n_status]) 1bac

381 df.loc[df['date'] == compute_date, 'training_status'] = training_status 1bac

382 

383 return df 1bac

384 

385 

386def pass_through_training_hierachy(status_new, status_old): 

387 """ 

388 Makes sure that the new training status is not less than the one from the previous day. e.g Subject cannot regress in 

389 performance 

390 :param status_new: latest training status 

391 :param status_old: previous training status 

392 :return: 

393 """ 

394 

395 if TRAINING_STATUS[status_old][0] > TRAINING_STATUS[status_new][0]: 1bach

396 return status_old 1c

397 else: 

398 return status_new 1bach

399 

400 

401def compute_session_duration_delay_location(sess_path, collections=None, **kwargs): 

402 """ 

403 Get meta information about task. Extracts session duration, delay before session start and location of session 

404 

405 Parameters 

406 ---------- 

407 sess_path : pathlib.Path, str 

408 The session path with the pattern subject/yyyy-mm-dd/nnn. 

409 collections : list 

410 The location within the session path directory of task settings and data. 

411 

412 Returns 

413 ------- 

414 int 

415 The session duration in minutes, rounded to the nearest minute. 

416 int 

417 The delay between session start time and the first trial in seconds. 

418 str {'ephys_rig', 'training_rig'} 

419 The location of the session. 

420 """ 

421 if collections is None: 1bac

422 collections, _ = get_data_collection(sess_path) 

423 

424 session_duration = 0 1bac

425 session_delay = 0 1bac

426 session_location = 'training_rig' 1bac

427 for collection in collections: 1bac

428 md, sess_data = load_bpod(sess_path, task_collection=collection) 1bac

429 if md is None: 1bac

430 continue 

431 try: 1bac

432 start_time, end_time = _get_session_times(sess_path, md, sess_data) 1bac

433 session_duration = session_duration + int((end_time - start_time).total_seconds() / 60) 1bac

434 session_delay = session_delay + md.get('SESSION_START_DELAY_SEC', 0) 1bac

435 except Exception: 

436 session_duration = session_duration + 0 

437 session_delay = session_delay + 0 

438 

439 if 'ephys' in md.get('PYBPOD_BOARD', None): 1bac

440 session_location = 'ephys_rig' 1ac

441 else: 

442 session_location = 'training_rig' 1bac

443 

444 return session_duration, session_delay, session_location 1bac

445 

446 

447def get_data_collection(session_path): 

448 """Return the location of the raw behavioral data and extracted trials data for a given session. 

449 

450 For multiple locations in one session (e.g. chained protocols), returns all collections. 

451 Passive protocols are excluded. 

452 

453 Parameters 

454 ---------- 

455 session_path : pathlib.Path 

456 A session path in the form subject/date/number. 

457 

458 Returns 

459 ------- 

460 list of str 

461 A list of sub-directory names that contain raw behaviour data. 

462 list of str 

463 A list of sub-directory names that contain ALF trials data. 

464 

465 Examples 

466 -------- 

467 An iblrig v7 session 

468 

469 >>> get_data_collection(Path(r'C:/data/subject/2023-01-01/001')) 

470 ['raw_behavior_data'], ['alf'] 

471 

472 An iblrig v8 session where two protocols were run 

473 

474 >>> get_data_collection(Path(r'C:/data/subject/2023-01-01/001')) 

475 ['raw_task_data_00', 'raw_task_data_01], ['alf/task_00', 'alf/task_01'] 

476 """ 

477 experiment_description = read_params(session_path) 1bac

478 collections = [] 1bac

479 if experiment_description is not None: 1bac

480 task_protocols = experiment_description.get('tasks', []) 1bac

481 for i, (protocol, task_info) in enumerate(chain(*map(dict.items, task_protocols))): 1bac

482 if 'passiveChoiceWorld' in protocol: 1bac

483 continue 1c

484 collection = task_info.get('collection', f'raw_task_data_{i:02}') 1bac

485 if collection == 'raw_passive_data': 1bac

486 continue 

487 collections.append(collection) 1bac

488 else: 

489 settings = Path(session_path).rglob('_iblrig_taskSettings.raw.json') 1ac

490 for setting in settings: 1ac

491 if setting.parent.name != 'raw_passive_data': 1ac

492 collections.append(setting.parent.name) 1ac

493 

494 if len(collections) == 1 and collections[0] == 'raw_behavior_data': 1bac

495 alf_collections = ['alf'] 1bac

496 elif all(['raw_task_data' in c for c in collections]): 1ac

497 alf_collections = [f'alf/task_{c[-2:]}' for c in collections] 1ac

498 else: 

499 alf_collections = None 

500 

501 return collections, alf_collections 1bac

502 

503 

504def get_sess_dict(session_path, one, protocol, alf_collections=None, raw_collections=None, force=True): 

505 

506 sess_dict = {} 1bac

507 sess_dict['date'] = str(one.path2ref(session_path)['date']) 1bac

508 sess_dict['session_path'] = str(session_path) 1bac

509 sess_dict['task_protocol'] = protocol 1bac

510 

511 if sess_dict['task_protocol'] == 'habituation': 1bac

512 nan_array = np.array([np.nan]) 

513 sess_dict['performance'], sess_dict['contrasts'], _ = (nan_array, nan_array, np.nan) 

514 sess_dict['performance_easy'] = np.nan 

515 sess_dict['reaction_time'] = np.nan 

516 sess_dict['n_trials'] = np.nan 

517 sess_dict['sess_duration'] = np.nan 

518 sess_dict['n_delay'] = np.nan 

519 sess_dict['location'] = np.nan 

520 sess_dict['training_status'] = 'habituation' 

521 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 

522 (np.nan, np.nan, np.nan, np.nan) 

523 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 

524 (np.nan, np.nan, np.nan, np.nan) 

525 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 

526 (np.nan, np.nan, np.nan, np.nan) 

527 

528 else: 

529 # if we can't compute trials then we need to pass 

530 trials = load_trials(session_path, one, collections=alf_collections, force=force, mode='warn') 1bac

531 if trials is None: 1bac

532 return 1a

533 

534 sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True) 1bac

535 if sess_dict['task_protocol'] == 'training': 1bac

536 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1ac

537 training.compute_psychometric(trials) 

538 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1ac

539 (np.nan, np.nan, np.nan, np.nan) 

540 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1ac

541 (np.nan, np.nan, np.nan, np.nan) 

542 else: 

543 sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \ 1bac

544 training.compute_psychometric(trials, block=0.5) 

545 sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \ 1bac

546 training.compute_psychometric(trials, block=0.2) 

547 sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \ 1bac

548 training.compute_psychometric(trials, block=0.8) 

549 

550 sess_dict['performance_easy'] = training.compute_performance_easy(trials) 1bac

551 sess_dict['reaction_time'] = training.compute_median_reaction_time(trials) 1bac

552 sess_dict['n_trials'] = training.compute_n_trials(trials) 1bac

553 sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \ 1bac

554 compute_session_duration_delay_location(session_path, collections=raw_collections) 

555 sess_dict['training_status'] = 'not_computed' 1bac

556 

557 return sess_dict 1bac

558 

559 

560def get_training_info_for_session(session_paths, one, force=True): 

561 """ 

562 Extract the training information needed for plots for each session. 

563 

564 Parameters 

565 ---------- 

566 session_paths : list of pathlib.Path 

567 List of session paths on same date. 

568 one : one.api.One 

569 An ONE instance. 

570 force : bool 

571 When true and if the session trials can't be found, will attempt to re-extract from disk. 

572 

573 Returns 

574 ------- 

575 list of dict 

576 A list of dictionaries the length of `session_paths` containing individual and aggregate 

577 performance information. 

578 """ 

579 

580 # return list of dicts to add 

581 sess_dicts = [] 1bac

582 for session_path in session_paths: 1bac

583 collections, alf_collections = get_data_collection(session_path) 1bac

584 session_path = Path(session_path) 1bac

585 protocols = [] 1bac

586 for c in collections: 1bac

587 try: 1bac

588 prot = get_bpod_extractor_class(session_path, task_collection=c) 1bac

589 prot = prot[:-6].lower() 1bac

590 except Exception: 

591 prot = get_session_extractor_type(session_path, task_collection=c) 

592 protocols.append(prot) 1bac

593 

594 un_protocols = np.unique(protocols) 1bac

595 # Example, training, training, biased - training would be combined, biased not 

596 sess_dict = None 1bac

597 if len(un_protocols) != 1: 1bac

598 print(f'Different protocols in same session {session_path} : {protocols}') 

599 for prot in un_protocols: 

600 if prot is False: 

601 continue 

602 try: 

603 alf = alf_collections[np.where(protocols == prot)[0]] 

604 raw = collections[np.where(protocols == prot)[0]] 

605 except TypeError: 

606 alf = None 

607 raw = None 

608 sess_dict = get_sess_dict(session_path, one, prot, alf_collections=alf, raw_collections=raw, force=force) 

609 else: 

610 prot = un_protocols[0] 1bac

611 sess_dict = get_sess_dict( 1bac

612 session_path, one, prot, alf_collections=alf_collections, raw_collections=collections, force=force) 

613 

614 if sess_dict is not None: 1bac

615 sess_dicts.append(sess_dict) 1bac

616 

617 protocols = [s['task_protocol'] for s in sess_dicts] 1bac

618 

619 if len(protocols) > 0 and len(set(protocols)) != 1: 1bac

620 print(f'Different protocols on same date {sess_dicts[0]["date"]} : {protocols}') 1a

621 

622 # Only if all protocols are the same and are not habituation 

623 if len(sess_dicts) > 1 and len(set(protocols)) == 1 and protocols[0] != 'habituation': # Only if all protocols are the same 1bac

624 print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}') 1c

625 combined_trials = load_combined_trials(session_paths, one, force=force) 1c

626 performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True) 1c

627 psychs = {} 1c

628 psychs['50'] = training.compute_psychometric(combined_trials, block=0.5) 1c

629 psychs['20'] = training.compute_psychometric(combined_trials, block=0.2) 1c

630 psychs['80'] = training.compute_psychometric(combined_trials, block=0.8) 1c

631 

632 performance_easy = training.compute_performance_easy(combined_trials) 1c

633 reaction_time = training.compute_median_reaction_time(combined_trials) 1c

634 n_trials = training.compute_n_trials(combined_trials) 1c

635 

636 sess_duration = np.nansum([s['sess_duration'] for s in sess_dicts]) 1c

637 n_delay = np.nanmax([s['n_delay'] for s in sess_dicts]) 1c

638 

639 for sess_dict in sess_dicts: 1c

640 sess_dict['combined_performance'] = performance 1c

641 sess_dict['combined_contrasts'] = contrasts 1c

642 sess_dict['combined_performance_easy'] = performance_easy 1c

643 sess_dict['combined_reaction_time'] = reaction_time 1c

644 sess_dict['combined_n_trials'] = n_trials 1c

645 sess_dict['combined_sess_duration'] = sess_duration 1c

646 sess_dict['combined_n_delay'] = n_delay 1c

647 

648 for bias in [50, 20, 80]: 1c

649 sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0] 1c

650 sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1] 1c

651 sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][2] 1c

652 sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][3] 1c

653 

654 # Case where two sessions on same day with different number of contrasts! Oh boy 

655 if sess_dict['combined_performance'].size != sess_dict['performance'].size: 1c

656 sess_dict['performance'] = \ 

657 np.r_[sess_dict['performance'], 

658 np.full(sess_dict['combined_performance'].size - sess_dict['performance'].size, np.nan)] 

659 sess_dict['contrasts'] = \ 

660 np.r_[sess_dict['contrasts'], 

661 np.full(sess_dict['combined_contrasts'].size - sess_dict['contrasts'].size, np.nan)] 

662 

663 else: 

664 for sess_dict in sess_dicts: 1bac

665 sess_dict['combined_performance'] = sess_dict['performance'] 1bac

666 sess_dict['combined_contrasts'] = sess_dict['contrasts'] 1bac

667 sess_dict['combined_performance_easy'] = sess_dict['performance_easy'] 1bac

668 sess_dict['combined_reaction_time'] = sess_dict['reaction_time'] 1bac

669 sess_dict['combined_n_trials'] = sess_dict['n_trials'] 1bac

670 sess_dict['combined_sess_duration'] = sess_dict['sess_duration'] 1bac

671 sess_dict['combined_n_delay'] = sess_dict['n_delay'] 1bac

672 

673 for bias in [50, 20, 80]: 1bac

674 sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}'] 1bac

675 sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}'] 1bac

676 sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}'] 1bac

677 sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}'] 1bac

678 

679 return sess_dicts 1bac

680 

681 

682def check_up_to_date(subj_path, df): 

683 """ 

684 Check which sessions on local file system are missing from the computed training table. 

685 

686 Parameters 

687 ---------- 

688 subj_path : pathlib.Path 

689 The path to the subject's dated session folders. 

690 df : pandas.DataFrame 

691 The computed training table. 

692 

693 Returns 

694 ------- 

695 pandas.DataFrame 

696 A table of dates and session paths that are missing from the computed training table. 

697 """ 

698 df_session = pd.DataFrame(columns=['date', 'session_path']) 1baec

699 

700 for session in alfio.iter_sessions(subj_path, pattern='????-??-??/*'): 1baec

701 s_df = pd.DataFrame({'date': session.parts[-2], 'session_path': str(session)}, index=[0]) 1baec

702 df_session = pd.concat([df_session, s_df], ignore_index=True) 1baec

703 

704 if df is None or 'combined_thres_50' not in df.columns: 1baec

705 return df_session 1baec

706 else: 

707 # recorded_session_paths = df['session_path'].values 

708 isin, _ = ismember(df_session.date.unique(), df.date.unique()) 1ae

709 missing_dates = df_session.date.unique()[~isin] 1ae

710 return df_session[df_session['date'].isin(missing_dates)].sort_values('date') 1ae

711 

712 

713def plot_trial_count_and_session_duration(df, subject): 

714 

715 df = df.drop_duplicates('date').reset_index(drop=True) 1ba

716 

717 y1 = {'column': 'combined_n_trials', 1ba

718 'title': 'Trial counts', 

719 'lim': None, 

720 'color': 'k', 

721 'join': True} 

722 

723 y2 = {'column': 'combined_sess_duration', 1ba

724 'title': 'Session duration (mins)', 

725 'lim': None, 

726 'color': 'r', 

727 'log': False, 

728 'join': True} 

729 

730 ax = plot_over_days(df, subject, y1, y2) 1ba

731 

732 return ax 1ba

733 

734 

735def plot_performance_easy_median_reaction_time(df, subject): 

736 df = df.drop_duplicates('date').reset_index(drop=True) 1ba

737 

738 y1 = {'column': 'combined_performance_easy', 1ba

739 'title': 'Performance on easy trials', 

740 'lim': [0, 1.05], 

741 'color': 'k', 

742 'join': True} 

743 

744 y2 = {'column': 'combined_reaction_time', 1ba

745 'title': 'Median reaction time (s)', 

746 'lim': [0.1, np.nanmax([10, np.nanmax(df.combined_reaction_time.values)])], 

747 'color': 'r', 

748 'log': True, 

749 'join': True} 

750 ax = plot_over_days(df, subject, y1, y2) 1ba

751 

752 return ax 1ba

753 

754 

755def plot_fit_params(df, subject): 

756 fig, axs = plt.subplots(2, 2, figsize=(12, 6)) 1ba

757 axs = axs.ravel() 1ba

758 

759 df = df.drop_duplicates('date').reset_index(drop=True) 1ba

760 

761 cmap = sns.diverging_palette(20, 220, n=3, center="dark") 1ba

762 

763 y50 = {'column': 'combined_bias_50', 1ba

764 'title': 'Bias', 

765 'lim': [-100, 100], 

766 'color': cmap[1], 

767 'join': False} 

768 

769 y80 = {'column': 'combined_bias_80', 1ba

770 'title': 'Bias', 

771 'lim': [-100, 100], 

772 'color': cmap[2], 

773 'join': False} 

774 

775 y20 = {'column': 'combined_bias_20', 1ba

776 'title': 'Bias', 

777 'lim': [-100, 100], 

778 'color': cmap[0], 

779 'join': False} 

780 

781 plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False) 1ba

782 plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False) 1ba

783 plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False) 1ba

784 axs[0].axhline(16, linewidth=2, linestyle='--', color='k') 1ba

785 axs[0].axhline(-16, linewidth=2, linestyle='--', color='k') 1ba

786 

787 y50['column'] = 'combined_thres_50' 1ba

788 y50['title'] = 'Threshold' 1ba

789 y50['lim'] = [0, 100] 1ba

790 y80['column'] = 'combined_thres_20' 1ba

791 y80['title'] = 'Threshold' 1ba

792 y20['lim'] = [0, 100] 1ba

793 y20['column'] = 'combined_thres_80' 1ba

794 y20['title'] = 'Threshold' 1ba

795 y80['lim'] = [0, 100] 1ba

796 

797 plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False) 1ba

798 plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False) 1ba

799 plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False) 1ba

800 axs[1].axhline(19, linewidth=2, linestyle='--', color='k') 1ba

801 

802 y50['column'] = 'combined_lapselow_50' 1ba

803 y50['title'] = 'Lapse Low' 1ba

804 y50['lim'] = [0, 1] 1ba

805 y80['column'] = 'combined_lapselow_20' 1ba

806 y80['title'] = 'Lapse Low' 1ba

807 y80['lim'] = [0, 1] 1ba

808 y20['column'] = 'combined_lapselow_80' 1ba

809 y20['title'] = 'Lapse Low' 1ba

810 y20['lim'] = [0, 1] 1ba

811 

812 plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False) 1ba

813 plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False) 1ba

814 plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False) 1ba

815 axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k') 1ba

816 

817 y50['column'] = 'combined_lapsehigh_50' 1ba

818 y50['title'] = 'Lapse High' 1ba

819 y50['lim'] = [0, 1] 1ba

820 y80['column'] = 'combined_lapsehigh_20' 1ba

821 y80['title'] = 'Lapse High' 1ba

822 y80['lim'] = [0, 1] 1ba

823 y20['column'] = 'combined_lapsehigh_80' 1ba

824 y20['title'] = 'Lapse High' 1ba

825 y20['lim'] = [0, 1] 1ba

826 

827 plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True) 1ba

828 plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False) 1ba

829 plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False) 1ba

830 axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k') 1ba

831 

832 fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1ba

833 lines, labels = axs[3].get_legend_handles_labels() 1ba

834 fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5) 1ba

835 

836 legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8), 1ba

837 Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8), 

838 Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)] 

839 legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True) 1ba

840 fig.add_artist(legend2) 1ba

841 

842 return axs 1ba

843 

844 

845def plot_psychometric_curve(df, subject, one): 

846 df = df.drop_duplicates('date').reset_index(drop=True) 1ba

847 sess_path = Path(df.iloc[-1]["session_path"]) 1ba

848 trials = load_trials(sess_path, one) 1ba

849 

850 fig, ax1 = plt.subplots(figsize=(8, 6)) 1ba

851 

852 training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1ba

853 

854 return ax1 1ba

855 

856 

857def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True): 

858 

859 if ax is None: 1ba

860 fig, ax1 = plt.subplots(figsize=(12, 6)) 1ba

861 else: 

862 ax1 = ax 1ba

863 

864 dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']] 1ba

865 if y1['join']: 1ba

866 ax1.plot(dates, df[y1['column']], color=y1['color']) 1ba

867 ax1.scatter(dates, df[y1['column']], color=y1['color']) 1ba

868 ax1.set_ylabel(y1['title']) 1ba

869 ax1.set_ylim(y1['lim']) 1ba

870 

871 if y2 is not None: 1ba

872 ax2 = ax1.twinx() 1ba

873 if y2['join']: 1ba

874 ax2.plot(dates, df[y2['column']], color=y2['color']) 1ba

875 ax2.scatter(dates, df[y2['column']], color=y2['color']) 1ba

876 ax2.set_ylabel(y2['title']) 1ba

877 ax2.yaxis.label.set_color(y2['color']) 1ba

878 ax2.tick_params(axis='y', colors=y2['color']) 1ba

879 ax2.set_ylim(y2['lim']) 1ba

880 if y2['log']: 1ba

881 ax2.set_yscale('log') 1ba

882 

883 ax2.spines['right'].set_visible(False) 1ba

884 ax2.spines['top'].set_visible(False) 1ba

885 ax2.spines['left'].set_visible(False) 1ba

886 

887 month_format = mdates.DateFormatter('%b %Y') 1ba

888 month_locator = mdates.MonthLocator() 1ba

889 ax1.xaxis.set_major_locator(month_locator) 1ba

890 ax1.xaxis.set_major_formatter(month_format) 1ba

891 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1ba

892 ax1.xaxis.set_minor_locator(week_locator) 1ba

893 ax1.grid(True, which='minor', axis='x', linestyle='--') 1ba

894 

895 ax1.spines['left'].set_visible(False) 1ba

896 ax1.spines['right'].set_visible(False) 1ba

897 ax1.spines['top'].set_visible(False) 1ba

898 

899 if training_lines: 1ba

900 ax1 = add_training_lines(df, ax1) 1ba

901 

902 if title: 1ba

903 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1ba

904 

905 # Put a legend below current axis 

906 box = ax1.get_position() 1ba

907 ax1.set_position([box.x0, box.y0 + box.height * 0.1, 1ba

908 box.width, box.height * 0.9]) 

909 if legend: 1ba

910 ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), 1ba

911 fancybox=True, shadow=True, ncol=5) 

912 

913 return ax1 1ba

914 

915 

916def add_training_lines(df, ax): 

917 

918 status = df.drop_duplicates(subset='training_status', keep='first') 1ba

919 for _, st in status.iterrows(): 1ba

920 

921 if st['training_status'] in ['untrainable', 'unbiasable']: 1ba

922 continue 

923 

924 if TRAINING_STATUS[st['training_status']][0] <= 0: 1ba

925 continue 1a

926 

927 ax.axvline(datetime.strptime(st['date'], '%Y-%m-%d'), linewidth=2, 1ba

928 color=np.array(TRAINING_STATUS[st['training_status']][1]) / 255, label=st['training_status']) 

929 

930 return ax 1ba

931 

932 

933def plot_heatmap_performance_over_days(df, subject): 

934 

935 df = df.drop_duplicates(subset=['date', 'combined_contrasts']) 1ba

936 df_perf = df.pivot(index=['date'], columns=['combined_contrasts'], values=['combined_performance']).sort_values( 1ba

937 by='combined_contrasts', axis=1, ascending=False) 

938 df_perf.index = pd.to_datetime(df_perf.index) 1ba

939 full_date_range = pd.date_range(start=df_perf.index.min(), end=df_perf.index.max(), freq="D") 1ba

940 df_perf = df_perf.reindex(full_date_range, fill_value=np.nan) 1ba

941 

942 n_contrasts = len(df.combined_contrasts.unique()) 1ba

943 

944 dates = df_perf.index.to_pydatetime() 1ba

945 dnum = mdates.date2num(dates) 1ba

946 if len(dnum) > 1: 1ba

947 start = dnum[0] - (dnum[1] - dnum[0]) / 2. 1a

948 stop = dnum[-1] + (dnum[1] - dnum[0]) / 2. 1a

949 else: 

950 start = dnum[0] + 0.5 1ba

951 stop = dnum[0] + 1.5 1ba

952 

953 extent = [start, stop, 0, n_contrasts] 1ba

954 

955 fig, ax1 = plt.subplots(figsize=(12, 6)) 1ba

956 im = ax1.imshow(df_perf.T.values, extent=extent, aspect="auto", cmap='PuOr') 1ba

957 

958 month_format = mdates.DateFormatter('%b %Y') 1ba

959 month_locator = mdates.MonthLocator() 1ba

960 ax1.xaxis.set_major_locator(month_locator) 1ba

961 ax1.xaxis.set_major_formatter(month_format) 1ba

962 week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1) 1ba

963 ax1.xaxis.set_minor_locator(week_locator) 1ba

964 ax1.grid(True, which='minor', axis='x', linestyle='--') 1ba

965 ax1.set_yticks(np.arange(0.5, n_contrasts + 0.5, 1)) 1ba

966 ax1.set_yticklabels(np.sort(df.combined_contrasts.unique())) 1ba

967 ax1.set_ylabel('Contrast (%)') 1ba

968 ax1.set_xlabel('Date') 1ba

969 cbar = fig.colorbar(im) 1ba

970 cbar.set_label('Rightward choice (%') 1ba

971 ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') 1ba

972 

973 ax1.spines['left'].set_visible(False) 1ba

974 ax1.spines['right'].set_visible(False) 1ba

975 ax1.spines['top'].set_visible(False) 1ba

976 

977 return ax1 1ba

978 

979 

980def make_plots(session_path, one, df=None, save=False, upload=False, task_collection='raw_behavior_data'): 

981 subject = one.path2ref(session_path)['subject'] 1ba

982 subj_path = session_path.parent.parent 1ba

983 

984 df = load_existing_dataframe(subj_path) if df is None else df 1ba

985 

986 df = df[df['task_protocol'] != 'habituation'] 1ba

987 

988 if len(df) == 0: 1ba

989 return 

990 

991 ax1 = plot_trial_count_and_session_duration(df, subject) 1ba

992 ax2 = plot_performance_easy_median_reaction_time(df, subject) 1ba

993 ax3 = plot_heatmap_performance_over_days(df, subject) 1ba

994 ax4 = plot_fit_params(df, subject) 1ba

995 ax5 = plot_psychometric_curve(df, subject, one) 1ba

996 

997 outputs = [] 1ba

998 if save: 1ba

999 save_path = Path(subj_path) 1ba

1000 save_name = save_path.joinpath('subj_trial_count_session_duration.png') 1ba

1001 outputs.append(save_name) 1ba

1002 ax1.get_figure().savefig(save_name, bbox_inches='tight') 1ba

1003 

1004 save_name = save_path.joinpath('subj_performance_easy_reaction_time.png') 1ba

1005 outputs.append(save_name) 1ba

1006 ax2.get_figure().savefig(save_name, bbox_inches='tight') 1ba

1007 

1008 save_name = save_path.joinpath('subj_performance_heatmap.png') 1ba

1009 outputs.append(save_name) 1ba

1010 ax3.get_figure().savefig(save_name, bbox_inches='tight') 1ba

1011 

1012 save_name = save_path.joinpath('subj_psychometric_fit_params.png') 1ba

1013 outputs.append(save_name) 1ba

1014 ax4[0].get_figure().savefig(save_name, bbox_inches='tight') 1ba

1015 

1016 save_name = save_path.joinpath('subj_psychometric_curve.png') 1ba

1017 outputs.append(save_name) 1ba

1018 ax5.get_figure().savefig(save_name, bbox_inches='tight') 1ba

1019 

1020 if upload: 1ba

1021 subj = one.alyx.rest('subjects', 'list', nickname=subject)[0] 1ba

1022 snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one) 1ba

1023 snp.outputs = outputs 1ba

1024 snp.register_images(widths=['orig']) 1ba