Coverage for ibllib/io/extractors/video_motion.py: 59%

229 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1""" 

2A module for aligning the wheel motion with the rotary encoder. Currently used by the camera QC 

3in order to check timestamp alignment. 

4""" 

5import matplotlib 

6import matplotlib.pyplot as plt 

7from matplotlib.widgets import RectangleSelector 

8import numpy as np 

9from scipy import signal 

10import cv2 

11from itertools import cycle 

12import matplotlib.animation as animation 

13import logging 

14from pathlib import Path 

15 

16from one.api import ONE 

17import ibllib.io.video as vidio 

18from iblutil.util import Bunch 

19import brainbox.video as video 

20import brainbox.behavior.wheel as wh 

21import one.alf.io as alfio 

22from one.alf.spec import is_session_path, is_uuid_string 

23 

24 

25def find_nearest(array, value): 

26 array = np.asarray(array) 1bca

27 idx = (np.abs(array - value)).argmin() 1bca

28 return idx 1bca

29 

30 

31class MotionAlignment: 

32 roi = { 

33 'left': ((800, 1020), (233, 1096)), 

34 'right': ((426, 510), (104, 545)), 

35 'body': ((402, 481), (31, 103)) 

36 } 

37 

38 def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), **kwargs): 

39 self.one = one or ONE() 1bca

40 self.eid = eid 1bca

41 self.session_path = kwargs.pop('session_path', None) or self.one.eid2path(eid) 1bca

42 self.ref = self.one.dict2ref(self.one.path2ref(self.session_path)) 1bca

43 self.log = log 1bca

44 self.trials = self.wheel = self.camera_times = None 1bca

45 raw_cam_path = self.session_path.joinpath('raw_video_data') 1bca

46 camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*')) 1bca

47 self.video_paths = {vidio.label_from_path(x): x for x in camera_path} 1bca

48 self.data = Bunch() 1bca

49 self.alignment = Bunch() 1bca

50 

51 def align_all_trials(self, side='all'): 

52 """Align all wheel motion for all trials""" 

53 if self.trials is None: 

54 self.load_data() 

55 if side == 'all': 

56 side = self.video_paths.keys() 

57 if not isinstance(side, str): 

58 # Try to iterate over sides 

59 [self.align_all_trials(s) for s in side] 

60 if side not in self.video_paths: 

61 raise ValueError(f'{side} camera video file not found') 

62 # Align each trial sequentially 

63 for i in np.arange(self.trials['intervals'].shape[0]): 

64 self.align_motion(i, display=False) 

65 

66 @staticmethod 

67 def set_roi(video_path): 

68 """Manually set the ROIs for a given set of videos 

69 TODO Improve docstring 

70 TODO A method for setting ROIs by label 

71 """ 

72 frame = vidio.get_video_frame(str(video_path), 0) 

73 

74 def line_select_callback(eclick, erelease): 

75 """ 

76 Callback for line selection. 

77 

78 *eclick* and *erelease* are the press and release events. 

79 """ 

80 x1, y1 = eclick.xdata, eclick.ydata 

81 x2, y2 = erelease.xdata, erelease.ydata 

82 print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2)) 

83 return np.array([[x1, x2], [y1, y2]]) 

84 

85 plt.imshow(frame) 

86 roi = RectangleSelector(plt.gca(), line_select_callback, 

87 drawtype='box', useblit=True, 

88 button=[1, 3], # don't use middle button 

89 minspanx=5, minspany=5, 

90 spancoords='pixels', 

91 interactive=True) 

92 plt.show() 

93 ((x1, x2, *_), (y1, *_, y2)) = roi.corners 

94 col = np.arange(round(x1), round(x2), dtype=int) 

95 row = np.arange(round(y1), round(y2), dtype=int) 

96 return col, row 

97 

98 def load_data(self, download=False): 

99 """ 

100 Load wheel, trial and camera timestamp data 

101 :return: wheel, trials 

102 """ 

