Coverage for brainbox/behavior/dlc.py: 84%
289 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1"""
2Set of functions to deal with dlc data
3"""
4import logging
5import pandas as pd
6import warnings
8import numpy as np
9import matplotlib
10import matplotlib.pyplot as plt
11import scipy.interpolate as interpolate
12from scipy.stats import zscore
14from neurodsp.smooth import smooth_interpolate_savgol
15from iblutil.numerical import bincount2D
16import brainbox.behavior.wheel as bbox_wheel
18logger = logging.getLogger('ibllib')
20SAMPLING = {'left': 60,
21 'right': 150,
22 'body': 30}
23RESOLUTION = {'left': 2,
24 'right': 1,
25 'body': 1}
27T_BIN = 0.02 # sec
28WINDOW_LEN = 2 # sec
29WINDOW_LAG = -0.5 # sec
32# For plotting we use a window around the event the data is aligned to WINDOW_LAG before and WINDOW_LEN after the event
33def plt_window(x):
34 return x + WINDOW_LAG, x + WINDOW_LEN 1a
37def insert_idx(array, values):
38 idx = np.searchsorted(array, values, side="left") 1ea
39 # Choose lower index if insertion would be after last index or if lower index is closer
40 idx[idx == len(array)] -= 1 1ea
41 idx[np.where(abs(values - array[idx - 1]) < abs(values - array[idx]))] -= 1 1ea
42 # If 0 index was reduced, revert
43 idx[idx == -1] = 0 1ea
44 if np.all(idx == 0): 1ea
45 raise ValueError('Something is wrong, all values to insert are outside of the array.')
46 return idx 1ea
49def likelihood_threshold(dlc, threshold=0.9):
50 """
51 Set dlc points with likelihood less than threshold to nan
52 :param dlc: dlc pqt object
53 :param threshold: likelihood threshold
54 :return:
55 """
56 features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc.keys()]) 1bfcda
57 for feat in features: 1bfcda
58 nan_fill = dlc[f'{feat}_likelihood'] < threshold 1bfcda
59 dlc[f'{feat}_x'][nan_fill] = np.nan 1bfcda
60 dlc[f'{feat}_y'][nan_fill] = np.nan 1bfcda
62 return dlc 1bfcda
65def get_speed(dlc, dlc_t, camera, feature='paw_r'):
66 """
68 :param dlc: dlc pqt table
69 :param dlc_t: dlc time points
70 :param camera: camera type e.g 'left', 'right', 'body'
71 :param feature: dlc feature to compute speed over
72 :return:
73 """
74 x = dlc[f'{feature}_x'] / RESOLUTION[camera] 1a
75 y = dlc[f'{feature}_y'] / RESOLUTION[camera] 1a
77 # get speed in px/sec [half res]
78 s = ((np.diff(x) ** 2 + np.diff(y) ** 2) ** .5) * SAMPLING[camera] 1a
80 dt = np.diff(dlc_t) 1a
81 tv = dlc_t[:-1] + dt / 2 1a
83 # interpolate over original time scale
84 if tv.size > 1: 1a
85 ifcn = interpolate.interp1d(tv, s, fill_value="extrapolate") 1a
86 return ifcn(dlc_t) 1a
89def get_speed_for_features(dlc, dlc_t, camera, features=['paw_r', 'paw_l', 'nose_tip']):
90 """
91 Wrapper to compute speed for a number of dlc features and add them to dlc table
92 :param dlc: dlc pqt table
93 :param dlc_t: dlc time points
94 :param camera: camera type e.g 'left', 'right', 'body'
95 :param features: dlc features to compute speed for
96 :return:
97 """
98 for feat in features:
99 dlc[f'{feat}_speed'] = get_speed(dlc, dlc_t, camera, feat)
101 return dlc
104def get_feature_event_times(dlc, dlc_t, features):
105 """
106 Detect events from the dlc traces. Based on the standard deviation between frames
107 :param dlc: dlc pqt table
108 :param dlc_t: dlc times
109 :param features: features to consider
110 :return:
111 """
113 for i, feat in enumerate(features): 1ba
114 f = dlc[feat] 1ba
115 threshold = np.nanstd(np.diff(f)) / 4 1ba
116 if i == 0: 1ba
117 events = np.where(np.abs(np.diff(f)) > threshold)[0] 1ba
118 else:
119 events = np.r_[events, np.where(np.abs(np.diff(f)) > threshold)[0]] 1ba
121 return dlc_t[np.unique(events)] 1ba
124def get_licks(dlc, dlc_t):
125 """
126 Compute lick times from the tongue dlc points
127 :param dlc: dlc pqt table
128 :param dlc_t: dlc times
129 :return:
130 """
131 lick_times = get_feature_event_times(dlc, dlc_t, ['tongue_end_l_x', 'tongue_end_l_y', 1ba
132 'tongue_end_r_x', 'tongue_end_r_y'])
133 return lick_times 1ba
136def get_sniffs(dlc, dlc_t):
137 """
138 Compute sniff times from the nose tip
139 :param dlc: dlc pqt table
140 :param dlc_t: dlc times
141 :return:
142 """
144 sniff_times = get_feature_event_times(dlc, dlc_t, ['nose_tip_y'])
145 return sniff_times
148def get_dlc_everything(dlc_cam, camera):
149 """
150 Get out features of interest for dlc
151 :param dlc_cam: dlc object
152 :param camera: camera type e.g 'left', 'right'
153 :return:
154 """
156 aligned = True
157 if dlc_cam.times.shape[0] != dlc_cam.dlc.shape[0]:
158 # logger warning and print out status of the qc, specific serializer django!
159 logger.warning('Dimension mismatch between dlc points and timestamps')
160 min_samps = min(dlc_cam.times.shape[0], dlc_cam.dlc.shape[0])
161 dlc_cam.times = dlc_cam.times[:min_samps]
162 dlc_cam.dlc = dlc_cam.dlc[:min_samps]
163 aligned = False
165 dlc_cam.dlc = likelihood_threshold(dlc_cam.dlc)
166 dlc_cam.dlc = get_speed_for_features(dlc_cam.dlc, dlc_cam.times, camera)
167 dlc_cam['licks'] = get_licks(dlc_cam.dlc, dlc_cam.times)
168 dlc_cam['sniffs'] = get_sniffs(dlc_cam.dlc, dlc_cam.times)
169 dlc_cam['aligned'] = aligned
171 return dlc_cam
174def get_pupil_diameter(dlc):
175 """
176 Estimates pupil diameter by taking median of different computations.
178 The two most straightforward estimates: d1 = top - bottom, d2 = left - right
179 In addition, assume the pupil is a circle and estimate diameter from other pairs of points
181 :param dlc: dlc pqt table with pupil estimates, should be likelihood thresholded (e.g. at 0.9)
182 :return: np.array, pupil diameter estimate for each time point, shape (n_frames,)
183 """
184 diameters = [] 1bcda
185 # Get the x,y coordinates of the four pupil points
186 top, bottom, left, right = [np.vstack((dlc[f'pupil_{point}_r_x'], dlc[f'pupil_{point}_r_y'])) 1bcda
187 for point in ['top', 'bottom', 'left', 'right']]
188 # First compute direct diameters
189 diameters.append(np.linalg.norm(top - bottom, axis=0)) 1bcda
190 diameters.append(np.linalg.norm(left - right, axis=0)) 1bcda
192 # For non-crossing edges, estimate diameter via circle assumption
193 for pair in [(top, left), (top, right), (bottom, left), (bottom, right)]: 1bcda
194 diameters.append(np.linalg.norm(pair[0] - pair[1], axis=0) * 2 ** 0.5) 1bcda
196 # Ignore all nan runtime warning
197 with warnings.catch_warnings(): 1bcda
198 warnings.simplefilter("ignore", category=RuntimeWarning) 1bcda
199 return np.nanmedian(diameters, axis=0) 1bcda
202def get_smooth_pupil_diameter(diameter_raw, camera, std_thresh=5, nan_thresh=1):
203 """
204 :param diameter_raw: np.array, raw pupil diameters, calculated from (thresholded) dlc traces
205 :param camera: str ('left', 'right'), which camera to run the smoothing for
206 :param std_thresh: threshold (in standard deviations) beyond which a point is labeled as an outlier
207 :param nan_thresh: threshold (in seconds) above which we will not interpolate nans, but keep them
208 (for long stretches interpolation may not be appropriate)
209 :return:
210 """
211 # set framerate of camera
212 if camera == 'left': 1bcda
213 fr = SAMPLING['left'] # set by hardware 1bcda
214 window = 31 # works well empirically 1bcda
215 elif camera == 'right': 1ba
216 fr = SAMPLING['right'] # set by hardware 1ba
217 window = 75 # works well empirically 1ba
218 else:
219 raise NotImplementedError("camera has to be 'left' or 'right")
221 # Raise error if too many NaN time points, in this case it doesn't make sense to interpolate
222 if np.mean(np.isnan(diameter_raw)) > 0.9: 1bcda
223 raise ValueError(f"Raw pupil diameter for {camera} is too often NaN, cannot smooth.")
224 # run savitzy-golay filter on non-nan time points to denoise
225 diameter_smoothed = smooth_interpolate_savgol(diameter_raw, window=window, order=3, interp_kind='linear') 1bcda
227 # find outliers and set them to nan
228 difference = diameter_raw - diameter_smoothed 1bcda
229 outlier_thresh = std_thresh * np.nanstd(difference) 1bcda
230 without_outliers = np.copy(diameter_raw) 1bcda
231 without_outliers[(difference < -outlier_thresh) | (difference > outlier_thresh)] = np.nan 1bcda
232 # run savitzy-golay filter again on (possibly reduced) non-nan timepoints to denoise
233 diameter_smoothed = smooth_interpolate_savgol(without_outliers, window=window, order=3, interp_kind='linear') 1bcda
235 # don't interpolate long strings of nans
236 t = np.diff(np.isnan(without_outliers).astype(int)) 1bcda
237 begs = np.where(t == 1)[0] 1bcda
238 ends = np.where(t == -1)[0] 1bcda
239 if begs.shape[0] > ends.shape[0]: 1bcda
240 begs = begs[:ends.shape[0]]
241 for b, e in zip(begs, ends): 1bcda
242 if (e - b) > (fr * nan_thresh): 1bcda
243 diameter_smoothed[(b + 1):(e + 1)] = np.nan # offset by 1 due to earlier diff
245 return diameter_smoothed 1bcda
248def plot_trace_on_frame(frame, dlc_df, cam):
249 """
250 Plots dlc traces as scatter plots on a frame of the video.
251 For left and right video also plots whisker pad and eye and tongue zoom.
253 :param frame: np.array, single video frame to plot on
254 :param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace
255 :param cam: str, which camera to process ('left', 'right', 'body')
256 :returns: matplolib.axis
257 """
258 # Define colors
259 colors = {'tail_start': '#636EFA', 1a
260 'nose_tip': '#636EFA',
261 'paw_l': '#EF553B',
262 'paw_r': '#00CC96',
263 'pupil_bottom_r': '#AB63FA',
264 'pupil_left_r': '#FFA15A',
265 'pupil_right_r': '#19D3F3',
266 'pupil_top_r': '#FF6692',
267 'tongue_end_l': '#B6E880',
268 'tongue_end_r': '#FF97FF'}
269 # Threshold the dlc traces
270 dlc_df = likelihood_threshold(dlc_df) 1a
271 # Features without tube
272 features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc_df.keys() if 'tube' not in x]) 1a
273 # Normalize the number of points across cameras
274 dlc_df_norm = pd.DataFrame() 1a
275 for feat in features: 1a
276 dlc_df_norm[f'{feat}_x'] = dlc_df[f'{feat}_x'][0::int(SAMPLING[cam] / 10)] 1a
277 dlc_df_norm[f'{feat}_y'] = dlc_df[f'{feat}_y'][0::int(SAMPLING[cam] / 10)] 1a
278 # Scatter
279 plt.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=0.05, s=2, label=feat, c=colors[feat]) 1a
281 plt.axis('off') 1a
282 plt.imshow(frame, cmap='gray') 1a
283 plt.tight_layout() 1a
285 ax = plt.gca() 1a
286 if cam == 'body': 1a
287 plt.title(f'{cam.capitalize()} camera') 1a
288 return ax 1a
289 # For left and right cam plot whisker pad rectangle
290 # heuristic: square with side length half the distance between nose and pupil and anchored on midpoint
291 p_nose = np.array(dlc_df[['nose_tip_x', 'nose_tip_y']].mean()) 1a
292 p_pupil = np.array(dlc_df[['pupil_top_r_x', 'pupil_top_r_y']].mean()) 1a
293 p_anchor = np.mean([p_nose, p_pupil], axis=0) 1a
294 dist = np.linalg.norm(p_nose - p_pupil) 1a
295 rect = matplotlib.patches.Rectangle((int(p_anchor[0] - dist / 4), int(p_anchor[1])), int(dist / 2), int(dist / 3), 1a
296 linewidth=1, edgecolor='lime', facecolor='none')
297 ax.add_patch(rect) 1a
298 # Plot eye region zoom
299 inset_anchor = 0 if cam == 'right' else 0.5 1a
300 ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5]) 1a
301 ax_ins.imshow(frame, cmap='gray', origin="lower") 1a
302 for feat in features: 1a
303 ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat]) 1a
304 ax_ins.set_xlim(int(p_pupil[0] - 33 * RESOLUTION[cam] / 2), int(p_pupil[0] + 33 * RESOLUTION[cam] / 2)) 1a
305 ax_ins.set_ylim(int(p_pupil[1] + 38 * RESOLUTION[cam] / 2), int(p_pupil[1] - 28 * RESOLUTION[cam] / 2)) 1a
306 ax_ins.axis('off') 1a
307 # Plot tongue region zoom
308 p1 = np.array(dlc_df[['tube_top_x', 'tube_top_y']].mean()) 1a
309 p2 = np.array(dlc_df[['tube_bottom_x', 'tube_bottom_y']].mean()) 1a
310 p_tongue = np.nanmean([p1, p2], axis=0) 1a
311 inset_anchor = 0 if cam == 'left' else 0.5 1a
312 ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5]) 1a
313 ax_ins.imshow(frame, cmap='gray', origin="upper") 1a
314 for feat in features: 1a
315 ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat]) 1a
316 ax_ins.set_xlim(int(p_tongue[0] - 60 * RESOLUTION[cam] / 2), int(p_tongue[0] + 100 * RESOLUTION[cam] / 2)) 1a
317 ax_ins.set_ylim(int(p_tongue[1] + 60 * RESOLUTION[cam] / 2), int(p_tongue[1] - 100 * RESOLUTION[cam] / 2)) 1a
318 ax_ins.axis('off') 1a
320 plt.title(f'{cam.capitalize()} camera') 1a
321 return ax 1a
324def plot_wheel_position(wheel_position, wheel_time, trials_df):
325 """
326 Plots wheel position across trials, color by which side was chosen
328 :param wheel_position: np.array, interpolated wheel position
329 :param wheel_time: np.array, interpolated wheel timestamps
330 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset times for each trial)
331 :returns: matplotlib.axis
332 """
333 # Interpolate wheel data
334 wheel_position, wheel_time = bbox_wheel.interpolate_position(wheel_time, wheel_position, freq=1 / T_BIN) 1a
335 # Create a window around the stimulus onset
336 start_window, end_window = plt_window(trials_df['stimOn_times']) 1a
337 # Translating the time window into an index window
338 start_idx = insert_idx(wheel_time, start_window) 1a
339 end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64') 1a
340 # Getting the wheel position for each window, normalize to first value of each window
341 trials_df['wheel_position'] = [wheel_position[start_idx[w]: end_idx[w]] - wheel_position[start_idx[w]] 1a
342 for w in range(len(start_idx))]
343 # Plotting
344 times = np.arange(len(trials_df['wheel_position'].iloc[0])) * T_BIN + WINDOW_LAG 1a
345 for side, label, color in zip([-1, 1], ['right', 'left'], ['darkred', '#1f77b4']): 1a
346 side_df = trials_df[trials_df['choice'] == side] 1a
347 for idx in side_df.index: 1a
348 plt.plot(times, side_df.loc[idx, 'wheel_position'], c=color, alpha=0.5, linewidth=0.05) 1a
349 plt.plot(times, side_df['wheel_position'].mean(), c=color, linewidth=2, label=f'{label} turn') 1a
351 plt.axvline(x=0, linestyle='--', c='k', label='stimOn') 1a
352 plt.axhline(y=-0.26, linestyle='--', c='g', label='reward') 1a
353 plt.axhline(y=0.26, linestyle='--', c='g', label='reward') 1a
354 plt.ylim([-0.27, 0.27]) 1a
355 plt.xlabel('time [sec]') 1a
356 plt.ylabel('wheel position diff to first value [rad]') 1a
357 plt.legend(loc='center right') 1a
358 plt.title('Wheel position trial avg\n(and individual trials)') 1a
359 plt.tight_layout() 1a
361 return plt.gca() 1a
364def _bin_window_licks(lick_times, trials_df):
365 """
366 Helper function to bin and window the lick times and get them into trials df for plotting
368 :param lick_times: np.array, timestamps of lick events
369 :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial)
370 :returns: pd.DataFrame with binned, windowed lick times for plotting
371 """
372 # Bin the licks
373 lick_bins, bin_times, _ = bincount2D(lick_times, np.ones(len(lick_times)), T_BIN) 1a
374 lick_bins = np.squeeze(lick_bins) 1a
375 start_window, end_window = plt_window(trials_df['feedback_times']) 1a
376 # Translating the time window into an index window
377 try: 1a
378 start_idx = insert_idx(bin_times, start_window) 1a
379 except ValueError:
380 logger.error('Lick time stamps are outside of the trials windows')
381 raise
382 end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64') 1a
383 # Get the binned licks for each window
384 trials_df['lick_bins'] = [lick_bins[start_idx[i]:end_idx[i]] for i in range(len(start_idx))] 1a
385 # Remove windows that the exceed bins
386 trials_df['end_idx'] = end_idx 1a
387 trials_df = trials_df[trials_df['end_idx'] <= len(lick_bins)] 1a
388 return trials_df 1a
391def plot_lick_hist(lick_times, trials_df):
392 """
393 Plots histogramm of lick events aligned to feedback time, separate for correct and incorrect trials
395 :param lick_times: np.array, timestamps of lick events
396 :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and
397 'feedbackType' (1 for correct, -1 for incorrect trials)
398 :returns: matplotlib axis
399 """
400 licks_df = _bin_window_licks(lick_times, trials_df) 1a
401 # Plot
402 times = np.arange(len(licks_df['lick_bins'].iloc[0])) * T_BIN + WINDOW_LAG 1a
403 correct = licks_df[licks_df['feedbackType'] == 1]['lick_bins'] 1a
404 incorrect = licks_df[licks_df['feedbackType'] == -1]['lick_bins'] 1a
405 plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1), 1a
406 c='k', label='correct trial')
407 plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, incorrect.values))).mean(axis=1), 1a
408 c='gray', label='incorrect trial')
409 plt.axvline(x=0, label='feedback', linestyle='--', c='purple') 1a
410 plt.title('Lick events trial avg') 1a
411 plt.xlabel('time [sec]') 1a
412 plt.ylabel('lick events [a.u.]') 1a
413 plt.legend(loc='lower right') 1a
414 return plt.gca() 1a
417def plot_lick_raster(lick_times, trials_df):
418 """
419 Plots lick raster for correct trials
421 :param lick_times: np.array, timestamps of lick events
422 :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and
423 feedbackType (1 for correct, -1 for incorrect trials)
424 :returns: matplotlib.axis
425 """
426 licks_df = _bin_window_licks(lick_times, trials_df) 1a
427 plt.imshow(list(licks_df[licks_df['feedbackType'] == 1]['lick_bins']), aspect='auto', 1a
428 extent=[-0.5, 1.5, len(licks_df['lick_bins'].iloc[0]), 0], cmap='gray_r')
429 plt.xticks([-0.5, 0, 0.5, 1, 1.5]) 1a
430 plt.ylabel('trials') 1a
431 plt.xlabel('time [sec]') 1a
432 plt.axvline(x=0, label='feedback', linestyle='--', c='purple') 1a
433 plt.title('Lick events per correct trial') 1a
434 plt.tight_layout() 1a
435 return plt.gca() 1a
438def plot_motion_energy_hist(camera_dict, trials_df):
439 """
440 Plots mean motion energy of given cameras, aligned to stimulus onset.
442 :param camera_dict: dict, one key for each camera to be plotted (e.g. 'left'), value is another dict with items
443 'motion_energy' (np.array, motion energy calculated from this camera) and
444 'times' (np.array, camera timestamps)
445 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial)
446 :returns: matplotlib.axis
447 """
448 colors = {'left': '#bd7a98', 1a
449 'right': '#2b6f39',
450 'body': '#035382'}
452 start_window, end_window = plt_window(trials_df['stimOn_times']) 1a
453 missing_data = [] 1a
454 for cam in camera_dict.keys(): 1a
455 if (camera_dict[cam]['motion_energy'] is not None and len(camera_dict[cam]['motion_energy']) > 0 1a
456 and camera_dict[cam]['times'] is not None and len(camera_dict[cam]['times']) > 0):
457 try: 1a
458 motion_energy = zscore(camera_dict[cam]['motion_energy'], nan_policy='omit') 1a
459 try: 1a
460 start_idx = insert_idx(camera_dict[cam]['times'], start_window) 1a
461 end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64') 1a
462 me_all = [motion_energy[start_idx[i]:end_idx[i]] for i in range(len(start_idx))] 1a
463 me_all = [m for m in me_all if len(m) > 0] 1a
464 times = np.arange(len(me_all[0])) / SAMPLING[cam] + WINDOW_LAG 1a
465 me_mean = np.mean(me_all, axis=0) 1a
466 me_std = np.std(me_all, axis=0) / np.sqrt(len(me_all)) 1a
467 plt.plot(times, me_mean, label=f'{cam} cam', color=colors[cam], linewidth=2) 1a
468 plt.fill_between(times, me_mean + me_std, me_mean - me_std, color=colors[cam], alpha=0.2) 1a
469 except ValueError:
470 logger.error(f"{cam}Camera camera.times are outside of the trial windows")
471 missing_data.append(cam)
472 except AttributeError:
473 logger.warning(f"Cannot load motion energy and/or times data for {cam} camera")
474 missing_data.append(cam)
475 else:
476 logger.warning(f"Data missing or empty for motion energy and/or times data for {cam} camera")
477 missing_data.append(cam)
479 plt.xticks([-0.5, 0, 0.5, 1, 1.5]) 1a
480 plt.ylabel('z-scored motion energy [a.u.]') 1a
481 plt.xlabel('time [sec]') 1a
482 plt.axvline(x=0, label='stimOn', linestyle='--', c='k') 1a
483 plt.legend(loc='lower right') 1a
484 plt.title('Motion Energy trial avg\n(+/- std)') 1a
485 if len(missing_data) > 0: 1a
486 ax = plt.gca()
487 ax.text(.95, .35, f"Data incomplete for\n{' and '.join(missing_data)} camera", color='r', fontsize=10,
488 fontweight='bold', horizontalalignment='right', verticalalignment='center', transform=ax.transAxes)
489 return plt.gca() 1a
492def plot_speed_hist(dlc_df, cam_times, trials_df, feature='paw_r', cam='left', legend=True):
493 """
494 Plots speed histogram of a given dlc feature, aligned to stimulus onset, separate for correct and incorrect trials
496 :param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace
497 :param cam_times: np.array, camera timestamps
498 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial)
499 :param feature: str, feature with trace in dlc_df for which to plot speed hist, default is 'paw_r'
500 :param cam: str, camera to use ('body', 'left', 'right') default is 'left'
501 :param legend: bool, whether to add legend to the plot, default is True
502 :returns: matplotlib.axis
503 """
504 # Threshold the dlc traces
505 dlc_df = likelihood_threshold(dlc_df) 1a
506 # For pre-GPIO sessions, remove the first few timestamps to match the number of frames
507 cam_times = cam_times[-len(dlc_df):] 1a
508 if len(cam_times) != len(dlc_df): 1a
509 raise ValueError("Camera times length and DLC length are inconsistent")
510 # Get speeds
511 speeds = get_speed(dlc_df, cam_times, camera=cam, feature=feature) 1a
512 # Windows aligned to align_to
513 start_window, end_window = plt_window(trials_df['stimOn_times']) 1a
514 start_idx = insert_idx(cam_times, start_window) 1a
515 end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64') 1a
516 # Add speeds to trials_df
517 trials_df[f'speed_{feature}'] = [speeds[start_idx[i]:end_idx[i]] for i in range(len(start_idx))] 1a
518 # Plot
519 times = np.arange(len(trials_df[f'speed_{feature}'].iloc[0])) / SAMPLING[cam] + WINDOW_LAG 1a
520 # Need to expand the series of lists into a dataframe first, for the nan skipping to work
521 correct = trials_df[trials_df['feedbackType'] == 1][f'speed_{feature}'] 1a
522 incorrect = trials_df[trials_df['feedbackType'] == -1][f'speed_{feature}'] 1a
523 plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1), 1a
524 c='k', label='correct trial')
525 plt.plot(times, pd.DataFrame.from_dict(dict(zip(incorrect.index, incorrect.values))).mean(axis=1), 1a
526 c='gray', label='incorrect trial')
527 plt.axvline(x=0, label='stimOn', linestyle='--', c='r') 1a
528 plt.title(f'{feature.capitalize()} speed trial avg\n({cam.upper()} cam)') 1a
529 plt.xticks([-0.5, 0, 0.5, 1, 1.5]) 1a
530 plt.xlabel('time [sec]') 1a
531 plt.ylabel('speed [px/sec]') 1a
532 if legend: 1a
533 plt.legend() 1a
535 return plt.gca() 1a
538def plot_pupil_diameter_hist(pupil_diameter, cam_times, trials_df, cam='left'):
539 """
540 Plots histogram of pupil diameter aligned to simulus onset and feedback time.
542 :param pupil_diameter: np.array, (smoothed) pupil diameter estimate
543 :param cam_times: np.array, camera timestamps
544 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial) and
545 feedback_times (time of feedback for each trial)
546 :param cam: str, camera to use ('body', 'left', 'right') default is 'left'
547 :returns: matplotlib.axis
548 """
549 for align_to, color in zip(['stimOn_times', 'feedback_times'], ['red', 'purple']): 1a
550 start_window, end_window = plt_window(trials_df[align_to]) 1a
551 start_idx = insert_idx(cam_times, start_window) 1a
552 end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64') 1a
553 # Per trial norm
554 pupil_all = [zscore(list(pupil_diameter[start_idx[i]:end_idx[i]])) for i in range(len(start_idx))] 1a
555 pupil_all_norm = [trial - trial[0] for trial in pupil_all] 1a
557 pupil_mean = np.nanmean(pupil_all_norm, axis=0)
558 pupil_std = np.nanstd(pupil_all_norm, axis=0) / np.sqrt(len(pupil_all_norm))
559 times = np.arange(len(pupil_all_norm[0])) / SAMPLING[cam] + WINDOW_LAG
561 plt.plot(times, pupil_mean, label=align_to.split("_")[0], color=color)
562 plt.fill_between(times, pupil_mean + pupil_std, pupil_mean - pupil_std, color=color, alpha=0.5)
563 plt.axvline(x=0, linestyle='--', c='k')
564 plt.title(f'Pupil diameter trial avg\n({cam.upper()} cam)')
565 plt.xlabel('time [sec]')
566 plt.xticks([-0.5, 0, 0.5, 1, 1.5])
567 plt.ylabel('z-scored smoothed pupil diameter [px]')
568 plt.legend(loc='lower right', title='aligned to')