Coverage for ibllib/io/extractors/video_motion.py: 45%

622 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1""" 

2A module for aligning the wheel motion with the rotary encoder. Currently used by the camera QC 

3in order to check timestamp alignment. 

4""" 

5import matplotlib 

6import matplotlib.pyplot as plt 

7import matplotlib.gridspec as gridspec 

8from matplotlib.widgets import RectangleSelector 

9import numpy as np 

10from scipy import signal, ndimage, interpolate 

11import cv2 

12from itertools import cycle 

13import matplotlib.animation as animation 

14import logging 

15from pathlib import Path 

16from joblib import Parallel, delayed, cpu_count 

17 

18from ibldsp.utils import WindowGenerator 

19from one.api import ONE 

20import ibllib.io.video as vidio 

21from iblutil.util import Bunch 

22from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map 

23import ibllib.io.raw_data_loaders as raw 

24import ibllib.io.extractors.camera as cam 

25from ibllib.plots.snapshot import ReportSnapshot 

26import brainbox.video as video 

27import brainbox.behavior.wheel as wh 

28from brainbox.singlecell import bin_spikes 

29from brainbox.behavior.dlc import likelihood_threshold, get_speed 

30from brainbox.task.trials import find_trial_ids 

31import one.alf.io as alfio 

32from one.alf.exceptions import ALFObjectNotFound 

33from one.alf.spec import is_session_path, is_uuid_string 

34 

35 

36def find_nearest(array, value): 

37 array = np.asarray(array) 1deb

38 idx = (np.abs(array - value)).argmin() 1deb

39 return idx 1deb

40 

41 

42class MotionAlignment: 

43 roi = {'left': ((800, 1020), (233, 1096)), 'right': ((426, 510), (104, 545)), 'body': ((402, 481), (31, 103))} 

44 

45 def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), stream=False, **kwargs): 

46 self.one = one or ONE() 1deb

47 self.eid = eid 1deb

48 self.session_path = kwargs.pop('session_path', None) or self.one.eid2path(eid) 1deb

49 self.ref = self.one.dict2ref(self.one.path2ref(self.session_path)) 1deb

50 self.log = log 1deb

51 self.trials = self.wheel = self.camera_times = None 1deb

52 raw_cam_path = self.session_path.joinpath('raw_video_data') 1deb

53 camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*')) 1deb

54 if stream: 1deb

55 self.video_paths = vidio.url_from_eid(self.eid) 

56 else: 

57 self.video_paths = {vidio.label_from_path(x): x for x in camera_path} 1deb

58 self.data = Bunch() 1deb

59 self.alignment = Bunch() 1deb

60 

61 def align_all_trials(self, side='all'): 

62 """Align all wheel motion for all trials""" 

63 if self.trials is None: 

64 self.load_data() 

65 if side == 'all': 

66 side = self.video_paths.keys() 

67 if not isinstance(side, str): 

68 # Try to iterate over sides 

69 [self.align_all_trials(s) for s in side] 

70 if side not in self.video_paths: 

71 raise ValueError(f'{side} camera video file not found') 

72 # Align each trial sequentially 

73 for i in np.arange(self.trials['intervals'].shape[0]): 

74 self.align_motion(i, display=False) 

75 

76 @staticmethod 

77 def set_roi(video_path): 

78 """Manually set the ROIs for a given set of videos 

79 TODO Improve docstring 

80 TODO A method for setting ROIs by label 

81 """ 

82 frame = vidio.get_video_frame(str(video_path), 0) 

83 

84 def line_select_callback(eclick, erelease): 

85 """ 

86 Callback for line selection. 

87 

88 *eclick* and *erelease* are the press and release events. 

89 """ 

90 x1, y1 = eclick.xdata, eclick.ydata 

91 x2, y2 = erelease.xdata, erelease.ydata 

92 print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2)) 

93 return np.array([[x1, x2], [y1, y2]]) 

94 

95 plt.imshow(frame) 

96 roi = RectangleSelector(plt.gca(), line_select_callback, drawtype='box', useblit=True, button=[1, 3], 

97 # don't use middle button 

98 minspanx=5, minspany=5, spancoords='pixels', interactive=True) 

99 plt.show() 

100 ((x1, x2, *_), (y1, *_, y2)) = roi.corners 

101 col = np.arange(round(x1), round(x2), dtype=int) 

102 row = np.arange(round(y1), round(y2), dtype=int) 

103 return col, row 

104 

105 def load_data(self, download=False): 

106 """ 

107 Load wheel, trial and camera timestamp data 

108 :return: wheel, trials 

109 """ 

110 if download: 

111 self.data.wheel = self.one.load_object(self.eid, 'wheel') 

112 self.data.trials = self.one.load_object(self.eid, 'trials') 

113 cam, det = self.one.load_datasets(self.eid, ['*Camera.times*']) 

114 self.data.camera_times = {vidio.label_from_path(d['rel_path']): ts for ts, d in zip(cam, det)} 

115 else: 

116 alf_path = self.session_path / 'alf' 

117 wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent 

118 self.data.wheel = alfio.load_object(wheel_path, 'wheel', short_keys=True) 

119 trials_path = next(alf_path.rglob('*trials.table*')).parent 

120 self.data.trials = alfio.load_object(trials_path, 'trials') 

121 self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x) for x in 

122 alf_path.rglob('*Camera.times*')} 

123 assert all(x is not None for x in self.data.values()) 

124 

125 def _set_eid_or_path(self, session_path_or_eid): 

126 """Parse a given eID or session path 

127 If a session UUID is given, resolves and stores the local path and vice versa 

128 :param session_path_or_eid: A session eid or path 

129 :return: 

130 """ 

131 self.eid = None 

132 if is_uuid_string(str(session_path_or_eid)): 

133 self.eid = session_path_or_eid 

134 # Try to set session_path if data is found locally 

135 self.session_path = self.one.eid2path(self.eid) 

