Coverage for ibllib/io/extractors/video_motion.py: 59%
229 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1"""
2A module for aligning the wheel motion with the rotary encoder. Currently used by the camera QC
3in order to check timestamp alignment.
4"""
5import matplotlib
6import matplotlib.pyplot as plt
7from matplotlib.widgets import RectangleSelector
8import numpy as np
9from scipy import signal
10import cv2
11from itertools import cycle
12import matplotlib.animation as animation
13import logging
14from pathlib import Path
16from one.api import ONE
17import ibllib.io.video as vidio
18from iblutil.util import Bunch
19import brainbox.video as video
20import brainbox.behavior.wheel as wh
21import one.alf.io as alfio
22from one.alf.spec import is_session_path, is_uuid_string
25def find_nearest(array, value):
26 array = np.asarray(array) 1bca
27 idx = (np.abs(array - value)).argmin() 1bca
28 return idx 1bca
31class MotionAlignment:
32 roi = {
33 'left': ((800, 1020), (233, 1096)),
34 'right': ((426, 510), (104, 545)),
35 'body': ((402, 481), (31, 103))
36 }
38 def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), **kwargs):
39 self.one = one or ONE() 1bca
40 self.eid = eid 1bca
41 self.session_path = kwargs.pop('session_path', None) or self.one.eid2path(eid) 1bca
42 self.ref = self.one.dict2ref(self.one.path2ref(self.session_path)) 1bca
43 self.log = log 1bca
44 self.trials = self.wheel = self.camera_times = None 1bca
45 raw_cam_path = self.session_path.joinpath('raw_video_data') 1bca
46 camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*')) 1bca
47 self.video_paths = {vidio.label_from_path(x): x for x in camera_path} 1bca
48 self.data = Bunch() 1bca
49 self.alignment = Bunch() 1bca
51 def align_all_trials(self, side='all'):
52 """Align all wheel motion for all trials"""
53 if self.trials is None:
54 self.load_data()
55 if side == 'all':
56 side = self.video_paths.keys()
57 if not isinstance(side, str):
58 # Try to iterate over sides
59 [self.align_all_trials(s) for s in side]
60 if side not in self.video_paths:
61 raise ValueError(f'{side} camera video file not found')
62 # Align each trial sequentially
63 for i in np.arange(self.trials['intervals'].shape[0]):
64 self.align_motion(i, display=False)
66 @staticmethod
67 def set_roi(video_path):
68 """Manually set the ROIs for a given set of videos
69 TODO Improve docstring
70 TODO A method for setting ROIs by label
71 """
72 frame = vidio.get_video_frame(str(video_path), 0)
74 def line_select_callback(eclick, erelease):
75 """
76 Callback for line selection.
78 *eclick* and *erelease* are the press and release events.
79 """
80 x1, y1 = eclick.xdata, eclick.ydata
81 x2, y2 = erelease.xdata, erelease.ydata
82 print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
83 return np.array([[x1, x2], [y1, y2]])
85 plt.imshow(frame)
86 roi = RectangleSelector(plt.gca(), line_select_callback,
87 drawtype='box', useblit=True,
88 button=[1, 3], # don't use middle button
89 minspanx=5, minspany=5,
90 spancoords='pixels',
91 interactive=True)
92 plt.show()
93 ((x1, x2, *_), (y1, *_, y2)) = roi.corners
94 col = np.arange(round(x1), round(x2), dtype=int)
95 row = np.arange(round(y1), round(y2), dtype=int)
96 return col, row
98 def load_data(self, download=False):
99 """
100 Load wheel, trial and camera timestamp data
101 :return: wheel, trials
102 """
103 if download:
104 self.data.wheel = self.one.load_object(self.eid, 'wheel')
105 self.data.trials = self.one.load_object(self.eid, 'trials')
106 cam = self.one.load(self.eid, ['camera.times'], dclass_output=True)
107 self.data.camera_times = {vidio.label_from_path(url): ts
108 for ts, url in zip(cam.data, cam.url)}
109 else:
110 alf_path = self.session_path / 'alf'
111 self.data.wheel = alfio.load_object(alf_path, 'wheel', short_keys=True)
112 self.data.trials = alfio.load_object(alf_path, 'trials')
113 self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x)
114 for x in alf_path.glob('*Camera.times*')}
115 assert all(x is not None for x in self.data.values())
117 def _set_eid_or_path(self, session_path_or_eid):
118 """Parse a given eID or session path
119 If a session UUID is given, resolves and stores the local path and vice versa
120 :param session_path_or_eid: A session eid or path
121 :return:
122 """
123 self.eid = None
124 if is_uuid_string(str(session_path_or_eid)):
125 self.eid = session_path_or_eid
126 # Try to set session_path if data is found locally
127 self.session_path = self.one.eid2path(self.eid)
128 elif is_session_path(session_path_or_eid):
129 self.session_path = Path(session_path_or_eid)
130 if self.one is not None:
131 self.eid = self.one.path2eid(self.session_path)
132 if not self.eid:
133 self.log.warning('Failed to determine eID from session path')
134 else:
135 self.log.error('Cannot run alignment: an experiment uuid or session path is required')
136 raise ValueError("'session' must be a valid session path or uuid")
138 def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False):
139 """
140 Align video to the wheel using cross-correlation of the video motion signal and the rotary
141 encoder.
143 Parameters
144 ----------
145 period : (float, float)
146 The time period over which to do the alignment.
147 side : {'left', 'right'}
148 With which camera to perform the alignment.
149 sd_thresh : float
150 For plotting where the motion energy goes above this standard deviation threshold.
151 display : bool
152 When true, displays the aligned wheel motion energy along with the rotary encoder
153 signal.
155 Returns
156 -------
157 int
158 Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder
159 signal. Negative values mean the video was shifted backwards with respect to the wheel
160 timestamps.
161 float
162 The peak cross-correlation.
163 numpy.ndarray
164 The motion energy used in the cross-correlation, i.e. the frame difference for the
165 period given.
166 """
167 # Get data samples within period
168 wheel = self.data['wheel'] 1bca
169 self.alignment.label = side 1bca
170 self.alignment.to_mask = lambda ts: np.logical_and(ts >= period[0], ts <= period[1]) 1bca
171 camera_times = self.data['camera_times'][side] 1bca
172 cam_mask = self.alignment.to_mask(camera_times) 1bca
173 frame_numbers, = np.where(cam_mask) 1bca
175 if frame_numbers.size == 0: 1bca
176 raise ValueError('No frames during given period') 1a
178 # Motion Energy
179 camera_path = self.video_paths[side] 1bca
180 roi = (*[slice(*r) for r in self.roi[side]], 0) 1bca
181 try: 1bca
182 # TODO Add function arg to make grayscale
183 self.alignment.frames = \ 1bca
184 vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi)
185 assert self.alignment.frames.size != 0 1bca
186 except AssertionError:
187 self.log.error('Failed to open video')
188 return None, None, None
189 self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2) 1bca
190 self.alignment.period = period # For plotting 1bca
192 # Calculate rotary encoder velocity trace
193 x = camera_times[cam_mask] 1bca
194 Fs = 1000 1bca
195 pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs) 1bca
196 v, _ = wh.velocity_filtered(pos, Fs) 1bca
197 interp_mask = self.alignment.to_mask(t) 1bca
198 # Convert to normalized speed
199 xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x]) 1bca
200 vs = np.abs(v[interp_mask][xs]) 1bca
201 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1bca
203 # FIXME This can be used as a goodness of fit measure
204 USE_CV2 = False 1bca
205 if USE_CV2: 1bca
206 # convert from numpy format to openCV format
207 dfCV = np.float32(self.alignment.df.reshape((-1, 1)))
208 reCV = np.float32(vs.reshape((-1, 1)))
210 # perform cross correlation
211 resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED)
213 # convert result back to numpy array
214 xcorr = np.asarray(resultCv)
215 else:
216 xcorr = signal.correlate(self.alignment.df, vs) 1bca
218 # Cross correlate wheel speed trace with the motion energy
219 CORRECTION = 2 1bca
220 self.alignment.c = max(xcorr) 1bca
221 self.alignment.xcorr = np.argmax(xcorr) 1bca
222 self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION 1bca
223 self.log.info(f'{side} camera, adjusted by {self.alignment.dt_i} frames') 1bca
225 if display: 1bca
226 # Plot the motion energy
227 fig, ax = plt.subplots(2, 1, sharex='all')
228 y = np.pad(self.alignment.df, 1, 'edge')
229 ax[0].plot(x, y, '-x', label='wheel motion energy')
230 thresh = stDev > sd_thresh
231 ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1,
232 linewidth=0.5, linestyle=':', label=f'>{sd_thresh} s.d. diff')
233 ax[1].plot(t[interp_mask], np.abs(v[interp_mask]))
235 # Plot other stuff
236 dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]])
237 fps = 1 / np.diff(camera_times).mean()
238 ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)')
239 ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps))
240 ax[0].set_ylabel('rate of change (a.u.)')
241 ax[0].legend()
242 ax[1].set_ylabel('wheel speed (rad / s)')
243 ax[1].set_xlabel('Time (s)')
245 title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s'
246 fig.suptitle(title, fontsize=16)
247 fig.set_size_inches(19.2, 9.89)
249 return self.alignment.dt_i, self.alignment.c, self.alignment.df 1bca
251 def plot_alignment(self, energy=True, save=False):
252 if not self.alignment: 1a
253 self.log.error('No alignment data, run `align_motion` first')
254 return
255 # Change backend based on save flag
256 backend = matplotlib.get_backend().lower() 1a
257 if (save and backend != 'agg') or (not save and backend == 'agg'): 1a
258 new_backend = 'Agg' if save else 'Qt5Agg'
259 self.log.warning('Switching backend from %s to %s', backend, new_backend)
260 matplotlib.use(new_backend)
261 from matplotlib import pyplot as plt 1a
263 # Main animated plots
264 fig, axes = plt.subplots(nrows=2) 1a
265 title = f'{self.ref}' # ', from {period[0]:.1f}s - {period[1]:.1f}s' 1a
266 fig.suptitle(title, fontsize=16) 1a
268 wheel = self.data['wheel'] 1a
269 wheel_mask = self.alignment['to_mask'](wheel.timestamps) 1a
270 ts = self.data['camera_times'][self.alignment['label']] 1a
271 frame_numbers, = np.where(self.alignment['to_mask'](ts)) 1a
272 if energy: 1a
273 self.alignment['frames'] = video.frame_diffs(self.alignment['frames'], 2) 1a
274 frame_numbers = frame_numbers[1:-1] 1a
275 data = {'frame_ids': frame_numbers} 1a
277 def init_plot(): 1a
278 """
279 Plot the wheel data for the current trial
280 :return: None
281 """
282 data['im'] = axes[0].imshow(self.alignment['frames'][0]) 1a
283 axes[0].axis('off') 1a
284 axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames') 1a
286 # Plot the wheel position
287 ax = axes[1] 1a
288 ax.clear() 1a
289 ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask], '-x') 1a
291 ts_0 = frame_numbers[0] 1a
292 data['idx_0'] = ts_0 - self.alignment['dt_i'] 1a
293 ts_0 = ts[ts_0 + self.alignment['dt_i']] 1a
294 data['ln'] = ax.axvline(x=ts_0, color='k') 1a
295 ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)]) 1a
296 data['frame_num'] = 0 1a
297 mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0) 1a
299 data['marker'], = ax.plot( 1a
300 wheel.timestamps[wheel_mask][mkr],
301 wheel.position[wheel_mask][mkr], 'r-x')
302 ax.set_ylabel('Wheel position (rad))') 1a
303 ax.set_xlabel('Time (s))') 1a
304 return 1a
306 def animate(i): 1a
307 """
308 Callback for figure animation. Sets image data for current frame and moves pointer
309 along axis
310 :param i: unused; the current time step of the calling method
311 :return: None
312 """
313 if i < 0: 1a
314 data['frame_num'] -= 1
315 if data['frame_num'] < 0:
316 data['frame_num'] = len(self.alignment['frames']) - 1
317 else:
318 data['frame_num'] += 1 1a
319 if data['frame_num'] >= len(self.alignment['frames']): 1a
320 data['frame_num'] = 0 1a
321 i = data['frame_num'] # NB: This is index for current trial's frame list 1a
323 frame = self.alignment['frames'][i] 1a
324 t_x = ts[data['idx_0'] + i] 1a
325 data['ln'].set_xdata([t_x, t_x]) 1a
326 axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)]) 1a
327 data['im'].set_data(frame) 1a
329 mkr = find_nearest(wheel.timestamps[wheel_mask], t_x) 1a
330 data['marker'].set_data( 1a
331 wheel.timestamps[wheel_mask][mkr],
332 wheel.position[wheel_mask][mkr]
333 )
335 return data['im'], data['ln'], data['marker'] 1a
337 anim = animation.FuncAnimation(fig, animate, init_func=init_plot, 1a
338 frames=(range(len(self.alignment.df))
339 if save
340 else cycle(range(60))),
341 interval=20, blit=False,
342 repeat=not save, cache_frame_data=False)
343 anim.running = False 1a
345 def process_key(event): 1a
346 """
347 Callback for key presses.
348 :param event: a figure key_press_event
349 :return: None
350 """
351 if event.key.isspace():
352 if anim.running:
353 anim.event_source.stop()
354 else:
355 anim.event_source.start()
356 anim.running = ~anim.running
357 elif event.key == 'right':
358 if anim.running:
359 anim.event_source.stop()
360 anim.running = False
361 animate(1)
362 fig.canvas.draw()
363 elif event.key == 'left':
364 if anim.running:
365 anim.event_source.stop()
366 anim.running = False
367 animate(-1)
368 fig.canvas.draw()
370 fig.canvas.mpl_connect('key_press_event', process_key) 1a
372 # init_plot()
373 # while True:
374 # animate(0)
375 if save: 1a
376 filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0]) 1a
377 if isinstance(save, (str, Path)): 1a
378 filename = Path(save).joinpath(filename) 1a
379 self.log.info(f'Saving to {filename}') 1a
380 # Set up formatting for the movie files
381 Writer = animation.writers['ffmpeg'] 1a
382 writer = Writer(fps=24, metadata=dict(artist='Miles Wells'), bitrate=1800) 1a
383 anim.save(str(filename), writer=writer) 1a
384 else:
385 plt.show()