Coverage for ibllib/io/extractors/camera.py: 96%

346 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-02 18:55 +0100

1""" Camera extractor functions. 

2 

3This module handles extraction of camera timestamps for both Bpod and DAQ. 

4""" 

5import logging 

6 

7import cv2 

8import numpy as np 

9import matplotlib.pyplot as plt 

10from iblutil.util import range_str 

11 

12import ibldsp.utils as dsp 

13from ibllib.plots import squares, vertical_lines 

14from ibllib.io.video import assert_valid_label, VideoStreamer 

15from iblutil.numerical import within_ranges 

16from ibllib.io.extractors.ephys_fpga import get_sync_fronts 

17import ibllib.io.raw_data_loaders as raw 

18import ibllib.io.extractors.video_motion as vmotion 

19from ibllib.io.extractors.base import ( 

20 BaseBpodTrialsExtractor, 

21 BaseExtractor, 

22) 

23 

24_logger = logging.getLogger(__name__) 

25 

26 

27def extract_camera_sync(sync, chmap=None): 

28 """ 

29 Extract camera timestamps from the sync matrix. 

30 

31 :param sync: dictionary 'times', 'polarities' of fronts detected on sync trace 

32 :param chmap: dictionary containing channel indices. Default to constant. 

33 :return: dictionary containing camera timestamps 

34 """ 

35 assert chmap 1oqdunc

36 times = {} 1oqdunc

37 for k in filter(lambda x: x.endswith('_camera'), chmap): 1oqdunc

38 label, _ = k.rsplit('_', 1) 1oqdunc

39 times[label] = get_sync_fronts(sync, chmap[k]).times[::2] 1oqdunc

40 return times 1oqdunc

41 

42 

43def get_video_length(video_path): 

44 """ 

45 Returns video length. 

46 

47 :param video_path: A path to the video 

48 :return: 

49 """ 

50 is_url = isinstance(video_path, str) and video_path.startswith('http') 1jlodkbncviame

51 cap = VideoStreamer(video_path).cap if is_url else cv2.VideoCapture(str(video_path)) 1jlodkbncviame

52 assert cap.isOpened(), f'Failed to open video file {video_path}' 1jlodkbncviame

53 length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 1jlodkbncviame

54 cap.release() 1jlodkbncviame

55 return length 1jlodkbncviame

56 

57 

58class CameraTimestampsFPGA(BaseExtractor): 

59 """Extractor for videos using DAQ sync and channel map.""" 

60 

61 def __init__(self, label, session_path=None): 

62 super().__init__(session_path) 1odnc

63 self.label = assert_valid_label(label) 1odnc

64 self.save_names = f'_ibl_{label}Camera.times.npy' 1odnc

65 self.var_names = f'{label}_camera_timestamps' 1odnc

66 self._log_level = _logger.level 1odnc

67 _logger.setLevel(logging.DEBUG) 1odnc

68 

69 def __del__(self): 

70 _logger.setLevel(self._log_level) 1pod

71 

72 def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', 

73 display=False, extrapolate_missing=True, **kwargs): 

74 """ 

75 The raw timestamps are taken from the DAQ. These are the times of the camera's frame TTLs. 

76 If the pin state file exists, these timestamps are aligned to the video frames using 

77 task TTLs (typically the audio TTLs). Frames missing from the embedded frame count are 

78 removed from the timestamps array. 

79 If the pin state file does not exist, the left and right camera timestamps may be aligned 

80 using the wheel data. 

81 

82 Parameters 

83 ---------- 

84 sync : dict 

85 Dictionary 'times', 'polarities' of fronts detected on sync trace. 

86 chmap : dict 

87 Dictionary containing channel indices. Default to constant. 

88 video_path : str, pathlib.Path, int 

89 An optional path for fetching the number of frames. If None, the video is loaded from 

90 the session path. If an int is provided this is taken to be the total number of frames. 

91 sync_label : str 

92 The sync label of the channel that's wired to the GPIO for synchronising the times. 

93 display : bool 

94 If true, the TTL and GPIO fronts are plotted. 

95 extrapolate_missing : bool 

96 If true, any missing timestamps at the beginning and end of the session are 

97 extrapolated based on the median frame rate, otherwise they will be NaNs. 

98 **kwargs 

99 Extra keyword arguments (unused). 

100 

101 Returns 

102 ------- 

103 numpy.array 

104 The extracted camera timestamps. 

105 """ 

106 fpga_times = extract_camera_sync(sync=sync, chmap=chmap) 1odnc

107 count, (*_, gpio) = raw.load_embedded_frame_data(self.session_path, self.label) 1odnc

108 raw_ts = fpga_times[self.label] 1odnc

109 

110 if video_path is None: 1odnc

111 filename = f'_iblrig_{self.label}Camera.raw.mp4' 1onc

112 video_path = self.session_path.joinpath('raw_video_data', filename) 1onc

113 # Permit the video path to be the length for development and debugging purposes 

114 length = (video_path if isinstance(video_path, int) else get_video_length(video_path)) 1odnc

115 _logger.debug(f'Number of video frames = {length}') 1odnc

116 

