Coverage for brainbox/behavior/dlc.py: 30%

288 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Set of functions to deal with dlc data.""" 

2import logging 

3import pandas as pd 

4import warnings 

5 

6import numpy as np 

7import matplotlib 

8import matplotlib.pyplot as plt 

9import scipy.interpolate as interpolate 

10from scipy.stats import zscore 

11 

12from ibldsp.smooth import smooth_interpolate_savgol 

13from iblutil.numerical import bincount2D 

14import brainbox.behavior.wheel as bbox_wheel 

15 

16logger = logging.getLogger('ibllib') 

17 

18SAMPLING = {'left': 60, 

19 'right': 150, 

20 'body': 30} 

21RESOLUTION = {'left': 2, 

22 'right': 1, 

23 'body': 1} 

24 

25T_BIN = 0.02 # sec 

26WINDOW_LEN = 2 # sec 

27WINDOW_LAG = -0.5 # sec 

28 

29 

30# For plotting we use a window around the event the data is aligned to WINDOW_LAG before and WINDOW_LEN after the event 

31def plt_window(x): 

32 return x + WINDOW_LAG, x + WINDOW_LEN 

33 

34 

35def insert_idx(array, values): 

36 idx = np.searchsorted(array, values, side="left") 1d

37 # Choose lower index if insertion would be after last index or if lower index is closer 

38 idx[idx == len(array)] -= 1 1d

39 idx[np.where(abs(values - array[idx - 1]) < abs(values - array[idx]))] -= 1 1d

40 # If 0 index was reduced, revert 

41 idx[idx == -1] = 0 1d

42 if np.all(idx == 0): 1d

43 raise ValueError('Something is wrong, all values to insert are outside of the array.') 

44 return idx 1d

45 

46 

47def likelihood_threshold(dlc, threshold=0.9): 

48 """ 

49 Set dlc points with likelihood less than threshold to nan. 

50 

51 FIXME Add unit test. 

52 :param dlc: dlc pqt object 

53 :param threshold: likelihood threshold 

54 :return: 

55 """ 

56 features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc.keys()]) 1aebc

57 for feat in features: 1aebc

58 nan_fill = dlc[f'{feat}_likelihood'] < threshold 1aebc

59 dlc.loc[nan_fill, (f'{feat}_x', f'{feat}_y')] = np.nan 1aebc

60 return dlc 1aebc

61 

62 

63def get_speed(dlc, dlc_t, camera, feature='paw_r'): 

64 """ 

65 FIXME Document and add unit test! 

66 

67 :param dlc: dlc pqt table 

68 :param dlc_t: dlc time points 

69 :param camera: camera type e.g 'left', 'right', 'body' 

70 :param feature: dlc feature to compute speed over 

71 :return: 

72 """ 

73 x = dlc[f'{feature}_x'] / RESOLUTION[camera] 

74 y = dlc[f'{feature}_y'] / RESOLUTION[camera] 

75 

76 # get speed in px/sec [half res] 

77 s = ((np.diff(x) ** 2 + np.diff(y) ** 2) ** .5) * SAMPLING[camera] 

78 

79 dt = np.diff(dlc_t) 

80 tv = dlc_t[:-1] + dt / 2 

81 

82 # interpolate over original time scale 

83 if tv.size > 1: 

84 ifcn = interpolate.interp1d(tv, s, fill_value="extrapolate") 

85 return ifcn(dlc_t) 

86 

87 

88def get_speed_for_features(dlc, dlc_t, camera, features=['paw_r', 'paw_l', 'nose_tip']): 

89 """ 

90 Wrapper to compute speed for a number of dlc features and add them to dlc table 

91 :param dlc: dlc pqt table 

92 :param dlc_t: dlc time points 

93 :param camera: camera type e.g 'left', 'right', 'body' 

94 :param features: dlc features to compute speed for 

95 :return: 

96 """ 

97 for feat in features: 

98 dlc[f'{feat}_speed'] = get_speed(dlc, dlc_t, camera, feat) 

99 

100 return dlc 

101 

102 

103def get_feature_event_times(dlc, dlc_t, features): 

104 """ 

105 Detect events from the dlc traces. Based on the standard deviation between frames 

