Coverage for ibllib/io/extractors/video_motion.py: 45%
622 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""
2A module for aligning the wheel motion with the rotary encoder. Currently used by the camera QC
3in order to check timestamp alignment.
4"""
5import matplotlib
6import matplotlib.pyplot as plt
7import matplotlib.gridspec as gridspec
8from matplotlib.widgets import RectangleSelector
9import numpy as np
10from scipy import signal, ndimage, interpolate
11import cv2
12from itertools import cycle
13import matplotlib.animation as animation
14import logging
15from pathlib import Path
16from joblib import Parallel, delayed, cpu_count
18from ibldsp.utils import WindowGenerator
19from one.api import ONE
20import ibllib.io.video as vidio
21from iblutil.util import Bunch
22from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map
23import ibllib.io.raw_data_loaders as raw
24import ibllib.io.extractors.camera as cam
25from ibllib.plots.snapshot import ReportSnapshot
26import brainbox.video as video
27import brainbox.behavior.wheel as wh
28from brainbox.singlecell import bin_spikes
29from brainbox.behavior.dlc import likelihood_threshold, get_speed
30from brainbox.task.trials import find_trial_ids
31import one.alf.io as alfio
32from one.alf.exceptions import ALFObjectNotFound
33from one.alf.spec import is_session_path, is_uuid_string
36def find_nearest(array, value):
37 array = np.asarray(array) 1deb
38 idx = (np.abs(array - value)).argmin() 1deb
39 return idx 1deb
42class MotionAlignment:
43 roi = {'left': ((800, 1020), (233, 1096)), 'right': ((426, 510), (104, 545)), 'body': ((402, 481), (31, 103))}
45 def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), stream=False, **kwargs):
46 self.one = one or ONE() 1deb
47 self.eid = eid 1deb
48 self.session_path = kwargs.pop('session_path', None) or self.one.eid2path(eid) 1deb
49 self.ref = self.one.dict2ref(self.one.path2ref(self.session_path)) 1deb
50 self.log = log 1deb
51 self.trials = self.wheel = self.camera_times = None 1deb
52 raw_cam_path = self.session_path.joinpath('raw_video_data') 1deb
53 camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*')) 1deb
54 if stream: 1deb
55 self.video_paths = vidio.url_from_eid(self.eid)
56 else:
57 self.video_paths = {vidio.label_from_path(x): x for x in camera_path} 1deb
58 self.data = Bunch() 1deb
59 self.alignment = Bunch() 1deb
61 def align_all_trials(self, side='all'):
62 """Align all wheel motion for all trials"""
63 if self.trials is None:
64 self.load_data()
65 if side == 'all':
66 side = self.video_paths.keys()
67 if not isinstance(side, str):
68 # Try to iterate over sides
69 [self.align_all_trials(s) for s in side]
70 if side not in self.video_paths:
71 raise ValueError(f'{side} camera video file not found')
72 # Align each trial sequentially
73 for i in np.arange(self.trials['intervals'].shape[0]):
74 self.align_motion(i, display=False)
76 @staticmethod
77 def set_roi(video_path):
78 """Manually set the ROIs for a given set of videos
79 TODO Improve docstring
80 TODO A method for setting ROIs by label
81 """
82 frame = vidio.get_video_frame(str(video_path), 0)
84 def line_select_callback(eclick, erelease):
85 """
86 Callback for line selection.
88 *eclick* and *erelease* are the press and release events.
89 """
90 x1, y1 = eclick.xdata, eclick.ydata
91 x2, y2 = erelease.xdata, erelease.ydata
92 print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
93 return np.array([[x1, x2], [y1, y2]])
95 plt.imshow(frame)
96 roi = RectangleSelector(plt.gca(), line_select_callback, drawtype='box', useblit=True, button=[1, 3],
97 # don't use middle button
98 minspanx=5, minspany=5, spancoords='pixels', interactive=True)
99 plt.show()
100 ((x1, x2, *_), (y1, *_, y2)) = roi.corners
101 col = np.arange(round(x1), round(x2), dtype=int)
102 row = np.arange(round(y1), round(y2), dtype=int)
103 return col, row
105 def load_data(self, download=False):
106 """
107 Load wheel, trial and camera timestamp data
108 :return: wheel, trials
109 """
110 if download:
111 self.data.wheel = self.one.load_object(self.eid, 'wheel')
112 self.data.trials = self.one.load_object(self.eid, 'trials')
113 cam, det = self.one.load_datasets(self.eid, ['*Camera.times*'])
114 self.data.camera_times = {vidio.label_from_path(d['rel_path']): ts for ts, d in zip(cam, det)}
115 else:
116 alf_path = self.session_path / 'alf'
117 wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent
118 self.data.wheel = alfio.load_object(wheel_path, 'wheel', short_keys=True)
119 trials_path = next(alf_path.rglob('*trials.table*')).parent
120 self.data.trials = alfio.load_object(trials_path, 'trials')
121 self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x) for x in
122 alf_path.rglob('*Camera.times*')}
123 assert all(x is not None for x in self.data.values())
125 def _set_eid_or_path(self, session_path_or_eid):
126 """Parse a given eID or session path
127 If a session UUID is given, resolves and stores the local path and vice versa
128 :param session_path_or_eid: A session eid or path
129 :return:
130 """
131 self.eid = None
132 if is_uuid_string(str(session_path_or_eid)):
133 self.eid = session_path_or_eid
134 # Try to set session_path if data is found locally
135 self.session_path = self.one.eid2path(self.eid)
136 elif is_session_path(session_path_or_eid):
137 self.session_path = Path(session_path_or_eid)
138 if self.one is not None:
139 self.eid = self.one.path2eid(self.session_path)
140 if not self.eid:
141 self.log.warning('Failed to determine eID from session path')
142 else:
143 self.log.error('Cannot run alignment: an experiment uuid or session path is required')
144 raise ValueError("'session' must be a valid session path or uuid")
146 def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False):
147 """
148 Align video to the wheel using cross-correlation of the video motion signal and the rotary
149 encoder.
151 Parameters
152 ----------
153 period : (float, float)
154 The time period over which to do the alignment.
155 side : {'left', 'right'}
156 With which camera to perform the alignment.
157 sd_thresh : float
158 For plotting where the motion energy goes above this standard deviation threshold.
159 display : bool
160 When true, displays the aligned wheel motion energy along with the rotary encoder
161 signal.
163 Returns
164 -------
165 int
166 Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder
167 signal. Negative values mean the video was shifted backwards with respect to the wheel
168 timestamps.
169 float
170 The peak cross-correlation.
171 numpy.ndarray
172 The motion energy used in the cross-correlation, i.e. the frame difference for the
173 period given.
174 """
175 # Get data samples within period
176 wheel = self.data['wheel'] 1deb
177 self.alignment.label = side 1deb
178 self.alignment.to_mask = lambda ts: np.logical_and(ts >= period[0], ts <= period[1]) 1deb
179 camera_times = self.data['camera_times'][side] 1deb
180 cam_mask = self.alignment.to_mask(camera_times) 1deb
181 frame_numbers, = np.where(cam_mask) 1deb
183 if frame_numbers.size == 0: 1deb
184 raise ValueError('No frames during given period') 1b
186 # Motion Energy
187 camera_path = self.video_paths[side] 1deb
188 roi = (*[slice(*r) for r in self.roi[side]], 0) 1deb
189 try: 1deb
190 # TODO Add function arg to make grayscale
191 self.alignment.frames = vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi) 1deb
192 assert self.alignment.frames.size != 0 1deb
193 except AssertionError:
194 self.log.error('Failed to open video')
195 return None, None, None
196 self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2) 1deb
197 self.alignment.period = period # For plotting 1deb
199 # Calculate rotary encoder velocity trace
200 x = camera_times[cam_mask] 1deb
201 Fs = 1000 1deb
202 pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs) 1deb
203 v, _ = wh.velocity_filtered(pos, Fs) 1deb
204 interp_mask = self.alignment.to_mask(t) 1deb
205 # Convert to normalized speed
206 xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x]) 1deb
207 vs = np.abs(v[interp_mask][xs]) 1deb
208 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1deb
210 # FIXME This can be used as a goodness of fit measure
211 USE_CV2 = False 1deb
212 if USE_CV2: 1deb
213 # convert from numpy format to openCV format
214 dfCV = np.float32(self.alignment.df.reshape((-1, 1)))
215 reCV = np.float32(vs.reshape((-1, 1)))
217 # perform cross correlation
218 resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED)
220 # convert result back to numpy array
221 xcorr = np.asarray(resultCv)
222 else:
223 xcorr = signal.correlate(self.alignment.df, vs) 1deb
225 # Cross correlate wheel speed trace with the motion energy
226 CORRECTION = 2 1deb
227 self.alignment.c = max(xcorr) 1deb
228 self.alignment.xcorr = np.argmax(xcorr) 1deb
229 self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION 1deb
230 self.log.info(f'{side} camera, adjusted by {self.alignment.dt_i} frames') 1deb
232 if display: 1deb
233 # Plot the motion energy
234 fig, ax = plt.subplots(2, 1, sharex='all')
235 y = np.pad(self.alignment.df, 1, 'edge')
236 ax[0].plot(x, y, '-x', label='wheel motion energy')
237 thresh = stDev > sd_thresh
238 ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, linewidth=0.5, linestyle=':',
239 label=f'>{sd_thresh} s.d. diff')
240 ax[1].plot(t[interp_mask], np.abs(v[interp_mask]))
242 # Plot other stuff
243 dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]])
244 fps = 1 / np.diff(camera_times).mean()
245 ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)')
246 ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps))
247 ax[0].set_ylabel('rate of change (a.u.)')
248 ax[0].legend()
249 ax[1].set_ylabel('wheel speed (rad / s)')
250 ax[1].set_xlabel('Time (s)')
252 title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s'
253 fig.suptitle(title, fontsize=16)
254 fig.set_size_inches(19.2, 9.89)
256 return self.alignment.dt_i, self.alignment.c, self.alignment.df 1deb
258 def plot_alignment(self, energy=True, save=False):
259 if not self.alignment: 1b
260 self.log.error('No alignment data, run `align_motion` first')
261 return
262 # Change backend based on save flag
263 backend = matplotlib.get_backend().lower() 1b
264 if (save and backend != 'agg') or (not save and backend == 'agg'): 1b
265 new_backend = 'Agg' if save else 'Qt5Agg'
266 self.log.warning('Switching backend from %s to %s', backend, new_backend)
267 matplotlib.use(new_backend)
268 from matplotlib import pyplot as plt 1b
270 # Main animated plots
271 fig, axes = plt.subplots(nrows=2) 1b
272 title = f'{self.ref}' # ', from {period[0]:.1f}s - {period[1]:.1f}s' 1b
273 fig.suptitle(title, fontsize=16) 1b
275 wheel = self.data['wheel'] 1b
276 wheel_mask = self.alignment['to_mask'](wheel.timestamps) 1b
277 ts = self.data['camera_times'][self.alignment['label']] 1b
278 frame_numbers, = np.where(self.alignment['to_mask'](ts)) 1b
279 if energy: 1b
280 self.alignment['frames'] = video.frame_diffs(self.alignment['frames'], 2) 1b
281 frame_numbers = frame_numbers[1:-1] 1b
282 data = {'frame_ids': frame_numbers} 1b
284 def init_plot(): 1b
285 """
286 Plot the wheel data for the current trial
287 :return: None
288 """
289 data['im'] = axes[0].imshow(self.alignment['frames'][0]) 1b
290 axes[0].axis('off') 1b
291 axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames') 1b
293 # Plot the wheel position
294 ax = axes[1] 1b
295 ax.clear() 1b
296 ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask], '-x') 1b
298 ts_0 = frame_numbers[0] 1b
299 data['idx_0'] = ts_0 - self.alignment['dt_i'] 1b
300 ts_0 = ts[ts_0 + self.alignment['dt_i']] 1b
301 data['ln'] = ax.axvline(x=ts_0, color='k') 1b
302 ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)]) 1b
303 data['frame_num'] = 0 1b
304 mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0) 1b
306 data['marker'], = ax.plot(wheel.timestamps[wheel_mask][mkr], wheel.position[wheel_mask][mkr], 'r-x') 1b
307 ax.set_ylabel('Wheel position (rad))') 1b
308 ax.set_xlabel('Time (s))') 1b
309 return 1b
311 def animate(i): 1b
312 """
313 Callback for figure animation. Sets image data for current frame and moves pointer
314 along axis
315 :param i: unused; the current time step of the calling method
316 :return: None
317 """
318 if i < 0: 1b
319 data['frame_num'] -= 1
320 if data['frame_num'] < 0:
321 data['frame_num'] = len(self.alignment['frames']) - 1
322 else:
323 data['frame_num'] += 1 1b
324 if data['frame_num'] >= len(self.alignment['frames']): 1b
325 data['frame_num'] = 0 1b
326 i = data['frame_num'] # NB: This is index for current trial's frame list 1b
328 frame = self.alignment['frames'][i] 1b
329 t_x = ts[data['idx_0'] + i] 1b
330 data['ln'].set_xdata([t_x, t_x]) 1b
331 axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)]) 1b
332 data['im'].set_data(frame) 1b
334 mkr = find_nearest(wheel.timestamps[wheel_mask], t_x) 1b
335 data['marker'].set_data([wheel.timestamps[wheel_mask][mkr]], [wheel.position[wheel_mask][mkr]]) 1b
337 return data['im'], data['ln'], data['marker'] 1b
339 anim = animation.FuncAnimation(fig, animate, init_func=init_plot, 1b
340 frames=(range(len(self.alignment.df)) if save else cycle(range(60))), interval=20,
341 blit=False, repeat=not save, cache_frame_data=False)
342 anim.running = False 1b
344 def process_key(event): 1b
345 """
346 Callback for key presses.
347 :param event: a figure key_press_event
348 :return: None
349 """
350 if event.key.isspace():
351 if anim.running:
352 anim.event_source.stop()
353 else:
354 anim.event_source.start()
355 anim.running = ~anim.running
356 elif event.key == 'right':
357 if anim.running:
358 anim.event_source.stop()
359 anim.running = False
360 animate(1)
361 fig.canvas.draw()
362 elif event.key == 'left':
363 if anim.running:
364 anim.event_source.stop()
365 anim.running = False
366 animate(-1)
367 fig.canvas.draw()
369 fig.canvas.mpl_connect('key_press_event', process_key) 1b
371 # init_plot()
372 # while True:
373 # animate(0)
374 if save: 1b
375 filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0]) 1b
376 if isinstance(save, (str, Path)): 1b
377 filename = Path(save).joinpath(filename) 1b
378 self.log.info(f'Saving to {filename}') 1b
379 # Set up formatting for the movie files
380 Writer = animation.writers['ffmpeg'] 1b
381 writer = Writer(fps=24, metadata=dict(artist='Miles Wells'), bitrate=1800) 1b
382 anim.save(str(filename), writer=writer) 1b
383 else:
384 plt.show()
387class MotionAlignmentFullSession:
388 def __init__(self, session_path, label, **kwargs):
389 """
390 Class to extract camera times using video motion energy wheel alignment
391 :param session_path: path of the session
392 :param label: video label, only 'left' and 'right' videos are supported
393 :param kwargs: threshold - the threshold to apply when identifying frames with artefacts (default 20)
394 upload - whether to upload summary figure to alyx (default False)
395 twin - the window length used when computing the shifts between the wheel and video
396 nprocesses - the number of CPU processes to use
397 sync - the type of sync scheme used (options 'nidq' or 'bpod')
398 location - whether the code is being run on SDSC or not (options 'SDSC' or None)
399 """
400 self.session_path = session_path 1hfga
401 self.label = label 1hfga
402 self.threshold = kwargs.get('threshold', 20) 1hfga
403 self.upload = kwargs.get('upload', False) 1hfga
404 self.twin = kwargs.get('twin', 150) 1hfga
405 self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4)) 1hfga
407 self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None)) 1hfga
408 self.roi, self.mask = self.get_roi_mask() 1a
410 if self.upload: 1a
411 self.one = ONE(mode='remote')
412 self.one.alyx.authenticate()
413 self.eid = self.one.path2eid(self.session_path)
415 def load_data(self, sync='nidq', location=None):
416 """
417 Loads relevant data from disk to perform motion alignment
418 :param sync: type of sync used, 'nidq' or 'bpod'
419 :param location: where the code is being run, if location='SDSC', the dataset uuids are removed
420 when loading the data
421 :return:
422 """
423 def fix_keys(alf_object): 1hfga
424 """
425 Given an alf object removes the dataset uuid from the keys
426 :param alf_object:
427 :return:
428 """
429 ob = Bunch()
430 for key in alf_object.keys():
431 vals = alf_object[key]
432 ob[key.split('.')[0]] = vals
433 return ob
435 alf_path = self.session_path.joinpath('alf') 1hfga
436 wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent 1hfga
437 wheel = (fix_keys(alfio.load_object(wheel_path, 'wheel')) if location == 'SDSC' 1fga
438 else alfio.load_object(wheel_path, 'wheel'))
439 self.wheel_timestamps = wheel.timestamps 1fga
440 # Compute interpolated wheel position and wheel times
441 wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) 1fga
442 # Compute wheel velocity
443 self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) 1fga
444 # Load in original camera times
445 self.camera_times = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.times*.npy'))) 1fga
446 self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob(f'_iblrig_{self.label}Camera.raw*.mp4'))) 1fga
447 self.camera_meta = vidio.get_video_meta(self.camera_path) 1fa
449 # TODO should read in the description file to get the correct sync location
450 if sync == 'nidq': 1fa
451 # If the sync is 'nidq' we read in the camera ttls from the spikeglx sync object
452 sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection='raw_ephys_data') 1fa
453 sr = get_sync_fronts(sync, chmap[f'{self.label}_camera']) 1fa
454 self.ttls = sr.times[::2] 1fa
455 else:
456 # Otherwise we assume the sync is 'bpod' and we read in the camera ttls from the raw bpod data
457 cam_extractor = cam.CameraTimestampsBpod(session_path=self.session_path)
458 cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data')
459 self.ttls = cam_extractor._times_from_bpod()
461 # Check if the ttl and video sizes match up
462 self.tdiff = self.ttls.size - self.camera_meta['length'] 1fa
464 if self.tdiff < 0: 1fa
465 # In this case there are fewer ttls than camera frames. This is not ideal, for now we pad the ttls with
466 # nans but if this is too many we reject the wheel alignment based on the qc
467 self.ttl_times = self.ttls
468 self.times = np.r_[self.ttl_times, np.full((np.abs(self.tdiff)), np.nan)]
469 self.short_flag = True
470 elif self.tdiff > 0: 1fa
471 # In this case there are more ttls than camera frames. This happens often, for now we remove the first
472 # tdiff ttls from the ttls
473 self.ttl_times = self.ttls[self.tdiff:] 1fa
474 self.times = self.ttls[self.tdiff:] 1fa
475 self.short_flag = False 1fa
477 # Compute the frame rate of the camera
478 self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) 1fa
480 # We attempt to load in some behavior data (trials and dlc). This is only needed for the summary plots, having
481 # trial aligned paw velocity (from the dlc) is a nice sanity check to make sure the alignment went well
482 try: 1fa
483 self.trials = alfio.load_file_content(next(alf_path.rglob('_ibl_trials.table*.pqt'))) 1fa
484 self.dlc = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.dlc*.pqt')))
485 self.dlc = likelihood_threshold(self.dlc)
486 self.behavior = True
487 except (ALFObjectNotFound, StopIteration): 1fa
488 self.behavior = False 1fa
490 # Load in a single frame that we will use for the summary plot
491 self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0]) 1fa
493 def get_roi_mask(self):
494 """
495 Compute the region of interest mask for a given camera. This corresponds to a box in the video that we will
496 use to compute the wheel motion energy
497 :return:
498 """
500 if self.label == 'right': 1a
501 roi = ((450, 512), (120, 200)) 1a
502 else:
503 roi = ((900, 1024), (850, 1010))
504 roi_mask = (*[slice(*r) for r in roi], 0) 1a
506 return roi, roi_mask 1a
508 def find_contaminated_frames(self, video_frames, thresold=20, normalise=True):
509 """
510 Finds frames in the video that have artefacts such as the mouse's paw or a human hand. In order to determine
511 frames with contamination an Otsu thresholding is applied to each frame to detect the artefact from the
512 background image
513 :param video_frames: np array of video frames (nframes, nwidth, nheight)
514 :param thresold: threshold to differentiate artefact from background
515 :param normalise: whether to normalise the threshold values for each frame to the baseline
516 :return: mask of frames that are contaminated
517 """
518 high = np.zeros((video_frames.shape[0]))
519 # Iterate through each frame and compute and store the otsu threshold value for each frame
520 for idx, frame in enumerate(video_frames):
521 ret, _ = cv2.threshold(cv2.GaussianBlur(frame, (5, 5), 0), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
522 high[idx] = ret
524 # If normalise is True, we divide the threshold values for each frame by the minimum value
525 if normalise:
526 high -= np.min(high)
528 # Identify the frames that have a threshold value greater than the specified threshold cutoff
529 contaminated_frames = np.where(high > thresold)[0]
531 return contaminated_frames
533 def compute_motion_energy(self, first, last, wg, iw):
534 """
535 Computes the video motion energy for frame indexes between first and last. This function is written to be run
536 in a parallel fashion jusing joblib.parallel
537 :param first: first frame index of frame interval to consider
538 :param last: last frame index of frame interval to consider
539 :param wg: WindowGenerator
540 :param iw: iteration of the WindowGenerator
541 :return:
542 """
544 if iw == wg.nwin - 1:
545 return
547 # Open the video and read in the relvant video frames between first idx and last idx
548 cap = cv2.VideoCapture(self.camera_path)
549 frames = vidio.get_video_frames_preload(cap, np.arange(first, last), mask=self.mask)
550 # Identify if any of the frames have artefacts in them
551 idx = self.find_contaminated_frames(frames, self.threshold)
553 # If some of the frames are contaminated we find all the continuous intervals of contamination
554 # and set the value for contaminated pixels for these frames to the average of the first frame before and after
555 # this contamination interval
556 if len(idx) != 0:
558 before_status = False
559 after_status = False
561 counter = 0
562 n_frames = 200
563 # If it is the first frame that is contaminated, we need to read in a bit more of the video to find a
564 # frame prior to contamination. We attempt this 20 times, after that we just take the value for the first
565 # frame
566 while np.any(idx == 0) and counter < 20 and iw != 0:
567 n_before_offset = (counter + 1) * n_frames
568 first -= n_frames
569 extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(first - n_frames, first),
570 mask=self.mask)
571 frames = np.concatenate([extra_frames, frames], axis=0)
573 idx = self.find_contaminated_frames(frames, self.threshold)
574 before_status = True
575 counter += 1
576 if counter > 0:
577 print(f'In before: {counter}')
579 counter = 0
580 # If it is the last frame that is contaminated, we need to read in a bit more of the video to find a
581 # frame after the contamination. We attempt this 20 times, after that we just take the value for the last
582 # frame
583 while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1:
584 n_after_offset = (counter + 1) * n_frames
585 last += n_frames
586 extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), mask=self.mask)
587 frames = np.concatenate([frames, extra_frames], axis=0)
588 idx = self.find_contaminated_frames(frames, self.threshold)
589 after_status = True
590 counter += 1
592 if counter > 0:
593 print(f'In after: {counter}')
595 # We find all the continuous intervals that contain contamination and fix the affected pixels
596 # by taking the average value of the frame prior and after contamination
597 intervals = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1)
598 for ints in intervals:
599 if len(ints) > 0 and ints[0] == 0:
600 ints = ints[1:]
601 if len(ints) > 0 and ints[-1] == frames.shape[0] - 1:
602 ints = ints[:-1]
603 th_all = np.zeros_like(frames[0])
604 # We find all affected pixels
605 for idx in ints:
606 img = np.copy(frames[idx])
607 blur = cv2.GaussianBlur(img, (5, 5), 0)
608 ret, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
609 th = cv2.GaussianBlur(th, (5, 5), 10)
610 th_all += th
611 # Compute the average image of the frame prior and after the interval
612 vals = np.mean(np.dstack([frames[ints[0] - 1], frames[ints[-1] + 1]]), axis=-1)
613 # For each frame set the affected pixels to the value of the clean average image
614 for idx in ints:
615 img = frames[idx]
616 img[th_all > 0] = vals[th_all > 0]
618 # If we have read in extra video frames we need to cut these off and make sure we only
619 # consider the frames between the interval first and last given as args
620 if before_status:
621 frames = frames[n_before_offset:]
622 if after_status:
623 frames = frames[:(-1 * n_after_offset)]
625 # Once the frames have been cleaned we compute the motion energy between frames
626 frame_me, _ = video.motion_energy(frames, diff=2, normalize=False)
628 cap.release()
630 return frame_me[2:]
632 def compute_shifts(self, times, me, first, last, iw, wg):
633 """
634 Compute the cross-correlation between the video motion energy and the wheel velocity to find the mismatch
635 between the camera ttls and the video frames. This function is written to run in a parallel manner using
636 joblib.parallel
638 :param times: the times of the video frames across the whole session (ttls)
639 :param me: the video motion energy computed across the whole session
640 :param first: first time idx to consider
641 :param last: last time idx to consider
642 :param wg: WindowGenerator
643 :param iw: iteration of the WindowGenerator
644 :return:
645 """
647 # If we are in the last window we exit
648 if iw == wg.nwin - 1: 1a
649 return np.nan, np.nan 1a
651 # Find the time interval we are interested in
652 t_first = times[first] 1a
653 t_last = times[last] 1a
655 # If both times during this interval are nan exit
656 if np.isnan(t_last) and np.isnan(t_first): 1a
657 return np.nan, np.nan
658 # If only the last time is nan, we find the last non nan time value
659 elif np.isnan(t_last): 1a
660 t_last = times[np.where(~np.isnan(times))[0][-1]]
662 # Find the mask of timepoints that fall in this interval
663 mask = np.logical_and(times >= t_first, times <= t_last) 1a
664 # Restrict the video motion energy to this interval and normalise the values
665 align_me = me[np.where(mask)[0]] 1a
666 align_me = (align_me - np.nanmin(align_me)) / (np.nanmax(align_me) - np.nanmin(align_me)) 1a
668 # Find closest timepoints in wheel that match the time interval
669 wh_mask = np.logical_and(self.wheel_time >= t_first, self.wheel_time <= t_last) 1a
670 if np.sum(wh_mask) == 0: 1a
671 return np.nan, np.nan
672 # Find the mask for the wheel times
673 xs = np.searchsorted(self.wheel_time[wh_mask], times[mask]) 1a
674 xs[xs == np.sum(wh_mask)] = np.sum(wh_mask) - 1 1a
675 # Convert to normalized speed
676 vs = np.abs(self.wheel_vel[wh_mask][xs]) 1a
677 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1a
679 # Account for nan values in the video motion energy
680 isnan = np.isnan(align_me) 1a
681 if np.sum(isnan) > 0: 1a
682 where_nan = np.where(isnan)[0] 1a
683 assert where_nan[0] == 0 1a
684 assert where_nan[-1] == np.sum(isnan) - 1 1a
686 if np.all(isnan): 1a
687 return np.nan, np.nan
689 # Compute the cross correlation between the video motion energy and the wheel speed
690 xcorr = signal.correlate(align_me[~isnan], vs[~isnan]) 1a
691 # The max value of the cross correlation indicates the shift that needs to be applied
692 # The +2 comes from the fact that the video motion energy was computed from the difference between frames
693 shift = np.nanargmax(xcorr) - align_me[~isnan].size + 2 1a
695 return shift, t_first + (t_last - t_first) / 2 1a
697 def clean_shifts(self, x, n=1):
698 """
699 Removes artefacts from the computed shifts across time. We assume that the shifts should never increase
700 over time and that the jump between consecutive shifts shouldn't be greater than 1
701 :param x: computed shifts
702 :param n: condition to apply
703 :return:
704 """
705 y = x.copy() 1a
706 dy = np.diff(y, prepend=y[0]) 1a
707 while True: 1a
708 pos = np.where(dy == 1)[0] if n == 1 else np.where(dy > 2)[0] 1a
709 # added frames: this doesn't make sense and this is noise
710 if pos.size == 0: 1a
711 break 1a
712 neg = np.where(dy == -1)[0] if n == 1 else np.where(dy < -2)[0]
714 if len(pos) > len(neg):
715 neg = np.append(neg, dy.size - 1)
717 iss = np.minimum(np.searchsorted(neg, pos), neg.size - 1)
718 imin = np.argmin(np.minimum(np.abs(pos - neg[iss - 1]), np.abs(pos - neg[iss])))
720 idx = np.max([0, iss[imin] - 1])
721 ineg = neg[idx:iss[imin] + 1]
722 ineg = ineg[np.argmin(np.abs(pos[imin] - ineg))]
723 dy[pos[imin]] = 0
724 dy[ineg] = 0
726 return np.cumsum(dy) + y[0] 1a
728 def qc_shifts(self, shifts, shifts_filt):
729 """
730 Compute qc values for the wheel alignment. We consider 4 things
731 1. The number of camera ttl values that are missing (when we have less ttls than video frames)
732 2. The number of shifts that have nan values, this means the video motion energy computation
733 3. The number of large jumps (>10) between the computed shifts
734 4. The number of jumps (>1) between the shifts after they have been cleaned
736 :param shifts: np.array of shifts over session
737 :param shifts_filt: np.array of shifts after being cleaned over session
738 :return:
739 """
741 ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0 1a
742 nan_per = (np.sum(np.isnan(shifts_filt)) / shifts_filt.size) * 100 1a
743 shifts_sum = np.where(np.abs(np.diff(shifts)) > 10)[0].size 1a
744 shifts_filt_sum = np.where(np.abs(np.diff(shifts_filt)) > 1)[0].size 1a
746 qc = dict() 1a
747 qc['ttl_per'] = ttl_per 1a
748 qc['nan_per'] = nan_per 1a
749 qc['shifts_sum'] = shifts_sum 1a
750 qc['shifts_filt_sum'] = shifts_filt_sum 1a
752 qc_outcome = True 1a
753 # If more than 10% of ttls are missing we don't get new times
754 if ttl_per > 10: 1a
755 qc_outcome = False
756 # If too many of the shifts are nans it means the alignment is not accurate
757 if nan_per > 40: 1a
758 qc_outcome = False
759 # If there are too many artefacts could be errors
760 if shifts_sum > 60: 1a
761 qc_outcome = False
762 # If there are jumps > 1 in the filtered shifts then there is a problem
763 if shifts_filt_sum > 0: 1a
764 qc_outcome = False
766 return qc, qc_outcome 1a
768 def extract_times(self, shifts_filt, t_shifts):
769 """
770 Extracts new camera times after applying the computed shifts across the session
772 :param shifts_filt: filtered shifts computed across session
773 :param t_shifts: time point of computed shifts
774 :return:
775 """
777 # Compute the interpolation function to apply to the ttl times
778 t_new = t_shifts - (shifts_filt * 1 / self.frate) 1a
779 fcn = interpolate.interp1d(t_shifts, t_new, fill_value="extrapolate") 1a
780 # Apply the function and get out new times
781 new_times = fcn(self.ttl_times) 1a
783 # If we are missing ttls then interpolate and append the correct number at the end
784 if self.tdiff < 0: 1a
785 to_app = (np.arange(np.abs(self.tdiff), ) + 1) / self.frate + new_times[-1]
786 new_times = np.r_[new_times, to_app]
788 return new_times 1a
790 @staticmethod
791 def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True,
792 norm=False, axs=None):
793 """
794 Compute and plot trial aligned spike rasters and psth
795 :param spike_times: times of variable
796 :param events: trial times to align to
797 :param trial_idx: trial idx to sort by
798 :param dividers:
799 :param colors:
800 :param labels:
801 :param weights:
802 :param fr:
803 :param norm:
804 :param axs:
805 :return:
806 """
807 pre_time = 0.4
808 post_time = 1
809 raster_bin = 0.01
810 psth_bin = 0.05
811 raster, t_raster = bin_spikes(
812 spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=raster_bin, weights=weights)
813 psth, t_psth = bin_spikes(
814 spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=psth_bin, weights=weights)
816 if fr:
817 psth = psth / psth_bin
819 if norm:
820 psth = psth - np.repeat(psth[:, 0][:, np.newaxis], psth.shape[1], axis=1)
821 raster = raster - np.repeat(raster[:, 0][:, np.newaxis], raster.shape[1], axis=1)
823 dividers = [0] + dividers + [len(trial_idx)]
824 if axs is None:
825 fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, sharex=True)
826 else:
827 fig = axs[0].get_figure()
829 label, lidx = np.unique(labels, return_index=True)
830 label_pos = []
831 for lab, lid in zip(label, lidx):
832 idx = np.where(np.array(labels) == lab)[0]
833 for iD in range(len(idx)):
834 if iD == 0:
835 t_ids = trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]
836 t_ints = dividers[idx[iD] + 1] - dividers[idx[iD]]
837 else:
838 t_ids = np.r_[t_ids, trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]]
839 t_ints = np.r_[t_ints, dividers[idx[iD] + 1] - dividers[idx[iD]]]
841 psth_div = np.nanmean(psth[t_ids], axis=0)
842 std_div = np.nanstd(psth[t_ids], axis=0) / np.sqrt(len(t_ids))
844 axs[0].fill_between(t_psth, psth_div - std_div, psth_div + std_div, alpha=0.4, color=colors[lid])
845 axs[0].plot(t_psth, psth_div, alpha=1, color=colors[lid])
847 lab_max = idx[np.argmax(t_ints)]
848 label_pos.append((dividers[lab_max + 1] - dividers[lab_max]) / 2 + dividers[lab_max])
850 axs[1].imshow(raster[trial_idx], cmap='binary', origin='lower',
851 extent=[np.min(t_raster), np.max(t_raster), 0, len(trial_idx)], aspect='auto')
853 width = raster_bin * 4
854 for iD in range(len(dividers) - 1):
855 axs[1].fill_between([post_time + raster_bin / 2, post_time + raster_bin / 2 + width],
856 [dividers[iD + 1], dividers[iD + 1]], [dividers[iD], dividers[iD]], color=colors[iD])
858 axs[1].set_xlim([-1 * pre_time, post_time + raster_bin / 2 + width])
859 secax = axs[1].secondary_yaxis('right')
861 secax.set_yticks(label_pos)
862 secax.set_yticklabels(label, rotation=90, rotation_mode='anchor', ha='center')
863 for ic, c in enumerate(np.array(colors)[lidx]):
864 secax.get_yticklabels()[ic].set_color(c)
866 axs[0].axvline(0, *axs[0].get_ylim(), c='k', ls='--', zorder=10) # TODO this doesn't always work
867 axs[1].axvline(0, *axs[1].get_ylim(), c='k', ls='--', zorder=10)
869 return fig, axs
871 def plot_with_behavior(self):
872 """
873 Makes a summary figure of the alignment when behaviour data is available
874 :return:
875 """
877 self.dlc = likelihood_threshold(self.dlc)
878 trial_idx, dividers = find_trial_ids(self.trials, sort='side')
879 feature_ext = get_speed(self.dlc, self.camera_times, self.label, feature='paw_r')
880 feature_new = get_speed(self.dlc, self.new_times, self.label, feature='paw_r')
882 fig = plt.figure()
883 fig.set_size_inches(15, 9)
884 gs = gridspec.GridSpec(1, 5, figure=fig, width_ratios=[4, 1, 1, 1, 3], wspace=0.3, hspace=0.5)
885 gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0])
886 ax01 = fig.add_subplot(gs0[0, 0])
887 ax02 = fig.add_subplot(gs0[1, 0])
888 ax03 = fig.add_subplot(gs0[2, 0])
889 gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1], height_ratios=[1, 3])
890 ax11 = fig.add_subplot(gs1[0, 0])
891 ax12 = fig.add_subplot(gs1[1, 0])
892 gs2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 2], height_ratios=[1, 3])
893 ax21 = fig.add_subplot(gs2[0, 0])
894 ax22 = fig.add_subplot(gs2[1, 0])
895 gs3 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 3], height_ratios=[1, 3])
896 ax31 = fig.add_subplot(gs3[0, 0])
897 ax32 = fig.add_subplot(gs3[1, 0])
898 gs4 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 4])
899 ax41 = fig.add_subplot(gs4[0, 0])
900 ax42 = fig.add_subplot(gs4[1, 0])
902 ax01.plot(self.t_shifts, self.shifts, label='shifts')
903 ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt')
904 ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10)
905 ax01.legend()
906 ax01.set_ylabel('Frames')
907 ax01.set_xlabel('Time in session')
909 xs = np.searchsorted(self.ttl_times, self.t_shifts)
910 ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps']
911 ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl')
912 ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10)
913 ax02.legend()
914 ax02.set_ylabel('Frames')
915 ax02.set_xlabel('Time in session')
917 ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new')
918 ax03.legend()
919 ax03.set_ylim(-5, 5)
920 ax03.set_ylabel('Frames')
921 ax03.set_xlabel('Time in session')
923 self.single_cluster_raster(self.wheel_timestamps, self.trials['firstMovement_times'].values, trial_idx, dividers,
924 ['g', 'y'], ['left', 'right'], weights=self.wheel_vel, fr=False, axs=[ax11, ax12])
925 ax11.sharex(ax12)
926 ax11.set_ylabel('Wheel velocity')
927 ax11.set_title('Wheel')
928 ax12.set_xlabel('Time from first move')
930 self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'],
931 ['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22])
932 ax21.sharex(ax22)
933 ax21.set_ylabel('Paw r velocity')
934 ax21.set_title('Extracted times')
935 ax22.set_xlabel('Time from first move')
937 self.single_cluster_raster(self.new_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'],
938 ['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32])
939 ax31.sharex(ax32)
940 ax31.set_ylabel('Paw r velocity')
941 ax31.set_title('New times')
942 ax32.set_xlabel('Time from first move')
944 ax41.imshow(self.frame_example[0])
945 rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1],
946 self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none')
947 ax41.add_patch(rect)
949 ax42.plot(self.all_me)
951 return fig
953 def plot_without_behavior(self):
954 """
955 Makes a summary figure of the alignment when behaviour data is not available
956 :return:
957 """
959 fig = plt.figure()
960 fig.set_size_inches(7, 7)
961 gs = gridspec.GridSpec(1, 2, figure=fig)
962 gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0])
963 ax01 = fig.add_subplot(gs0[0, 0])
964 ax02 = fig.add_subplot(gs0[1, 0])
965 ax03 = fig.add_subplot(gs0[2, 0])
967 gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1])
968 ax04 = fig.add_subplot(gs1[0, 0])
969 ax05 = fig.add_subplot(gs1[1, 0])
971 ax01.plot(self.t_shifts, self.shifts, label='shifts')
972 ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt')
973 ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10)
974 ax01.legend()
975 ax01.set_ylabel('Frames')
976 ax01.set_xlabel('Time in session')
978 xs = np.searchsorted(self.ttl_times, self.t_shifts)
979 ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps']
980 ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl')
981 ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10)
982 ax02.legend()
983 ax02.set_ylabel('Frames')
984 ax02.set_xlabel('Time in session')
986 ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new')
987 ax03.legend()
988 ax03.set_ylim(-5, 5)
989 ax03.set_ylabel('Frames')
990 ax03.set_xlabel('Time in session')
992 ax04.imshow(self.frame_example[0])
993 rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1],
994 self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none')
995 ax04.add_patch(rect)
997 ax05.plot(self.all_me)
999 return fig
1001 def process(self):
1002 """
1003 Main function used to apply the video motion wheel alignment to the camera times. This function does the
1004 following
1005 1. Computes the video motion energy across the whole session (computed in windows and parallelised)
1006 2. Computes the shift that should be applied to the camera times across the whole session by computing
1007 the cross correlation between the video motion energy and the wheel speed (computed in
1008 overlapping windows and parallelised)
1009 3. Removes artefacts from the computed shifts
1010 4. Computes the qc for the wheel alignment
1011 5. Extracts the new camera times using the shifts computed from the video wheel alignment
1012 6. If upload is True, creates a summary plot of the alignment and uploads the figure to the relevant session
1013 on alyx
1014 :return:
1015 """
1017 # Compute the motion energy of the wheel for the whole video
1018 wg = WindowGenerator(self.camera_meta['length'], 5000, 4) 1a
1019 out = Parallel(n_jobs=self.nprocess)( 1a
1020 delayed(self.compute_motion_energy)(first, last, wg, iw) for iw, (first, last) in enumerate(wg.firstlast))
1021 # Concatenate the motion energy into one big array
1022 self.all_me = np.array([]) 1a
1023 for vals in out[:-1]: 1a
1024 self.all_me = np.r_[self.all_me, vals] 1a
1026 toverlap = self.twin - 1 1a
1027 all_me = np.r_[np.full((int(self.camera_meta['fps'] * toverlap)), np.nan), self.all_me] 1a
1028 to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / self.frate)[::-1] 1a
1029 times = np.r_[to_app, self.times] 1a
1031 wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap)) 1a
1033 out = Parallel(n_jobs=1)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) 1a
1034 for iw, (first, last) in enumerate(wg.firstlast))
1036 self.shifts = np.array([]) 1a
1037 self.t_shifts = np.array([]) 1a
1038 for vals in out[:-1]: 1a
1039 self.shifts = np.r_[self.shifts, vals[0]] 1a
1040 self.t_shifts = np.r_[self.t_shifts, vals[1]] 1a
1042 idx = np.bitwise_and(self.t_shifts >= self.ttl_times[0], self.t_shifts < self.ttl_times[-1]) 1a
1043 self.shifts = self.shifts[idx] 1a
1044 self.t_shifts = self.t_shifts[idx] 1a
1045 shifts_filt = ndimage.percentile_filter(self.shifts, 80, 120) 1a
1046 shifts_filt = self.clean_shifts(shifts_filt, n=1) 1a
1047 self.shifts_filt = self.clean_shifts(shifts_filt, n=2) 1a
1049 self.qc, self.qc_outcome = self.qc_shifts(self.shifts, self.shifts_filt) 1a
1051 self.new_times = self.extract_times(self.shifts_filt, self.t_shifts) 1a
1053 if self.upload: 1a
1054 fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior()
1055 save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', f'video_wheel_alignment_{self.label}.png'))
1056 save_fig_path.parent.mkdir(exist_ok=True, parents=True)
1057 fig.savefig(save_fig_path)
1058 snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one)
1059 snp.outputs = [save_fig_path]
1060 snp.register_images(widths=['orig'])
1061 plt.close(fig)
1063 return self.new_times 1a