136 elif is_session_path(session_path_or_eid): 

137 self.session_path = Path(session_path_or_eid) 

138 if self.one is not None: 

139 self.eid = self.one.path2eid(self.session_path) 

140 if not self.eid: 

141 self.log.warning('Failed to determine eID from session path') 

142 else: 

143 self.log.error('Cannot run alignment: an experiment uuid or session path is required') 

144 raise ValueError("'session' must be a valid session path or uuid") 

145 

146 def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False): 

147 """ 

148 Align video to the wheel using cross-correlation of the video motion signal and the rotary 

149 encoder. 

150 

151 Parameters 

152 ---------- 

153 period : (float, float) 

154 The time period over which to do the alignment. 

155 side : {'left', 'right'} 

156 With which camera to perform the alignment. 

157 sd_thresh : float 

158 For plotting where the motion energy goes above this standard deviation threshold. 

159 display : bool 

160 When true, displays the aligned wheel motion energy along with the rotary encoder 

161 signal. 

162 

163 Returns 

164 ------- 

165 int 

166 Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder 

167 signal. Negative values mean the video was shifted backwards with respect to the wheel 

168 timestamps. 

169 float 

170 The peak cross-correlation. 

171 numpy.ndarray 

172 The motion energy used in the cross-correlation, i.e. the frame difference for the 

173 period given. 

174 """ 

175 # Get data samples within period 

176 wheel = self.data['wheel'] 1deb

177 self.alignment.label = side 1deb

178 self.alignment.to_mask = lambda ts: np.logical_and(ts >= period[0], ts <= period[1]) 1deb

179 camera_times = self.data['camera_times'][side] 1deb

180 cam_mask = self.alignment.to_mask(camera_times) 1deb

181 frame_numbers, = np.where(cam_mask) 1deb

182 

183 if frame_numbers.size == 0: 1deb

184 raise ValueError('No frames during given period') 1b

185 

186 # Motion Energy 

187 camera_path = self.video_paths[side] 1deb

188 roi = (*[slice(*r) for r in self.roi[side]], 0) 1deb

189 try: 1deb

190 # TODO Add function arg to make grayscale 

191 self.alignment.frames = vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi) 1deb

192 assert self.alignment.frames.size != 0 1deb

193 except AssertionError: 

194 self.log.error('Failed to open video') 

195 return None, None, None 

196 self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2) 1deb

197 self.alignment.period = period # For plotting 1deb

198 

199 # Calculate rotary encoder velocity trace 

200 x = camera_times[cam_mask] 1deb

201 Fs = 1000 1deb

202 pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs) 1deb

203 v, _ = wh.velocity_filtered(pos, Fs) 1deb

204 interp_mask = self.alignment.to_mask(t) 1deb

205 # Convert to normalized speed 

206 xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x]) 1deb

207 vs = np.abs(v[interp_mask][xs]) 1deb

208 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1deb

209 

210 # FIXME This can be used as a goodness of fit measure 

211 USE_CV2 = False 1deb

212 if USE_CV2: 1deb

213 # convert from numpy format to openCV format 

214 dfCV = np.float32(self.alignment.df.reshape((-1, 1))) 

215 reCV = np.float32(vs.reshape((-1, 1))) 

216 

217 # perform cross correlation 

218 resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED) 

219 

220 # convert result back to numpy array 

221 xcorr = np.asarray(resultCv) 

222 else: 

223 xcorr = signal.correlate(self.alignment.df, vs) 1deb

224 

225 # Cross correlate wheel speed trace with the motion energy 

226 CORRECTION = 2 1deb

227 self.alignment.c = max(xcorr) 1deb

228 self.alignment.xcorr = np.argmax(xcorr) 1deb

229 self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION 1deb

230 self.log.info(f'{side} camera, adjusted by {self.alignment.dt_i} frames') 1deb

231 

232 if display: 1deb

233 # Plot the motion energy 

234 fig, ax = plt.subplots(2, 1, sharex='all') 

235 y = np.pad(self.alignment.df, 1, 'edge') 

236 ax[0].plot(x, y, '-x', label='wheel motion energy') 

237 thresh = stDev > sd_thresh 

238 ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, linewidth=0.5, linestyle=':', 

239 label=f'>{sd_thresh} s.d. diff') 

240 ax[1].plot(t[interp_mask], np.abs(v[interp_mask])) 

241 

242 # Plot other stuff 

243 dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]]) 

244 fps = 1 / np.diff(camera_times).mean() 

245 ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)') 

246 ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps)) 

247 ax[0].set_ylabel('rate of change (a.u.)') 

248 ax[0].legend() 

249 ax[1].set_ylabel('wheel speed (rad / s)') 

250 ax[1].set_xlabel('Time (s)') 

251 

252 title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s' 

253 fig.suptitle(title, fontsize=16) 

254 fig.set_size_inches(19.2, 9.89) 

255 

256 return self.alignment.dt_i, self.alignment.c, self.alignment.df 1deb

257 

258 def plot_alignment(self, energy=True, save=False): 

259 if not self.alignment: 1b

260 self.log.error('No alignment data, run `align_motion` first') 

261 return 

262 # Change backend based on save flag 

263 backend = matplotlib.get_backend().lower() 1b

264 if (save and backend != 'agg') or (not save and backend == 'agg'): 1b

265 new_backend = 'Agg' if save else 'Qt5Agg' 

266 self.log.warning('Switching backend from %s to %s', backend, new_backend) 

267 matplotlib.use(new_backend) 

268 from matplotlib import pyplot as plt 1b

269 

270 # Main animated plots 

271 fig, axes = plt.subplots(nrows=2) 1b

272 title = f'{self.ref}' # ', from {period[0]:.1f}s - {period[1]:.1f}s' 1b

273 fig.suptitle(title, fontsize=16) 1b

274 

275 wheel = self.data['wheel'] 1b