106 :param dlc: dlc pqt table 

107 :param dlc_t: dlc times 

108 :param features: features to consider 

109 :return: 

110 """ 

111 

112 for i, feat in enumerate(features): 

113 f = dlc[feat] 

114 threshold = np.nanstd(np.diff(f)) / 4 

115 if i == 0: 

116 events = np.where(np.abs(np.diff(f)) > threshold)[0] 

117 else: 

118 events = np.r_[events, np.where(np.abs(np.diff(f)) > threshold)[0]] 

119 

120 return dlc_t[np.unique(events)] 

121 

122 

123def get_licks(dlc, dlc_t): 

124 """ 

125 Compute lick times from the tongue dlc points 

126 :param dlc: dlc pqt table 

127 :param dlc_t: dlc times 

128 :return: 

129 """ 

130 lick_times = get_feature_event_times(dlc, dlc_t, ['tongue_end_l_x', 'tongue_end_l_y', 

131 'tongue_end_r_x', 'tongue_end_r_y']) 

132 return lick_times 

133 

134 

135def get_sniffs(dlc, dlc_t): 

136 """ 

137 Compute sniff times from the nose tip 

138 :param dlc: dlc pqt table 

139 :param dlc_t: dlc times 

140 :return: 

141 """ 

142 

143 sniff_times = get_feature_event_times(dlc, dlc_t, ['nose_tip_y']) 

144 return sniff_times 

145 

146 

147def get_dlc_everything(dlc_cam, camera): 

148 """ 

149 Get out features of interest for dlc 

150 :param dlc_cam: dlc object 

151 :param camera: camera type e.g 'left', 'right' 

152 :return: 

153 """ 

154 

155 aligned = True 

156 if dlc_cam.times.shape[0] != dlc_cam.dlc.shape[0]: 

157 # logger warning and print out status of the qc, specific serializer django! 

158 logger.warning('Dimension mismatch between dlc points and timestamps') 

159 min_samps = min(dlc_cam.times.shape[0], dlc_cam.dlc.shape[0]) 

160 dlc_cam.times = dlc_cam.times[:min_samps] 

161 dlc_cam.dlc = dlc_cam.dlc[:min_samps] 

162 aligned = False 

163 

164 dlc_cam.dlc = likelihood_threshold(dlc_cam.dlc) 

165 dlc_cam.dlc = get_speed_for_features(dlc_cam.dlc, dlc_cam.times, camera) 

166 dlc_cam['licks'] = get_licks(dlc_cam.dlc, dlc_cam.times) 

167 dlc_cam['sniffs'] = get_sniffs(dlc_cam.dlc, dlc_cam.times) 

168 dlc_cam['aligned'] = aligned 

169 

170 return dlc_cam 

171 

172 

173def get_pupil_diameter(dlc): 

174 """ 

175 Estimates pupil diameter by taking median of different computations. 

176 

177 The two most straightforward estimates: d1 = top - bottom, d2 = left - right 

178 In addition, assume the pupil is a circle and estimate diameter from other pairs of points 

179 

180 :param dlc: dlc pqt table with pupil estimates, should be likelihood thresholded (e.g. at 0.9) 

181 :return: np.array, pupil diameter estimate for each time point, shape (n_frames,) 

182 """ 

183 diameters = [] 1abc

184 # Get the x,y coordinates of the four pupil points 

185 top, bottom, left, right = [np.vstack((dlc[f'pupil_{point}_r_x'], dlc[f'pupil_{point}_r_y'])) 1abc

186 for point in ['top', 'bottom', 'left', 'right']] 

187 # First compute direct diameters 

188 diameters.append(np.linalg.norm(top - bottom, axis=0)) 1abc

189 diameters.append(np.linalg.norm(left - right, axis=0)) 1abc

190 

191 # For non-crossing edges, estimate diameter via circle assumption 

192 for pair in [(top, left), (top, right), (bottom, left), (bottom, right)]: 1abc

193 diameters.append(np.linalg.norm(pair[0] - pair[1], axis=0) * 2 ** 0.5) 1abc

194 

195 # Ignore all nan runtime warning 

196 with warnings.catch_warnings(): 1abc

197 warnings.simplefilter("ignore", category=RuntimeWarning) 1abc

198 return np.nanmedian(diameters, axis=0) 1abc

199 

200 

201def get_smooth_pupil_diameter(diameter_raw, camera, std_thresh=5, nan_thresh=1): 

202 """ 