117 if gpio is not None and gpio['indices'].size > 1 and sync_label is not None: 1odnc

118 _logger.info(f'Aligning to {sync_label} TTLs') 1dnc

119 # Extract sync TTLs 

120 ttl = get_sync_fronts(sync, chmap[sync_label]) 1dnc

121 _, ts = raw.load_camera_ssv_times(self.session_path, self.label) 1dnc

122 try: 1dnc

123 """ 1dnc

124 NB: Some of the sync TTLs occur very close together, and are therefore not 

125 reflected in the pin state. This function removes those. Also converts frame 

126 times to DAQ time. 

127 """ 

128 gpio, ttl, ts = groom_pin_state(gpio, ttl, ts, display=display) 1dnc

129 """ 1dc

130 The length of the count and pin state are regularly longer than the length of 

131 the video file. Here we assert that the video is either shorter or the same 

132 length as the arrays, and we make an assumption that the missing frames are 

133 right at the end of the video. We therefore simply shorten the arrays to match 

134 the length of the video. 

135 """ 

136 if count.size > length: 1dc

137 count = count[:length] 1dc

138 else: 

139 assert length == count.size, 'fewer counts than frames' 

140 assert raw_ts.shape[0] > 0, 'no timestamps found in channel indicated for ' \ 1dc

141 f'{self.label} camera' 

142 return align_with_gpio(raw_ts, ttl, gpio, count, 1dc

143 display=display, 

144 extrapolate_missing=extrapolate_missing) 

145 except AssertionError as ex: 1n

146 _logger.critical('Failed to extract using %s: %s', sync_label, ex) 1n

147 

148 # If you reach here extracting using sync TTLs was not possible, we attempt to align using wheel motion energy 

149 _logger.warning('Attempting to align using wheel') 1on

150 

151 try: 1on

152 if self.label not in ['left', 'right']: 1on

153 # Can only use wheel alignment for left and right cameras 

154 raise ValueError(f'Wheel alignment not supported for {self.label} camera') 1o

155 

156 motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, sync='nidq', upload=True) 1on

157 new_times = motion_class.process() 

158 if not motion_class.qc_outcome: 

159 raise ValueError(f'Wheel alignment for {self.label} camera failed to pass qc: {motion_class.qc}') 

160 else: 

161 _logger.warning(f'Wheel alignment for {self.label} camera successful, qc: {motion_class.qc}') 

162 return new_times 

163 

164 except Exception as err: 1on

165 _logger.critical(f'Failed to align with wheel for {self.label} camera: {err}') 1on

166 

167 if length < raw_ts.size: 1on

168 df = raw_ts.size - length 1o

169 _logger.info(f'Discarding first {df} pulses') 1o

170 raw_ts = raw_ts[df:] 1o

171 

172 return raw_ts 1on

173 

174 

175class CameraTimestampsCamlog(BaseExtractor): 

176 def __init__(self, label, session_path=None): 

177 super().__init__(session_path) 1q

178 self.label = assert_valid_label(label) 1q

179 self.save_names = f'_ibl_{label}Camera.times.npy' 1q

180 self.var_names = f'{label}_camera_timestamps' 1q

181 self._log_level = _logger.level 1q

182 _logger.setLevel(logging.DEBUG) 1q

183 

184 def __del__(self): 

185 _logger.setLevel(self._log_level) 1q

186 

187 def _extract(self, sync=None, chmap=None, video_path=None, 

188 display=False, extrapolate_missing=True, **kwargs): 

189 

190 fpga_times = extract_camera_sync(sync=sync, chmap=chmap) 1q

191 video_frames = get_video_length(self.session_path.joinpath('raw_video_data', f'_iblrig_{self.label}Camera.raw.mp4')) 1q

192 raw_ts = fpga_times[self.label] 1q

193 

194 # For left camera sometimes we have one extra pulse than video frame 

195 if (raw_ts.size - video_frames) == 1: 1q

196 _logger.warning(f'One extra sync pulse detected for {self.label} camera') 

197 raw_ts = raw_ts[:-1] 

198 elif (raw_ts.size - video_frames) == -1: 1q

199 _logger.warning(f'One extra video frame detected for {self.label} camera') 

200 med_time = np.median(np.diff(raw_ts)) 

201 raw_ts = np.r_[raw_ts, np.array([raw_ts[-1] + med_time])] 

202 

203 assert video_frames == raw_ts.size, f'dimension mismatch between video frames and TTL pulses for {self.label} camera' \ 1q

204 f' by {np.abs(video_frames - raw_ts.size)} frames' 

205 

206 return raw_ts 1q

207 

208 

209class CameraTimestampsBpod(BaseBpodTrialsExtractor): 

210 """ 

211 Get the camera timestamps from the Bpod 

212 

213 The camera events are logged only during the events not in between, so the times need 

214 to be interpolated 

215 """ 

216 save_names = '_ibl_leftCamera.times.npy' 

217 var_names = 'left_camera_timestamps' 

218 

219 def __init__(self, *args, **kwargs): 

220 super().__init__(*args, **kwargs) 1pjlkbia

221 self._log_level = _logger.level 1pjlkbia

222 _logger.setLevel(logging.DEBUG) 1pjlkbia

