Coverage for ibllib/io/extractors/video_motion.py: 45%

627 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1""" 

2A module for aligning the wheel motion with the rotary encoder. Currently used by the camera QC 

3in order to check timestamp alignment. 

4""" 

5import matplotlib 

6import matplotlib.pyplot as plt 

7import matplotlib.gridspec as gridspec 

8from matplotlib.widgets import RectangleSelector 

9import numpy as np 

10from scipy import signal, ndimage, interpolate 

11import cv2 

12from itertools import cycle 

13import matplotlib.animation as animation 

14import logging 

15from pathlib import Path 

16from joblib import Parallel, delayed, cpu_count 

17 

18from ibldsp.utils import WindowGenerator 

19from one.api import ONE 

20import ibllib.io.video as vidio 

21from iblutil.util import Bunch 

22from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map 

23import ibllib.io.raw_data_loaders as raw 

24import ibllib.io.extractors.camera as cam 

25from ibllib.plots.snapshot import ReportSnapshot 

26import brainbox.video as video 

27import brainbox.behavior.wheel as wh 

28from brainbox.singlecell import bin_spikes 

29from brainbox.behavior.dlc import likelihood_threshold, get_speed 

30from brainbox.task.trials import find_trial_ids 

31import one.alf.io as alfio 

32from one.alf.exceptions import ALFObjectNotFound 

33from one.alf.spec import is_session_path, is_uuid_string 

34 

35 

36def find_nearest(array, value): 

37 array = np.asarray(array) 1deb

38 idx = (np.abs(array - value)).argmin() 1deb

39 return idx 1deb

40 

41 

42class MotionAlignment: 

43 roi = {'left': ((800, 1020), (233, 1096)), 'right': ((426, 510), (104, 545)), 'body': ((402, 481), (31, 103))} 

44 

45 def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), stream=False, **kwargs): 

46 self.one = one or ONE() 1deb

47 self.eid = eid 1deb

48 self.session_path = kwargs.pop('session_path', None) or self.one.eid2path(eid) 1deb

49 self.ref = self.one.dict2ref(self.one.path2ref(self.session_path)) 1deb

50 self.log = log 1deb

51 self.trials = self.wheel = self.camera_times = None 1deb

52 raw_cam_path = self.session_path.joinpath('raw_video_data') 1deb

53 camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*')) 1deb

54 if stream: 1deb

55 self.video_paths = vidio.url_from_eid(self.eid) 

56 else: 

57 self.video_paths = {vidio.label_from_path(x): x for x in camera_path} 1deb

58 self.data = Bunch() 1deb

59 self.alignment = Bunch() 1deb

60 

61 def align_all_trials(self, side='all'): 

62 """Align all wheel motion for all trials""" 

63 if self.trials is None: 

64 self.load_data() 

65 if side == 'all': 

66 side = self.video_paths.keys() 

67 if not isinstance(side, str): 

68 # Try to iterate over sides 

69 [self.align_all_trials(s) for s in side] 

70 if side not in self.video_paths: 

71 raise ValueError(f'{side} camera video file not found') 

72 # Align each trial sequentially 

73 for i in np.arange(self.trials['intervals'].shape[0]): 

74 self.align_motion(i, display=False) 

75 

76 @staticmethod 

77 def set_roi(video_path): 

78 """Manually set the ROIs for a given set of videos 

79 TODO Improve docstring 

80 TODO A method for setting ROIs by label 

81 """ 

82 frame = vidio.get_video_frame(str(video_path), 0) 

83 

84 def line_select_callback(eclick, erelease): 

85 """ 

86 Callback for line selection. 

87 

88 *eclick* and *erelease* are the press and release events. 

89 """ 

90 x1, y1 = eclick.xdata, eclick.ydata 

91 x2, y2 = erelease.xdata, erelease.ydata 

92 print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2)) 

93 return np.array([[x1, x2], [y1, y2]]) 

94 

95 plt.imshow(frame) 

96 roi = RectangleSelector(plt.gca(), line_select_callback, drawtype='box', useblit=True, button=[1, 3], 

97 # don't use middle button 

98 minspanx=5, minspany=5, spancoords='pixels', interactive=True) 

99 plt.show() 

100 ((x1, x2, *_), (y1, *_, y2)) = roi.corners 

101 col = np.arange(round(x1), round(x2), dtype=int) 

102 row = np.arange(round(y1), round(y2), dtype=int) 

103 return col, row 

104 

105 def load_data(self, download=False): 

106 """ 

107 Load wheel, trial and camera timestamp data 

108 :return: wheel, trials 

109 """ 

110 if download: 

111 self.data.wheel = self.one.load_object(self.eid, 'wheel') 

112 self.data.trials = self.one.load_object(self.eid, 'trials') 

113 cam, det = self.one.load_datasets(self.eid, ['*Camera.times*']) 

114 self.data.camera_times = {vidio.label_from_path(d['rel_path']): ts for ts, d in zip(cam, det)} 

115 else: 

116 alf_path = self.session_path / 'alf' 

117 wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent 

118 self.data.wheel = alfio.load_object(wheel_path, 'wheel', short_keys=True) 

119 trials_path = next(alf_path.rglob('*trials.table*')).parent 

120 self.data.trials = alfio.load_object(trials_path, 'trials') 

121 self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x) for x in 

122 alf_path.rglob('*Camera.times*')} 

123 assert all(x is not None for x in self.data.values()) 

124 

125 def _set_eid_or_path(self, session_path_or_eid): 

126 """Parse a given eID or session path 

127 If a session UUID is given, resolves and stores the local path and vice versa 

128 :param session_path_or_eid: A session eid or path 

129 :return: 

130 """ 

131 self.eid = None 

132 if is_uuid_string(str(session_path_or_eid)): 

133 self.eid = session_path_or_eid 

134 # Try to set session_path if data is found locally 

135 self.session_path = self.one.eid2path(self.eid) 

136 elif is_session_path(session_path_or_eid): 