103 if download: 

104 self.data.wheel = self.one.load_object(self.eid, 'wheel') 

105 self.data.trials = self.one.load_object(self.eid, 'trials') 

106 cam = self.one.load(self.eid, ['camera.times'], dclass_output=True) 

107 self.data.camera_times = {vidio.label_from_path(url): ts 

108 for ts, url in zip(cam.data, cam.url)} 

109 else: 

110 alf_path = self.session_path / 'alf' 

111 self.data.wheel = alfio.load_object(alf_path, 'wheel', short_keys=True) 

112 self.data.trials = alfio.load_object(alf_path, 'trials') 

113 self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x) 

114 for x in alf_path.glob('*Camera.times*')} 

115 assert all(x is not None for x in self.data.values()) 

116 

117 def _set_eid_or_path(self, session_path_or_eid): 

118 """Parse a given eID or session path 

119 If a session UUID is given, resolves and stores the local path and vice versa 

120 :param session_path_or_eid: A session eid or path 

121 :return: 

122 """ 

123 self.eid = None 

124 if is_uuid_string(str(session_path_or_eid)): 

125 self.eid = session_path_or_eid 

126 # Try to set session_path if data is found locally 

127 self.session_path = self.one.eid2path(self.eid) 

128 elif is_session_path(session_path_or_eid): 

129 self.session_path = Path(session_path_or_eid) 

130 if self.one is not None: 

131 self.eid = self.one.path2eid(self.session_path) 

132 if not self.eid: 

133 self.log.warning('Failed to determine eID from session path') 

134 else: 

135 self.log.error('Cannot run alignment: an experiment uuid or session path is required') 

136 raise ValueError("'session' must be a valid session path or uuid") 

137 

138 def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False): 

139 """ 

140 Align video to the wheel using cross-correlation of the video motion signal and the rotary 

141 encoder. 

142 

143 Parameters 

144 ---------- 

145 period : (float, float) 

146 The time period over which to do the alignment. 

147 side : {'left', 'right'} 

148 With which camera to perform the alignment. 

149 sd_thresh : float 

150 For plotting where the motion energy goes above this standard deviation threshold. 

151 display : bool 

152 When true, displays the aligned wheel motion energy along with the rotary encoder 

153 signal. 

154 

155 Returns 

156 ------- 

157 int 

158 Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder 

159 signal. Negative values mean the video was shifted backwards with respect to the wheel 

160 timestamps. 

161 float 

162 The peak cross-correlation. 

163 numpy.ndarray 

164 The motion energy used in the cross-correlation, i.e. the frame difference for the 

165 period given. 

166 """ 

167 # Get data samples within period 

168 wheel = self.data['wheel'] 1bca

169 self.alignment.label = side 1bca

170 self.alignment.to_mask = lambda ts: np.logical_and(ts >= period[0], ts <= period[1]) 1bca

171 camera_times = self.data['camera_times'][side] 1bca

172 cam_mask = self.alignment.to_mask(camera_times) 1bca

173 frame_numbers, = np.where(cam_mask) 1bca

174 

175 if frame_numbers.size == 0: 1bca

176 raise ValueError('No frames during given period') 1a

177 

178 # Motion Energy 

179 camera_path = self.video_paths[side] 1bca

180 roi = (*[slice(*r) for r in self.roi[side]], 0) 1bca

181 try: 1bca

182 # TODO Add function arg to make grayscale 

183 self.alignment.frames = \ 1bca

184 vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi) 

185 assert self.alignment.frames.size != 0 1bca

186 except AssertionError: 

187 self.log.error('Failed to open video') 

188 return None, None, None 

189 self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2) 1bca

190 self.alignment.period = period # For plotting 1bca

191 

192 # Calculate rotary encoder velocity trace 

193 x = camera_times[cam_mask] 1bca

194 Fs = 1000 1bca

195 pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs) 1bca

196 v, _ = wh.velocity_filtered(pos, Fs) 1bca