223 

224 def __del__(self): 

225 _logger.setLevel(self._log_level) 1pjkb

226 

227 def _extract(self, video_path=None, display=False, extrapolate_missing=True, **kwargs): 

228 """ 

229 The raw timestamps are taken from the Bpod. These are the times of the camera's frame TTLs. 

230 If the pin state file exists, these timestamps are aligned to the video frames using the 

231 sync TTLs. Frames missing from the embedded frame count are removed from the timestamps 

232 array. 

233 If the pin state file does not exist, the left camera timestamps may be aligned using the 

234 wheel data. 

235 :param video_path: an optional path for fetching the number of frames. If None, 

236 the video is loaded from the session path. If an int is provided this is taken to be 

237 the total number of frames. 

238 :param display: if True, the TTL and GPIO fronts are plotted. 

239 :param extrapolate_missing: if True, any missing timestamps at the beginning and end of 

240 the session are extrapolated based on the median frame rate, otherwise they will be NaNs. 

241 :return: a numpy array of camera timestamps 

242 """ 

243 raw_ts = self._times_from_bpod() 1jlkbiame

244 count, (*_, gpio) = raw.load_embedded_frame_data(self.session_path, 'left') 1jlkbiame

245 if video_path is None: 1jlkbiame

246 filename = '_iblrig_leftCamera.raw.mp4' 1iame

247 video_path = self.session_path.joinpath('raw_video_data', filename) 1iame

248 # Permit the video path to be the length for development and debugging purposes 

249 length = video_path if isinstance(video_path, int) else get_video_length(video_path) 1jlkbiame

250 _logger.debug(f'Number of video frames = {length}') 1jlkbiame

251 

252 # Check if the GPIO is usable for extraction. GPIO is None if the file does not exist, 

253 # is empty, or contains only one value (i.e. doesn't change) 

254 if gpio is not None and gpio['indices'].size > 1: 1jlkbiame

255 _logger.info('Aligning to sync TTLs') 1bae

256 # Extract audio TTLs 

257 _, audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials, 1bae

258 task_collection=self.task_collection) 

259 _, ts = raw.load_camera_ssv_times(self.session_path, 'left') 1bae

260 """ 1bae

261 There are many sync TTLs that are for some reason missed by the GPIO. Conversely 

262 the last GPIO doesn't often correspond to any audio TTL. These will be removed. 

263 The drift appears to be less severe than the DAQ, so when assigning TTLs we'll take 

264 the nearest TTL within 500ms. The go cue TTLs comprise two short pulses ~3ms apart. 

265 We will fuse any TTLs less than 5ms apart to make assignment more accurate. 

266 """ 

267 try: 1bae

268 gpio, audio, ts = groom_pin_state(gpio, audio, ts, take='nearest', 1bae

269 tolerance=.5, min_diff=5e-3, display=display) 

270 if count.size > length: 1bae

271 count = count[:length] 1a

272 else: 

273 assert length == count.size, 'fewer counts than frames' 1bae

274 

275 return align_with_gpio(raw_ts, audio, gpio, count, 1bae

276 extrapolate_missing, display=display) 

277 except AssertionError as ex: 1a

278 _logger.critical('Failed to extract using audio: %s', ex) 1a

279 

280 # If you reach here extracting using audio TTLs was not possible 

281 _logger.warning('Alignment by wheel data not yet implemented') 1jlkiam

282 # Extrapolate at median frame rate 

283 n_missing = length - raw_ts.size 1jlkiam

284 if n_missing > 0: 1jlkiam

285 _logger.warning(f'{n_missing} fewer Bpod timestamps than frames; ' 1kiam

286 f'{"extrapolating" if extrapolate_missing else "appending nans"}') 

287 frate = np.median(np.diff(raw_ts)) 1kiam

288 to_app = ((np.arange(n_missing, ) + 1) / frate + raw_ts[-1] 1kiam

289 if extrapolate_missing 

290 else np.full(n_missing, np.nan)) 

291 raw_ts = np.r_[raw_ts, to_app] # Append the missing times 1kiam

292 elif n_missing < 0: 1jli

293 _logger.warning(f'{abs(n_missing)} fewer frames than Bpod timestamps') 1jli

294 _logger.info(f'Discarding first {abs(n_missing)} pulses') 1jli

295 raw_ts = raw_ts[abs(n_missing):] 1jli

296 

297 return raw_ts 1jlkiam

298 

299 def _times_from_bpod(self): 

300 ntrials = len(self.bpod_trials) 1jlkbiame

301 

302 cam_times = [] 1jlkbiame

303 n_frames = 0 1jlkbiame

304 n_out_of_sync = 0 1jlkbiame

305 missed_trials = [] 1jlkbiame

306 for ind in range(ntrials): 1jlkbiame

307 # get upgoing and downgoing fronts 

308 events = self.bpod_trials[ind]['behavior_data']['Events timestamps'] 1jlkbiame

309 pin = np.array(events.get('Port1In') or [np.nan]) 1jlkbiame

310 pout = np.array(events.get('Port1Out') or [np.nan]) 1jlkbiame

311 # some trials at startup may not have the camera working, discard 