137 self.session_path = Path(session_path_or_eid) 

138 if self.one is not None: 

139 self.eid = self.one.path2eid(self.session_path) 

140 if not self.eid: 

141 self.log.warning('Failed to determine eID from session path') 

142 else: 

143 self.log.error('Cannot run alignment: an experiment uuid or session path is required') 

144 raise ValueError("'session' must be a valid session path or uuid") 

145 

146 def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False): 

147 """ 

148 Align video to the wheel using cross-correlation of the video motion signal and the rotary 

149 encoder. 

150 

151 Parameters 

152 ---------- 

153 period : (float, float) 

154 The time period over which to do the alignment. 

155 side : {'left', 'right'} 

156 With which camera to perform the alignment. 

157 sd_thresh : float 

158 For plotting where the motion energy goes above this standard deviation threshold. 

159 display : bool 

160 When true, displays the aligned wheel motion energy along with the rotary encoder 

161 signal. 

162 

163 Returns 

164 ------- 

165 int 

166 Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder 

167 signal. Negative values mean the video was shifted backwards with respect to the wheel 

168 timestamps. 

169 float 

170 The peak cross-correlation. 

171 numpy.ndarray 

172 The motion energy used in the cross-correlation, i.e. the frame difference for the 

173 period given. 

174 """ 

175 # Get data samples within period 

176 wheel = self.data['wheel'] 1deb

177 self.alignment.label = side 1deb

178 self.alignment.to_mask = lambda ts: np.logical_and(ts >= period[0], ts <= period[1]) 1deb

179 camera_times = self.data['camera_times'][side] 1deb

180 cam_mask = self.alignment.to_mask(camera_times) 1deb

181 frame_numbers, = np.where(cam_mask) 1deb

182 

183 if frame_numbers.size == 0: 1deb

184 raise ValueError('No frames during given period') 1b

185 

186 # Motion Energy 

187 camera_path = self.video_paths[side] 1deb

188 roi = (*[slice(*r) for r in self.roi[side]], 0) 1deb

189 try: 1deb

190 # TODO Add function arg to make grayscale 

191 self.alignment.frames = vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi) 1deb

192 assert self.alignment.frames.size != 0 1deb

193 except AssertionError: 

194 self.log.error('Failed to open video') 

195 return None, None, None 

196 self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2) 1deb

197 self.alignment.period = period # For plotting 1deb

198 

199 # Calculate rotary encoder velocity trace 

200 x = camera_times[cam_mask] 1deb

201 Fs = 1000 1deb

202 pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs) 1deb

203 v, _ = wh.velocity_filtered(pos, Fs) 1deb

204 interp_mask = self.alignment.to_mask(t) 1deb

205 # Convert to normalized speed 

206 xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x]) 1deb

207 vs = np.abs(v[interp_mask][xs]) 1deb

208 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1deb

209 

210 # FIXME This can be used as a goodness of fit measure 

211 USE_CV2 = False 1deb

212 if USE_CV2: 1deb

213 # convert from numpy format to openCV format 

214 dfCV = np.float32(self.alignment.df.reshape((-1, 1))) 

215 reCV = np.float32(vs.reshape((-1, 1))) 

216 

217 # perform cross correlation 

218 resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED) 

219 

220 # convert result back to numpy array 

221 xcorr = np.asarray(resultCv) 

222 else: 

223 xcorr = signal.correlate(self.alignment.df, vs) 1deb

224 

225 # Cross correlate wheel speed trace with the motion energy 

226 CORRECTION = 2 1deb

227 self.alignment.c = max(xcorr) 1deb

228 self.alignment.xcorr = np.argmax(xcorr) 1deb

229 self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION 1deb

230 self.log.info(f'{side} camera, adjusted by {self.alignment.dt_i} frames') 1deb

231 

232 if display: 1deb

233 # Plot the motion energy 

234 fig, ax = plt.subplots(2, 1, sharex='all') 

235 y = np.pad(self.alignment.df, 1, 'edge') 

236 ax[0].plot(x, y, '-x', label='wheel motion energy') 

237 thresh = stDev > sd_thresh 

238 ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, linewidth=0.5, linestyle=':', 

239 label=f'>{sd_thresh} s.d. diff') 

240 ax[1].plot(t[interp_mask], np.abs(v[interp_mask])) 

241 

242 # Plot other stuff 

243 dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]]) 

244 fps = 1 / np.diff(camera_times).mean() 

245 ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)') 

246 ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps)) 

247 ax[0].set_ylabel('rate of change (a.u.)') 

248 ax[0].legend() 

249 ax[1].set_ylabel('wheel speed (rad / s)') 

250 ax[1].set_xlabel('Time (s)') 

251 

252 title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s' 

253 fig.suptitle(title, fontsize=16) 

254 fig.set_size_inches(19.2, 9.89) 

255 

256 return self.alignment.dt_i, self.alignment.c, self.alignment.df 1deb

257 

258 def plot_alignment(self, energy=True, save=False): 

259 if not self.alignment: 1b

260 self.log.error('No alignment data, run `align_motion` first') 

261 return 

262 # Change backend based on save flag 

263 backend = matplotlib.get_backend().lower() 1b

264 if (save and backend != 'agg') or (not save and backend == 'agg'): 1b

265 new_backend = 'Agg' if save else 'Qt5Agg' 

266 self.log.warning('Switching backend from %s to %s', backend, new_backend) 

267 matplotlib.use(new_backend) 

268 from matplotlib import pyplot as plt 1b

269 

270 # Main animated plots 

271 fig, axes = plt.subplots(nrows=2) 1b

272 title = f'{self.ref}' # ', from {period[0]:.1f}s - {period[1]:.1f}s' 1b

273 fig.suptitle(title, fontsize=16) 1b

274 

275 wheel = self.data['wheel'] 1b

276 wheel_mask = self.alignment['to_mask'](wheel.timestamps) 1b

277 ts = self.data['camera_times'][self.alignment['label']] 1b