276 wheel_mask = self.alignment['to_mask'](wheel.timestamps) 1b

277 ts = self.data['camera_times'][self.alignment['label']] 1b

278 frame_numbers, = np.where(self.alignment['to_mask'](ts)) 1b

279 if energy: 1b

280 self.alignment['frames'] = video.frame_diffs(self.alignment['frames'], 2) 1b

281 frame_numbers = frame_numbers[1:-1] 1b

282 data = {'frame_ids': frame_numbers} 1b

283 

284 def init_plot(): 1b

285 """ 

286 Plot the wheel data for the current trial 

287 :return: None 

288 """ 

289 data['im'] = axes[0].imshow(self.alignment['frames'][0]) 1b

290 axes[0].axis('off') 1b

291 axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames') 1b

292 

293 # Plot the wheel position 

294 ax = axes[1] 1b

295 ax.clear() 1b

296 ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask], '-x') 1b

297 

298 ts_0 = frame_numbers[0] 1b

299 data['idx_0'] = ts_0 - self.alignment['dt_i'] 1b

300 ts_0 = ts[ts_0 + self.alignment['dt_i']] 1b

301 data['ln'] = ax.axvline(x=ts_0, color='k') 1b

302 ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)]) 1b

303 data['frame_num'] = 0 1b

304 mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0) 1b

305 

306 data['marker'], = ax.plot(wheel.timestamps[wheel_mask][mkr], wheel.position[wheel_mask][mkr], 'r-x') 1b

307 ax.set_ylabel('Wheel position (rad))') 1b

308 ax.set_xlabel('Time (s))') 1b

309 return 1b

310 

311 def animate(i): 1b

312 """ 

313 Callback for figure animation. Sets image data for current frame and moves pointer 

314 along axis 

315 :param i: unused; the current time step of the calling method 

316 :return: None 

317 """ 

318 if i < 0: 1b

319 data['frame_num'] -= 1 

320 if data['frame_num'] < 0: 

321 data['frame_num'] = len(self.alignment['frames']) - 1 

322 else: 

323 data['frame_num'] += 1 1b

324 if data['frame_num'] >= len(self.alignment['frames']): 1b

325 data['frame_num'] = 0 1b

326 i = data['frame_num'] # NB: This is index for current trial's frame list 1b

327 

328 frame = self.alignment['frames'][i] 1b

329 t_x = ts[data['idx_0'] + i] 1b

330 data['ln'].set_xdata([t_x, t_x]) 1b

331 axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)]) 1b

332 data['im'].set_data(frame) 1b

333 

334 mkr = find_nearest(wheel.timestamps[wheel_mask], t_x) 1b

335 data['marker'].set_data([wheel.timestamps[wheel_mask][mkr]], [wheel.position[wheel_mask][mkr]]) 1b

336 

337 return data['im'], data['ln'], data['marker'] 1b

338 

339 anim = animation.FuncAnimation(fig, animate, init_func=init_plot, 1b

340 frames=(range(len(self.alignment.df)) if save else cycle(range(60))), interval=20, 

341 blit=False, repeat=not save, cache_frame_data=False) 

342 anim.running = False 1b

343 

344 def process_key(event): 1b

345 """ 

346 Callback for key presses. 

347 :param event: a figure key_press_event 

348 :return: None 

349 """ 

350 if event.key.isspace(): 

351 if anim.running: 

352 anim.event_source.stop() 

353 else: 

354 anim.event_source.start() 

355 anim.running = ~anim.running 

356 elif event.key == 'right': 

357 if anim.running: 

358 anim.event_source.stop() 

359 anim.running = False 

360 animate(1) 

361 fig.canvas.draw() 

362 elif event.key == 'left': 

363 if anim.running: 

364 anim.event_source.stop() 

365 anim.running = False 

366 animate(-1) 

367 fig.canvas.draw() 

368 

369 fig.canvas.mpl_connect('key_press_event', process_key) 1b

370 

371 # init_plot() 

372 # while True: 

373 # animate(0) 

374 if save: 1b

375 filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0]) 1b

376 if isinstance(save, (str, Path)): 1b

377 filename = Path(save).joinpath(filename) 1b

378 self.log.info(f'Saving to {filename}') 1b

379 # Set up formatting for the movie files 

380 Writer = animation.writers['ffmpeg'] 1b

381 writer = Writer(fps=24, metadata=dict(artist='Miles Wells'), bitrate=1800) 1b

382 anim.save(str(filename), writer=writer) 1b

383 else: 

384 plt.show() 

385 

386 

387class MotionAlignmentFullSession: 

388 def __init__(self, session_path, label, **kwargs): 

389 """ 

390 Class to extract camera times using video motion energy wheel alignment 

391 :param session_path: path of the session 

392 :param label: video label, only 'left' and 'right' videos are supported 

393 :param kwargs: threshold - the threshold to apply when identifying frames with artefacts (default 20) 

394 upload - whether to upload summary figure to alyx (default False) 

395 twin - the window length used when computing the shifts between the wheel and video 

396 nprocesses - the number of CPU processes to use 

397 sync - the type of sync scheme used (options 'nidq' or 'bpod') 

398 location - whether the code is being run on SDSC or not (options 'SDSC' or None) 

399 """ 

400 self.session_path = session_path 1hfga

401 self.label = label 1hfga

402 self.threshold = kwargs.get('threshold', 20) 1hfga

403 self.upload = kwargs.get('upload', False) 1hfga

404 self.twin = kwargs.get('twin', 150) 1hfga

405 self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4)) 1hfga

406 

407 self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None)) 1hfga

408 self.roi, self.mask = self.get_roi_mask() 1a

409 

410 if self.upload: 1a

411 self.one = ONE(mode='remote') 

412 self.one.alyx.authenticate() 

413 self.eid = self.one.path2eid(self.session_path) 

414 

415 def load_data(self, sync='nidq', location=None): 