312 if np.isnan(pin).all(): 1jlkbiame

313 missed_trials.append(ind) 1a

314 continue 1a

315 # if the trial starts in the middle of a square, discard the first downgoing front 

316 if pout[0] < pin[0]: 1jlkbiame

317 pout = pout[1:] 1jlkbia

318 # same if the last sample is during an upgoing front, 

319 # always put size as it happens last 

320 pin = pin[:pout.size] 1jlkbiame

321 frate = np.median(np.diff(pin)) 1jlkbiame

322 if ind > 0: 1jlkbiame

323 """ 1jlkbiame

324 assert that the pulses have the same length and that we don't miss frames during 

325 the trial, the refresh rate of bpod is 100us 

326 """ 

327 test1 = np.all(np.abs(1 - (pin - pout) / np.median(pin - pout)) < 0.1) 1jlkbiame

328 test2 = np.all(np.abs(np.diff(pin) - frate) <= 0.00011) 1jlkbiame

329 if not all([test1, test2]): 1jlkbiame

330 n_out_of_sync += 1 1bia

331 # grow a list of cam times for ech trial 

332 cam_times.append(pin) 1jlkbiame

333 n_frames += pin.size 1jlkbiame

334 

335 if missed_trials: 1jlkbiame

336 _logger.debug('trial(s) %s missing TTL events', range_str(missed_trials)) 1a

337 if n_out_of_sync > 0: 1jlkbiame

338 _logger.warning(f"{n_out_of_sync} trials with bpod camera frame times not within" 1bia

339 f" 10% of the expected sampling rate") 

340 

341 t_first_frame = np.array([c[0] for c in cam_times]) 1jlkbiame

342 t_last_frame = np.array([c[-1] for c in cam_times]) 1jlkbiame

343 frate = 1 / np.nanmedian(np.array([np.median(np.diff(c)) for c in cam_times])) 1jlkbiame

344 intertrial_duration = t_first_frame[1:] - t_last_frame[:-1] 1jlkbiame

345 intertrial_missed_frames = np.int32(np.round(intertrial_duration * frate)) - 1 1jlkbiame

346 

347 # initialize the full times array 

348 frame_times = np.zeros(n_frames + int(np.sum(intertrial_missed_frames))) 1jlkbiame

349 ii = 0 1jlkbiame

350 for trial, cam_time in enumerate(cam_times): 1jlkbiame

351 if cam_time is not None: 1jlkbiame

352 # populate first the recovered times within the trials 

353 frame_times[ii: ii + cam_time.size] = cam_time 1jlkbiame

354 ii += cam_time.size 1jlkbiame

355 if trial == (len(cam_times) - 1): 1jlkbiame

356 break 1jlkbiame

357 # then extrapolate in-between 

358 nmiss = intertrial_missed_frames[trial] 1jlkbiame

359 frame_times[ii: ii + nmiss] = (cam_time[-1] + intertrial_duration[trial] / 1jlkbiame

360 (nmiss + 1) * (np.arange(nmiss) + 1)) 

361 ii += nmiss 1jlkbiame

362 assert all(np.diff(frame_times) > 0) # negative diffs implies a big problem 1jlkbiame

363 return frame_times 1jlkbiame

364 

365 

366def align_with_gpio(timestamps, ttl, pin_state, count, extrapolate_missing=True, display=False): 

367 """ 

368 Groom the raw DAQ or Bpod camera timestamps using the frame embedded GPIO and frame counter. 

369 

370 Parameters 

371 ---------- 

372 timestamps : numpy.array 

373 An array of raw DAQ or Bpod camera timestamps. 

374 ttl : dict 

375 A dictionary of DAQ sync TTLs, with keys {'times', 'polarities'}. 

376 pin_state : dict 

377 A dictionary containing GPIO pin state values, with keys {'indices', 'polarities'}. 

378 count : numpy.array 

379 An array of frame numbers. 

380 extrapolate_missing : bool 

381 If true and the number of timestamps is fewer than the number of frame counts, the 

382 remaining timestamps are extrapolated based on the frame rate, otherwise they are NaNs. 

383 display : bool 

384 Plot the resulting timestamps. 

385 

386 Returns 

387 ------- 

388 numpy.array 

389 The corrected frame timestamps. 

390 """ 

391 # Some assertions made on the raw data 

392 # assert count.size == pin_state.size, 'frame count and pin state size mismatch' 

393 assert all(np.diff(count) > 0), 'frame count not strictly increasing' 1dbcae

394 assert all(np.diff(timestamps) > 0), 'DAQ/Bpod camera times not strictly increasing' 1dbcae

395 same_n_ttl = pin_state['times'].size == ttl['times'].size 1dbcae

396 assert same_n_ttl, 'more ttl TTLs detected on camera than TTLs sent' 1dbcae

397 