278 frame_numbers, = np.where(self.alignment['to_mask'](ts)) 1b

279 if energy: 1b

280 self.alignment['frames'] = video.frame_diffs(self.alignment['frames'], 2) 1b

281 frame_numbers = frame_numbers[1:-1] 1b

282 data = {'frame_ids': frame_numbers} 1b

283 

284 def init_plot(): 1b

285 """ 

286 Plot the wheel data for the current trial 

287 :return: None 

288 """ 

289 data['im'] = axes[0].imshow(self.alignment['frames'][0]) 1b

290 axes[0].axis('off') 1b

291 axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames') 1b

292 

293 # Plot the wheel position 

294 ax = axes[1] 1b

295 ax.clear() 1b

296 ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask], '-x') 1b

297 

298 ts_0 = frame_numbers[0] 1b

299 data['idx_0'] = ts_0 - self.alignment['dt_i'] 1b

300 ts_0 = ts[ts_0 + self.alignment['dt_i']] 1b

301 data['ln'] = ax.axvline(x=ts_0, color='k') 1b

302 ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)]) 1b

303 data['frame_num'] = 0 1b

304 mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0) 1b

305 

306 data['marker'], = ax.plot(wheel.timestamps[wheel_mask][mkr], wheel.position[wheel_mask][mkr], 'r-x') 1b

307 ax.set_ylabel('Wheel position (rad))') 1b

308 ax.set_xlabel('Time (s))') 1b

309 return 1b

310 

311 def animate(i): 1b

312 """ 

313 Callback for figure animation. Sets image data for current frame and moves pointer 

314 along axis 

315 :param i: unused; the current time step of the calling method 

316 :return: None 

317 """ 

318 if i < 0: 1b

319 data['frame_num'] -= 1 

320 if data['frame_num'] < 0: 

321 data['frame_num'] = len(self.alignment['frames']) - 1 

322 else: 

323 data['frame_num'] += 1 1b

324 if data['frame_num'] >= len(self.alignment['frames']): 1b

325 data['frame_num'] = 0 1b

326 i = data['frame_num'] # NB: This is index for current trial's frame list 1b

327 

328 frame = self.alignment['frames'][i] 1b

329 t_x = ts[data['idx_0'] + i] 1b

330 data['ln'].set_xdata([t_x, t_x]) 1b

331 axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)]) 1b

332 data['im'].set_data(frame) 1b

333 

334 mkr = find_nearest(wheel.timestamps[wheel_mask], t_x) 1b

335 data['marker'].set_data([wheel.timestamps[wheel_mask][mkr]], [wheel.position[wheel_mask][mkr]]) 1b

336 

337 return data['im'], data['ln'], data['marker'] 1b

338 

339 anim = animation.FuncAnimation(fig, animate, init_func=init_plot, 1b

340 frames=(range(len(self.alignment.df)) if save else cycle(range(60))), interval=20, 

341 blit=False, repeat=not save, cache_frame_data=False) 

342 anim.running = False 1b

343 

344 def process_key(event): 1b

345 """ 

346 Callback for key presses. 

347 :param event: a figure key_press_event 

348 :return: None 

349 """ 

350 if event.key.isspace(): 

351 if anim.running: 

352 anim.event_source.stop() 

353 else: 

354 anim.event_source.start() 

355 anim.running = ~anim.running 

356 elif event.key == 'right': 

357 if anim.running: 

358 anim.event_source.stop() 

359 anim.running = False 

360 animate(1) 

361 fig.canvas.draw() 

362 elif event.key == 'left': 

363 if anim.running: 

364 anim.event_source.stop() 

365 anim.running = False 

366 animate(-1) 

367 fig.canvas.draw() 

368 

369 fig.canvas.mpl_connect('key_press_event', process_key) 1b

370 

371 # init_plot() 

372 # while True: 

373 # animate(0) 

374 if save: 1b

375 filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0]) 1b

376 if isinstance(save, (str, Path)): 1b

377 filename = Path(save).joinpath(filename) 1b

378 self.log.info(f'Saving to {filename}') 1b

379 # Set up formatting for the movie files 

380 Writer = animation.writers['ffmpeg'] 1b

381 writer = Writer(fps=24, metadata=dict(artist='Miles Wells'), bitrate=1800) 1b

382 anim.save(str(filename), writer=writer) 1b

383 else: 

384 plt.show() 

385 

386 

387class MotionAlignmentFullSession: 

388 def __init__(self, session_path, label, **kwargs): 

389 """ 

390 Class to extract camera times using video motion energy wheel alignment 

391 :param session_path: path of the session 

392 :param label: video label, only 'left' and 'right' videos are supported 

393 :param kwargs: threshold - the threshold to apply when identifying frames with artefacts (default 20) 

394 upload - whether to upload summary figure to alyx (default False) 

395 twin - the window length used when computing the shifts between the wheel and video 

396 nprocesses - the number of CPU processes to use 

397 sync - the type of sync scheme used (options 'nidq' or 'bpod') 

398 location - whether the code is being run on SDSC or not (options 'SDSC' or None) 

399 """ 

400 self.session_path = session_path 1gfa

401 self.label = label 1gfa

402 self.threshold = kwargs.get('threshold', 20) 1gfa

403 self.upload = kwargs.get('upload', False) 1gfa

404 self.twin = kwargs.get('twin', 150) 1gfa

405 self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4)) 1gfa

406 

407 self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None)) 1gfa

408 self.roi, self.mask = self.get_roi_mask() 1a

409 

410 if self.upload: 1a

411 self.one = ONE(mode='remote') 

412 self.one.alyx.authenticate() 

413 self.eid = self.one.path2eid(self.session_path) 

414 

415 def load_data(self, sync='nidq', location=None): 

416 """ 

417 Loads relevant data from disk to perform motion alignment 

418 :param sync: type of sync used, 'nidq' or 'bpod' 

419 :param location: where the code is being run, if location='SDSC', the dataset uuids are removed 

420 when loading the data 

421 :return: 

422 """ 