416 """ 

417 Loads relevant data from disk to perform motion alignment 

418 :param sync: type of sync used, 'nidq' or 'bpod' 

419 :param location: where the code is being run, if location='SDSC', the dataset uuids are removed 

420 when loading the data 

421 :return: 

422 """ 

423 def fix_keys(alf_object): 1hfga

424 """ 

425 Given an alf object removes the dataset uuid from the keys 

426 :param alf_object: 

427 :return: 

428 """ 

429 ob = Bunch() 

430 for key in alf_object.keys(): 

431 vals = alf_object[key] 

432 ob[key.split('.')[0]] = vals 

433 return ob 

434 

435 alf_path = self.session_path.joinpath('alf') 1hfga

436 wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent 1hfga

437 wheel = (fix_keys(alfio.load_object(wheel_path, 'wheel')) if location == 'SDSC' 1fga

438 else alfio.load_object(wheel_path, 'wheel')) 

439 self.wheel_timestamps = wheel.timestamps 1fga

440 # Compute interpolated wheel position and wheel times 

441 wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) 1fga

442 # Compute wheel velocity 

443 self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) 1fga

444 # Load in original camera times 

445 self.camera_times = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.times*.npy'))) 1fga

446 self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob(f'_iblrig_{self.label}Camera.raw*.mp4'))) 1fga

447 self.camera_meta = vidio.get_video_meta(self.camera_path) 1fa

448 

449 # TODO should read in the description file to get the correct sync location 

450 if sync == 'nidq': 1fa

451 # If the sync is 'nidq' we read in the camera ttls from the spikeglx sync object 

452 sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection='raw_ephys_data') 1fa

453 sr = get_sync_fronts(sync, chmap[f'{self.label}_camera']) 1fa

454 self.ttls = sr.times[::2] 1fa

455 else: 

456 # Otherwise we assume the sync is 'bpod' and we read in the camera ttls from the raw bpod data 

457 cam_extractor = cam.CameraTimestampsBpod(session_path=self.session_path) 

458 cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data') 

459 self.ttls = cam_extractor._times_from_bpod() 

460 

461 # Check if the ttl and video sizes match up 

462 self.tdiff = self.ttls.size - self.camera_meta['length'] 1fa

463 

464 if self.tdiff < 0: 1fa

465 # In this case there are fewer ttls than camera frames. This is not ideal, for now we pad the ttls with 

466 # nans but if this is too many we reject the wheel alignment based on the qc 

467 self.ttl_times = self.ttls 

468 self.times = np.r_[self.ttl_times, np.full((np.abs(self.tdiff)), np.nan)] 

469 self.short_flag = True 

470 elif self.tdiff > 0: 1fa

471 # In this case there are more ttls than camera frames. This happens often, for now we remove the first 

472 # tdiff ttls from the ttls 

473 self.ttl_times = self.ttls[self.tdiff:] 1fa

474 self.times = self.ttls[self.tdiff:] 1fa

475 self.short_flag = False 1fa

476 

477 # Compute the frame rate of the camera 

478 self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) 1fa

479 

480 # We attempt to load in some behavior data (trials and dlc). This is only needed for the summary plots, having 

481 # trial aligned paw velocity (from the dlc) is a nice sanity check to make sure the alignment went well 

482 try: 1fa

483 self.trials = alfio.load_file_content(next(alf_path.rglob('_ibl_trials.table*.pqt'))) 1fa

484 self.dlc = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.dlc*.pqt'))) 

485 self.dlc = likelihood_threshold(self.dlc) 

486 self.behavior = True 

487 except (ALFObjectNotFound, StopIteration): 1fa

488 self.behavior = False 1fa

489 

490 # Load in a single frame that we will use for the summary plot 

491 self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0]) 1fa

492 

493 def get_roi_mask(self): 

494 """ 

495 Compute the region of interest mask for a given camera. This corresponds to a box in the video that we will 

496 use to compute the wheel motion energy 

497 :return: 

498 """ 

499 

500 if self.label == 'right': 1a

501 roi = ((450, 512), (120, 200)) 1a

502 else: 

503 roi = ((900, 1024), (850, 1010)) 

504 roi_mask = (*[slice(*r) for r in roi], 0) 1a

505 

506 return roi, roi_mask 1a

507 

508 def find_contaminated_frames(self, video_frames, thresold=20, normalise=True): 

509 """ 

510 Finds frames in the video that have artefacts such as the mouse's paw or a human hand. In order to determine 

511 frames with contamination an Otsu thresholding is applied to each frame to detect the artefact from the 

512 background image 

513 :param video_frames: np array of video frames (nframes, nwidth, nheight) 

514 :param thresold: threshold to differentiate artefact from background 

515 :param normalise: whether to normalise the threshold values for each frame to the baseline 

516 :return: mask of frames that are contaminated 

517 """ 

518 high = np.zeros((video_frames.shape[0])) 

519 # Iterate through each frame and compute and store the otsu threshold value for each frame 

520 for idx, frame in enumerate(video_frames): 