398 """Here we will ensure that the DAQ camera times match the number of video frames in 1dbcae

399 length. We will make the following assumptions: 

400 

401 1. The number of DAQ camera times is equal to or greater than the number of video frames. 

402 2. No TTLs were missed between the camera and DAQ. 

403 3. No pin states were missed by Bonsai. 

404 4 No pixel count data was missed by Bonsai. 

405 

406 In other words the count and pin state arrays accurately reflect the number of frames 

407 sent by the camera and should therefore be the same length, and the length of the frame 

408 counter should match the number of saved video frames. 

409 

410 The missing frame timestamps are removed in three stages: 

411 

412 1. Remove any timestamps that occurred before video frame acquisition in Bonsai. 

413 2. Remove any timestamps where the frame counter reported missing frames, i.e. remove the 

414 dropped frames which occurred throughout the session. 

415 3. Remove the trailing timestamps at the end of the session if the camera was turned off 

416 in the wrong order. 

417 """ 

418 # Align on first pin state change 

419 first_uptick = pin_state['indices'][0] 1dbcae

420 first_ttl = np.searchsorted(timestamps, ttl['times'][0]) 1dbcae

421 """Here we find up to which index in the DAQ times we discard by taking the difference 1dbcae

422 between the index of the first pin state change (when the sync TTL was reported by the 

423 camera) and the index of the first sync TTL in DAQ time. We subtract the difference 

424 between the frame count at the first pin state change and the index to account for any 

425 video frames that were not saved during this period (we will remove those from the 

426 camera DAQ times later). 

427 """ 

428 # Minus any frames that were dropped between the start of frame acquisition and the first TTL 

429 start = first_ttl - first_uptick - (count[first_uptick] - first_uptick) 1dbcae

430 # Get approximate frame rate for extrapolating timestamps (if required) 

431 frate = round(1 / np.nanmedian(np.diff(timestamps))) 1dbcae

432 

433 if start < 0: 1dbcae

434 n_missing = abs(start) 1bae

435 _logger.warning(f'{n_missing} missing DAQ/Bpod timestamp(s) at start; ' 1bae

436 f'{"extrapolating" if extrapolate_missing else "prepending nans"}') 

437 to_app = (timestamps[0] - (np.arange(n_missing, 0, -1) + 1) / frate 1bae

438 if extrapolate_missing 

439 else np.full(n_missing, np.nan)) 

440 timestamps = np.r_[to_app, timestamps] # Prepend the missing times 1bae

441 start = 0 1bae

442 

443 # Remove the extraneous timestamps from the beginning and end 

444 end = count[-1] + 1 + start 1dbcae

445 ts = timestamps[start:end] 1dbcae

446 if (n_missing := count[-1] - ts.size + 1) > 0: 1dbcae

447 """ 1dbcae

448 For ephys sessions there may be fewer DAQ times than frame counts if DAQ acquisition is 

449 turned off before the video acquisition workflow. For Bpod this always occurs because Bpod 

450 finishes before the camera workflow. For Bpod the times are already extrapolated for 

451 these late frames.""" 

452 _logger.warning(f'{n_missing} fewer DAQ/Bpod timestamps than frame counts; ' 1dbcae

453 f'{"extrapolating" if extrapolate_missing else "appending nans"}') 

454 to_app = ((np.arange(n_missing, ) + 1) / frate + ts[-1] 1dbcae

455 if extrapolate_missing 

456 else np.full(n_missing, np.nan)) 

457 ts = np.r_[ts, to_app] # Append the missing times 1dbcae

458 assert ts.size >= count.size, 'fewer timestamps than frame counts' 1dbcae

459 assert ts.size == count[-1] + 1, 'more frames recorded in frame count than timestamps ' 1dbcae

460 

461 # Remove the rest of the dropped frames 

462 ts = ts[count] 1dbcae

463 assert np.searchsorted(ts, ttl['times'][0]) == first_uptick, \ 1dbcae

464 'time of first sync TTL doesn\'t match after alignment' 

465 if ts.size != count.size: 1dbcae

466 _logger.error('number of timestamps and frames don\'t match after alignment') 

467 

468 if display: 1dbcae

469 # Plot to check 

470 fig, axes = plt.subplots(1, 1) 1ca

471 y = within_ranges(np.arange(ts.size), pin_state['indices'].reshape(-1, 2)).astype(float) 1ca

472 y *= 1e-5 # For scale when zoomed in 1ca

473 axes.plot(ts, y, marker='d', color='blue', drawstyle='steps-pre', label='GPIO') 1ca

474 axes.plot(ts, np.zeros_like(ts), 'kx', label='DAQ timestamps') 1ca

475 vertical_lines(ttl['times'], ymin=0, ymax=1e-5, 1ca

476 color='r', linestyle=':', ax=axes, label='sync TTL') 

477 plt.legend() 1ca

478 

479 return ts 1dbcae

480 

481 

482def attribute_times(arr, events, tol=.1, injective=True, take='first'): 