423 def fix_keys(alf_object): 1gfa

424 """ 

425 Given an alf object removes the dataset uuid from the keys 

426 :param alf_object: 

427 :return: 

428 """ 

429 ob = Bunch() 

430 for key in alf_object.keys(): 

431 vals = alf_object[key] 

432 ob[key.split('.')[0]] = vals 

433 return ob 

434 

435 alf_path = self.session_path.joinpath('alf') 1gfa

436 wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent 1gfa

437 wheel = (fix_keys(alfio.load_object(wheel_path, 'wheel')) if location == 'SDSC' 1fa

438 else alfio.load_object(wheel_path, 'wheel')) 

439 self.wheel_timestamps = wheel.timestamps 1fa

440 # Compute interpolated wheel position and wheel times 

441 wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) 1fa

442 # Compute wheel velocity 

443 self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) 1fa

444 # Load in original camera times 

445 self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob(f'_iblrig_{self.label}Camera.raw*.mp4'))) 1fa

446 self.camera_meta = vidio.get_video_meta(self.camera_path) 1a

447 

448 # TODO should read in the description file to get the correct sync location 

449 if sync == 'nidq': 1a

450 # If the sync is 'nidq' we read in the camera ttls from the spikeglx sync object 

451 sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection='raw_ephys_data') 1a

452 sr = get_sync_fronts(sync, chmap[f'{self.label}_camera']) 1a

453 self.ttls = sr.times[::2] 1a

454 else: 

455 # Otherwise we assume the sync is 'bpod' and we read in the camera ttls from the raw bpod data 

456 cam_extractor = cam.CameraTimestampsBpod(session_path=self.session_path) 

457 cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data') 

458 self.ttls = cam_extractor._times_from_bpod() 

459 

460 # Check if the ttl and video sizes match up 

461 self.tdiff = self.ttls.size - self.camera_meta['length'] 1a

462 

463 # Load in original camera times if available otherwise set to ttls 

464 camera_times = next(alf_path.rglob(f'_ibl_{self.label}Camera.times*.npy'), None) 1a

465 self.camera_times = alfio.load_file_content(camera_times) if camera_times else self.ttls 1a

466 

467 if self.tdiff < 0: 1a

468 # In this case there are fewer ttls than camera frames. This is not ideal, for now we pad the ttls with 

469 # nans but if this is too many we reject the wheel alignment based on the qc 

470 self.ttl_times = self.ttls 

471 self.times = np.r_[self.ttl_times, np.full((np.abs(self.tdiff)), np.nan)] 

472 if self.camera_times.size != self.camera_meta['length']: 

473 self.camera_times = np.r_[self.camera_times, np.full((np.abs(self.tdiff)), np.nan)] 

474 self.short_flag = True 

475 elif self.tdiff > 0: 1a

476 # In this case there are more ttls than camera frames. This happens often, for now we remove the first 

477 # tdiff ttls from the ttls 

478 self.ttl_times = self.ttls[self.tdiff:] 1a

479 self.times = self.ttls[self.tdiff:] 1a

480 if self.camera_times.size != self.camera_meta['length']: 1a

481 self.camera_times = self.camera_times[self.tdiff:] 1a

482 self.short_flag = False 1a

483 

484 # Compute the frame rate of the camera 

485 self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) 1a

486 

487 # We attempt to load in some behavior data (trials and dlc). This is only needed for the summary plots, having 

488 # trial aligned paw velocity (from the dlc) is a nice sanity check to make sure the alignment went well 

489 try: 1a

490 self.trials = alfio.load_file_content(next(alf_path.rglob('_ibl_trials.table*.pqt'))) 1a

491 self.dlc = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.dlc*.pqt'))) 

492 self.dlc = likelihood_threshold(self.dlc) 

493 self.behavior = True 

494 except (ALFObjectNotFound, StopIteration): 1a

495 self.behavior = False 1a

496 

497 # Load in a single frame that we will use for the summary plot 

498 self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0]) 1a

499 

500 def get_roi_mask(self): 

501 """ 

502 Compute the region of interest mask for a given camera. This corresponds to a box in the video that we will 

503 use to compute the wheel motion energy 

504 :return: 

505 """ 

506 

507 if self.label == 'right': 1a

508 roi = ((450, 512), (120, 200)) 1a

509 else: 

510 roi = ((900, 1024), (850, 1010)) 

511 roi_mask = (*[slice(*r) for r in roi], 0) 1a

512 

513 return roi, roi_mask 1a

514 

515 def find_contaminated_frames(self, video_frames, thresold=20, normalise=True): 

516 """ 

517 Finds frames in the video that have artefacts such as the mouse's paw or a human hand. In order to determine 

518 frames with contamination an Otsu thresholding is applied to each frame to detect the artefact from the 

519 background image 

520 :param video_frames: np array of video frames (nframes, nwidth, nheight) 

521 :param thresold: threshold to differentiate artefact from background 

522 :param normalise: whether to normalise the threshold values for each frame to the baseline 

523 :return: mask of frames that are contaminated 

524 """ 

525 high = np.zeros((video_frames.shape[0])) 

526 # Iterate through each frame and compute and store the otsu threshold value for each frame 

527 for idx, frame in enumerate(video_frames): 