521 ret, _ = cv2.threshold(cv2.GaussianBlur(frame, (5, 5), 0), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 

522 high[idx] = ret 

523 

524 # If normalise is True, we divide the threshold values for each frame by the minimum value 

525 if normalise: 

526 high -= np.min(high) 

527 

528 # Identify the frames that have a threshold value greater than the specified threshold cutoff 

529 contaminated_frames = np.where(high > thresold)[0] 

530 

531 return contaminated_frames 

532 

533 def compute_motion_energy(self, first, last, wg, iw): 

534 """ 

535 Computes the video motion energy for frame indexes between first and last. This function is written to be run 

536 in a parallel fashion jusing joblib.parallel 

537 :param first: first frame index of frame interval to consider 

538 :param last: last frame index of frame interval to consider 

539 :param wg: WindowGenerator 

540 :param iw: iteration of the WindowGenerator 

541 :return: 

542 """ 

543 

544 if iw == wg.nwin - 1: 

545 return 

546 

547 # Open the video and read in the relvant video frames between first idx and last idx 

548 cap = cv2.VideoCapture(self.camera_path) 

549 frames = vidio.get_video_frames_preload(cap, np.arange(first, last), mask=self.mask) 

550 # Identify if any of the frames have artefacts in them 

551 idx = self.find_contaminated_frames(frames, self.threshold) 

552 

553 # If some of the frames are contaminated we find all the continuous intervals of contamination 

554 # and set the value for contaminated pixels for these frames to the average of the first frame before and after 

555 # this contamination interval 

556 if len(idx) != 0: 

557 

558 before_status = False 

559 after_status = False 

560 

561 counter = 0 

562 n_frames = 200 

563 # If it is the first frame that is contaminated, we need to read in a bit more of the video to find a 

564 # frame prior to contamination. We attempt this 20 times, after that we just take the value for the first 

565 # frame 

566 while np.any(idx == 0) and counter < 20 and iw != 0: 

567 n_before_offset = (counter + 1) * n_frames 

568 first -= n_frames 

569 extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(first - n_frames, first), 

570 mask=self.mask) 

571 frames = np.concatenate([extra_frames, frames], axis=0) 

572 

573 idx = self.find_contaminated_frames(frames, self.threshold) 

574 before_status = True 

575 counter += 1 

576 if counter > 0: 

577 print(f'In before: {counter}') 

578 

579 counter = 0 

580 # If it is the last frame that is contaminated, we need to read in a bit more of the video to find a 

581 # frame after the contamination. We attempt this 20 times, after that we just take the value for the last 

582 # frame 

583 while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1: 

584 n_after_offset = (counter + 1) * n_frames 

585 last += n_frames 

586 extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), mask=self.mask) 

587 frames = np.concatenate([frames, extra_frames], axis=0) 

588 idx = self.find_contaminated_frames(frames, self.threshold) 

589 after_status = True 

590 counter += 1 

591 

592 if counter > 0: 

593 print(f'In after: {counter}') 

594 

595 # We find all the continuous intervals that contain contamination and fix the affected pixels 

596 # by taking the average value of the frame prior and after contamination 

597 intervals = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1) 

598 for ints in intervals: 

599 if len(ints) > 0 and ints[0] == 0: 

600 ints = ints[1:] 

601 if len(ints) > 0 and ints[-1] == frames.shape[0] - 1: 

602 ints = ints[:-1] 

603 th_all = np.zeros_like(frames[0]) 

604 # We find all affected pixels 

605 for idx in ints: 

606 img = np.copy(frames[idx]) 

607 blur = cv2.GaussianBlur(img, (5, 5), 0) 

608 ret, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 

609 th = cv2.GaussianBlur(th, (5, 5), 10) 

610 th_all += th 

611 # Compute the average image of the frame prior and after the interval 

612 vals = np.mean(np.dstack([frames[ints[0] - 1], frames[ints[-1] + 1]]), axis=-1) 

613 # For each frame set the affected pixels to the value of the clean average image 

614 for idx in ints: 

615 img = frames[idx] 

616 img[th_all > 0] = vals[th_all > 0] 

617 

618 # If we have read in extra video frames we need to cut these off and make sure we only 

619 # consider the frames between the interval first and last given as args 

620 if before_status: 

621 frames = frames[n_before_offset:] 

622 if after_status: 

623 frames = frames[:(-1 * n_after_offset)] 

624 

625 # Once the frames have been cleaned we compute the motion energy between frames 

626 frame_me, _ = video.motion_energy(frames, diff=2, normalize=False) 

627 

628 cap.release() 

629 

630 return frame_me[2:] 

631 

632 def compute_shifts(self, times, me, first, last, iw, wg): 

633 """ 

634 Compute the cross-correlation between the video motion energy and the wheel velocity to find the mismatch 

635 between the camera ttls and the video frames. This function is written to run in a parallel manner using 

636 joblib.parallel 

637 

638 :param times: the times of the video frames across the whole session (ttls) 

639 :param me: the video motion energy computed across the whole session 

640 :param first: first time idx to consider 

641 :param last: last time idx to consider 

642 :param wg: WindowGenerator 

643 :param iw: iteration of the WindowGenerator 

644 :return: 

645 """ 

646 

647 # If we are in the last window we exit 

648 if iw == wg.nwin - 1: 1a

649 return np.nan, np.nan 1a

650 

651 # Find the time interval we are interested in 

652 t_first = times[first] 1a

653 t_last = times[last] 1a

654 

655 # If both times during this interval are nan exit 

656 if np.isnan(t_last) and np.isnan(t_first): 1a

657 return np.nan, np.nan 

658 # If only the last time is nan, we find the last non nan time value 

659 elif np.isnan(t_last): 1a

660 t_last = times[np.where(~np.isnan(times))[0][-1]] 

661 

662 # Find the mask of timepoints that fall in this interval 

663 mask = np.logical_and(times >= t_first, times <= t_last) 1a

664 # Restrict the video motion energy to this interval and normalise the values 

665 align_me = me[np.where(mask)[0]] 1a

666 align_me = (align_me - np.nanmin(align_me)) / (np.nanmax(align_me) - np.nanmin(align_me)) 1a

667 

668 # Find closest timepoints in wheel that match the time interval 

669 wh_mask = np.logical_and(self.wheel_time >= t_first, self.wheel_time <= t_last) 1a

670 if np.sum(wh_mask) == 0: 1a

671 return np.nan, np.nan 

672 # Find the mask for the wheel times 

673 xs = np.searchsorted(self.wheel_time[wh_mask], times[mask]) 1a

674 xs[xs == np.sum(wh_mask)] = np.sum(wh_mask) - 1 1a

675 # Convert to normalized speed 