203 :param diameter_raw: np.array, raw pupil diameters, calculated from (thresholded) dlc traces 

204 :param camera: str ('left', 'right'), which camera to run the smoothing for 

205 :param std_thresh: threshold (in standard deviations) beyond which a point is labeled as an outlier 

206 :param nan_thresh: threshold (in seconds) above which we will not interpolate nans, but keep them 

207 (for long stretches interpolation may not be appropriate) 

208 :return: 

209 """ 

210 # set framerate of camera 

211 if camera == 'left': 1abc

212 fr = SAMPLING['left'] # set by hardware 1abc

213 window = 31 # works well empirically 1abc

214 elif camera == 'right': 

215 fr = SAMPLING['right'] # set by hardware 

216 window = 75 # works well empirically 

217 else: 

218 raise NotImplementedError("camera has to be 'left' or 'right") 

219 

220 # Raise error if too many NaN time points, in this case it doesn't make sense to interpolate 

221 if np.mean(np.isnan(diameter_raw)) > 0.9: 1abc

222 raise ValueError(f"Raw pupil diameter for {camera} is too often NaN, cannot smooth.") 

223 # run savitzy-golay filter on non-nan time points to denoise 

224 diameter_smoothed = smooth_interpolate_savgol(diameter_raw, window=window, order=3, interp_kind='linear') 1abc

225 

226 # find outliers and set them to nan 

227 difference = diameter_raw - diameter_smoothed 1abc

228 outlier_thresh = std_thresh * np.nanstd(difference) 1abc

229 without_outliers = np.copy(diameter_raw) 1abc

230 without_outliers[(difference < -outlier_thresh) | (difference > outlier_thresh)] = np.nan 1abc

231 # run savitzy-golay filter again on (possibly reduced) non-nan timepoints to denoise 

232 diameter_smoothed = smooth_interpolate_savgol(without_outliers, window=window, order=3, interp_kind='linear') 1abc

233 

234 # don't interpolate long strings of nans 

235 t = np.diff(np.isnan(without_outliers).astype(int)) 1abc

236 begs = np.where(t == 1)[0] 1abc

237 ends = np.where(t == -1)[0] 1abc

238 if begs.shape[0] > ends.shape[0]: 1abc

239 begs = begs[:ends.shape[0]] 

240 for b, e in zip(begs, ends): 1abc

241 if (e - b) > (fr * nan_thresh): 1abc

242 diameter_smoothed[(b + 1):(e + 1)] = np.nan # offset by 1 due to earlier diff 

243 

244 return diameter_smoothed 1abc

245 

246 

247def plot_trace_on_frame(frame, dlc_df, cam): 

248 """ 

249 Plots dlc traces as scatter plots on a frame of the video. 

250 For left and right video also plots whisker pad and eye and tongue zoom. 

251 

252 :param frame: np.array, single video frame to plot on 

253 :param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace 

254 :param cam: str, which camera to process ('left', 'right', 'body') 

255 :returns: matplolib.axis 