528 ret, _ = cv2.threshold(cv2.GaussianBlur(frame, (5, 5), 0), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 

529 high[idx] = ret 

530 

531 # If normalise is True, we divide the threshold values for each frame by the minimum value 

532 if normalise: 

533 high -= np.min(high) 

534 

535 # Identify the frames that have a threshold value greater than the specified threshold cutoff 

536 contaminated_frames = np.where(high > thresold)[0] 

537 

538 return contaminated_frames 

539 

540 def compute_motion_energy(self, first, last, wg, iw): 

541 """ 

542 Computes the video motion energy for frame indexes between first and last. This function is written to be run 

543 in a parallel fashion jusing joblib.parallel 

544 :param first: first frame index of frame interval to consider 

545 :param last: last frame index of frame interval to consider 

546 :param wg: WindowGenerator 

547 :param iw: iteration of the WindowGenerator 

548 :return: 

549 """ 

550 

551 if iw == wg.nwin - 1: 

552 return 

553 

554 # Open the video and read in the relvant video frames between first idx and last idx 

555 cap = cv2.VideoCapture(self.camera_path) 

556 frames = vidio.get_video_frames_preload(cap, np.arange(first, last), mask=self.mask) 

557 # Identify if any of the frames have artefacts in them 

558 idx = self.find_contaminated_frames(frames, self.threshold) 

559 

560 # If some of the frames are contaminated we find all the continuous intervals of contamination 

561 # and set the value for contaminated pixels for these frames to the average of the first frame before and after 

562 # this contamination interval 

563 if len(idx) != 0: 

564 

565 before_status = False 

566 after_status = False 

567 

568 counter = 0 

569 n_frames = 200 

570 # If it is the first frame that is contaminated, we need to read in a bit more of the video to find a 

571 # frame prior to contamination. We attempt this 20 times, after that we just take the value for the first 

572 # frame 

573 while np.any(idx == 0) and counter < 20 and iw != 0: 

574 n_before_offset = (counter + 1) * n_frames 

575 first -= n_frames 

576 extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(first - n_frames, first), 

577 mask=self.mask) 

578 frames = np.concatenate([extra_frames, frames], axis=0) 

579 

580 idx = self.find_contaminated_frames(frames, self.threshold) 

581 before_status = True 

582 counter += 1 

583 if counter > 0: 

584 print(f'In before: {counter}') 

585 

586 counter = 0 

587 # If it is the last frame that is contaminated, we need to read in a bit more of the video to find a 

588 # frame after the contamination. We attempt this 20 times, after that we just take the value for the last 

589 # frame 

590 while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1: 

591 n_after_offset = (counter + 1) * n_frames 

592 last += n_frames 

593 extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), mask=self.mask) 

594 frames = np.concatenate([frames, extra_frames], axis=0) 

595 idx = self.find_contaminated_frames(frames, self.threshold) 

596 after_status = True 

597 counter += 1 

598 

599 if counter > 0: 

600 print(f'In after: {counter}') 

601 

602 # We find all the continuous intervals that contain contamination and fix the affected pixels 

603 # by taking the average value of the frame prior and after contamination 

604 intervals = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1) 

605 for ints in intervals: 

606 if len(ints) > 0 and ints[0] == 0: 

607 ints = ints[1:] 

608 if len(ints) > 0 and ints[-1] == frames.shape[0] - 1: 

609 ints = ints[:-1] 

610 th_all = np.zeros_like(frames[0]) 

611 # We find all affected pixels 

612 for idx in ints: 

613 img = np.copy(frames[idx]) 

614 blur = cv2.GaussianBlur(img, (5, 5), 0) 

615 ret, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 

616 th = cv2.GaussianBlur(th, (5, 5), 10) 

617 th_all += th 

618 # Compute the average image of the frame prior and after the interval 

619 vals = np.mean(np.dstack([frames[ints[0] - 1], frames[ints[-1] + 1]]), axis=-1) 

620 # For each frame set the affected pixels to the value of the clean average image 

621 for idx in ints: 

622 img = frames[idx] 

623 img[th_all > 0] = vals[th_all > 0] 

624 

625 # If we have read in extra video frames we need to cut these off and make sure we only 

626 # consider the frames between the interval first and last given as args 

627 if before_status: 

628 frames = frames[n_before_offset:] 

629 if after_status: 

630 frames = frames[:(-1 * n_after_offset)] 

631 

632 # Once the frames have been cleaned we compute the motion energy between frames 

633 frame_me, _ = video.motion_energy(frames, diff=2, normalize=False) 

634 

635 cap.release() 

636 

637 return frame_me[2:] 

638 

639 def compute_shifts(self, times, me, first, last, iw, wg): 

640 """ 

641 Compute the cross-correlation between the video motion energy and the wheel velocity to find the mismatch 

642 between the camera ttls and the video frames. This function is written to run in a parallel manner using 

643 joblib.parallel 

644 

645 :param times: the times of the video frames across the whole session (ttls) 

646 :param me: the video motion energy computed across the whole session 

647 :param first: first time idx to consider 

648 :param last: last time idx to consider 

649 :param wg: WindowGenerator 

650 :param iw: iteration of the WindowGenerator 

651 :return: 

652 """ 

653 

654 # If we are in the last window we exit 

655 if iw == wg.nwin - 1: 1a

656 return np.nan, np.nan 1a

657 

658 # Find the time interval we are interested in 

659 t_first = times[first] 1a

660 t_last = times[last] 1a

661 

662 # If both times during this interval are nan exit 

663 if np.isnan(t_last) and np.isnan(t_first): 1a

664 return np.nan, np.nan 

665 # If only the last time is nan, we find the last non nan time value 

666 elif np.isnan(t_last): 1a

667 t_last = times[np.where(~np.isnan(times))[0][-1]] 

668 

669 # Find the mask of timepoints that fall in this interval 

670 mask = np.logical_and(times >= t_first, times <= t_last) 1a

671 # Restrict the video motion energy to this interval and normalise the values 

672 align_me = me[np.where(mask)[0]] 1a

673 align_me = (align_me - np.nanmin(align_me)) / (np.nanmax(align_me) - np.nanmin(align_me)) 1a

674 

675 # Find closest timepoints in wheel that match the time interval 

676 wh_mask = np.logical_and(self.wheel_time >= t_first, self.wheel_time <= t_last) 1a

677 if np.sum(wh_mask) == 0: 1a

678 return np.nan, np.nan 

679 # Find the mask for the wheel times 

680 xs = np.searchsorted(self.wheel_time[wh_mask], times[mask]) 1a