676 vs = np.abs(self.wheel_vel[wh_mask][xs]) 1a

677 vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) 1a

678 

679 # Account for nan values in the video motion energy 

680 isnan = np.isnan(align_me) 1a

681 if np.sum(isnan) > 0: 1a

682 where_nan = np.where(isnan)[0] 1a

683 assert where_nan[0] == 0 1a

684 assert where_nan[-1] == np.sum(isnan) - 1 1a

685 

686 if np.all(isnan): 1a

687 return np.nan, np.nan 

688 

689 # Compute the cross correlation between the video motion energy and the wheel speed 

690 xcorr = signal.correlate(align_me[~isnan], vs[~isnan]) 1a

691 # The max value of the cross correlation indicates the shift that needs to be applied 

692 # The +2 comes from the fact that the video motion energy was computed from the difference between frames 

693 shift = np.nanargmax(xcorr) - align_me[~isnan].size + 2 1a

694 

695 return shift, t_first + (t_last - t_first) / 2 1a

696 

697 def clean_shifts(self, x, n=1): 

698 """ 

699 Removes artefacts from the computed shifts across time. We assume that the shifts should never increase 

700 over time and that the jump between consecutive shifts shouldn't be greater than 1 

701 :param x: computed shifts 

702 :param n: condition to apply 

703 :return: 

704 """ 

705 y = x.copy() 1a

706 dy = np.diff(y, prepend=y[0]) 1a

707 while True: 1a

708 pos = np.where(dy == 1)[0] if n == 1 else np.where(dy > 2)[0] 1a

709 # added frames: this doesn't make sense and this is noise 

710 if pos.size == 0: 1a

711 break 1a

712 neg = np.where(dy == -1)[0] if n == 1 else np.where(dy < -2)[0] 

713 

714 if len(pos) > len(neg): 

715 neg = np.append(neg, dy.size - 1) 

716 

717 iss = np.minimum(np.searchsorted(neg, pos), neg.size - 1) 

718 imin = np.argmin(np.minimum(np.abs(pos - neg[iss - 1]), np.abs(pos - neg[iss]))) 

719 

720 idx = np.max([0, iss[imin] - 1]) 

721 ineg = neg[idx:iss[imin] + 1] 

722 ineg = ineg[np.argmin(np.abs(pos[imin] - ineg))] 

723 dy[pos[imin]] = 0 

724 dy[ineg] = 0 

725 

726 return np.cumsum(dy) + y[0] 1a

727 

728 def qc_shifts(self, shifts, shifts_filt): 

729 """ 

730 Compute qc values for the wheel alignment. We consider 4 things 

731 1. The number of camera ttl values that are missing (when we have less ttls than video frames) 

732 2. The number of shifts that have nan values, this means the video motion energy computation 

733 3. The number of large jumps (>10) between the computed shifts 

734 4. The number of jumps (>1) between the shifts after they have been cleaned 

735 

736 :param shifts: np.array of shifts over session 

737 :param shifts_filt: np.array of shifts after being cleaned over session 

738 :return: 

739 """ 

740 

741 ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0 1a

742 nan_per = (np.sum(np.isnan(shifts_filt)) / shifts_filt.size) * 100 1a

743 shifts_sum = np.where(np.abs(np.diff(shifts)) > 10)[0].size 1a

744 shifts_filt_sum = np.where(np.abs(np.diff(shifts_filt)) > 1)[0].size 1a

745 

746 qc = dict() 1a

747 qc['ttl_per'] = ttl_per 1a

748 qc['nan_per'] = nan_per 1a

749 qc['shifts_sum'] = shifts_sum 1a

750 qc['shifts_filt_sum'] = shifts_filt_sum 1a

751 

752 qc_outcome = True 1a

753 # If more than 10% of ttls are missing we don't get new times 

754 if ttl_per > 10: 1a

755 qc_outcome = False 

756 # If too many of the shifts are nans it means the alignment is not accurate 

757 if nan_per > 40: 1a

758 qc_outcome = False 

759 # If there are too many artefacts could be errors 

760 if shifts_sum > 60: 1a

761 qc_outcome = False 

762 # If there are jumps > 1 in the filtered shifts then there is a problem 

763 if shifts_filt_sum > 0: 1a

764 qc_outcome = False 

765 

766 return qc, qc_outcome 1a

767 

768 def extract_times(self, shifts_filt, t_shifts): 

769 """ 

770 Extracts new camera times after applying the computed shifts across the session 

771 

772 :param shifts_filt: filtered shifts computed across session 

773 :param t_shifts: time point of computed shifts 

774 :return: 

775 """ 

776 

777 # Compute the interpolation function to apply to the ttl times 

778 t_new = t_shifts - (shifts_filt * 1 / self.frate) 1a

779 fcn = interpolate.interp1d(t_shifts, t_new, fill_value="extrapolate") 1a

780 # Apply the function and get out new times 

781 new_times = fcn(self.ttl_times) 1a

782 

783 # If we are missing ttls then interpolate and append the correct number at the end 

784 if self.tdiff < 0: 1a

785 to_app = (np.arange(np.abs(self.tdiff), ) + 1) / self.frate + new_times[-1] 

786 new_times = np.r_[new_times, to_app] 

787 

788 return new_times 1a

789 

790 @staticmethod 

791 def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True, 

792 norm=False, axs=None): 

793 """ 

794 Compute and plot trial aligned spike rasters and psth 

795 :param spike_times: times of variable 

796 :param events: trial times to align to 

797 :param trial_idx: trial idx to sort by 

798 :param dividers: 

799 :param colors: 

800 :param labels: 

801 :param weights: 

802 :param fr: 

803 :param norm: 

804 :param axs: 

805 :return: 

806 """ 

807 pre_time = 0.4 

808 post_time = 1 

809 raster_bin = 0.01 

810 psth_bin = 0.05 

811 raster, t_raster = bin_spikes( 

812 spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=raster_bin, weights=weights) 