256 """ 

257 # Define colors 

258 colors = {'tail_start': '#636EFA', 

259 'nose_tip': '#636EFA', 

260 'paw_l': '#EF553B', 

261 'paw_r': '#00CC96', 

262 'pupil_bottom_r': '#AB63FA', 

263 'pupil_left_r': '#FFA15A', 

264 'pupil_right_r': '#19D3F3', 

265 'pupil_top_r': '#FF6692', 

266 'tongue_end_l': '#B6E880', 

267 'tongue_end_r': '#FF97FF'} 

268 # Threshold the dlc traces 

269 dlc_df = likelihood_threshold(dlc_df) 

270 # Features without tube 

271 features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc_df.keys() if 'tube' not in x]) 

272 # Normalize the number of points across cameras 

273 dlc_df_norm = pd.DataFrame() 

274 for feat in features: 

275 dlc_df_norm[f'{feat}_x'] = dlc_df[f'{feat}_x'][0::int(SAMPLING[cam] / 10)] 

276 dlc_df_norm[f'{feat}_y'] = dlc_df[f'{feat}_y'][0::int(SAMPLING[cam] / 10)] 

277 # Scatter 

278 plt.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=0.05, s=2, label=feat, c=colors[feat]) 

279 

280 plt.axis('off') 

281 plt.imshow(frame, cmap='gray') 

282 plt.tight_layout() 

283 

284 ax = plt.gca() 

285 if cam == 'body': 

286 plt.title(f'{cam.capitalize()} camera') 

287 return ax 

288 # For left and right cam plot whisker pad rectangle 

289 # heuristic: square with side length half the distance between nose and pupil and anchored on midpoint 

290 p_nose = np.array(dlc_df[['nose_tip_x', 'nose_tip_y']].mean()) 

291 p_pupil = np.array(dlc_df[['pupil_top_r_x', 'pupil_top_r_y']].mean()) 

292 p_anchor = np.mean([p_nose, p_pupil], axis=0) 

293 dist = np.linalg.norm(p_nose - p_pupil) 

294 rect = matplotlib.patches.Rectangle((int(p_anchor[0] - dist / 4), int(p_anchor[1])), int(dist / 2), int(dist / 3), 

295 linewidth=1, edgecolor='lime', facecolor='none') 

296 ax.add_patch(rect) 

297 # Plot eye region zoom 

298 inset_anchor = 0 if cam == 'right' else 0.5 

299 ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5]) 

300 ax_ins.imshow(frame, cmap='gray', origin="lower") 

301 for feat in features: 

302 ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat]) 

303 ax_ins.set_xlim(int(p_pupil[0] - 33 * RESOLUTION[cam] / 2), int(p_pupil[0] + 33 * RESOLUTION[cam] / 2)) 

304 ax_ins.set_ylim(int(p_pupil[1] + 38 * RESOLUTION[cam] / 2), int(p_pupil[1] - 28 * RESOLUTION[cam] / 2)) 

305 ax_ins.axis('off') 

306 # Plot tongue region zoom 

307 p1 = np.array(dlc_df[['tube_top_x', 'tube_top_y']].mean()) 

308 p2 = np.array(dlc_df[['tube_bottom_x', 'tube_bottom_y']].mean()) 

309 p_tongue = np.nanmean([p1, p2], axis=0) 

310 inset_anchor = 0 if cam == 'left' else 0.5 

311 ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5]) 

312 ax_ins.imshow(frame, cmap='gray', origin="upper") 

313 for feat in features: 

314 ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat]) 

315 ax_ins.set_xlim(int(p_tongue[0] - 60 * RESOLUTION[cam] / 2), int(p_tongue[0] + 100 * RESOLUTION[cam] / 2)) 

316 ax_ins.set_ylim(int(p_tongue[1] + 60 * RESOLUTION[cam] / 2), int(p_tongue[1] - 100 * RESOLUTION[cam] / 2)) 

317 ax_ins.axis('off') 

318 

319 plt.title(f'{cam.capitalize()} camera') 

320 return ax 

321 

322 

323def plot_wheel_position(wheel_position, wheel_time, trials_df): 

324 """ 

325 Plots wheel position across trials, color by which side was chosen 

326 

327 :param wheel_position: np.array, interpolated wheel position 

328 :param wheel_time: np.array, interpolated wheel timestamps 

329 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset times for each trial) 

330 :returns: matplotlib.axis 