681 xs[xs == np.sum(wh_mask)] = np.sum(wh_mask) - 1 1a

682 # Convert to normalized speed 

683 vs = np.abs(self.wheel_vel[wh_mask][xs]) 1a

684 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1a

685 

686 # Account for nan values in the video motion energy 

687 isnan = np.isnan(align_me) 1a

688 if np.sum(isnan) > 0: 1a

689 where_nan = np.where(isnan)[0] 1a

690 assert where_nan[0] == 0 1a

691 assert where_nan[-1] == np.sum(isnan) - 1 1a

692 

693 if np.all(isnan): 1a

694 return np.nan, np.nan 

695 

696 # Compute the cross correlation between the video motion energy and the wheel speed 

697 xcorr = signal.correlate(align_me[~isnan], vs[~isnan]) 1a

698 # The max value of the cross correlation indicates the shift that needs to be applied 

699 # The +2 comes from the fact that the video motion energy was computed from the difference between frames 

700 shift = np.nanargmax(xcorr) - align_me[~isnan].size + 2 1a

701 

702 return shift, t_first + (t_last - t_first) / 2 1a

703 

704 def clean_shifts(self, x, n=1): 

705 """ 

706 Removes artefacts from the computed shifts across time. We assume that the shifts should never increase 

707 over time and that the jump between consecutive shifts shouldn't be greater than 1 

708 :param x: computed shifts 

709 :param n: condition to apply 

710 :return: 

711 """ 

712 y = x.copy() 1a

713 dy = np.diff(y, prepend=y[0]) 1a

714 while True: 1a

715 pos = np.where(dy == 1)[0] if n == 1 else np.where(dy > 2)[0] 1a

716 # added frames: this doesn't make sense and this is noise 

717 if pos.size == 0: 1a

718 break 1a

719 neg = np.where(dy == -1)[0] if n == 1 else np.where(dy < -2)[0] 

720 

721 if len(pos) > len(neg): 

722 neg = np.append(neg, dy.size - 1) 

723 

724 iss = np.minimum(np.searchsorted(neg, pos), neg.size - 1) 

725 imin = np.argmin(np.minimum(np.abs(pos - neg[iss - 1]), np.abs(pos - neg[iss]))) 

726 

727 idx = np.max([0, iss[imin] - 1]) 

728 ineg = neg[idx:iss[imin] + 1] 

729 ineg = ineg[np.argmin(np.abs(pos[imin] - ineg))] 

730 dy[pos[imin]] = 0 

731 dy[ineg] = 0 

732 

733 return np.cumsum(dy) + y[0] 1a

734 

735 def qc_shifts(self, shifts, shifts_filt): 

736 """ 

737 Compute qc values for the wheel alignment. We consider 4 things 

738 1. The number of camera ttl values that are missing (when we have less ttls than video frames) 

739 2. The number of shifts that have nan values, this means the video motion energy computation 

740 3. The number of large jumps (>10) between the computed shifts 

741 4. The number of jumps (>1) between the shifts after they have been cleaned 

742 

743 :param shifts: np.array of shifts over session 

744 :param shifts_filt: np.array of shifts after being cleaned over session 

745 :return: 

746 """ 

747 

748 ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0 1a

749 nan_per = (np.sum(np.isnan(shifts_filt)) / shifts_filt.size) * 100 1a

750 shifts_sum = np.where(np.abs(np.diff(shifts)) > 10)[0].size 1a

751 shifts_filt_sum = np.where(np.abs(np.diff(shifts_filt)) > 1)[0].size 1a

752 

753 qc = dict() 1a

754 qc['ttl_per'] = ttl_per 1a

755 qc['nan_per'] = nan_per 1a

756 qc['shifts_sum'] = shifts_sum 1a

757 qc['shifts_filt_sum'] = shifts_filt_sum 1a

758 

759 qc_outcome = True 1a

760 # If more than 10% of ttls are missing we don't get new times 

761 if ttl_per > 10: 1a

762 qc_outcome = False 

763 # If too many of the shifts are nans it means the alignment is not accurate 

764 if nan_per > 40: 1a

765 qc_outcome = False 

766 # If there are too many artefacts could be errors 

767 if shifts_sum > 60: 1a

768 qc_outcome = False 

769 # If there are jumps > 1 in the filtered shifts then there is a problem 

770 if shifts_filt_sum > 0: 1a

771 qc_outcome = False 

772 

773 return qc, qc_outcome 1a

774 

775 def extract_times(self, shifts_filt, t_shifts): 

776 """ 

777 Extracts new camera times after applying the computed shifts across the session 

778 

779 :param shifts_filt: filtered shifts computed across session 

780 :param t_shifts: time point of computed shifts 

781 :return: 

782 """ 

783 

784 # Compute the interpolation function to apply to the ttl times 

785 t_new = t_shifts - (shifts_filt * 1 / self.frate) 1a

786 fcn = interpolate.interp1d(t_shifts, t_new, fill_value="extrapolate") 1a

787 # Apply the function and get out new times 

788 new_times = fcn(self.ttl_times) 1a

789 

790 # If we are missing ttls then interpolate and append the correct number at the end 

791 if self.tdiff < 0: 1a

792 to_app = (np.arange(np.abs(self.tdiff), ) + 1) / self.frate + new_times[-1] 

793 new_times = np.r_[new_times, to_app] 

794 

795 return new_times 1a

796 

797 @staticmethod 

798 def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True, 

799 norm=False, axs=None): 

800 """ 

801 Compute and plot trial aligned spike rasters and psth 

802 :param spike_times: times of variable 

803 :param events: trial times to align to 

804 :param trial_idx: trial idx to sort by 

805 :param dividers: 

806 :param colors: 

807 :param labels: 

808 :param weights: 

809 :param fr: 

810 :param norm: 

811 :param axs: 

812 :return: 

813 """ 

814 pre_time = 0.4 

815 post_time = 1 

816 raster_bin = 0.01 