197 interp_mask = self.alignment.to_mask(t) 1bca

198 # Convert to normalized speed 

199 xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x]) 1bca

200 vs = np.abs(v[interp_mask][xs]) 1bca

201 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1bca

202 

203 # FIXME This can be used as a goodness of fit measure 

204 USE_CV2 = False 1bca

205 if USE_CV2: 1bca

206 # convert from numpy format to openCV format 

207 dfCV = np.float32(self.alignment.df.reshape((-1, 1))) 

208 reCV = np.float32(vs.reshape((-1, 1))) 

209 

210 # perform cross correlation 

211 resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED) 

212 

213 # convert result back to numpy array 

214 xcorr = np.asarray(resultCv) 

215 else: 

216 xcorr = signal.correlate(self.alignment.df, vs) 1bca

217 

218 # Cross correlate wheel speed trace with the motion energy 

219 CORRECTION = 2 1bca

220 self.alignment.c = max(xcorr) 1bca

221 self.alignment.xcorr = np.argmax(xcorr) 1bca

222 self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION 1bca

223 self.log.info(f'{side} camera, adjusted by {self.alignment.dt_i} frames') 1bca

224 

225 if display: 1bca

226 # Plot the motion energy 

227 fig, ax = plt.subplots(2, 1, sharex='all') 

228 y = np.pad(self.alignment.df, 1, 'edge') 

229 ax[0].plot(x, y, '-x', label='wheel motion energy') 

230 thresh = stDev > sd_thresh 

231 ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, 

232 linewidth=0.5, linestyle=':', label=f'>{sd_thresh} s.d. diff') 

233 ax[1].plot(t[interp_mask], np.abs(v[interp_mask])) 

234 

235 # Plot other stuff 

236 dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]]) 

237 fps = 1 / np.diff(camera_times).mean() 

238 ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)') 

239 ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps)) 

240 ax[0].set_ylabel('rate of change (a.u.)') 

241 ax[0].legend() 

242 ax[1].set_ylabel('wheel speed (rad / s)') 

243 ax[1].set_xlabel('Time (s)') 

244 

245 title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s' 

246 fig.suptitle(title, fontsize=16) 

247 fig.set_size_inches(19.2, 9.89) 

248 

249 return self.alignment.dt_i, self.alignment.c, self.alignment.df 1bca

250 

251 def plot_alignment(self, energy=True, save=False): 

252 if not self.alignment: 1a

253 self.log.error('No alignment data, run `align_motion` first') 

254 return 

255 # Change backend based on save flag 

256 backend = matplotlib.get_backend().lower() 1a

257 if (save and backend != 'agg') or (not save and backend == 'agg'): 1a

258 new_backend = 'Agg' if save else 'Qt5Agg' 

259 self.log.warning('Switching backend from %s to %s', backend, new_backend) 

260 matplotlib.use(new_backend) 

261 from matplotlib import pyplot as plt 1a

262 

263 # Main animated plots 

264 fig, axes = plt.subplots(nrows=2) 1a

265 title = f'{self.ref}' # ', from {period[0]:.1f}s - {period[1]:.1f}s' 1a

266 fig.suptitle(title, fontsize=16) 1a

267 

268 wheel = self.data['wheel'] 1a

269 wheel_mask = self.alignment['to_mask'](wheel.timestamps) 1a

270 ts = self.data['camera_times'][self.alignment['label']] 1a

271 frame_numbers, = np.where(self.alignment['to_mask'](ts)) 1a

272 if energy: 1a

273 self.alignment['frames'] = video.frame_diffs(self.alignment['frames'], 2) 1a

274 frame_numbers = frame_numbers[1:-1] 1a

275 data = {'frame_ids': frame_numbers} 1a

276 

277 def init_plot(): 1a

278 """ 

279 Plot the wheel data for the current trial 

280 :return: None 

281 """ 

282 data['im'] = axes[0].imshow(self.alignment['frames'][0]) 1a