331 """ 

332 # Interpolate wheel data 

333 wheel_position, wheel_time = bbox_wheel.interpolate_position(wheel_time, wheel_position, freq=1 / T_BIN) 

334 # Create a window around the stimulus onset 

335 start_window, end_window = plt_window(trials_df['stimOn_times']) 

336 # Translating the time window into an index window 

337 start_idx = insert_idx(wheel_time, start_window) 

338 end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64') 

339 # Getting the wheel position for each window, normalize to first value of each window 

340 trials_df['wheel_position'] = [wheel_position[start_idx[w]: end_idx[w]] - wheel_position[start_idx[w]] 

341 for w in range(len(start_idx))] 

342 # Plotting 

343 times = np.arange(len(trials_df['wheel_position'].iloc[0])) * T_BIN + WINDOW_LAG 

344 for side, label, color in zip([-1, 1], ['right', 'left'], ['darkred', '#1f77b4']): 

345 side_df = trials_df[trials_df['choice'] == side] 

346 for idx in side_df.index: 

347 plt.plot(times, side_df.loc[idx, 'wheel_position'], c=color, alpha=0.5, linewidth=0.05) 

348 plt.plot(times, side_df['wheel_position'].mean(), c=color, linewidth=2, label=f'{label} turn') 

349 

350 plt.axvline(x=0, linestyle='--', c='k', label='stimOn') 

351 plt.axhline(y=-0.26, linestyle='--', c='g', label='reward') 

352 plt.axhline(y=0.26, linestyle='--', c='g', label='reward') 

353 plt.ylim([-0.27, 0.27]) 

354 plt.xlabel('time [sec]') 

355 plt.ylabel('wheel position diff to first value [rad]') 

356 plt.legend(loc='center right') 

357 plt.title('Wheel position trial avg\n(and individual trials)') 

358 plt.tight_layout() 

359 

360 return plt.gca() 

361 

362 

363def _bin_window_licks(lick_times, trials_df): 

364 """ 

365 Helper function to bin and window the lick times and get them into trials df for plotting 

366 

367 :param lick_times: np.array, timestamps of lick events 

368 :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) 

369 :returns: pd.DataFrame with binned, windowed lick times for plotting 

370 """ 

371 # Bin the licks 

372 lick_bins, bin_times, _ = bincount2D(lick_times, np.ones(len(lick_times)), T_BIN) 

373 lick_bins = np.squeeze(lick_bins) 

374 start_window, end_window = plt_window(trials_df['feedback_times']) 

375 # Translating the time window into an index window 

376 try: 

377 start_idx = insert_idx(bin_times, start_window) 

378 except ValueError: 

379 logger.error('Lick time stamps are outside of the trials windows') 

380 raise 

381 end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64') 

382 # Get the binned licks for each window 

383 trials_df['lick_bins'] = [lick_bins[start_idx[i]:end_idx[i]] for i in range(len(start_idx))] 

384 # Remove windows that the exceed bins 

385 trials_df['end_idx'] = end_idx 

386 trials_df = trials_df[trials_df['end_idx'] <= len(lick_bins)] 

387 return trials_df 

388 

389 

390def plot_lick_hist(lick_times, trials_df): 

391 """ 

392 Plots histogramm of lick events aligned to feedback time, separate for correct and incorrect trials 

393 

394 :param lick_times: np.array, timestamps of lick events 

395 :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and 

396 'feedbackType' (1 for correct, -1 for incorrect trials) 

397 :returns: matplotlib axis 

398 """ 

399 licks_df = _bin_window_licks(lick_times, trials_df) 

400 # Plot 

401 times = np.arange(len(licks_df['lick_bins'].iloc[0])) * T_BIN + WINDOW_LAG 

402 correct = licks_df[licks_df['feedbackType'] == 1]['lick_bins'] 

403 incorrect = licks_df[licks_df['feedbackType'] == -1]['lick_bins'] 

404 plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1), 

405 c='k', label='correct trial') 

406 plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, incorrect.values))).mean(axis=1), 

407 c='gray', label='incorrect trial') 

408 plt.axvline(x=0, label='feedback', linestyle='--', c='purple') 

409 plt.title('Lick events trial avg') 

410 plt.xlabel('time [sec]') 

411 plt.ylabel('lick events [a.u.]') 

412 plt.legend(loc='lower right') 

413 return plt.gca() 

414 

415 

416def plot_lick_raster(lick_times, trials_df): 

417 """ 

418 Plots lick raster for correct trials 

419 

420 :param lick_times: np.array, timestamps of lick events 

421 :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and 

422 feedbackType (1 for correct, -1 for incorrect trials) 

423 :returns: matplotlib.axis 