817 psth_bin = 0.05 

818 raster, t_raster = bin_spikes( 

819 spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=raster_bin, weights=weights) 

820 psth, t_psth = bin_spikes( 

821 spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=psth_bin, weights=weights) 

822 

823 if fr: 

824 psth = psth / psth_bin 

825 

826 if norm: 

827 psth = psth - np.repeat(psth[:, 0][:, np.newaxis], psth.shape[1], axis=1) 

828 raster = raster - np.repeat(raster[:, 0][:, np.newaxis], raster.shape[1], axis=1) 

829 

830 dividers = [0] + dividers + [len(trial_idx)] 

831 if axs is None: 

832 fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, sharex=True) 

833 else: 

834 fig = axs[0].get_figure() 

835 

836 label, lidx = np.unique(labels, return_index=True) 

837 label_pos = [] 

838 for lab, lid in zip(label, lidx): 

839 idx = np.where(np.array(labels) == lab)[0] 

840 for iD in range(len(idx)): 

841 if iD == 0: 

842 t_ids = trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1] 

843 t_ints = dividers[idx[iD] + 1] - dividers[idx[iD]] 

844 else: 

845 t_ids = np.r_[t_ids, trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]] 

846 t_ints = np.r_[t_ints, dividers[idx[iD] + 1] - dividers[idx[iD]]] 

847 

848 psth_div = np.nanmean(psth[t_ids], axis=0) 

849 std_div = np.nanstd(psth[t_ids], axis=0) / np.sqrt(len(t_ids)) 

850 

851 axs[0].fill_between(t_psth, psth_div - std_div, psth_div + std_div, alpha=0.4, color=colors[lid]) 

852 axs[0].plot(t_psth, psth_div, alpha=1, color=colors[lid]) 

853 

854 lab_max = idx[np.argmax(t_ints)] 

855 label_pos.append((dividers[lab_max + 1] - dividers[lab_max]) / 2 + dividers[lab_max]) 

856 

857 axs[1].imshow(raster[trial_idx], cmap='binary', origin='lower', 

858 extent=[np.min(t_raster), np.max(t_raster), 0, len(trial_idx)], aspect='auto') 

859 

860 width = raster_bin * 4 

861 for iD in range(len(dividers) - 1): 

862 axs[1].fill_between([post_time + raster_bin / 2, post_time + raster_bin / 2 + width], 

863 [dividers[iD + 1], dividers[iD + 1]], [dividers[iD], dividers[iD]], color=colors[iD]) 

864 

865 axs[1].set_xlim([-1 * pre_time, post_time + raster_bin / 2 + width]) 

866 secax = axs[1].secondary_yaxis('right') 

867 

868 secax.set_yticks(label_pos) 

869 secax.set_yticklabels(label, rotation=90, rotation_mode='anchor', ha='center') 

870 for ic, c in enumerate(np.array(colors)[lidx]): 

871 secax.get_yticklabels()[ic].set_color(c) 

872 

873 axs[0].axvline(0, *axs[0].get_ylim(), c='k', ls='--', zorder=10) # TODO this doesn't always work 

874 axs[1].axvline(0, *axs[1].get_ylim(), c='k', ls='--', zorder=10) 

875 

876 return fig, axs 

877 

878 def plot_with_behavior(self): 

879 """ 

880 Makes a summary figure of the alignment when behaviour data is available 

881 :return: 

882 """ 

883 

884 self.dlc = likelihood_threshold(self.dlc) 

885 trial_idx, dividers = find_trial_ids(self.trials, sort='side') 

886 feature_ext = get_speed(self.dlc, self.camera_times, self.label, feature='paw_r') 

887 feature_new = get_speed(self.dlc, self.new_times, self.label, feature='paw_r') 

888 

889 fig = plt.figure() 

890 fig.set_size_inches(15, 9) 

891 gs = gridspec.GridSpec(1, 5, figure=fig, width_ratios=[4, 1, 1, 1, 3], wspace=0.3, hspace=0.5) 

892 gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) 

893 ax01 = fig.add_subplot(gs0[0, 0]) 

894 ax02 = fig.add_subplot(gs0[1, 0]) 

895 ax03 = fig.add_subplot(gs0[2, 0]) 

896 gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1], height_ratios=[1, 3]) 

897 ax11 = fig.add_subplot(gs1[0, 0]) 

898 ax12 = fig.add_subplot(gs1[1, 0]) 

899 gs2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 2], height_ratios=[1, 3]) 

900 ax21 = fig.add_subplot(gs2[0, 0]) 

901 ax22 = fig.add_subplot(gs2[1, 0]) 

902 gs3 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 3], height_ratios=[1, 3]) 

903 ax31 = fig.add_subplot(gs3[0, 0]) 

904 ax32 = fig.add_subplot(gs3[1, 0]) 

905 gs4 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 4]) 

906 ax41 = fig.add_subplot(gs4[0, 0]) 

907 ax42 = fig.add_subplot(gs4[1, 0]) 

908 

909 ax01.plot(self.t_shifts, self.shifts, label='shifts') 

910 ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') 

911 ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) 

912 ax01.legend() 

913 ax01.set_ylabel('Frames') 

914 ax01.set_xlabel('Time in session') 

915 

916 xs = np.searchsorted(self.ttl_times, self.t_shifts) 

917 ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] 

918 ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') 

919 ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) 

920 ax02.legend() 

921 ax02.set_ylabel('Frames') 

922 ax02.set_xlabel('Time in session') 

923 

924 ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new') 

925 ax03.legend() 

926 ax03.set_ylim(-5, 5) 

927 ax03.set_ylabel('Frames') 

928 ax03.set_xlabel('Time in session') 

929 

930 self.single_cluster_raster(self.wheel_timestamps, self.trials['firstMovement_times'].values, trial_idx, dividers, 

931 ['g', 'y'], ['left', 'right'], weights=self.wheel_vel, fr=False, axs=[ax11, ax12]) 