483 """ 

484 Returns the values of the first array that correspond to those of the second. 

485 

486 Given two arrays of timestamps, the function will return the values of the first array 

487 that most likely correspond to the values of the second. For each of the values in the 

488 second array, the absolute difference is taken and the index of either the first sufficiently 

489 close value, or simply the closest one, is assigned. 

490 

491 If injective is True, once a value has been assigned to an event it can't be assigned to 

492 another. In other words there is a one-to-one mapping between the two arrays. 

493 

494 Parameters 

495 ---------- 

496 arr : numpy.array 

497 An array of event times to attribute to those in `events`. 

498 events : numpy.array 

499 An array of event times considered a subset of `arr`. 

500 tol : float 

501 The max absolute difference between values in order to be considered a match. 

502 injective : bool 

503 If true, once a value has been assigned it will not be assigned again. 

504 take : {'first', 'nearest', 'after'} 

505 If 'first' the first value within tolerance is assigned; if 'nearest' the 

506 closest value is assigned; if 'after' assign the first event after. 

507 

508 Returns 

509 ------- 

510 numpy.array 

511 An array the same length as `events` containing indices of `arr` corresponding to each 

512 event. 

513 """ 

514 if (take := take.lower()) not in ('first', 'nearest', 'after'): 1rgstdbchaf

515 raise ValueError('Parameter `take` must be either "first", "nearest", or "after"') 1r

516 stack = np.ma.masked_invalid(arr, copy=False) 1rgstdbchaf

517 stack.fill_value = np.inf 1rgstdbchaf

518 # If there are no invalid values, the mask is False so let's ensure it's a bool array 

519 if stack.mask is np.bool_(0): 1rgstdbchaf

520 stack.mask = np.zeros(arr.shape, dtype=bool) 

521 assigned = np.full(events.shape, -1, dtype=int) # Initialize output array 1rgstdbchaf

522 min_tol = 0 if take == 'after' else -tol 1rgstdbchaf

523 for i, x in enumerate(events): 1rgstdbchaf

524 dx = stack.filled() - x 1rgstdbchaf

525 candidates = np.logical_and(min_tol < dx, dx < tol) 1rgstdbchaf

526 if any(candidates): # is any value within tolerance 1rgstdbchaf

527 idx = np.abs(dx).argmin() if take == 'nearest' else np.where(candidates)[0][0] 1rgstdbchaf

528 assigned[i] = idx 1rgstdbchaf

529 stack.mask[idx] = injective # If one-to-one, remove the assigned value 1rgstdbchaf

530 return assigned 1rgstdbchaf

531 

532 

533def groom_pin_state(gpio, ttl, ts, tolerance=2., display=False, take='first', min_diff=0.): 

534 """ 

535 Align the GPIO pin state to the DAQ sync TTLs. Any sync TTLs not reflected in the pin 

536 state are removed from the dict and the times of the detected fronts are converted to DAQ 

537 time. At the end of this the number of GPIO fronts should equal the number of TTLs. 

538 

539 Note: 

540 - This function is ultra safe: we probably don't need assign all the ups and down fronts. 

541 separately and could potentially even align the timestamps without removing the missed fronts 

542 - The input gpio and TTL dicts may be modified by this function. 

543 - For training sessions the frame rate is only 30Hz and the TTLs tend to be broken up by 

544 small gaps. Setting the min_diff to 5ms helps the timestamp assignment accuracy. 

545 

546 Parameters 

547 ---------- 

548 gpio : dict 

549 A dictionary containing GPIO pin state values, with keys {'indices', 'polarities'}. 

550 ttl : dict 

551 A dictionary of DAQ sync TTLs, with keys {'times', 'polarities'}. 

552 ts : numpy.array 

553 The camera frame times (the camera frame TTLs acquired by the main DAQ). 

554 tolerance : float 

555 Two pulses need to be within this many seconds to be considered related. 

556 display : bool 

557 If true, the resulting timestamps are plotted against the raw audio signal. 

558 take : {'first', 'nearest'} 

559 If 'first' the first value within tolerance is assigned; if 'nearest' the 

560 closest value is assigned. 

561 min_diff : float 

562 Sync TTL fronts less than min_diff seconds apart will be removed. 

563 

564 Returns 

565 ------- 

566 dict 

567 Dictionary of GPIO DAQ front indices, polarities and DAQ aligned times. 

568 dict 

569 Sync TTL times and polarities sans the TTLs not detected in the frame data. 

570 numpy.array 

571 Frame times in DAQ time. 

572 

573 See Also 

574 -------- 

575 ibllib.io.extractors.ephys_fpga._get_sync_fronts 

576 """ 

577 # Check that the dimensions match 

578 if np.any(gpio['indices'] >= ts.size): 1gdbnchafe

579 _logger.warning('GPIO events occurring beyond timestamps array length') 1h

580 keep = gpio['indices'] < ts.size 1h

581 gpio = {k: gpio[k][keep] for k, v in gpio.items()} 1h

582 assert ttl and ttl['times'].size > 0, 'no sync TTLs for session' 1gdbnchafe

583 assert ttl['times'].size == ttl['polarities'].size, 'sync TTL data dimension mismatch' 1gdbnchafe

584 # make sure that there are no 2 consecutive fall or consecutive rise events 

585 assert np.all(np.abs(np.diff(ttl['polarities'])) == 2), 'consecutive high/low sync TTL events' 1gdbnchafe

586 # make sure first TTL is high 

587 assert ttl['polarities'][0] == 1 1gdbnchafe

588 # make sure ttl times in order 

589 assert np.all(np.diff(ttl['times']) > 0) 1gdbnchafe

590 # make sure raw timestamps increase 