424 """ 

425 licks_df = _bin_window_licks(lick_times, trials_df) 

426 plt.imshow(list(licks_df[licks_df['feedbackType'] == 1]['lick_bins']), aspect='auto', 

427 extent=[-0.5, 1.5, len(licks_df['lick_bins'].iloc[0]), 0], cmap='gray_r') 

428 plt.xticks([-0.5, 0, 0.5, 1, 1.5]) 

429 plt.ylabel('trials') 

430 plt.xlabel('time [sec]') 

431 plt.axvline(x=0, label='feedback', linestyle='--', c='purple') 

432 plt.title('Lick events per correct trial') 

433 plt.tight_layout() 

434 return plt.gca() 

435 

436 

437def plot_motion_energy_hist(camera_dict, trials_df): 

438 """ 

439 Plots mean motion energy of given cameras, aligned to stimulus onset. 

440 

441 :param camera_dict: dict, one key for each camera to be plotted (e.g. 'left'), value is another dict with items 

442 'motion_energy' (np.array, motion energy calculated from this camera) and 

443 'times' (np.array, camera timestamps) 

444 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial) 

445 :returns: matplotlib.axis 

446 """ 

447 colors = {'left': '#bd7a98', 

448 'right': '#2b6f39', 

449 'body': '#035382'} 

450 

451 start_window, end_window = plt_window(trials_df['stimOn_times']) 

452 missing_data = [] 

453 for cam in camera_dict.keys(): 

454 if (camera_dict[cam]['motion_energy'] is not None and len(camera_dict[cam]['motion_energy']) > 0 

455 and camera_dict[cam]['times'] is not None and len(camera_dict[cam]['times']) > 0): 

456 try: 

457 motion_energy = zscore(camera_dict[cam]['motion_energy'], nan_policy='omit') 

458 try: 

459 start_idx = insert_idx(camera_dict[cam]['times'], start_window) 

460 end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64') 

461 me_all = [motion_energy[start_idx[i]:end_idx[i]] for i in range(len(start_idx))] 

462 me_all = [m for m in me_all if len(m) > 0] 

463 times = np.arange(len(me_all[0])) / SAMPLING[cam] + WINDOW_LAG 

464 me_mean = np.mean(me_all, axis=0) 

465 me_std = np.std(me_all, axis=0) / np.sqrt(len(me_all)) 

466 plt.plot(times, me_mean, label=f'{cam} cam', color=colors[cam], linewidth=2) 

467 plt.fill_between(times, me_mean + me_std, me_mean - me_std, color=colors[cam], alpha=0.2) 

468 except ValueError: 

469 logger.error(f"{cam}Camera camera.times are outside of the trial windows") 

470 missing_data.append(cam) 

471 except AttributeError: 

472 logger.warning(f"Cannot load motion energy and/or times data for {cam} camera") 

473 missing_data.append(cam) 

474 else: 

475 logger.warning(f"Data missing or empty for motion energy and/or times data for {cam} camera") 

476 missing_data.append(cam) 

477 

478 plt.xticks([-0.5, 0, 0.5, 1, 1.5]) 

479 plt.ylabel('z-scored motion energy [a.u.]') 

480 plt.xlabel('time [sec]') 

481 plt.axvline(x=0, label='stimOn', linestyle='--', c='k') 

482 plt.legend(loc='lower right') 

483 plt.title('Motion Energy trial avg\n(+/- std)') 

484 if len(missing_data) > 0: 

485 ax = plt.gca() 

486 ax.text(.95, .35, f"Data incomplete for\n{' and '.join(missing_data)} camera", color='r', fontsize=10, 

487 fontweight='bold', horizontalalignment='right', verticalalignment='center', transform=ax.transAxes) 

488 return plt.gca() 

489 

490 

491def plot_speed_hist(dlc_df, cam_times, trials_df, feature='paw_r', cam='left', legend=True): 

492 """ 

493 Plots speed histogram of a given dlc feature, aligned to stimulus onset, separate for correct and incorrect trials 

494 

495 :param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace 

496 :param cam_times: np.array, camera timestamps 

497 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial) 

498 :param feature: str, feature with trace in dlc_df for which to plot speed hist, default is 'paw_r' 

499 :param cam: str, camera to use ('body', 'left', 'right') default is 'left' 