283 axes[0].axis('off') 1a

284 axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames') 1a

285 

286 # Plot the wheel position 

287 ax = axes[1] 1a

288 ax.clear() 1a

289 ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask], '-x') 1a

290 

291 ts_0 = frame_numbers[0] 1a

292 data['idx_0'] = ts_0 - self.alignment['dt_i'] 1a

293 ts_0 = ts[ts_0 + self.alignment['dt_i']] 1a

294 data['ln'] = ax.axvline(x=ts_0, color='k') 1a

295 ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)]) 1a

296 data['frame_num'] = 0 1a

297 mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0) 1a

298 

299 data['marker'], = ax.plot( 1a

300 wheel.timestamps[wheel_mask][mkr], 

301 wheel.position[wheel_mask][mkr], 'r-x') 

302 ax.set_ylabel('Wheel position (rad))') 1a

303 ax.set_xlabel('Time (s))') 1a

304 return 1a

305 

306 def animate(i): 1a

307 """ 

308 Callback for figure animation. Sets image data for current frame and moves pointer 

309 along axis 

310 :param i: unused; the current time step of the calling method 

311 :return: None 

312 """ 

313 if i < 0: 1a

314 data['frame_num'] -= 1 

315 if data['frame_num'] < 0: 

316 data['frame_num'] = len(self.alignment['frames']) - 1 

317 else: 

318 data['frame_num'] += 1 1a

319 if data['frame_num'] >= len(self.alignment['frames']): 1a

320 data['frame_num'] = 0 1a

321 i = data['frame_num'] # NB: This is index for current trial's frame list 1a

322 

323 frame = self.alignment['frames'][i] 1a

324 t_x = ts[data['idx_0'] + i] 1a

325 data['ln'].set_xdata([t_x, t_x]) 1a

326 axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)]) 1a

327 data['im'].set_data(frame) 1a

328 

329 mkr = find_nearest(wheel.timestamps[wheel_mask], t_x) 1a

330 data['marker'].set_data( 1a

331 wheel.timestamps[wheel_mask][mkr], 

332 wheel.position[wheel_mask][mkr] 

333 ) 

334 

335 return data['im'], data['ln'], data['marker'] 1a

336 

337 anim = animation.FuncAnimation(fig, animate, init_func=init_plot, 1a

338 frames=(range(len(self.alignment.df)) 

339 if save 

340 else cycle(range(60))), 

341 interval=20, blit=False, 

342 repeat=not save, cache_frame_data=False) 

343 anim.running = False 1a

344 

345 def process_key(event): 1a

346 """ 

347 Callback for key presses. 

348 :param event: a figure key_press_event 

349 :return: None 

350 """ 

351 if event.key.isspace(): 

352 if anim.running: 

353 anim.event_source.stop() 

354 else: 

355 anim.event_source.start() 

356 anim.running = ~anim.running 

357 elif event.key == 'right': 

358 if anim.running: 

359 anim.event_source.stop() 

360 anim.running = False 

361 animate(1) 

362 fig.canvas.draw() 

363 elif event.key == 'left': 

364 if anim.running: 

365 anim.event_source.stop() 

366 anim.running = False 

367 animate(-1) 

368 fig.canvas.draw() 

369 

370 fig.canvas.mpl_connect('key_press_event', process_key) 1a

371 

372 # init_plot() 

373 # while True: 

374 # animate(0) 

375 if save: 1a

376 filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0]) 1a

377 if isinstance(save, (str, Path)): 1a

378 filename = Path(save).joinpath(filename) 1a

379 self.log.info(f'Saving to {filename}') 1a

380 # Set up formatting for the movie files 

381 Writer = animation.writers['ffmpeg'] 1a

382 writer = Writer(fps=24, metadata=dict(artist='Miles Wells'), bitrate=1800) 1a

383 anim.save(str(filename), writer=writer) 1a

384 else: 

385 plt.show()