Coverage for ibllib/io/extractors/video_motion.py: 45%
627 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""
2A module for aligning the wheel motion with the rotary encoder. Currently used by the camera QC
3in order to check timestamp alignment.
4"""
5import matplotlib
6import matplotlib.pyplot as plt
7import matplotlib.gridspec as gridspec
8from matplotlib.widgets import RectangleSelector
9import numpy as np
10from scipy import signal, ndimage, interpolate
11import cv2
12from itertools import cycle
13import matplotlib.animation as animation
14import logging
15from pathlib import Path
16from joblib import Parallel, delayed, cpu_count
18from ibldsp.utils import WindowGenerator
19from one.api import ONE
20import ibllib.io.video as vidio
21from iblutil.util import Bunch
22from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map
23import ibllib.io.raw_data_loaders as raw
24import ibllib.io.extractors.camera as cam
25from ibllib.plots.snapshot import ReportSnapshot
26import brainbox.video as video
27import brainbox.behavior.wheel as wh
28from brainbox.singlecell import bin_spikes
29from brainbox.behavior.dlc import likelihood_threshold, get_speed
30from brainbox.task.trials import find_trial_ids
31import one.alf.io as alfio
32from one.alf.exceptions import ALFObjectNotFound
33from one.alf.spec import is_session_path, is_uuid_string
36def find_nearest(array, value):
37 array = np.asarray(array) 1deb
38 idx = (np.abs(array - value)).argmin() 1deb
39 return idx 1deb
42class MotionAlignment:
43 roi = {'left': ((800, 1020), (233, 1096)), 'right': ((426, 510), (104, 545)), 'body': ((402, 481), (31, 103))}
45 def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), stream=False, **kwargs):
46 self.one = one or ONE() 1deb
47 self.eid = eid 1deb
48 self.session_path = kwargs.pop('session_path', None) or self.one.eid2path(eid) 1deb
49 self.ref = self.one.dict2ref(self.one.path2ref(self.session_path)) 1deb
50 self.log = log 1deb
51 self.trials = self.wheel = self.camera_times = None 1deb
52 raw_cam_path = self.session_path.joinpath('raw_video_data') 1deb
53 camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*')) 1deb
54 if stream: 1deb
55 self.video_paths = vidio.url_from_eid(self.eid)
56 else:
57 self.video_paths = {vidio.label_from_path(x): x for x in camera_path} 1deb
58 self.data = Bunch() 1deb
59 self.alignment = Bunch() 1deb
61 def align_all_trials(self, side='all'):
62 """Align all wheel motion for all trials"""
63 if self.trials is None:
64 self.load_data()
65 if side == 'all':
66 side = self.video_paths.keys()
67 if not isinstance(side, str):
68 # Try to iterate over sides
69 [self.align_all_trials(s) for s in side]
70 if side not in self.video_paths:
71 raise ValueError(f'{side} camera video file not found')
72 # Align each trial sequentially
73 for i in np.arange(self.trials['intervals'].shape[0]):
74 self.align_motion(i, display=False)
76 @staticmethod
77 def set_roi(video_path):
78 """Manually set the ROIs for a given set of videos
79 TODO Improve docstring
80 TODO A method for setting ROIs by label
81 """
82 frame = vidio.get_video_frame(str(video_path), 0)
84 def line_select_callback(eclick, erelease):
85 """
86 Callback for line selection.
88 *eclick* and *erelease* are the press and release events.
89 """
90 x1, y1 = eclick.xdata, eclick.ydata
91 x2, y2 = erelease.xdata, erelease.ydata
92 print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
93 return np.array([[x1, x2], [y1, y2]])
95 plt.imshow(frame)
96 roi = RectangleSelector(plt.gca(), line_select_callback, drawtype='box', useblit=True, button=[1, 3],
97 # don't use middle button
98 minspanx=5, minspany=5, spancoords='pixels', interactive=True)
99 plt.show()
100 ((x1, x2, *_), (y1, *_, y2)) = roi.corners
101 col = np.arange(round(x1), round(x2), dtype=int)
102 row = np.arange(round(y1), round(y2), dtype=int)
103 return col, row
105 def load_data(self, download=False):
106 """
107 Load wheel, trial and camera timestamp data
108 :return: wheel, trials
109 """
110 if download:
111 self.data.wheel = self.one.load_object(self.eid, 'wheel')
112 self.data.trials = self.one.load_object(self.eid, 'trials')
113 cam, det = self.one.load_datasets(self.eid, ['*Camera.times*'])
114 self.data.camera_times = {vidio.label_from_path(d['rel_path']): ts for ts, d in zip(cam, det)}
115 else:
116 alf_path = self.session_path / 'alf'
117 wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent
118 self.data.wheel = alfio.load_object(wheel_path, 'wheel', short_keys=True)
119 trials_path = next(alf_path.rglob('*trials.table*')).parent
120 self.data.trials = alfio.load_object(trials_path, 'trials')
121 self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x) for x in
122 alf_path.rglob('*Camera.times*')}
123 assert all(x is not None for x in self.data.values())
125 def _set_eid_or_path(self, session_path_or_eid):
126 """Parse a given eID or session path
127 If a session UUID is given, resolves and stores the local path and vice versa
128 :param session_path_or_eid: A session eid or path
129 :return:
130 """
131 self.eid = None
132 if is_uuid_string(str(session_path_or_eid)):
133 self.eid = session_path_or_eid
134 # Try to set session_path if data is found locally
135 self.session_path = self.one.eid2path(self.eid)
136 elif is_session_path(session_path_or_eid):
137 self.session_path = Path(session_path_or_eid)
138 if self.one is not None:
139 self.eid = self.one.path2eid(self.session_path)
140 if not self.eid:
141 self.log.warning('Failed to determine eID from session path')
142 else:
143 self.log.error('Cannot run alignment: an experiment uuid or session path is required')
144 raise ValueError("'session' must be a valid session path or uuid")
146 def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False):
147 """
148 Align video to the wheel using cross-correlation of the video motion signal and the rotary
149 encoder.
151 Parameters
152 ----------
153 period : (float, float)
154 The time period over which to do the alignment.
155 side : {'left', 'right'}
156 With which camera to perform the alignment.
157 sd_thresh : float
158 For plotting where the motion energy goes above this standard deviation threshold.
159 display : bool
160 When true, displays the aligned wheel motion energy along with the rotary encoder
161 signal.
163 Returns
164 -------
165 int
166 Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder
167 signal. Negative values mean the video was shifted backwards with respect to the wheel
168 timestamps.
169 float
170 The peak cross-correlation.
171 numpy.ndarray
172 The motion energy used in the cross-correlation, i.e. the frame difference for the
173 period given.
174 """
175 # Get data samples within period
176 wheel = self.data['wheel'] 1deb
177 self.alignment.label = side 1deb
178 self.alignment.to_mask = lambda ts: np.logical_and(ts >= period[0], ts <= period[1]) 1deb
179 camera_times = self.data['camera_times'][side] 1deb
180 cam_mask = self.alignment.to_mask(camera_times) 1deb
181 frame_numbers, = np.where(cam_mask) 1deb
183 if frame_numbers.size == 0: 1deb
184 raise ValueError('No frames during given period') 1b
186 # Motion Energy
187 camera_path = self.video_paths[side] 1deb
188 roi = (*[slice(*r) for r in self.roi[side]], 0) 1deb
189 try: 1deb
190 # TODO Add function arg to make grayscale
191 self.alignment.frames = vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi) 1deb
192 assert self.alignment.frames.size != 0 1deb
193 except AssertionError:
194 self.log.error('Failed to open video')
195 return None, None, None
196 self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2) 1deb
197 self.alignment.period = period # For plotting 1deb
199 # Calculate rotary encoder velocity trace
200 x = camera_times[cam_mask] 1deb
201 Fs = 1000 1deb
202 pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs) 1deb
203 v, _ = wh.velocity_filtered(pos, Fs) 1deb
204 interp_mask = self.alignment.to_mask(t) 1deb
205 # Convert to normalized speed
206 xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x]) 1deb
207 vs = np.abs(v[interp_mask][xs]) 1deb
208 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1deb
210 # FIXME This can be used as a goodness of fit measure
211 USE_CV2 = False 1deb
212 if USE_CV2: 1deb
213 # convert from numpy format to openCV format
214 dfCV = np.float32(self.alignment.df.reshape((-1, 1)))
215 reCV = np.float32(vs.reshape((-1, 1)))
217 # perform cross correlation
218 resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED)
220 # convert result back to numpy array
221 xcorr = np.asarray(resultCv)
222 else:
223 xcorr = signal.correlate(self.alignment.df, vs) 1deb
225 # Cross correlate wheel speed trace with the motion energy
226 CORRECTION = 2 1deb
227 self.alignment.c = max(xcorr) 1deb
228 self.alignment.xcorr = np.argmax(xcorr) 1deb
229 self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION 1deb
230 self.log.info(f'{side} camera, adjusted by {self.alignment.dt_i} frames') 1deb
232 if display: 1deb
233 # Plot the motion energy
234 fig, ax = plt.subplots(2, 1, sharex='all')
235 y = np.pad(self.alignment.df, 1, 'edge')
236 ax[0].plot(x, y, '-x', label='wheel motion energy')
237 thresh = stDev > sd_thresh
238 ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, linewidth=0.5, linestyle=':',
239 label=f'>{sd_thresh} s.d. diff')
240 ax[1].plot(t[interp_mask], np.abs(v[interp_mask]))
242 # Plot other stuff
243 dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]])
244 fps = 1 / np.diff(camera_times).mean()
245 ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)')
246 ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps))
247 ax[0].set_ylabel('rate of change (a.u.)')
248 ax[0].legend()
249 ax[1].set_ylabel('wheel speed (rad / s)')
250 ax[1].set_xlabel('Time (s)')
252 title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s'
253 fig.suptitle(title, fontsize=16)
254 fig.set_size_inches(19.2, 9.89)
256 return self.alignment.dt_i, self.alignment.c, self.alignment.df 1deb
258 def plot_alignment(self, energy=True, save=False):
259 if not self.alignment: 1b
260 self.log.error('No alignment data, run `align_motion` first')
261 return
262 # Change backend based on save flag
263 backend = matplotlib.get_backend().lower() 1b
264 if (save and backend != 'agg') or (not save and backend == 'agg'): 1b
265 new_backend = 'Agg' if save else 'Qt5Agg'
266 self.log.warning('Switching backend from %s to %s', backend, new_backend)
267 matplotlib.use(new_backend)
268 from matplotlib import pyplot as plt 1b
270 # Main animated plots
271 fig, axes = plt.subplots(nrows=2) 1b
272 title = f'{self.ref}' # ', from {period[0]:.1f}s - {period[1]:.1f}s' 1b
273 fig.suptitle(title, fontsize=16) 1b
275 wheel = self.data['wheel'] 1b
276 wheel_mask = self.alignment['to_mask'](wheel.timestamps) 1b
277 ts = self.data['camera_times'][self.alignment['label']] 1b
278 frame_numbers, = np.where(self.alignment['to_mask'](ts)) 1b
279 if energy: 1b
280 self.alignment['frames'] = video.frame_diffs(self.alignment['frames'], 2) 1b
281 frame_numbers = frame_numbers[1:-1] 1b
282 data = {'frame_ids': frame_numbers} 1b
284 def init_plot(): 1b
285 """
286 Plot the wheel data for the current trial
287 :return: None
288 """
289 data['im'] = axes[0].imshow(self.alignment['frames'][0]) 1b
290 axes[0].axis('off') 1b
291 axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames') 1b
293 # Plot the wheel position
294 ax = axes[1] 1b
295 ax.clear() 1b
296 ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask], '-x') 1b
298 ts_0 = frame_numbers[0] 1b
299 data['idx_0'] = ts_0 - self.alignment['dt_i'] 1b
300 ts_0 = ts[ts_0 + self.alignment['dt_i']] 1b
301 data['ln'] = ax.axvline(x=ts_0, color='k') 1b
302 ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)]) 1b
303 data['frame_num'] = 0 1b
304 mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0) 1b
306 data['marker'], = ax.plot(wheel.timestamps[wheel_mask][mkr], wheel.position[wheel_mask][mkr], 'r-x') 1b
307 ax.set_ylabel('Wheel position (rad))') 1b
308 ax.set_xlabel('Time (s))') 1b
309 return 1b
311 def animate(i): 1b
312 """
313 Callback for figure animation. Sets image data for current frame and moves pointer
314 along axis
315 :param i: unused; the current time step of the calling method
316 :return: None
317 """
318 if i < 0: 1b
319 data['frame_num'] -= 1
320 if data['frame_num'] < 0:
321 data['frame_num'] = len(self.alignment['frames']) - 1
322 else:
323 data['frame_num'] += 1 1b
324 if data['frame_num'] >= len(self.alignment['frames']): 1b
325 data['frame_num'] = 0 1b
326 i = data['frame_num'] # NB: This is index for current trial's frame list 1b
328 frame = self.alignment['frames'][i] 1b
329 t_x = ts[data['idx_0'] + i] 1b
330 data['ln'].set_xdata([t_x, t_x]) 1b
331 axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)]) 1b
332 data['im'].set_data(frame) 1b
334 mkr = find_nearest(wheel.timestamps[wheel_mask], t_x) 1b
335 data['marker'].set_data([wheel.timestamps[wheel_mask][mkr]], [wheel.position[wheel_mask][mkr]]) 1b
337 return data['im'], data['ln'], data['marker'] 1b
339 anim = animation.FuncAnimation(fig, animate, init_func=init_plot, 1b
340 frames=(range(len(self.alignment.df)) if save else cycle(range(60))), interval=20,
341 blit=False, repeat=not save, cache_frame_data=False)
342 anim.running = False 1b
344 def process_key(event): 1b
345 """
346 Callback for key presses.
347 :param event: a figure key_press_event
348 :return: None
349 """
350 if event.key.isspace():
351 if anim.running:
352 anim.event_source.stop()
353 else:
354 anim.event_source.start()
355 anim.running = ~anim.running
356 elif event.key == 'right':
357 if anim.running:
358 anim.event_source.stop()
359 anim.running = False
360 animate(1)
361 fig.canvas.draw()
362 elif event.key == 'left':
363 if anim.running:
364 anim.event_source.stop()
365 anim.running = False
366 animate(-1)
367 fig.canvas.draw()
369 fig.canvas.mpl_connect('key_press_event', process_key) 1b
371 # init_plot()
372 # while True:
373 # animate(0)
374 if save: 1b
375 filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0]) 1b
376 if isinstance(save, (str, Path)): 1b
377 filename = Path(save).joinpath(filename) 1b
378 self.log.info(f'Saving to {filename}') 1b
379 # Set up formatting for the movie files
380 Writer = animation.writers['ffmpeg'] 1b
381 writer = Writer(fps=24, metadata=dict(artist='Miles Wells'), bitrate=1800) 1b
382 anim.save(str(filename), writer=writer) 1b
383 else:
384 plt.show()
387class MotionAlignmentFullSession:
388 def __init__(self, session_path, label, **kwargs):
389 """
390 Class to extract camera times using video motion energy wheel alignment
391 :param session_path: path of the session
392 :param label: video label, only 'left' and 'right' videos are supported
393 :param kwargs: threshold - the threshold to apply when identifying frames with artefacts (default 20)
394 upload - whether to upload summary figure to alyx (default False)
395 twin - the window length used when computing the shifts between the wheel and video
396 nprocesses - the number of CPU processes to use
397 sync - the type of sync scheme used (options 'nidq' or 'bpod')
398 location - whether the code is being run on SDSC or not (options 'SDSC' or None)
399 """
400 self.session_path = session_path 1gfa
401 self.label = label 1gfa
402 self.threshold = kwargs.get('threshold', 20) 1gfa
403 self.upload = kwargs.get('upload', False) 1gfa
404 self.twin = kwargs.get('twin', 150) 1gfa
405 self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4)) 1gfa
407 self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None)) 1gfa
408 self.roi, self.mask = self.get_roi_mask() 1a
410 if self.upload: 1a
411 self.one = ONE(mode='remote')
412 self.one.alyx.authenticate()
413 self.eid = self.one.path2eid(self.session_path)
415 def load_data(self, sync='nidq', location=None):
416 """
417 Loads relevant data from disk to perform motion alignment
418 :param sync: type of sync used, 'nidq' or 'bpod'
419 :param location: where the code is being run, if location='SDSC', the dataset uuids are removed
420 when loading the data
421 :return:
422 """
423 def fix_keys(alf_object): 1gfa
424 """
425 Given an alf object removes the dataset uuid from the keys
426 :param alf_object:
427 :return:
428 """
429 ob = Bunch()
430 for key in alf_object.keys():
431 vals = alf_object[key]
432 ob[key.split('.')[0]] = vals
433 return ob
435 alf_path = self.session_path.joinpath('alf') 1gfa
436 wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent 1gfa
437 wheel = (fix_keys(alfio.load_object(wheel_path, 'wheel')) if location == 'SDSC' 1fa
438 else alfio.load_object(wheel_path, 'wheel'))
439 self.wheel_timestamps = wheel.timestamps 1fa
440 # Compute interpolated wheel position and wheel times
441 wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) 1fa
442 # Compute wheel velocity
443 self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) 1fa
444 # Load in original camera times
445 self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob(f'_iblrig_{self.label}Camera.raw*.mp4'))) 1fa
446 self.camera_meta = vidio.get_video_meta(self.camera_path) 1a
448 # TODO should read in the description file to get the correct sync location
449 if sync == 'nidq': 1a
450 # If the sync is 'nidq' we read in the camera ttls from the spikeglx sync object
451 sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection='raw_ephys_data') 1a
452 sr = get_sync_fronts(sync, chmap[f'{self.label}_camera']) 1a
453 self.ttls = sr.times[::2] 1a
454 else:
455 # Otherwise we assume the sync is 'bpod' and we read in the camera ttls from the raw bpod data
456 cam_extractor = cam.CameraTimestampsBpod(session_path=self.session_path)
457 cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data')
458 self.ttls = cam_extractor._times_from_bpod()
460 # Check if the ttl and video sizes match up
461 self.tdiff = self.ttls.size - self.camera_meta['length'] 1a
463 # Load in original camera times if available otherwise set to ttls
464 camera_times = next(alf_path.rglob(f'_ibl_{self.label}Camera.times*.npy'), None) 1a
465 self.camera_times = alfio.load_file_content(camera_times) if camera_times else self.ttls 1a
467 if self.tdiff < 0: 1a
468 # In this case there are fewer ttls than camera frames. This is not ideal, for now we pad the ttls with
469 # nans but if this is too many we reject the wheel alignment based on the qc
470 self.ttl_times = self.ttls
471 self.times = np.r_[self.ttl_times, np.full((np.abs(self.tdiff)), np.nan)]
472 if self.camera_times.size != self.camera_meta['length']:
473 self.camera_times = np.r_[self.camera_times, np.full((np.abs(self.tdiff)), np.nan)]
474 self.short_flag = True
475 elif self.tdiff > 0: 1a
476 # In this case there are more ttls than camera frames. This happens often, for now we remove the first
477 # tdiff ttls from the ttls
478 self.ttl_times = self.ttls[self.tdiff:] 1a
479 self.times = self.ttls[self.tdiff:] 1a
480 if self.camera_times.size != self.camera_meta['length']: 1a
481 self.camera_times = self.camera_times[self.tdiff:] 1a
482 self.short_flag = False 1a
484 # Compute the frame rate of the camera
485 self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) 1a
487 # We attempt to load in some behavior data (trials and dlc). This is only needed for the summary plots, having
488 # trial aligned paw velocity (from the dlc) is a nice sanity check to make sure the alignment went well
489 try: 1a
490 self.trials = alfio.load_file_content(next(alf_path.rglob('_ibl_trials.table*.pqt'))) 1a
491 self.dlc = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.dlc*.pqt')))
492 self.dlc = likelihood_threshold(self.dlc)
493 self.behavior = True
494 except (ALFObjectNotFound, StopIteration): 1a
495 self.behavior = False 1a
497 # Load in a single frame that we will use for the summary plot
498 self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0]) 1a
500 def get_roi_mask(self):
501 """
502 Compute the region of interest mask for a given camera. This corresponds to a box in the video that we will
503 use to compute the wheel motion energy
504 :return:
505 """
507 if self.label == 'right': 1a
508 roi = ((450, 512), (120, 200)) 1a
509 else:
510 roi = ((900, 1024), (850, 1010))
511 roi_mask = (*[slice(*r) for r in roi], 0) 1a
513 return roi, roi_mask 1a
515 def find_contaminated_frames(self, video_frames, thresold=20, normalise=True):
516 """
517 Finds frames in the video that have artefacts such as the mouse's paw or a human hand. In order to determine
518 frames with contamination an Otsu thresholding is applied to each frame to detect the artefact from the
519 background image
520 :param video_frames: np array of video frames (nframes, nwidth, nheight)
521 :param thresold: threshold to differentiate artefact from background
522 :param normalise: whether to normalise the threshold values for each frame to the baseline
523 :return: mask of frames that are contaminated
524 """
525 high = np.zeros((video_frames.shape[0]))
526 # Iterate through each frame and compute and store the otsu threshold value for each frame
527 for idx, frame in enumerate(video_frames):
528 ret, _ = cv2.threshold(cv2.GaussianBlur(frame, (5, 5), 0), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
529 high[idx] = ret
531 # If normalise is True, we divide the threshold values for each frame by the minimum value
532 if normalise:
533 high -= np.min(high)
535 # Identify the frames that have a threshold value greater than the specified threshold cutoff
536 contaminated_frames = np.where(high > thresold)[0]
538 return contaminated_frames
540 def compute_motion_energy(self, first, last, wg, iw):
541 """
542 Computes the video motion energy for frame indexes between first and last. This function is written to be run
543 in a parallel fashion jusing joblib.parallel
544 :param first: first frame index of frame interval to consider
545 :param last: last frame index of frame interval to consider
546 :param wg: WindowGenerator
547 :param iw: iteration of the WindowGenerator
548 :return:
549 """
551 if iw == wg.nwin - 1:
552 return
554 # Open the video and read in the relvant video frames between first idx and last idx
555 cap = cv2.VideoCapture(self.camera_path)
556 frames = vidio.get_video_frames_preload(cap, np.arange(first, last), mask=self.mask)
557 # Identify if any of the frames have artefacts in them
558 idx = self.find_contaminated_frames(frames, self.threshold)
560 # If some of the frames are contaminated we find all the continuous intervals of contamination
561 # and set the value for contaminated pixels for these frames to the average of the first frame before and after
562 # this contamination interval
563 if len(idx) != 0:
565 before_status = False
566 after_status = False
568 counter = 0
569 n_frames = 200
570 # If it is the first frame that is contaminated, we need to read in a bit more of the video to find a
571 # frame prior to contamination. We attempt this 20 times, after that we just take the value for the first
572 # frame
573 while np.any(idx == 0) and counter < 20 and iw != 0:
574 n_before_offset = (counter + 1) * n_frames
575 first -= n_frames
576 extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(first - n_frames, first),
577 mask=self.mask)
578 frames = np.concatenate([extra_frames, frames], axis=0)
580 idx = self.find_contaminated_frames(frames, self.threshold)
581 before_status = True
582 counter += 1
583 if counter > 0:
584 print(f'In before: {counter}')
586 counter = 0
587 # If it is the last frame that is contaminated, we need to read in a bit more of the video to find a
588 # frame after the contamination. We attempt this 20 times, after that we just take the value for the last
589 # frame
590 while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1:
591 n_after_offset = (counter + 1) * n_frames
592 last += n_frames
593 extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), mask=self.mask)
594 frames = np.concatenate([frames, extra_frames], axis=0)
595 idx = self.find_contaminated_frames(frames, self.threshold)
596 after_status = True
597 counter += 1
599 if counter > 0:
600 print(f'In after: {counter}')
602 # We find all the continuous intervals that contain contamination and fix the affected pixels
603 # by taking the average value of the frame prior and after contamination
604 intervals = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1)
605 for ints in intervals:
606 if len(ints) > 0 and ints[0] == 0:
607 ints = ints[1:]
608 if len(ints) > 0 and ints[-1] == frames.shape[0] - 1:
609 ints = ints[:-1]
610 th_all = np.zeros_like(frames[0])
611 # We find all affected pixels
612 for idx in ints:
613 img = np.copy(frames[idx])
614 blur = cv2.GaussianBlur(img, (5, 5), 0)
615 ret, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
616 th = cv2.GaussianBlur(th, (5, 5), 10)
617 th_all += th
618 # Compute the average image of the frame prior and after the interval
619 vals = np.mean(np.dstack([frames[ints[0] - 1], frames[ints[-1] + 1]]), axis=-1)
620 # For each frame set the affected pixels to the value of the clean average image
621 for idx in ints:
622 img = frames[idx]
623 img[th_all > 0] = vals[th_all > 0]
625 # If we have read in extra video frames we need to cut these off and make sure we only
626 # consider the frames between the interval first and last given as args
627 if before_status:
628 frames = frames[n_before_offset:]
629 if after_status:
630 frames = frames[:(-1 * n_after_offset)]
632 # Once the frames have been cleaned we compute the motion energy between frames
633 frame_me, _ = video.motion_energy(frames, diff=2, normalize=False)
635 cap.release()
637 return frame_me[2:]
639 def compute_shifts(self, times, me, first, last, iw, wg):
640 """
641 Compute the cross-correlation between the video motion energy and the wheel velocity to find the mismatch
642 between the camera ttls and the video frames. This function is written to run in a parallel manner using
643 joblib.parallel
645 :param times: the times of the video frames across the whole session (ttls)
646 :param me: the video motion energy computed across the whole session
647 :param first: first time idx to consider
648 :param last: last time idx to consider
649 :param wg: WindowGenerator
650 :param iw: iteration of the WindowGenerator
651 :return:
652 """
654 # If we are in the last window we exit
655 if iw == wg.nwin - 1: 1a
656 return np.nan, np.nan 1a
658 # Find the time interval we are interested in
659 t_first = times[first] 1a
660 t_last = times[last] 1a
662 # If both times during this interval are nan exit
663 if np.isnan(t_last) and np.isnan(t_first): 1a
664 return np.nan, np.nan
665 # If only the last time is nan, we find the last non nan time value
666 elif np.isnan(t_last): 1a
667 t_last = times[np.where(~np.isnan(times))[0][-1]]
669 # Find the mask of timepoints that fall in this interval
670 mask = np.logical_and(times >= t_first, times <= t_last) 1a
671 # Restrict the video motion energy to this interval and normalise the values
672 align_me = me[np.where(mask)[0]] 1a
673 align_me = (align_me - np.nanmin(align_me)) / (np.nanmax(align_me) - np.nanmin(align_me)) 1a
675 # Find closest timepoints in wheel that match the time interval
676 wh_mask = np.logical_and(self.wheel_time >= t_first, self.wheel_time <= t_last) 1a
677 if np.sum(wh_mask) == 0: 1a
678 return np.nan, np.nan
679 # Find the mask for the wheel times
680 xs = np.searchsorted(self.wheel_time[wh_mask], times[mask]) 1a
681 xs[xs == np.sum(wh_mask)] = np.sum(wh_mask) - 1 1a
682 # Convert to normalized speed
683 vs = np.abs(self.wheel_vel[wh_mask][xs]) 1a
684 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1a
686 # Account for nan values in the video motion energy
687 isnan = np.isnan(align_me) 1a
688 if np.sum(isnan) > 0: 1a
689 where_nan = np.where(isnan)[0] 1a
690 assert where_nan[0] == 0 1a
691 assert where_nan[-1] == np.sum(isnan) - 1 1a
693 if np.all(isnan): 1a
694 return np.nan, np.nan
696 # Compute the cross correlation between the video motion energy and the wheel speed
697 xcorr = signal.correlate(align_me[~isnan], vs[~isnan]) 1a
698 # The max value of the cross correlation indicates the shift that needs to be applied
699 # The +2 comes from the fact that the video motion energy was computed from the difference between frames
700 shift = np.nanargmax(xcorr) - align_me[~isnan].size + 2 1a
702 return shift, t_first + (t_last - t_first) / 2 1a
704 def clean_shifts(self, x, n=1):
705 """
706 Removes artefacts from the computed shifts across time. We assume that the shifts should never increase
707 over time and that the jump between consecutive shifts shouldn't be greater than 1
708 :param x: computed shifts
709 :param n: condition to apply
710 :return:
711 """
712 y = x.copy() 1a
713 dy = np.diff(y, prepend=y[0]) 1a
714 while True: 1a
715 pos = np.where(dy == 1)[0] if n == 1 else np.where(dy > 2)[0] 1a
716 # added frames: this doesn't make sense and this is noise
717 if pos.size == 0: 1a
718 break 1a
719 neg = np.where(dy == -1)[0] if n == 1 else np.where(dy < -2)[0]
721 if len(pos) > len(neg):
722 neg = np.append(neg, dy.size - 1)
724 iss = np.minimum(np.searchsorted(neg, pos), neg.size - 1)
725 imin = np.argmin(np.minimum(np.abs(pos - neg[iss - 1]), np.abs(pos - neg[iss])))
727 idx = np.max([0, iss[imin] - 1])
728 ineg = neg[idx:iss[imin] + 1]
729 ineg = ineg[np.argmin(np.abs(pos[imin] - ineg))]
730 dy[pos[imin]] = 0
731 dy[ineg] = 0
733 return np.cumsum(dy) + y[0] 1a
735 def qc_shifts(self, shifts, shifts_filt):
736 """
737 Compute qc values for the wheel alignment. We consider 4 things
738 1. The number of camera ttl values that are missing (when we have less ttls than video frames)
739 2. The number of shifts that have nan values, this means the video motion energy computation
740 3. The number of large jumps (>10) between the computed shifts
741 4. The number of jumps (>1) between the shifts after they have been cleaned
743 :param shifts: np.array of shifts over session
744 :param shifts_filt: np.array of shifts after being cleaned over session
745 :return:
746 """
748 ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0 1a
749 nan_per = (np.sum(np.isnan(shifts_filt)) / shifts_filt.size) * 100 1a
750 shifts_sum = np.where(np.abs(np.diff(shifts)) > 10)[0].size 1a
751 shifts_filt_sum = np.where(np.abs(np.diff(shifts_filt)) > 1)[0].size 1a
753 qc = dict() 1a
754 qc['ttl_per'] = ttl_per 1a
755 qc['nan_per'] = nan_per 1a
756 qc['shifts_sum'] = shifts_sum 1a
757 qc['shifts_filt_sum'] = shifts_filt_sum 1a
759 qc_outcome = True 1a
760 # If more than 10% of ttls are missing we don't get new times
761 if ttl_per > 10: 1a
762 qc_outcome = False
763 # If too many of the shifts are nans it means the alignment is not accurate
764 if nan_per > 40: 1a
765 qc_outcome = False
766 # If there are too many artefacts could be errors
767 if shifts_sum > 60: 1a
768 qc_outcome = False
769 # If there are jumps > 1 in the filtered shifts then there is a problem
770 if shifts_filt_sum > 0: 1a
771 qc_outcome = False
773 return qc, qc_outcome 1a
775 def extract_times(self, shifts_filt, t_shifts):
776 """
777 Extracts new camera times after applying the computed shifts across the session
779 :param shifts_filt: filtered shifts computed across session
780 :param t_shifts: time point of computed shifts
781 :return:
782 """
784 # Compute the interpolation function to apply to the ttl times
785 t_new = t_shifts - (shifts_filt * 1 / self.frate) 1a
786 fcn = interpolate.interp1d(t_shifts, t_new, fill_value="extrapolate") 1a
787 # Apply the function and get out new times
788 new_times = fcn(self.ttl_times) 1a
790 # If we are missing ttls then interpolate and append the correct number at the end
791 if self.tdiff < 0: 1a
792 to_app = (np.arange(np.abs(self.tdiff), ) + 1) / self.frate + new_times[-1]
793 new_times = np.r_[new_times, to_app]
795 return new_times 1a
797 @staticmethod
798 def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True,
799 norm=False, axs=None):
800 """
801 Compute and plot trial aligned spike rasters and psth
802 :param spike_times: times of variable
803 :param events: trial times to align to
804 :param trial_idx: trial idx to sort by
805 :param dividers:
806 :param colors:
807 :param labels:
808 :param weights:
809 :param fr:
810 :param norm:
811 :param axs:
812 :return:
813 """
814 pre_time = 0.4
815 post_time = 1
816 raster_bin = 0.01
817 psth_bin = 0.05
818 raster, t_raster = bin_spikes(
819 spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=raster_bin, weights=weights)
820 psth, t_psth = bin_spikes(
821 spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=psth_bin, weights=weights)
823 if fr:
824 psth = psth / psth_bin
826 if norm:
827 psth = psth - np.repeat(psth[:, 0][:, np.newaxis], psth.shape[1], axis=1)
828 raster = raster - np.repeat(raster[:, 0][:, np.newaxis], raster.shape[1], axis=1)
830 dividers = [0] + dividers + [len(trial_idx)]
831 if axs is None:
832 fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, sharex=True)
833 else:
834 fig = axs[0].get_figure()
836 label, lidx = np.unique(labels, return_index=True)
837 label_pos = []
838 for lab, lid in zip(label, lidx):
839 idx = np.where(np.array(labels) == lab)[0]
840 for iD in range(len(idx)):
841 if iD == 0:
842 t_ids = trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]
843 t_ints = dividers[idx[iD] + 1] - dividers[idx[iD]]
844 else:
845 t_ids = np.r_[t_ids, trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]]
846 t_ints = np.r_[t_ints, dividers[idx[iD] + 1] - dividers[idx[iD]]]
848 psth_div = np.nanmean(psth[t_ids], axis=0)
849 std_div = np.nanstd(psth[t_ids], axis=0) / np.sqrt(len(t_ids))
851 axs[0].fill_between(t_psth, psth_div - std_div, psth_div + std_div, alpha=0.4, color=colors[lid])
852 axs[0].plot(t_psth, psth_div, alpha=1, color=colors[lid])
854 lab_max = idx[np.argmax(t_ints)]
855 label_pos.append((dividers[lab_max + 1] - dividers[lab_max]) / 2 + dividers[lab_max])
857 axs[1].imshow(raster[trial_idx], cmap='binary', origin='lower',
858 extent=[np.min(t_raster), np.max(t_raster), 0, len(trial_idx)], aspect='auto')
860 width = raster_bin * 4
861 for iD in range(len(dividers) - 1):
862 axs[1].fill_between([post_time + raster_bin / 2, post_time + raster_bin / 2 + width],
863 [dividers[iD + 1], dividers[iD + 1]], [dividers[iD], dividers[iD]], color=colors[iD])
865 axs[1].set_xlim([-1 * pre_time, post_time + raster_bin / 2 + width])
866 secax = axs[1].secondary_yaxis('right')
868 secax.set_yticks(label_pos)
869 secax.set_yticklabels(label, rotation=90, rotation_mode='anchor', ha='center')
870 for ic, c in enumerate(np.array(colors)[lidx]):
871 secax.get_yticklabels()[ic].set_color(c)
873 axs[0].axvline(0, *axs[0].get_ylim(), c='k', ls='--', zorder=10) # TODO this doesn't always work
874 axs[1].axvline(0, *axs[1].get_ylim(), c='k', ls='--', zorder=10)
876 return fig, axs
878 def plot_with_behavior(self):
879 """
880 Makes a summary figure of the alignment when behaviour data is available
881 :return:
882 """
884 self.dlc = likelihood_threshold(self.dlc)
885 trial_idx, dividers = find_trial_ids(self.trials, sort='side')
886 feature_ext = get_speed(self.dlc, self.camera_times, self.label, feature='paw_r')
887 feature_new = get_speed(self.dlc, self.new_times, self.label, feature='paw_r')
889 fig = plt.figure()
890 fig.set_size_inches(15, 9)
891 gs = gridspec.GridSpec(1, 5, figure=fig, width_ratios=[4, 1, 1, 1, 3], wspace=0.3, hspace=0.5)
892 gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0])
893 ax01 = fig.add_subplot(gs0[0, 0])
894 ax02 = fig.add_subplot(gs0[1, 0])
895 ax03 = fig.add_subplot(gs0[2, 0])
896 gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1], height_ratios=[1, 3])
897 ax11 = fig.add_subplot(gs1[0, 0])
898 ax12 = fig.add_subplot(gs1[1, 0])
899 gs2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 2], height_ratios=[1, 3])
900 ax21 = fig.add_subplot(gs2[0, 0])
901 ax22 = fig.add_subplot(gs2[1, 0])
902 gs3 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 3], height_ratios=[1, 3])
903 ax31 = fig.add_subplot(gs3[0, 0])
904 ax32 = fig.add_subplot(gs3[1, 0])
905 gs4 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 4])
906 ax41 = fig.add_subplot(gs4[0, 0])
907 ax42 = fig.add_subplot(gs4[1, 0])
909 ax01.plot(self.t_shifts, self.shifts, label='shifts')
910 ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt')
911 ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10)
912 ax01.legend()
913 ax01.set_ylabel('Frames')
914 ax01.set_xlabel('Time in session')
916 xs = np.searchsorted(self.ttl_times, self.t_shifts)
917 ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps']
918 ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl')
919 ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10)
920 ax02.legend()
921 ax02.set_ylabel('Frames')
922 ax02.set_xlabel('Time in session')
924 ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new')
925 ax03.legend()
926 ax03.set_ylim(-5, 5)
927 ax03.set_ylabel('Frames')
928 ax03.set_xlabel('Time in session')
930 self.single_cluster_raster(self.wheel_timestamps, self.trials['firstMovement_times'].values, trial_idx, dividers,
931 ['g', 'y'], ['left', 'right'], weights=self.wheel_vel, fr=False, axs=[ax11, ax12])
932 ax11.sharex(ax12)
933 ax11.set_ylabel('Wheel velocity')
934 ax11.set_title('Wheel')
935 ax12.set_xlabel('Time from first move')
937 self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'],
938 ['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22])
939 ax21.sharex(ax22)
940 ax21.set_ylabel('Paw r velocity')
941 ax21.set_title('Extracted times')
942 ax22.set_xlabel('Time from first move')
944 self.single_cluster_raster(self.new_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'],
945 ['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32])
946 ax31.sharex(ax32)
947 ax31.set_ylabel('Paw r velocity')
948 ax31.set_title('New times')
949 ax32.set_xlabel('Time from first move')
951 ax41.imshow(self.frame_example[0])
952 rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1],
953 self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none')
954 ax41.add_patch(rect)
956 ax42.plot(self.all_me)
958 return fig
960 def plot_without_behavior(self):
961 """
962 Makes a summary figure of the alignment when behaviour data is not available
963 :return:
964 """
966 fig = plt.figure()
967 fig.set_size_inches(7, 7)
968 gs = gridspec.GridSpec(1, 2, figure=fig)
969 gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0])
970 ax01 = fig.add_subplot(gs0[0, 0])
971 ax02 = fig.add_subplot(gs0[1, 0])
972 ax03 = fig.add_subplot(gs0[2, 0])
974 gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1])
975 ax04 = fig.add_subplot(gs1[0, 0])
976 ax05 = fig.add_subplot(gs1[1, 0])
978 ax01.plot(self.t_shifts, self.shifts, label='shifts')
979 ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt')
980 ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10)
981 ax01.legend()
982 ax01.set_ylabel('Frames')
983 ax01.set_xlabel('Time in session')
985 xs = np.searchsorted(self.ttl_times, self.t_shifts)
986 ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps']
987 ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl')
988 ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10)
989 ax02.legend()
990 ax02.set_ylabel('Frames')
991 ax02.set_xlabel('Time in session')
993 ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new')
994 ax03.legend()
995 ax03.set_ylim(-5, 5)
996 ax03.set_ylabel('Frames')
997 ax03.set_xlabel('Time in session')
999 ax04.imshow(self.frame_example[0])
1000 rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1],
1001 self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none')
1002 ax04.add_patch(rect)
1004 ax05.plot(self.all_me)
1006 return fig
1008 def process(self):
1009 """
1010 Main function used to apply the video motion wheel alignment to the camera times. This function does the
1011 following
1012 1. Computes the video motion energy across the whole session (computed in windows and parallelised)
1013 2. Computes the shift that should be applied to the camera times across the whole session by computing
1014 the cross correlation between the video motion energy and the wheel speed (computed in
1015 overlapping windows and parallelised)
1016 3. Removes artefacts from the computed shifts
1017 4. Computes the qc for the wheel alignment
1018 5. Extracts the new camera times using the shifts computed from the video wheel alignment
1019 6. If upload is True, creates a summary plot of the alignment and uploads the figure to the relevant session
1020 on alyx
1021 :return:
1022 """
1024 # Compute the motion energy of the wheel for the whole video
1025 wg = WindowGenerator(self.camera_meta['length'], 5000, 4) 1a
1026 out = Parallel(n_jobs=self.nprocess)( 1a
1027 delayed(self.compute_motion_energy)(first, last, wg, iw) for iw, (first, last) in enumerate(wg.firstlast))
1028 # Concatenate the motion energy into one big array
1029 self.all_me = np.array([]) 1a
1030 for vals in out[:-1]: 1a
1031 self.all_me = np.r_[self.all_me, vals] 1a
1033 toverlap = self.twin - 1 1a
1034 all_me = np.r_[np.full((int(self.camera_meta['fps'] * toverlap)), np.nan), self.all_me] 1a
1035 to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / self.frate)[::-1] 1a
1036 times = np.r_[to_app, self.times] 1a
1038 wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap)) 1a
1040 out = Parallel(n_jobs=1)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) 1a
1041 for iw, (first, last) in enumerate(wg.firstlast))
1043 self.shifts = np.array([]) 1a
1044 self.t_shifts = np.array([]) 1a
1045 for vals in out[:-1]: 1a
1046 self.shifts = np.r_[self.shifts, vals[0]] 1a
1047 self.t_shifts = np.r_[self.t_shifts, vals[1]] 1a
1049 idx = np.bitwise_and(self.t_shifts >= self.ttl_times[0], self.t_shifts < self.ttl_times[-1]) 1a
1050 self.shifts = self.shifts[idx] 1a
1051 self.t_shifts = self.t_shifts[idx] 1a
1052 shifts_filt = ndimage.percentile_filter(self.shifts, 80, 120) 1a
1053 shifts_filt = self.clean_shifts(shifts_filt, n=1) 1a
1054 self.shifts_filt = self.clean_shifts(shifts_filt, n=2) 1a
1056 self.qc, self.qc_outcome = self.qc_shifts(self.shifts, self.shifts_filt) 1a
1058 self.new_times = self.extract_times(self.shifts_filt, self.t_shifts) 1a
1060 if self.upload: 1a
1061 fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior()
1062 save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', f'video_wheel_alignment_{self.label}.png'))
1063 save_fig_path.parent.mkdir(exist_ok=True, parents=True)
1064 fig.savefig(save_fig_path)
1065 snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one)
1066 snp.outputs = [save_fig_path]
1067 snp.register_images(widths=['orig'])
1068 plt.close(fig)
1070 return self.new_times 1a