813 psth, t_psth = bin_spikes( 

814 spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=psth_bin, weights=weights) 

815 

816 if fr: 

817 psth = psth / psth_bin 

818 

819 if norm: 

820 psth = psth - np.repeat(psth[:, 0][:, np.newaxis], psth.shape[1], axis=1) 

821 raster = raster - np.repeat(raster[:, 0][:, np.newaxis], raster.shape[1], axis=1) 

822 

823 dividers = [0] + dividers + [len(trial_idx)] 

824 if axs is None: 

825 fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, sharex=True) 

826 else: 

827 fig = axs[0].get_figure() 

828 

829 label, lidx = np.unique(labels, return_index=True) 

830 label_pos = [] 

831 for lab, lid in zip(label, lidx): 

832 idx = np.where(np.array(labels) == lab)[0] 

833 for iD in range(len(idx)): 

834 if iD == 0: 

835 t_ids = trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1] 

836 t_ints = dividers[idx[iD] + 1] - dividers[idx[iD]] 

837 else: 

838 t_ids = np.r_[t_ids, trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]] 

839 t_ints = np.r_[t_ints, dividers[idx[iD] + 1] - dividers[idx[iD]]] 

840 

841 psth_div = np.nanmean(psth[t_ids], axis=0) 

842 std_div = np.nanstd(psth[t_ids], axis=0) / np.sqrt(len(t_ids)) 

843 

844 axs[0].fill_between(t_psth, psth_div - std_div, psth_div + std_div, alpha=0.4, color=colors[lid]) 

845 axs[0].plot(t_psth, psth_div, alpha=1, color=colors[lid]) 

846 

847 lab_max = idx[np.argmax(t_ints)] 

848 label_pos.append((dividers[lab_max + 1] - dividers[lab_max]) / 2 + dividers[lab_max]) 

849 

850 axs[1].imshow(raster[trial_idx], cmap='binary', origin='lower', 

851 extent=[np.min(t_raster), np.max(t_raster), 0, len(trial_idx)], aspect='auto') 

852 

853 width = raster_bin * 4 

854 for iD in range(len(dividers) - 1): 

855 axs[1].fill_between([post_time + raster_bin / 2, post_time + raster_bin / 2 + width], 

856 [dividers[iD + 1], dividers[iD + 1]], [dividers[iD], dividers[iD]], color=colors[iD]) 

857 

858 axs[1].set_xlim([-1 * pre_time, post_time + raster_bin / 2 + width]) 

859 secax = axs[1].secondary_yaxis('right') 

860 

861 secax.set_yticks(label_pos) 

862 secax.set_yticklabels(label, rotation=90, rotation_mode='anchor', ha='center') 

863 for ic, c in enumerate(np.array(colors)[lidx]): 

864 secax.get_yticklabels()[ic].set_color(c) 

865 

866 axs[0].axvline(0, *axs[0].get_ylim(), c='k', ls='--', zorder=10) # TODO this doesn't always work 

867 axs[1].axvline(0, *axs[1].get_ylim(), c='k', ls='--', zorder=10) 

868 

869 return fig, axs 

870 

871 def plot_with_behavior(self): 

872 """ 

873 Makes a summary figure of the alignment when behaviour data is available 

874 :return: 

875 """ 

876 

877 self.dlc = likelihood_threshold(self.dlc) 

878 trial_idx, dividers = find_trial_ids(self.trials, sort='side') 

879 feature_ext = get_speed(self.dlc, self.camera_times, self.label, feature='paw_r') 

880 feature_new = get_speed(self.dlc, self.new_times, self.label, feature='paw_r') 

881 

882 fig = plt.figure() 

883 fig.set_size_inches(15, 9) 

884 gs = gridspec.GridSpec(1, 5, figure=fig, width_ratios=[4, 1, 1, 1, 3], wspace=0.3, hspace=0.5) 

885 gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) 

886 ax01 = fig.add_subplot(gs0[0, 0]) 

887 ax02 = fig.add_subplot(gs0[1, 0]) 

888 ax03 = fig.add_subplot(gs0[2, 0]) 

889 gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1], height_ratios=[1, 3]) 

890 ax11 = fig.add_subplot(gs1[0, 0]) 

891 ax12 = fig.add_subplot(gs1[1, 0]) 

892 gs2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 2], height_ratios=[1, 3]) 

893 ax21 = fig.add_subplot(gs2[0, 0]) 

894 ax22 = fig.add_subplot(gs2[1, 0]) 

895 gs3 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 3], height_ratios=[1, 3]) 

896 ax31 = fig.add_subplot(gs3[0, 0]) 

897 ax32 = fig.add_subplot(gs3[1, 0]) 

898 gs4 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 4]) 

899 ax41 = fig.add_subplot(gs4[0, 0]) 

900 ax42 = fig.add_subplot(gs4[1, 0]) 

901 

902 ax01.plot(self.t_shifts, self.shifts, label='shifts') 

903 ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') 

904 ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) 

905 ax01.legend() 

906 ax01.set_ylabel('Frames') 

907 ax01.set_xlabel('Time in session') 

908 

909 xs = np.searchsorted(self.ttl_times, self.t_shifts) 

910 ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] 

911 ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') 

912 ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) 

913 ax02.legend() 

914 ax02.set_ylabel('Frames') 

915 ax02.set_xlabel('Time in session') 

916 

917 ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new') 

918 ax03.legend() 

919 ax03.set_ylim(-5, 5) 

920 ax03.set_ylabel('Frames') 

921 ax03.set_xlabel('Time in session') 

922 

923 self.single_cluster_raster(self.wheel_timestamps, self.trials['firstMovement_times'].values, trial_idx, dividers, 

924 ['g', 'y'], ['left', 'right'], weights=self.wheel_vel, fr=False, axs=[ax11, ax12]) 

925 ax11.sharex(ax12) 

