Coverage for brainbox/behavior/dlc.py: 25%
288 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""Set of functions to deal with dlc data."""
2import logging
3import pandas as pd
4import warnings
6import numpy as np
7import matplotlib
8import matplotlib.pyplot as plt
9import scipy.interpolate as interpolate
10from scipy.stats import zscore
12from ibldsp.smooth import smooth_interpolate_savgol
13from iblutil.numerical import bincount2D
14import brainbox.behavior.wheel as bbox_wheel
16logger = logging.getLogger('ibllib')
18SAMPLING = {'left': 60,
19 'right': 150,
20 'body': 30}
21RESOLUTION = {'left': 2,
22 'right': 1,
23 'body': 1}
25T_BIN = 0.02 # sec
26WINDOW_LEN = 2 # sec
27WINDOW_LAG = -0.5 # sec
30# For plotting we use a window around the event the data is aligned to WINDOW_LAG before and WINDOW_LEN after the event
31def plt_window(x):
32 return x + WINDOW_LAG, x + WINDOW_LEN
35def insert_idx(array, values):
36 idx = np.searchsorted(array, values, side="left") 1d
37 # Choose lower index if insertion would be after last index or if lower index is closer
38 idx[idx == len(array)] -= 1 1d
39 idx[np.where(abs(values - array[idx - 1]) < abs(values - array[idx]))] -= 1 1d
40 # If 0 index was reduced, revert
41 idx[idx == -1] = 0 1d
42 if np.all(idx == 0): 1d
43 raise ValueError('Something is wrong, all values to insert are outside of the array.')
44 return idx 1d
47def likelihood_threshold(dlc, threshold=0.9):
48 """
49 Set dlc points with likelihood less than threshold to nan.
51 FIXME Add unit test.
52 :param dlc: dlc pqt object
53 :param threshold: likelihood threshold
54 :return:
55 """
56 features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc.keys()]) 1ebc
57 for feat in features: 1ebc
58 nan_fill = dlc[f'{feat}_likelihood'] < threshold 1ebc
59 dlc.loc[nan_fill, (f'{feat}_x', f'{feat}_y')] = np.nan 1ebc
60 return dlc 1ebc
63def get_speed(dlc, dlc_t, camera, feature='paw_r'):
64 """
65 FIXME Document and add unit test!
67 :param dlc: dlc pqt table
68 :param dlc_t: dlc time points
69 :param camera: camera type e.g 'left', 'right', 'body'
70 :param feature: dlc feature to compute speed over
71 :return:
72 """
73 x = dlc[f'{feature}_x'] / RESOLUTION[camera]
74 y = dlc[f'{feature}_y'] / RESOLUTION[camera]
76 # get speed in px/sec [half res]
77 s = ((np.diff(x) ** 2 + np.diff(y) ** 2) ** .5) * SAMPLING[camera]
79 dt = np.diff(dlc_t)
80 tv = dlc_t[:-1] + dt / 2
82 # interpolate over original time scale
83 if tv.size > 1:
84 ifcn = interpolate.interp1d(tv, s, fill_value="extrapolate")
85 return ifcn(dlc_t)
88def get_speed_for_features(dlc, dlc_t, camera, features=['paw_r', 'paw_l', 'nose_tip']):
89 """
90 Wrapper to compute speed for a number of dlc features and add them to dlc table
91 :param dlc: dlc pqt table
92 :param dlc_t: dlc time points
93 :param camera: camera type e.g 'left', 'right', 'body'
94 :param features: dlc features to compute speed for
95 :return:
96 """
97 for feat in features:
98 dlc[f'{feat}_speed'] = get_speed(dlc, dlc_t, camera, feat)
100 return dlc
103def get_feature_event_times(dlc, dlc_t, features):
104 """
105 Detect events from the dlc traces. Based on the standard deviation between frames
106 :param dlc: dlc pqt table
107 :param dlc_t: dlc times
108 :param features: features to consider
109 :return:
110 """
112 for i, feat in enumerate(features):
113 f = dlc[feat]
114 threshold = np.nanstd(np.diff(f)) / 4
115 if i == 0:
116 events = np.where(np.abs(np.diff(f)) > threshold)[0]
117 else:
118 events = np.r_[events, np.where(np.abs(np.diff(f)) > threshold)[0]]
120 return dlc_t[np.unique(events)]
123def get_licks(dlc, dlc_t):
124 """
125 Compute lick times from the tongue dlc points
126 :param dlc: dlc pqt table
127 :param dlc_t: dlc times
128 :return:
129 """
130 lick_times = get_feature_event_times(dlc, dlc_t, ['tongue_end_l_x', 'tongue_end_l_y',
131 'tongue_end_r_x', 'tongue_end_r_y'])
132 return lick_times
135def get_sniffs(dlc, dlc_t):
136 """
137 Compute sniff times from the nose tip
138 :param dlc: dlc pqt table
139 :param dlc_t: dlc times
140 :return:
141 """
143 sniff_times = get_feature_event_times(dlc, dlc_t, ['nose_tip_y'])
144 return sniff_times
147def get_dlc_everything(dlc_cam, camera):
148 """
149 Get out features of interest for dlc
150 :param dlc_cam: dlc object
151 :param camera: camera type e.g 'left', 'right'
152 :return:
153 """
155 aligned = True
156 if dlc_cam.times.shape[0] != dlc_cam.dlc.shape[0]:
157 # logger warning and print out status of the qc, specific serializer django!
158 logger.warning('Dimension mismatch between dlc points and timestamps')
159 min_samps = min(dlc_cam.times.shape[0], dlc_cam.dlc.shape[0])
160 dlc_cam.times = dlc_cam.times[:min_samps]
161 dlc_cam.dlc = dlc_cam.dlc[:min_samps]
162 aligned = False
164 dlc_cam.dlc = likelihood_threshold(dlc_cam.dlc)
165 dlc_cam.dlc = get_speed_for_features(dlc_cam.dlc, dlc_cam.times, camera)
166 dlc_cam['licks'] = get_licks(dlc_cam.dlc, dlc_cam.times)
167 dlc_cam['sniffs'] = get_sniffs(dlc_cam.dlc, dlc_cam.times)
168 dlc_cam['aligned'] = aligned
170 return dlc_cam
173def get_pupil_diameter(dlc):
174 """
175 Estimates pupil diameter by taking median of different computations.
177 The two most straightforward estimates: d1 = top - bottom, d2 = left - right
178 In addition, assume the pupil is a circle and estimate diameter from other pairs of points
180 :param dlc: dlc pqt table with pupil estimates, should be likelihood thresholded (e.g. at 0.9)
181 :return: np.array, pupil diameter estimate for each time point, shape (n_frames,)
182 """
183 diameters = [] 1bc
184 # Get the x,y coordinates of the four pupil points
185 top, bottom, left, right = [np.vstack((dlc[f'pupil_{point}_r_x'], dlc[f'pupil_{point}_r_y'])) 1bc
186 for point in ['top', 'bottom', 'left', 'right']]
187 # First compute direct diameters
188 diameters.append(np.linalg.norm(top - bottom, axis=0)) 1bc
189 diameters.append(np.linalg.norm(left - right, axis=0)) 1bc
191 # For non-crossing edges, estimate diameter via circle assumption
192 for pair in [(top, left), (top, right), (bottom, left), (bottom, right)]: 1bc
193 diameters.append(np.linalg.norm(pair[0] - pair[1], axis=0) * 2 ** 0.5) 1bc
195 # Ignore all nan runtime warning
196 with warnings.catch_warnings(): 1bc
197 warnings.simplefilter("ignore", category=RuntimeWarning) 1bc
198 return np.nanmedian(diameters, axis=0) 1bc
201def get_smooth_pupil_diameter(diameter_raw, camera, std_thresh=5, nan_thresh=1):
202 """
203 :param diameter_raw: np.array, raw pupil diameters, calculated from (thresholded) dlc traces
204 :param camera: str ('left', 'right'), which camera to run the smoothing for
205 :param std_thresh: threshold (in standard deviations) beyond which a point is labeled as an outlier
206 :param nan_thresh: threshold (in seconds) above which we will not interpolate nans, but keep them
207 (for long stretches interpolation may not be appropriate)
208 :return:
209 """
210 # set framerate of camera
211 if camera == 'left': 1bc
212 fr = SAMPLING['left'] # set by hardware 1bc
213 window = 31 # works well empirically 1bc
214 elif camera == 'right':
215 fr = SAMPLING['right'] # set by hardware
216 window = 75 # works well empirically
217 else:
218 raise NotImplementedError("camera has to be 'left' or 'right")
220 # Raise error if too many NaN time points, in this case it doesn't make sense to interpolate
221 if np.mean(np.isnan(diameter_raw)) > 0.9: 1bc
222 raise ValueError(f"Raw pupil diameter for {camera} is too often NaN, cannot smooth.")
223 # run savitzy-golay filter on non-nan time points to denoise
224 diameter_smoothed = smooth_interpolate_savgol(diameter_raw, window=window, order=3, interp_kind='linear') 1bc
226 # find outliers and set them to nan
227 difference = diameter_raw - diameter_smoothed 1bc
228 outlier_thresh = std_thresh * np.nanstd(difference) 1bc
229 without_outliers = np.copy(diameter_raw) 1bc
230 without_outliers[(difference < -outlier_thresh) | (difference > outlier_thresh)] = np.nan 1bc
231 # run savitzy-golay filter again on (possibly reduced) non-nan timepoints to denoise
232 diameter_smoothed = smooth_interpolate_savgol(without_outliers, window=window, order=3, interp_kind='linear') 1bc
234 # don't interpolate long strings of nans
235 t = np.diff(np.isnan(without_outliers).astype(int)) 1bc
236 begs = np.where(t == 1)[0] 1bc
237 ends = np.where(t == -1)[0] 1bc
238 if begs.shape[0] > ends.shape[0]: 1bc
239 begs = begs[:ends.shape[0]]
240 for b, e in zip(begs, ends): 1bc
241 if (e - b) > (fr * nan_thresh): 1bc
242 diameter_smoothed[(b + 1):(e + 1)] = np.nan # offset by 1 due to earlier diff
244 return diameter_smoothed 1bc
247def plot_trace_on_frame(frame, dlc_df, cam):
248 """
249 Plots dlc traces as scatter plots on a frame of the video.
250 For left and right video also plots whisker pad and eye and tongue zoom.
252 :param frame: np.array, single video frame to plot on
253 :param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace
254 :param cam: str, which camera to process ('left', 'right', 'body')
255 :returns: matplolib.axis
256 """
257 # Define colors
258 colors = {'tail_start': '#636EFA',
259 'nose_tip': '#636EFA',
260 'paw_l': '#EF553B',
261 'paw_r': '#00CC96',
262 'pupil_bottom_r': '#AB63FA',
263 'pupil_left_r': '#FFA15A',
264 'pupil_right_r': '#19D3F3',
265 'pupil_top_r': '#FF6692',
266 'tongue_end_l': '#B6E880',
267 'tongue_end_r': '#FF97FF'}
268 # Threshold the dlc traces
269 dlc_df = likelihood_threshold(dlc_df)
270 # Features without tube
271 features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc_df.keys() if 'tube' not in x])
272 # Normalize the number of points across cameras
273 dlc_df_norm = pd.DataFrame()
274 for feat in features:
275 dlc_df_norm[f'{feat}_x'] = dlc_df[f'{feat}_x'][0::int(SAMPLING[cam] / 10)]
276 dlc_df_norm[f'{feat}_y'] = dlc_df[f'{feat}_y'][0::int(SAMPLING[cam] / 10)]
277 # Scatter
278 plt.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=0.05, s=2, label=feat, c=colors[feat])
280 plt.axis('off')
281 plt.imshow(frame, cmap='gray')
282 plt.tight_layout()
284 ax = plt.gca()
285 if cam == 'body':
286 plt.title(f'{cam.capitalize()} camera')
287 return ax
288 # For left and right cam plot whisker pad rectangle
289 # heuristic: square with side length half the distance between nose and pupil and anchored on midpoint
290 p_nose = np.array(dlc_df[['nose_tip_x', 'nose_tip_y']].mean())
291 p_pupil = np.array(dlc_df[['pupil_top_r_x', 'pupil_top_r_y']].mean())
292 p_anchor = np.mean([p_nose, p_pupil], axis=0)
293 dist = np.linalg.norm(p_nose - p_pupil)
294 rect = matplotlib.patches.Rectangle((int(p_anchor[0] - dist / 4), int(p_anchor[1])), int(dist / 2), int(dist / 3),
295 linewidth=1, edgecolor='lime', facecolor='none')
296 ax.add_patch(rect)
297 # Plot eye region zoom
298 inset_anchor = 0 if cam == 'right' else 0.5
299 ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5])
300 ax_ins.imshow(frame, cmap='gray', origin="lower")
301 for feat in features:
302 ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat])
303 ax_ins.set_xlim(int(p_pupil[0] - 33 * RESOLUTION[cam] / 2), int(p_pupil[0] + 33 * RESOLUTION[cam] / 2))
304 ax_ins.set_ylim(int(p_pupil[1] + 38 * RESOLUTION[cam] / 2), int(p_pupil[1] - 28 * RESOLUTION[cam] / 2))
305 ax_ins.axis('off')
306 # Plot tongue region zoom
307 p1 = np.array(dlc_df[['tube_top_x', 'tube_top_y']].mean())
308 p2 = np.array(dlc_df[['tube_bottom_x', 'tube_bottom_y']].mean())
309 p_tongue = np.nanmean([p1, p2], axis=0)
310 inset_anchor = 0 if cam == 'left' else 0.5
311 ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5])
312 ax_ins.imshow(frame, cmap='gray', origin="upper")
313 for feat in features:
314 ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat])
315 ax_ins.set_xlim(int(p_tongue[0] - 60 * RESOLUTION[cam] / 2), int(p_tongue[0] + 100 * RESOLUTION[cam] / 2))
316 ax_ins.set_ylim(int(p_tongue[1] + 60 * RESOLUTION[cam] / 2), int(p_tongue[1] - 100 * RESOLUTION[cam] / 2))
317 ax_ins.axis('off')
319 plt.title(f'{cam.capitalize()} camera')
320 return ax
323def plot_wheel_position(wheel_position, wheel_time, trials_df):
324 """
325 Plots wheel position across trials, color by which side was chosen
327 :param wheel_position: np.array, interpolated wheel position
328 :param wheel_time: np.array, interpolated wheel timestamps
329 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset times for each trial)
330 :returns: matplotlib.axis
331 """
332 # Interpolate wheel data
333 wheel_position, wheel_time = bbox_wheel.interpolate_position(wheel_time, wheel_position, freq=1 / T_BIN)
334 # Create a window around the stimulus onset
335 start_window, end_window = plt_window(trials_df['stimOn_times'])
336 # Translating the time window into an index window
337 start_idx = insert_idx(wheel_time, start_window)
338 end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64')
339 # Getting the wheel position for each window, normalize to first value of each window
340 trials_df['wheel_position'] = [wheel_position[start_idx[w]: end_idx[w]] - wheel_position[start_idx[w]]
341 for w in range(len(start_idx))]
342 # Plotting
343 times = np.arange(len(trials_df['wheel_position'].iloc[0])) * T_BIN + WINDOW_LAG
344 for side, label, color in zip([-1, 1], ['right', 'left'], ['darkred', '#1f77b4']):
345 side_df = trials_df[trials_df['choice'] == side]
346 for idx in side_df.index:
347 plt.plot(times, side_df.loc[idx, 'wheel_position'], c=color, alpha=0.5, linewidth=0.05)
348 plt.plot(times, side_df['wheel_position'].mean(), c=color, linewidth=2, label=f'{label} turn')
350 plt.axvline(x=0, linestyle='--', c='k', label='stimOn')
351 plt.axhline(y=-0.26, linestyle='--', c='g', label='reward')
352 plt.axhline(y=0.26, linestyle='--', c='g', label='reward')
353 plt.ylim([-0.27, 0.27])
354 plt.xlabel('time [sec]')
355 plt.ylabel('wheel position diff to first value [rad]')
356 plt.legend(loc='center right')
357 plt.title('Wheel position trial avg\n(and individual trials)')
358 plt.tight_layout()
360 return plt.gca()
363def _bin_window_licks(lick_times, trials_df):
364 """
365 Helper function to bin and window the lick times and get them into trials df for plotting
367 :param lick_times: np.array, timestamps of lick events
368 :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial)
369 :returns: pd.DataFrame with binned, windowed lick times for plotting
370 """
371 # Bin the licks
372 lick_bins, bin_times, _ = bincount2D(lick_times, np.ones(len(lick_times)), T_BIN)
373 lick_bins = np.squeeze(lick_bins)
374 start_window, end_window = plt_window(trials_df['feedback_times'])
375 # Translating the time window into an index window
376 try:
377 start_idx = insert_idx(bin_times, start_window)
378 except ValueError:
379 logger.error('Lick time stamps are outside of the trials windows')
380 raise
381 end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64')
382 # Get the binned licks for each window
383 trials_df['lick_bins'] = [lick_bins[start_idx[i]:end_idx[i]] for i in range(len(start_idx))]
384 # Remove windows that the exceed bins
385 trials_df['end_idx'] = end_idx
386 trials_df = trials_df[trials_df['end_idx'] <= len(lick_bins)]
387 return trials_df
390def plot_lick_hist(lick_times, trials_df):
391 """
392 Plots histogramm of lick events aligned to feedback time, separate for correct and incorrect trials
394 :param lick_times: np.array, timestamps of lick events
395 :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and
396 'feedbackType' (1 for correct, -1 for incorrect trials)
397 :returns: matplotlib axis
398 """
399 licks_df = _bin_window_licks(lick_times, trials_df)
400 # Plot
401 times = np.arange(len(licks_df['lick_bins'].iloc[0])) * T_BIN + WINDOW_LAG
402 correct = licks_df[licks_df['feedbackType'] == 1]['lick_bins']
403 incorrect = licks_df[licks_df['feedbackType'] == -1]['lick_bins']
404 plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1),
405 c='k', label='correct trial')
406 plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, incorrect.values))).mean(axis=1),
407 c='gray', label='incorrect trial')
408 plt.axvline(x=0, label='feedback', linestyle='--', c='purple')
409 plt.title('Lick events trial avg')
410 plt.xlabel('time [sec]')
411 plt.ylabel('lick events [a.u.]')
412 plt.legend(loc='lower right')
413 return plt.gca()
416def plot_lick_raster(lick_times, trials_df):
417 """
418 Plots lick raster for correct trials
420 :param lick_times: np.array, timestamps of lick events
421 :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and
422 feedbackType (1 for correct, -1 for incorrect trials)
423 :returns: matplotlib.axis
424 """
425 licks_df = _bin_window_licks(lick_times, trials_df)
426 plt.imshow(list(licks_df[licks_df['feedbackType'] == 1]['lick_bins']), aspect='auto',
427 extent=[-0.5, 1.5, len(licks_df['lick_bins'].iloc[0]), 0], cmap='gray_r')
428 plt.xticks([-0.5, 0, 0.5, 1, 1.5])
429 plt.ylabel('trials')
430 plt.xlabel('time [sec]')
431 plt.axvline(x=0, label='feedback', linestyle='--', c='purple')
432 plt.title('Lick events per correct trial')
433 plt.tight_layout()
434 return plt.gca()
437def plot_motion_energy_hist(camera_dict, trials_df):
438 """
439 Plots mean motion energy of given cameras, aligned to stimulus onset.
441 :param camera_dict: dict, one key for each camera to be plotted (e.g. 'left'), value is another dict with items
442 'motion_energy' (np.array, motion energy calculated from this camera) and
443 'times' (np.array, camera timestamps)
444 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial)
445 :returns: matplotlib.axis
446 """
447 colors = {'left': '#bd7a98',
448 'right': '#2b6f39',
449 'body': '#035382'}
451 start_window, end_window = plt_window(trials_df['stimOn_times'])
452 missing_data = []
453 for cam in camera_dict.keys():
454 if (camera_dict[cam]['motion_energy'] is not None and len(camera_dict[cam]['motion_energy']) > 0
455 and camera_dict[cam]['times'] is not None and len(camera_dict[cam]['times']) > 0):
456 try:
457 motion_energy = zscore(camera_dict[cam]['motion_energy'], nan_policy='omit')
458 try:
459 start_idx = insert_idx(camera_dict[cam]['times'], start_window)
460 end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64')
461 me_all = [motion_energy[start_idx[i]:end_idx[i]] for i in range(len(start_idx))]
462 me_all = [m for m in me_all if len(m) > 0]
463 times = np.arange(len(me_all[0])) / SAMPLING[cam] + WINDOW_LAG
464 me_mean = np.mean(me_all, axis=0)
465 me_std = np.std(me_all, axis=0) / np.sqrt(len(me_all))
466 plt.plot(times, me_mean, label=f'{cam} cam', color=colors[cam], linewidth=2)
467 plt.fill_between(times, me_mean + me_std, me_mean - me_std, color=colors[cam], alpha=0.2)
468 except ValueError:
469 logger.error(f"{cam}Camera camera.times are outside of the trial windows")
470 missing_data.append(cam)
471 except AttributeError:
472 logger.warning(f"Cannot load motion energy and/or times data for {cam} camera")
473 missing_data.append(cam)
474 else:
475 logger.warning(f"Data missing or empty for motion energy and/or times data for {cam} camera")
476 missing_data.append(cam)
478 plt.xticks([-0.5, 0, 0.5, 1, 1.5])
479 plt.ylabel('z-scored motion energy [a.u.]')
480 plt.xlabel('time [sec]')
481 plt.axvline(x=0, label='stimOn', linestyle='--', c='k')
482 plt.legend(loc='lower right')
483 plt.title('Motion Energy trial avg\n(+/- std)')
484 if len(missing_data) > 0:
485 ax = plt.gca()
486 ax.text(.95, .35, f"Data incomplete for\n{' and '.join(missing_data)} camera", color='r', fontsize=10,
487 fontweight='bold', horizontalalignment='right', verticalalignment='center', transform=ax.transAxes)
488 return plt.gca()
491def plot_speed_hist(dlc_df, cam_times, trials_df, feature='paw_r', cam='left', legend=True):
492 """
493 Plots speed histogram of a given dlc feature, aligned to stimulus onset, separate for correct and incorrect trials
495 :param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace
496 :param cam_times: np.array, camera timestamps
497 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial)
498 :param feature: str, feature with trace in dlc_df for which to plot speed hist, default is 'paw_r'
499 :param cam: str, camera to use ('body', 'left', 'right') default is 'left'
500 :param legend: bool, whether to add legend to the plot, default is True
501 :returns: matplotlib.axis
502 """
503 # Threshold the dlc traces
504 dlc_df = likelihood_threshold(dlc_df)
505 # For pre-GPIO sessions, remove the first few timestamps to match the number of frames
506 cam_times = cam_times[-len(dlc_df):]
507 if len(cam_times) != len(dlc_df):
508 raise ValueError("Camera times length and DLC length are inconsistent")
509 # Get speeds
510 speeds = get_speed(dlc_df, cam_times, camera=cam, feature=feature)
511 # Windows aligned to align_to
512 start_window, end_window = plt_window(trials_df['stimOn_times'])
513 start_idx = insert_idx(cam_times, start_window)
514 end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64')
515 # Add speeds to trials_df
516 trials_df[f'speed_{feature}'] = [speeds[start_idx[i]:end_idx[i]] for i in range(len(start_idx))]
517 # Plot
518 times = np.arange(len(trials_df[f'speed_{feature}'].iloc[0])) / SAMPLING[cam] + WINDOW_LAG
519 # Need to expand the series of lists into a dataframe first, for the nan skipping to work
520 correct = trials_df[trials_df['feedbackType'] == 1][f'speed_{feature}']
521 incorrect = trials_df[trials_df['feedbackType'] == -1][f'speed_{feature}']
522 plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1),
523 c='k', label='correct trial')
524 plt.plot(times, pd.DataFrame.from_dict(dict(zip(incorrect.index, incorrect.values))).mean(axis=1),
525 c='gray', label='incorrect trial')
526 plt.axvline(x=0, label='stimOn', linestyle='--', c='r')
527 plt.title(f'{feature.capitalize()} speed trial avg\n({cam.upper()} cam)')
528 plt.xticks([-0.5, 0, 0.5, 1, 1.5])
529 plt.xlabel('time [sec]')
530 plt.ylabel('speed [px/sec]')
531 if legend:
532 plt.legend()
534 return plt.gca()
537def plot_pupil_diameter_hist(pupil_diameter, cam_times, trials_df, cam='left'):
538 """
539 Plots histogram of pupil diameter aligned to simulus onset and feedback time.
541 :param pupil_diameter: np.array, (smoothed) pupil diameter estimate
542 :param cam_times: np.array, camera timestamps
543 :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial) and
544 feedback_times (time of feedback for each trial)
545 :param cam: str, camera to use ('body', 'left', 'right') default is 'left'
546 :returns: matplotlib.axis
547 """
548 for align_to, color in zip(['stimOn_times', 'feedback_times'], ['red', 'purple']):
549 start_window, end_window = plt_window(trials_df[align_to])
550 start_idx = insert_idx(cam_times, start_window)
551 end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64')
552 # Per trial norm
553 pupil_all = [zscore(list(pupil_diameter[start_idx[i]:end_idx[i]])) for i in range(len(start_idx))]
554 pupil_all_norm = [trial - trial[0] for trial in pupil_all]
556 pupil_mean = np.nanmean(pupil_all_norm, axis=0)
557 pupil_std = np.nanstd(pupil_all_norm, axis=0) / np.sqrt(len(pupil_all_norm))
558 times = np.arange(len(pupil_all_norm[0])) / SAMPLING[cam] + WINDOW_LAG
560 plt.plot(times, pupil_mean, label=align_to.split("_")[0], color=color)
561 plt.fill_between(times, pupil_mean + pupil_std, pupil_mean - pupil_std, color=color, alpha=0.5)
562 plt.axvline(x=0, linestyle='--', c='k')
563 plt.title(f'Pupil diameter trial avg\n({cam.upper()} cam)')
564 plt.xlabel('time [sec]')
565 plt.xticks([-0.5, 0, 0.5, 1, 1.5])
566 plt.ylabel('z-scored smoothed pupil diameter [px]')
567 plt.legend(loc='lower right', title='aligned to')