591 assert np.all(np.diff(ts) > 0), 'timestamps must strictly increase' 1gdbnchafe

592 # make sure there are state changes 

593 assert gpio['indices'].any(), 'no TTLs detected in GPIO' 1gdbnchafe

594 # # make sure first GPIO state is high 

595 assert gpio['polarities'][0] == 1 1gdbnchafe

596 """ 1gdbchafe

597 Some sync TTLs appear to be so short that they are not recorded by the camera. These can 

598 be as short as a few microseconds. Applying a cutoff based on framerate was unsuccessful. 

599 Assigning each sync TTL to each pin state change is not easy because some onsets occur very 

600 close together (sometimes < 70ms), on the order of the delay between TTL and frame time. 

601 Also, the two clocks have some degree of drift, so the delay between sync TTL and pin state 

602 change may be zero or even negative. 

603 

604 Here we split the events into sync TTL onsets (lo->hi) and TTL offsets (hi->lo). For each 

605 uptick in the GPIO pin state, we take the first TTL onset time that was within 100ms of it. 

606 We ensure that each sync TTL is assigned only once, so a TTL that is closer to frame 3 than 

607 frame 1 may still be assigned to frame 1. 

608 """ 

609 ifronts = gpio['indices'] # The pin state flips 1gdbchafe

610 sync_times = ttl['times'] 1gdbchafe

611 if ifronts.size != ttl['times'].size: 1gdbchafe

612 _logger.warning('more sync TTLs than GPIO state changes, assigning timestamps') 1gdbchaf

613 to_remove = np.zeros(ifronts.size, dtype=bool) # unassigned GPIO fronts to remove 1gdbchaf

614 low2high = ifronts[gpio['polarities'] == 1] 1gdbchaf

615 high2low = ifronts[gpio['polarities'] == -1] 1gdbchaf

616 assert low2high.size >= high2low.size 1gdbchaf

617 

618 # Remove and/or fuse short TTLs 

619 if min_diff > 0: 1gdbchaf

620 short, = np.where(np.diff(ttl['times']) < min_diff) 1gbaf

621 sync_times = np.delete(ttl['times'], np.r_[short, short + 1]) 1gbaf

622 _logger.debug(f'Removed {short.size * 2} fronts TLLs less than ' 1gbaf

623 f'{min_diff * 1e3:.0f}ms apart') 

624 assert sync_times.size > 0, f'all sync TTLs less than {min_diff}s' 1gbaf

625 

626 # Onsets 

627 ups = ts[low2high] - ts[low2high][0] # times relative to first GPIO high 1gdbchaf

628 onsets = sync_times[::2] - sync_times[0] # TTL times relative to first onset 1gdbchaf

629 # assign GPIO fronts to ttl onset 

630 assigned = attribute_times(onsets, ups, tol=tolerance, take=take) 1gdbchaf

631 unassigned = np.setdiff1d(np.arange(onsets.size), assigned[assigned > -1]) 1gdbchaf

632 if unassigned.size > 0: 1gdbchaf

633 _logger.debug(f'{unassigned.size} sync TTL rises were not detected by the camera') 1gdbchaf

634 # Check that all pin state upticks could be attributed to an onset TTL 

635 if np.any(missed := assigned == -1): 1gdbchaf

636 _logger.warning(f'{sum(missed)} pin state rises could not be attributed to a sync TTL') 1gbhaf

637 if display: 1gbhaf

638 ax = plt.subplot() 1a

639 vertical_lines(ups[assigned > -1], 1a

640 linestyle='-', color='g', ax=ax, 

641 label='assigned GPIO up state') 

642 vertical_lines(ups[missed], 1a

643 linestyle='-', color='r', ax=ax, 

644 label='unassigned GPIO up state') 

645 vertical_lines(onsets[unassigned], 1a

646 linestyle=':', color='k', ax=ax, 

647 alpha=0.3, label='sync TTL onset') 

648 vertical_lines(onsets[assigned], 1a

649 linestyle=':', color='b', ax=ax, label='assigned TTL onset') 

650 plt.legend() 1a

651 plt.show() 1a

652 # Remove the missed fronts 

653 to_remove = np.in1d(gpio['indices'], low2high[missed]) 1gbhaf

654 assigned = assigned[~missed] 1gbhaf

655 onsets_ = sync_times[::2][assigned] 1gdbchaf

656 

657 # Offsets 

658 downs = ts[high2low] - ts[high2low][0] 1gdbchaf

659 offsets = sync_times[1::2] - sync_times[1] 1gdbchaf

660 assigned = attribute_times(offsets, downs, tol=tolerance, take=take) 1gdbchaf

661 unassigned = np.setdiff1d(np.arange(offsets.size), assigned[assigned > -1]) 1gdbchaf

662 if unassigned.size > 0: 1gdbchaf

663 _logger.debug(f'{unassigned.size} sync TTL falls were not detected by the camera') 1dbchaf

664 # Check that all pin state downticks could be attributed to an offset TTL 

665 if np.any(missed := assigned == -1): 1gdbchaf

666 _logger.warning(f'{sum(missed)} pin state falls could not be attributed to a sync TTL') 1bhaf

667 # Remove the missed fronts 