932 ax11.sharex(ax12) 

933 ax11.set_ylabel('Wheel velocity') 

934 ax11.set_title('Wheel') 

935 ax12.set_xlabel('Time from first move') 

936 

937 self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], 

938 ['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22]) 

939 ax21.sharex(ax22) 

940 ax21.set_ylabel('Paw r velocity') 

941 ax21.set_title('Extracted times') 

942 ax22.set_xlabel('Time from first move') 

943 

944 self.single_cluster_raster(self.new_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], 

945 ['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32]) 

946 ax31.sharex(ax32) 

947 ax31.set_ylabel('Paw r velocity') 

948 ax31.set_title('New times') 

949 ax32.set_xlabel('Time from first move') 

950 

951 ax41.imshow(self.frame_example[0]) 

952 rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], 

953 self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none') 

954 ax41.add_patch(rect) 

955 

956 ax42.plot(self.all_me) 

957 

958 return fig 

959 

960 def plot_without_behavior(self): 

961 """ 

962 Makes a summary figure of the alignment when behaviour data is not available 

963 :return: 

964 """ 

965 

966 fig = plt.figure() 

967 fig.set_size_inches(7, 7) 

968 gs = gridspec.GridSpec(1, 2, figure=fig) 

969 gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) 

970 ax01 = fig.add_subplot(gs0[0, 0]) 

971 ax02 = fig.add_subplot(gs0[1, 0]) 

972 ax03 = fig.add_subplot(gs0[2, 0]) 

973 

974 gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1]) 

975 ax04 = fig.add_subplot(gs1[0, 0]) 

976 ax05 = fig.add_subplot(gs1[1, 0]) 

977 

978 ax01.plot(self.t_shifts, self.shifts, label='shifts') 

979 ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') 

980 ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) 

981 ax01.legend() 

982 ax01.set_ylabel('Frames') 

983 ax01.set_xlabel('Time in session') 

984 

985 xs = np.searchsorted(self.ttl_times, self.t_shifts) 

986 ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] 

987 ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') 

988 ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) 

989 ax02.legend() 

990 ax02.set_ylabel('Frames') 

991 ax02.set_xlabel('Time in session') 

992 

993 ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new') 

994 ax03.legend() 

995 ax03.set_ylim(-5, 5) 

996 ax03.set_ylabel('Frames') 

997 ax03.set_xlabel('Time in session') 

998 

999 ax04.imshow(self.frame_example[0]) 

1000 rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], 

1001 self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none') 

1002 ax04.add_patch(rect) 

1003 

1004 ax05.plot(self.all_me) 

1005 

1006 return fig 

1007 

1008 def process(self): 

1009 """ 

1010 Main function used to apply the video motion wheel alignment to the camera times. This function does the 

1011 following 

1012 1. Computes the video motion energy across the whole session (computed in windows and parallelised) 

1013 2. Computes the shift that should be applied to the camera times across the whole session by computing 

1014 the cross correlation between the video motion energy and the wheel speed (computed in 

1015 overlapping windows and parallelised) 

1016 3. Removes artefacts from the computed shifts 

1017 4. Computes the qc for the wheel alignment 

1018 5. Extracts the new camera times using the shifts computed from the video wheel alignment 

1019 6. If upload is True, creates a summary plot of the alignment and uploads the figure to the relevant session 

1020 on alyx 

1021 :return: 

1022 """ 

1023 

1024 # Compute the motion energy of the wheel for the whole video 

1025 wg = WindowGenerator(self.camera_meta['length'], 5000, 4) 1a

1026 out = Parallel(n_jobs=self.nprocess)( 1a

1027 delayed(self.compute_motion_energy)(first, last, wg, iw) for iw, (first, last) in enumerate(wg.firstlast)) 

1028 # Concatenate the motion energy into one big array 

1029 self.all_me = np.array([]) 1a

1030 for vals in out[:-1]: 1a

1031 self.all_me = np.r_[self.all_me, vals] 1a

1032 

1033 toverlap = self.twin - 1 1a

1034 all_me = np.r_[np.full((int(self.camera_meta['fps'] * toverlap)), np.nan), self.all_me] 1a

1035 to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / self.frate)[::-1] 1a

1036 times = np.r_[to_app, self.times] 1a

1037 

1038 wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap)) 1a

1039 

1040 out = Parallel(n_jobs=1)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) 1a

1041 for iw, (first, last) in enumerate(wg.firstlast)) 

1042 

1043 self.shifts = np.array([]) 1a

1044 self.t_shifts = np.array([]) 1a

1045 for vals in out[:-1]: 1a

1046 self.shifts = np.r_[self.shifts, vals[0]] 1a

1047 self.t_shifts = np.r_[self.t_shifts, vals[1]] 1a

1048 

1049 idx = np.bitwise_and(self.t_shifts >= self.ttl_times[0], self.t_shifts < self.ttl_times[-1]) 1a

1050 self.shifts = self.shifts[idx] 1a

1051 self.t_shifts = self.t_shifts[idx] 1a

1052 shifts_filt = ndimage.percentile_filter(self.shifts, 80, 120) 1a

1053 shifts_filt = self.clean_shifts(shifts_filt, n=1) 1a

1054 self.shifts_filt = self.clean_shifts(shifts_filt, n=2) 1a

1055 

1056 self.qc, self.qc_outcome = self.qc_shifts(self.shifts, self.shifts_filt) 1a

1057 

1058 self.new_times = self.extract_times(self.shifts_filt, self.t_shifts) 1a

1059 

1060 if self.upload: 1a

1061 fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior() 

1062 save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', f'video_wheel_alignment_{self.label}.png')) 

1063 save_fig_path.parent.mkdir(exist_ok=True, parents=True) 

1064 fig.savefig(save_fig_path) 

1065 snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one) 

1066 snp.outputs = [save_fig_path] 

1067 snp.register_images(widths=['orig']) 

1068 plt.close(fig) 

1069 

1070 return self.new_times 1a