500 :param legend: bool, whether to add legend to the plot, default is True 

501 :returns: matplotlib.axis 

502 """ 

503 # Threshold the dlc traces 

504 dlc_df = likelihood_threshold(dlc_df) 

505 # For pre-GPIO sessions, remove the first few timestamps to match the number of frames 

506 cam_times = cam_times[-len(dlc_df):] 

507 if len(cam_times) != len(dlc_df): 

508 raise ValueError("Camera times length and DLC length are inconsistent") 

509 # Get speeds 

510 speeds = get_speed(dlc_df, cam_times, camera=cam, feature=feature) 

511 # Windows aligned to align_to 

512 start_window, end_window = plt_window(trials_df['stimOn_times']) 

513 start_idx = insert_idx(cam_times, start_window) 

514 end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64') 

515 # Add speeds to trials_df 

516 trials_df[f'speed_{feature}'] = [speeds[start_idx[i]:end_idx[i]] for i in range(len(start_idx))] 

517 # Plot 

518 times = np.arange(len(trials_df[f'speed_{feature}'].iloc[0])) / SAMPLING[cam] + WINDOW_LAG 

519 # Need to expand the series of lists into a dataframe first, for the nan skipping to work 

520 correct = trials_df[trials_df['feedbackType'] == 1][f'speed_{feature}'] 

521 incorrect = trials_df[trials_df['feedbackType'] == -1][f'speed_{feature}'] 

522 plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1), 

523 c='k', label='correct trial') 

524 plt.plot(times, pd.DataFrame.from_dict(dict(zip(incorrect.index, incorrect.values))).mean(axis=1), 

525 c='gray', label='incorrect trial') 

526 plt.axvline(x=0, label='stimOn', linestyle='--', c='r') 

527 plt.title(f'{feature.capitalize()} speed trial avg\n({cam.upper()} cam)') 

528 plt.xticks([-0.5, 0, 0.5, 1, 1.5]) 

529 plt.xlabel('time [sec]') 

530 plt.ylabel('speed [px/sec]') 

531 if legend: 

532 plt.legend() 

533 

534 return plt.gca() 

535 

536 

537def plot_pupil_diameter_hist(pupil_diameter, cam_times, trials_df, cam='left'): 

538 """ 

539 Plots histogram of pupil diameter aligned to simulus onset and feedback time. 

540 

541 :param pupil_diameter: np.array, (smoothed) pupil diameter estimate 

542 :param cam_times: np.array, camera timestamps 

543 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial) and 

544 feedback_times (time of feedback for each trial) 

545 :param cam: str, camera to use ('body', 'left', 'right') default is 'left' 

546 :returns: matplotlib.axis 

547 """ 

548 for align_to, color in zip(['stimOn_times', 'feedback_times'], ['red', 'purple']): 

549 start_window, end_window = plt_window(trials_df[align_to]) 

550 start_idx = insert_idx(cam_times, start_window) 

551 end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64') 

552 # Per trial norm 

553 pupil_all = [zscore(list(pupil_diameter[start_idx[i]:end_idx[i]])) for i in range(len(start_idx))] 

554 pupil_all_norm = [trial - trial[0] for trial in pupil_all] 

555 

556 pupil_mean = np.nanmean(pupil_all_norm, axis=0) 

557 pupil_std = np.nanstd(pupil_all_norm, axis=0) / np.sqrt(len(pupil_all_norm)) 

558 times = np.arange(len(pupil_all_norm[0])) / SAMPLING[cam] + WINDOW_LAG 

559 

560 plt.plot(times, pupil_mean, label=align_to.split("_")[0], color=color) 

561 plt.fill_between(times, pupil_mean + pupil_std, pupil_mean - pupil_std, color=color, alpha=0.5) 

562 plt.axvline(x=0, linestyle='--', c='k') 

563 plt.title(f'Pupil diameter trial avg\n({cam.upper()} cam)') 

564 plt.xlabel('time [sec]') 

565 plt.xticks([-0.5, 0, 0.5, 1, 1.5]) 

566 plt.ylabel('z-scored smoothed pupil diameter [px]') 

567 plt.legend(loc='lower right', title='aligned to')