668 to_remove |= np.in1d(gpio['indices'], high2low[missed]) 1bhaf

669 assigned = assigned[~missed] 1bhaf

670 offsets_ = sync_times[1::2][assigned] 1gdbchaf

671 

672 # Sync TTLs groomed 

673 if np.any(to_remove): 1gdbchaf

674 # Check for any orphaned fronts (only one pin state edge was assigned) 

675 to_remove = np.pad(to_remove, (0, to_remove.size % 2), 'edge') # Ensure even size 1gbhaf

676 # Perform xor to find GPIOs where only onset or offset is marked for removal 

677 orphaned = to_remove.reshape(-1, 2).sum(axis=1) == 1 1gbhaf

678 if orphaned.any(): 1gbhaf

679 """If there are orphaned GPIO fronts (i.e. only one edge was assigned to a sync 1gf

680 TTL front), remove the orphaned front its assigned sync TTL. In other words 

681 if both edges cannot be assigned to a sync TTL, we ignore the TTL entirely. 

682 This is a sign that the assignment was bad and extraction may fail.""" 

683 _logger.warning('Some onsets but not offsets (or vice versa) were not assigned; ' 1gf

684 'this may be a sign of faulty wiring or clock drift') 

685 # Find indices of GPIO upticks where only the downtick was marked for removal 

686 orphaned_onsets, = np.where(~to_remove.reshape(-1, 2)[:, 0] & orphaned) 1gf

687 # The onsets_ array already has the other TTLs removed (same size as to_remove == 

688 # False) so subtract the number of removed elements from index. 

689 for i, v in enumerate(orphaned_onsets): 1gf

690 orphaned_onsets[i] -= to_remove.reshape(-1, 2)[:v, 0].sum() 1f

691 # Same for offsets... 

692 orphaned_offsets, = np.where(~to_remove.reshape(-1, 2)[:, 1] & orphaned) 1gf

693 for i, v in enumerate(orphaned_offsets): 1gf

694 orphaned_offsets[i] -= to_remove.reshape(-1, 2)[:v, 1].sum() 1gf

695 # Remove orphaned ttl onsets and offsets 

696 onsets_ = np.delete(onsets_, orphaned_onsets[orphaned_onsets < onsets_.size]) 1gf

697 offsets_ = np.delete(offsets_, orphaned_offsets[orphaned_offsets < offsets_.size]) 1gf

698 _logger.debug(f'{orphaned.sum()} orphaned TTLs removed') 1gf

699 to_remove.reshape(-1, 2)[orphaned] = True 1gf

700 

701 # Remove those unassigned GPIOs 

702 gpio = {k: v[~to_remove[:v.size]] for k, v in gpio.items()} 1gbhaf

703 ifronts = gpio['indices'] 1gbhaf

704 

705 # Assert that we've removed discrete TTLs 

706 # A failure means e.g. an up-going front of one TTL was missed but not the down-going one. 

707 assert np.all(np.abs(np.diff(gpio['polarities'])) == 2) 1gbhaf

708 assert gpio['polarities'][0] == 1 1gbhaf

709 

710 ttl_ = {'times': np.empty(ifronts.size), 'polarities': gpio['polarities']} 1gdbchaf

711 ttl_['times'][::2] = onsets_ 1gdbchaf

712 ttl_['times'][1::2] = offsets_ 1gdbchaf

713 else: 

714 ttl_ = ttl.copy() 1ghe

715 

716 # Align the frame times to DAQ 

717 fcn_a2b, drift_ppm = dsp.sync_timestamps(ts[ifronts], ttl_['times']) 1gdbchafe

718 _logger.debug(f'frame ttl alignment drift = {drift_ppm:.2f}ppm') 1gdbchafe

719 # Add times to GPIO dict 

720 gpio['times'] = fcn_a2b(ts[ifronts]) 1gdbchafe

721 

722 if display: 1gdbchafe

723 # Plot all the onsets and offsets 

724 ax = plt.subplot() 1ca

725 # All sync TTLs 

726 squares(ttl['times'], ttl['polarities'], 1ca

727 ax=ax, label='sync TTLs', linestyle=':', color='k', yrange=[0, 1], alpha=0.3) 

728 # GPIO 

729 x = np.insert(gpio['times'], 0, 0) 1ca

730 y = np.arange(x.size) % 2 1ca

731 squares(x, y, ax=ax, label='GPIO') 1ca

732 y = within_ranges(np.arange(ts.size), ifronts.reshape(-1, 2)) # 0 or 1 for each frame 1ca

733 ax.plot(fcn_a2b(ts), y, 'kx', label='cam times') 1ca

734 # Assigned ttl 

735 squares(ttl_['times'], ttl_['polarities'], 1ca

736 ax=ax, label='assigned sync TTL', linestyle=':', color='g', yrange=[0, 1]) 

737 ax.legend() 1ca

738 plt.xlabel('DAQ time (s)') 1ca

739 ax.set_yticks([0, 1]) 1ca

740 ax.set_title('GPIO - sync TTL alignment') 1ca

741 plt.show() 1ca

742 

743 return gpio, ttl_, fcn_a2b(ts) 1gdbchafe