926 ax11.set_ylabel('Wheel velocity') 

927 ax11.set_title('Wheel') 

928 ax12.set_xlabel('Time from first move') 

929 

930 self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], 

931 ['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22]) 

932 ax21.sharex(ax22) 

933 ax21.set_ylabel('Paw r velocity') 

934 ax21.set_title('Extracted times') 

935 ax22.set_xlabel('Time from first move') 

936 

937 self.single_cluster_raster(self.new_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], 

938 ['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32]) 

939 ax31.sharex(ax32) 

940 ax31.set_ylabel('Paw r velocity') 

941 ax31.set_title('New times') 

942 ax32.set_xlabel('Time from first move') 

943 

944 ax41.imshow(self.frame_example[0]) 

945 rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], 

946 self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none') 

947 ax41.add_patch(rect) 

948 

949 ax42.plot(self.all_me) 

950 

951 return fig 

952 

953 def plot_without_behavior(self): 

954 """ 

955 Makes a summary figure of the alignment when behaviour data is not available 

956 :return: 

957 """ 

958 

959 fig = plt.figure() 

960 fig.set_size_inches(7, 7) 

961 gs = gridspec.GridSpec(1, 2, figure=fig) 

962 gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) 

963 ax01 = fig.add_subplot(gs0[0, 0]) 

964 ax02 = fig.add_subplot(gs0[1, 0]) 

965 ax03 = fig.add_subplot(gs0[2, 0]) 

966 

967 gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1]) 

968 ax04 = fig.add_subplot(gs1[0, 0]) 

969 ax05 = fig.add_subplot(gs1[1, 0]) 

970 

971 ax01.plot(self.t_shifts, self.shifts, label='shifts') 

972 ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') 

973 ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) 

974 ax01.legend() 

975 ax01.set_ylabel('Frames') 

976 ax01.set_xlabel('Time in session') 

977 

978 xs = np.searchsorted(self.ttl_times, self.t_shifts) 

979 ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] 

980 ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') 

981 ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) 

982 ax02.legend() 

983 ax02.set_ylabel('Frames') 

984 ax02.set_xlabel('Time in session') 

985 

986 ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new') 

987 ax03.legend() 

988 ax03.set_ylim(-5, 5) 

989 ax03.set_ylabel('Frames') 

990 ax03.set_xlabel('Time in session') 

991 

992 ax04.imshow(self.frame_example[0]) 

993 rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], 

994 self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none') 

995 ax04.add_patch(rect) 

996 

997 ax05.plot(self.all_me) 

998 

999 return fig 

1000 

1001 def process(self): 

1002 """ 

1003 Main function used to apply the video motion wheel alignment to the camera times. This function does the 

1004 following 

1005 1. Computes the video motion energy across the whole session (computed in windows and parallelised) 

1006 2. Computes the shift that should be applied to the camera times across the whole session by computing 

1007 the cross correlation between the video motion energy and the wheel speed (computed in 

1008 overlapping windows and parallelised) 

1009 3. Removes artefacts from the computed shifts 

1010 4. Computes the qc for the wheel alignment 

1011 5. Extracts the new camera times using the shifts computed from the video wheel alignment 

1012 6. If upload is True, creates a summary plot of the alignment and uploads the figure to the relevant session 

1013 on alyx 

1014 :return: 

1015 """ 

1016 

1017 # Compute the motion energy of the wheel for the whole video 

1018 wg = WindowGenerator(self.camera_meta['length'], 5000, 4) 1a

1019 out = Parallel(n_jobs=self.nprocess)( 1a

1020 delayed(self.compute_motion_energy)(first, last, wg, iw) for iw, (first, last) in enumerate(wg.firstlast)) 

1021 # Concatenate the motion energy into one big array 

1022 self.all_me = np.array([]) 1a

1023 for vals in out[:-1]: 1a

1024 self.all_me = np.r_[self.all_me, vals] 1a

1025 

1026 toverlap = self.twin - 1 1a

1027 all_me = np.r_[np.full((int(self.camera_meta['fps'] * toverlap)), np.nan), self.all_me] 1a

1028 to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / self.frate)[::-1] 1a

1029 times = np.r_[to_app, self.times] 1a

1030 

1031 wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap)) 1a

1032 

1033 out = Parallel(n_jobs=1)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) 1a

1034 for iw, (first, last) in enumerate(wg.firstlast)) 

1035 

1036 self.shifts = np.array([]) 1a

1037 self.t_shifts = np.array([]) 1a

1038 for vals in out[:-1]: 1a

1039 self.shifts = np.r_[self.shifts, vals[0]] 1a

1040 self.t_shifts = np.r_[self.t_shifts, vals[1]] 1a

1041 

1042 idx = np.bitwise_and(self.t_shifts >= self.ttl_times[0], self.t_shifts < self.ttl_times[-1]) 1a

1043 self.shifts = self.shifts[idx] 1a

1044 self.t_shifts = self.t_shifts[idx] 1a

1045 shifts_filt = ndimage.percentile_filter(self.shifts, 80, 120) 1a

1046 shifts_filt = self.clean_shifts(shifts_filt, n=1) 1a

1047 self.shifts_filt = self.clean_shifts(shifts_filt, n=2) 1a

1048 

1049 self.qc, self.qc_outcome = self.qc_shifts(self.shifts, self.shifts_filt) 1a

1050 

1051 self.new_times = self.extract_times(self.shifts_filt, self.t_shifts) 1a

1052 

1053 if self.upload: 1a

1054 fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior() 

1055 save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', f'video_wheel_alignment_{self.label}.png')) 

1056 save_fig_path.parent.mkdir(exist_ok=True, parents=True) 

1057 fig.savefig(save_fig_path) 

1058 snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one) 

1059 snp.outputs = [save_fig_path] 

1060 snp.register_images(widths=['orig']) 

1061 plt.close(fig) 

1062 

1063 return self.new_times 1a