Coverage for ibllib/io/extractors/camera.py: 96%

346 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1""" Camera extractor functions. 

2 

3This module handles extraction of camera timestamps for both Bpod and DAQ. 

4""" 

5import logging 

6 

7import cv2 

8import numpy as np 

9import matplotlib.pyplot as plt 

10from iblutil.util import range_str 

11 

12import ibldsp.utils as dsp 

13from ibllib.plots import squares, vertical_lines 

14from ibllib.io.video import assert_valid_label, VideoStreamer 

15from iblutil.numerical import within_ranges 

16from ibllib.io.extractors.ephys_fpga import get_sync_fronts 

17import ibllib.io.raw_data_loaders as raw 

18import ibllib.io.extractors.video_motion as vmotion 

19from ibllib.io.extractors.base import ( 

20 BaseBpodTrialsExtractor, 

21 BaseExtractor, 

22) 

23 

24_logger = logging.getLogger(__name__) 

25 

26 

27def extract_camera_sync(sync, chmap=None): 

28 """ 

29 Extract camera timestamps from the sync matrix. 

30 

31 :param sync: dictionary 'times', 'polarities' of fronts detected on sync trace 

32 :param chmap: dictionary containing channel indices. Default to constant. 

33 :return: dictionary containing camera timestamps 

34 """ 

35 assert chmap 1lndqkc

36 times = {} 1lndqkc

37 for k in filter(lambda x: x.endswith('_camera'), chmap): 1lndqkc

38 label, _ = k.rsplit('_', 1) 1lndqkc

39 times[label] = get_sync_fronts(sync, chmap[k]).times[::2] 1lndqkc

40 return times 1lndqkc

41 

42 

43def get_video_length(video_path): 

44 """ 

45 Returns video length. 

46 

47 :param video_path: A path to the video 

48 :return: 

49 """ 

50 is_url = isinstance(video_path, str) and video_path.startswith('http') 1ijldbkcrha

51 cap = VideoStreamer(video_path).cap if is_url else cv2.VideoCapture(str(video_path)) 1ijldbkcrha

52 assert cap.isOpened(), f'Failed to open video file {video_path}' 1ijldbkcrha

53 length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 1ijldbkcrha

54 cap.release() 1ijldbkcrha

55 return length 1ijldbkcrha

56 

57 

58class CameraTimestampsFPGA(BaseExtractor): 

59 """Extractor for videos using DAQ sync and channel map.""" 

60 

61 def __init__(self, label, session_path=None): 

62 super().__init__(session_path) 1ldkc

63 self.label = assert_valid_label(label) 1ldkc

64 self.save_names = f'_ibl_{label}Camera.times.npy' 1ldkc

65 self.var_names = f'{label}_camera_timestamps' 1ldkc

66 self._log_level = _logger.level 1ldkc

67 _logger.setLevel(logging.DEBUG) 1ldkc

68 

69 def __del__(self): 

70 _logger.setLevel(self._log_level) 1mld

71 

72 def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', 

73 display=False, extrapolate_missing=True, **kwargs): 

74 """ 

75 The raw timestamps are taken from the DAQ. These are the times of the camera's frame TTLs. 

76 If the pin state file exists, these timestamps are aligned to the video frames using 

77 task TTLs (typically the audio TTLs). Frames missing from the embedded frame count are 

78 removed from the timestamps array. 

79 If the pin state file does not exist, the left and right camera timestamps may be aligned 

80 using the wheel data. 

81 

82 Parameters 

83 ---------- 

84 sync : dict 

85 Dictionary 'times', 'polarities' of fronts detected on sync trace. 

86 chmap : dict 

87 Dictionary containing channel indices. Default to constant. 

88 video_path : str, pathlib.Path, int 

89 An optional path for fetching the number of frames. If None, the video is loaded from 

90 the session path. If an int is provided this is taken to be the total number of frames. 

91 sync_label : str 

92 The sync label of the channel that's wired to the GPIO for synchronising the times. 

93 display : bool 

94 If true, the TTL and GPIO fronts are plotted. 

95 extrapolate_missing : bool 

96 If true, any missing timestamps at the beginning and end of the session are 

97 extrapolated based on the median frame rate, otherwise they will be NaNs. 

98 **kwargs 

99 Extra keyword arguments (unused). 

100 

101 Returns 

102 ------- 

103 numpy.array 

104 The extracted camera timestamps. 

105 """ 

106 fpga_times = extract_camera_sync(sync=sync, chmap=chmap) 1ldkc

107 count, (*_, gpio) = raw.load_embedded_frame_data(self.session_path, self.label) 1ldkc

108 raw_ts = fpga_times[self.label] 1ldkc

109 

110 if video_path is None: 1ldkc

111 filename = f'_iblrig_{self.label}Camera.raw.mp4' 1lkc

112 video_path = self.session_path.joinpath('raw_video_data', filename) 1lkc

113 # Permit the video path to be the length for development and debugging purposes 

114 length = (video_path if isinstance(video_path, int) else get_video_length(video_path)) 1ldkc

115 _logger.debug(f'Number of video frames = {length}') 1ldkc

116 

117 if gpio is not None and gpio['indices'].size > 1 and sync_label is not None: 1ldkc

118 _logger.info(f'Aligning to {sync_label} TTLs') 1dkc

119 # Extract sync TTLs 

120 ttl = get_sync_fronts(sync, chmap[sync_label]) 1dkc

121 _, ts = raw.load_camera_ssv_times(self.session_path, self.label) 1dkc

122 try: 1dkc

123 """ 1dkc

124 NB: Some of the sync TTLs occur very close together, and are therefore not 

125 reflected in the pin state. This function removes those. Also converts frame 

126 times to DAQ time. 

127 """ 

128 gpio, ttl, ts = groom_pin_state(gpio, ttl, ts, display=display) 1dkc

129 """ 1dc

130 The length of the count and pin state are regularly longer than the length of 

131 the video file. Here we assert that the video is either shorter or the same 

132 length as the arrays, and we make an assumption that the missing frames are 

133 right at the end of the video. We therefore simply shorten the arrays to match 

134 the length of the video. 

135 """ 

136 if count.size > length: 1dc

137 count = count[:length] 1dc

138 else: 

139 assert length == count.size, 'fewer counts than frames' 

140 assert raw_ts.shape[0] > 0, 'no timestamps found in channel indicated for ' \ 1dc

141 f'{self.label} camera' 

142 return align_with_gpio(raw_ts, ttl, gpio, count, 1dc

143 display=display, 

144 extrapolate_missing=extrapolate_missing) 

145 except AssertionError as ex: 1k

146 _logger.critical('Failed to extract using %s: %s', sync_label, ex) 1k

147 

148 # If you reach here extracting using sync TTLs was not possible, we attempt to align using wheel motion energy 

149 _logger.warning('Attempting to align using wheel') 1lk

150 

151 try: 1lk

152 if self.label not in ['left', 'right']: 1lk

153 # Can only use wheel alignment for left and right cameras 

154 raise ValueError(f'Wheel alignment not supported for {self.label} camera') 1l

155 

156 motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, sync='nidq', upload=True) 1lk

157 new_times = motion_class.process() 

158 if not motion_class.qc_outcome: 

159 raise ValueError(f'Wheel alignment for {self.label} camera failed to pass qc: {motion_class.qc}') 

160 else: 

161 _logger.warning(f'Wheel alignment for {self.label} camera successful, qc: {motion_class.qc}') 

162 return new_times 

163 

164 except Exception as err: 1lk

165 _logger.critical(f'Failed to align with wheel for {self.label} camera: {err}') 1lk

166 

167 if length < raw_ts.size: 1lk

168 df = raw_ts.size - length 1l

169 _logger.info(f'Discarding first {df} pulses') 1l

170 raw_ts = raw_ts[df:] 1l

171 

172 return raw_ts 1lk

173 

174 

175class CameraTimestampsCamlog(BaseExtractor): 

176 def __init__(self, label, session_path=None): 

177 super().__init__(session_path) 1n

178 self.label = assert_valid_label(label) 1n

179 self.save_names = f'_ibl_{label}Camera.times.npy' 1n

180 self.var_names = f'{label}_camera_timestamps' 1n

181 self._log_level = _logger.level 1n

182 _logger.setLevel(logging.DEBUG) 1n

183 

184 def __del__(self): 

185 _logger.setLevel(self._log_level) 1n

186 

187 def _extract(self, sync=None, chmap=None, video_path=None, 

188 display=False, extrapolate_missing=True, **kwargs): 

189 

190 fpga_times = extract_camera_sync(sync=sync, chmap=chmap) 1n

191 video_frames = get_video_length(self.session_path.joinpath('raw_video_data', f'_iblrig_{self.label}Camera.raw.mp4')) 1n

192 raw_ts = fpga_times[self.label] 1n

193 

194 # For left camera sometimes we have one extra pulse than video frame 

195 if (raw_ts.size - video_frames) == 1: 1n

196 _logger.warning(f'One extra sync pulse detected for {self.label} camera') 

197 raw_ts = raw_ts[:-1] 

198 elif (raw_ts.size - video_frames) == -1: 1n

199 _logger.warning(f'One extra video frame detected for {self.label} camera') 

200 med_time = np.median(np.diff(raw_ts)) 

201 raw_ts = np.r_[raw_ts, np.array([raw_ts[-1] + med_time])] 

202 

203 assert video_frames == raw_ts.size, f'dimension mismatch between video frames and TTL pulses for {self.label} camera' \ 1n

204 f' by {np.abs(video_frames - raw_ts.size)} frames' 

205 

206 return raw_ts 1n

207 

208 

209class CameraTimestampsBpod(BaseBpodTrialsExtractor): 

210 """ 

211 Get the camera timestamps from the Bpod 

212 

213 The camera events are logged only during the events not in between, so the times need 

214 to be interpolated 

215 """ 

216 save_names = '_ibl_leftCamera.times.npy' 

217 var_names = 'left_camera_timestamps' 

218 

219 def __init__(self, *args, **kwargs): 

220 super().__init__(*args, **kwargs) 1ijbha

221 self._log_level = _logger.level 1ijbha

222 _logger.setLevel(logging.DEBUG) 1ijbha

223 

224 def __del__(self): 

225 _logger.setLevel(self._log_level) 1mib

226 

227 def _extract(self, video_path=None, display=False, extrapolate_missing=True, **kwargs): 

228 """ 

229 The raw timestamps are taken from the Bpod. These are the times of the camera's frame TTLs. 

230 If the pin state file exists, these timestamps are aligned to the video frames using the 

231 sync TTLs. Frames missing from the embedded frame count are removed from the timestamps 

232 array. 

233 If the pin state file does not exist, the left camera timestamps may be aligned using the 

234 wheel data. 

235 :param video_path: an optional path for fetching the number of frames. If None, 

236 the video is loaded from the session path. If an int is provided this is taken to be 

237 the total number of frames. 

238 :param display: if True, the TTL and GPIO fronts are plotted. 

239 :param extrapolate_missing: if True, any missing timestamps at the beginning and end of 

240 the session are extrapolated based on the median frame rate, otherwise they will be NaNs. 

241 :return: a numpy array of camera timestamps 

242 """ 

243 raw_ts = self._times_from_bpod() 1ijbha

244 count, (*_, gpio) = raw.load_embedded_frame_data(self.session_path, 'left') 1ijbha

245 if video_path is None: 1ijbha

246 filename = '_iblrig_leftCamera.raw.mp4' 1ha

247 video_path = self.session_path.joinpath('raw_video_data', filename) 1ha

248 # Permit the video path to be the length for development and debugging purposes 

249 length = video_path if isinstance(video_path, int) else get_video_length(video_path) 1ijbha

250 _logger.debug(f'Number of video frames = {length}') 1ijbha

251 

252 # Check if the GPIO is usable for extraction. GPIO is None if the file does not exist, 

253 # is empty, or contains only one value (i.e. doesn't change) 

254 if gpio is not None and gpio['indices'].size > 1: 1ijbha

255 _logger.info('Aligning to sync TTLs') 1ba

256 # Extract audio TTLs 

257 _, audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials, 1ba

258 task_collection=self.task_collection) 

259 _, ts = raw.load_camera_ssv_times(self.session_path, 'left') 1ba

260 """ 1ba

261 There are many sync TTLs that are for some reason missed by the GPIO. Conversely 

262 the last GPIO doesn't often correspond to any audio TTL. These will be removed. 

263 The drift appears to be less severe than the DAQ, so when assigning TTLs we'll take 

264 the nearest TTL within 500ms. The go cue TTLs comprise two short pulses ~3ms apart. 

265 We will fuse any TTLs less than 5ms apart to make assignment more accurate. 

266 """ 

267 try: 1ba

268 gpio, audio, ts = groom_pin_state(gpio, audio, ts, take='nearest', 1ba

269 tolerance=.5, min_diff=5e-3, display=display) 

270 if count.size > length: 1ba

271 count = count[:length] 1a

272 else: 

273 assert length == count.size, 'fewer counts than frames' 1ba

274 

275 return align_with_gpio(raw_ts, audio, gpio, count, 1ba

276 extrapolate_missing, display=display) 

277 except AssertionError as ex: 1a

278 _logger.critical('Failed to extract using audio: %s', ex) 1a

279 

280 # If you reach here extracting using audio TTLs was not possible 

281 _logger.warning('Alignment by wheel data not yet implemented') 1ijha

282 # Extrapolate at median frame rate 

283 n_missing = length - raw_ts.size 1ijha

284 if n_missing > 0: 1ijha

285 _logger.warning(f'{n_missing} fewer Bpod timestamps than frames; ' 1ha

286 f'{"extrapolating" if extrapolate_missing else "appending nans"}') 

287 frate = np.median(np.diff(raw_ts)) 1ha

288 to_app = ((np.arange(n_missing, ) + 1) / frate + raw_ts[-1] 1ha

289 if extrapolate_missing 

290 else np.full(n_missing, np.nan)) 

291 raw_ts = np.r_[raw_ts, to_app] # Append the missing times 1ha

292 elif n_missing < 0: 1ijh

293 _logger.warning(f'{abs(n_missing)} fewer frames than Bpod timestamps') 1ijh

294 _logger.info(f'Discarding first {abs(n_missing)} pulses') 1ijh

295 raw_ts = raw_ts[abs(n_missing):] 1ijh

296 

297 return raw_ts 1ijha

298 

299 def _times_from_bpod(self): 

300 ntrials = len(self.bpod_trials) 1ijbha

301 

302 cam_times = [] 1ijbha

303 n_frames = 0 1ijbha

304 n_out_of_sync = 0 1ijbha

305 missed_trials = [] 1ijbha

306 for ind in range(ntrials): 1ijbha

307 # get upgoing and downgoing fronts 

308 events = self.bpod_trials[ind]['behavior_data']['Events timestamps'] 1ijbha

309 pin = np.array(events.get('Port1In') or [np.nan]) 1ijbha

310 pout = np.array(events.get('Port1Out') or [np.nan]) 1ijbha

311 # some trials at startup may not have the camera working, discard 

312 if np.isnan(pin).all(): 1ijbha

313 missed_trials.append(ind) 1a

314 continue 1a

315 # if the trial starts in the middle of a square, discard the first downgoing front 

316 if pout[0] < pin[0]: 1ijbha

317 pout = pout[1:] 1ijbha

318 # same if the last sample is during an upgoing front, 

319 # always put size as it happens last 

320 pin = pin[:pout.size] 1ijbha

321 frate = np.median(np.diff(pin)) 1ijbha

322 if ind > 0: 1ijbha

323 """ 1ijbha

324 assert that the pulses have the same length and that we don't miss frames during 

325 the trial, the refresh rate of bpod is 100us 

326 """ 

327 test1 = np.all(np.abs(1 - (pin - pout) / np.median(pin - pout)) < 0.1) 1ijbha

328 test2 = np.all(np.abs(np.diff(pin) - frate) <= 0.00011) 1ijbha

329 if not all([test1, test2]): 1ijbha

330 n_out_of_sync += 1 1bha

331 # grow a list of cam times for ech trial 

332 cam_times.append(pin) 1ijbha

333 n_frames += pin.size 1ijbha

334 

335 if missed_trials: 1ijbha

336 _logger.debug('trial(s) %s missing TTL events', range_str(missed_trials)) 1a

337 if n_out_of_sync > 0: 1ijbha

338 _logger.warning(f"{n_out_of_sync} trials with bpod camera frame times not within" 1bha

339 f" 10% of the expected sampling rate") 

340 

341 t_first_frame = np.array([c[0] for c in cam_times]) 1ijbha

342 t_last_frame = np.array([c[-1] for c in cam_times]) 1ijbha

343 frate = 1 / np.nanmedian(np.array([np.median(np.diff(c)) for c in cam_times])) 1ijbha

344 intertrial_duration = t_first_frame[1:] - t_last_frame[:-1] 1ijbha

345 intertrial_missed_frames = np.int32(np.round(intertrial_duration * frate)) - 1 1ijbha

346 

347 # initialize the full times array 

348 frame_times = np.zeros(n_frames + int(np.sum(intertrial_missed_frames))) 1ijbha

349 ii = 0 1ijbha

350 for trial, cam_time in enumerate(cam_times): 1ijbha

351 if cam_time is not None: 1ijbha

352 # populate first the recovered times within the trials 

353 frame_times[ii: ii + cam_time.size] = cam_time 1ijbha

354 ii += cam_time.size 1ijbha

355 if trial == (len(cam_times) - 1): 1ijbha

356 break 1ijbha

357 # then extrapolate in-between 

358 nmiss = intertrial_missed_frames[trial] 1ijbha

359 frame_times[ii: ii + nmiss] = (cam_time[-1] + intertrial_duration[trial] / 1ijbha

360 (nmiss + 1) * (np.arange(nmiss) + 1)) 

361 ii += nmiss 1ijbha

362 assert all(np.diff(frame_times) > 0) # negative diffs implies a big problem 1ijbha

363 return frame_times 1ijbha

364 

365 

366def align_with_gpio(timestamps, ttl, pin_state, count, extrapolate_missing=True, display=False): 

367 """ 

368 Groom the raw DAQ or Bpod camera timestamps using the frame embedded GPIO and frame counter. 

369 

370 Parameters 

371 ---------- 

372 timestamps : numpy.array 

373 An array of raw DAQ or Bpod camera timestamps. 

374 ttl : dict 

375 A dictionary of DAQ sync TTLs, with keys {'times', 'polarities'}. 

376 pin_state : dict 

377 A dictionary containing GPIO pin state values, with keys {'indices', 'polarities'}. 

378 count : numpy.array 

379 An array of frame numbers. 

380 extrapolate_missing : bool 

381 If true and the number of timestamps is fewer than the number of frame counts, the 

382 remaining timestamps are extrapolated based on the frame rate, otherwise they are NaNs. 

383 display : bool 

384 Plot the resulting timestamps. 

385 

386 Returns 

387 ------- 

388 numpy.array 

389 The corrected frame timestamps. 

390 """ 

391 # Some assertions made on the raw data 

392 # assert count.size == pin_state.size, 'frame count and pin state size mismatch' 

393 assert all(np.diff(count) > 0), 'frame count not strictly increasing' 1dbca

394 assert all(np.diff(timestamps) > 0), 'DAQ/Bpod camera times not strictly increasing' 1dbca

395 same_n_ttl = pin_state['times'].size == ttl['times'].size 1dbca

396 assert same_n_ttl, 'more ttl TTLs detected on camera than TTLs sent' 1dbca

397 

398 """Here we will ensure that the DAQ camera times match the number of video frames in 1dbca

399 length. We will make the following assumptions: 

400 

401 1. The number of DAQ camera times is equal to or greater than the number of video frames. 

402 2. No TTLs were missed between the camera and DAQ. 

403 3. No pin states were missed by Bonsai. 

404 4 No pixel count data was missed by Bonsai. 

405 

406 In other words the count and pin state arrays accurately reflect the number of frames 

407 sent by the camera and should therefore be the same length, and the length of the frame 

408 counter should match the number of saved video frames. 

409 

410 The missing frame timestamps are removed in three stages: 

411 

412 1. Remove any timestamps that occurred before video frame acquisition in Bonsai. 

413 2. Remove any timestamps where the frame counter reported missing frames, i.e. remove the 

414 dropped frames which occurred throughout the session. 

415 3. Remove the trailing timestamps at the end of the session if the camera was turned off 

416 in the wrong order. 

417 """ 

418 # Align on first pin state change 

419 first_uptick = pin_state['indices'][0] 1dbca

420 first_ttl = np.searchsorted(timestamps, ttl['times'][0]) 1dbca

421 """Here we find up to which index in the DAQ times we discard by taking the difference 1dbca

422 between the index of the first pin state change (when the sync TTL was reported by the 

423 camera) and the index of the first sync TTL in DAQ time. We subtract the difference 

424 between the frame count at the first pin state change and the index to account for any 

425 video frames that were not saved during this period (we will remove those from the 

426 camera DAQ times later). 

427 """ 

428 # Minus any frames that were dropped between the start of frame acquisition and the first TTL 

429 start = first_ttl - first_uptick - (count[first_uptick] - first_uptick) 1dbca

430 # Get approximate frame rate for extrapolating timestamps (if required) 

431 frate = round(1 / np.nanmedian(np.diff(timestamps))) 1dbca

432 

433 if start < 0: 1dbca

434 n_missing = abs(start) 1ba

435 _logger.warning(f'{n_missing} missing DAQ/Bpod timestamp(s) at start; ' 1ba

436 f'{"extrapolating" if extrapolate_missing else "prepending nans"}') 

437 to_app = (timestamps[0] - (np.arange(n_missing, 0, -1) + 1) / frate 1ba

438 if extrapolate_missing 

439 else np.full(n_missing, np.nan)) 

440 timestamps = np.r_[to_app, timestamps] # Prepend the missing times 1ba

441 start = 0 1ba

442 

443 # Remove the extraneous timestamps from the beginning and end 

444 end = count[-1] + 1 + start 1dbca

445 ts = timestamps[start:end] 1dbca

446 if (n_missing := count[-1] - ts.size + 1) > 0: 1dbca

447 """ 1dbca

448 For ephys sessions there may be fewer DAQ times than frame counts if DAQ acquisition is 

449 turned off before the video acquisition workflow. For Bpod this always occurs because Bpod 

450 finishes before the camera workflow. For Bpod the times are already extrapolated for 

451 these late frames.""" 

452 _logger.warning(f'{n_missing} fewer DAQ/Bpod timestamps than frame counts; ' 1dbca

453 f'{"extrapolating" if extrapolate_missing else "appending nans"}') 

454 to_app = ((np.arange(n_missing, ) + 1) / frate + ts[-1] 1dbca

455 if extrapolate_missing 

456 else np.full(n_missing, np.nan)) 

457 ts = np.r_[ts, to_app] # Append the missing times 1dbca

458 assert ts.size >= count.size, 'fewer timestamps than frame counts' 1dbca

459 assert ts.size == count[-1] + 1, 'more frames recorded in frame count than timestamps ' 1dbca

460 

461 # Remove the rest of the dropped frames 

462 ts = ts[count] 1dbca

463 assert np.searchsorted(ts, ttl['times'][0]) == first_uptick, \ 1dbca

464 'time of first sync TTL doesn\'t match after alignment' 

465 if ts.size != count.size: 1dbca

466 _logger.error('number of timestamps and frames don\'t match after alignment') 

467 

468 if display: 1dbca

469 # Plot to check 

470 fig, axes = plt.subplots(1, 1) 1ca

471 y = within_ranges(np.arange(ts.size), pin_state['indices'].reshape(-1, 2)).astype(float) 1ca

472 y *= 1e-5 # For scale when zoomed in 1ca

473 axes.plot(ts, y, marker='d', color='blue', drawstyle='steps-pre', label='GPIO') 1ca

474 axes.plot(ts, np.zeros_like(ts), 'kx', label='DAQ timestamps') 1ca

475 vertical_lines(ttl['times'], ymin=0, ymax=1e-5, 1ca

476 color='r', linestyle=':', ax=axes, label='sync TTL') 

477 plt.legend() 1ca

478 

479 return ts 1dbca

480 

481 

482def attribute_times(arr, events, tol=.1, injective=True, take='first'): 

483 """ 

484 Returns the values of the first array that correspond to those of the second. 

485 

486 Given two arrays of timestamps, the function will return the values of the first array 

487 that most likely correspond to the values of the second. For each of the values in the 

488 second array, the absolute difference is taken and the index of either the first sufficiently 

489 close value, or simply the closest one, is assigned. 

490 

491 If injective is True, once a value has been assigned to an event it can't be assigned to 

492 another. In other words there is a one-to-one mapping between the two arrays. 

493 

494 Parameters 

495 ---------- 

496 arr : numpy.array 

497 An array of event times to attribute to those in `events`. 

498 events : numpy.array 

499 An array of event times considered a subset of `arr`. 

500 tol : float 

501 The max absolute difference between values in order to be considered a match. 

502 injective : bool 

503 If true, once a value has been assigned it will not be assigned again. 

504 take : {'first', 'nearest', 'after'} 

505 If 'first' the first value within tolerance is assigned; if 'nearest' the 

506 closest value is assigned; if 'after' assign the first event after. 

507 

508 Returns 

509 ------- 

510 numpy.array 

511 An array the same length as `events` containing indices of `arr` corresponding to each 

512 event. 

513 """ 

514 if (take := take.lower()) not in ('first', 'nearest', 'after'): 1ofpdbcgae

515 raise ValueError('Parameter `take` must be either "first", "nearest", or "after"') 1o

516 stack = np.ma.masked_invalid(arr, copy=False) 1ofpdbcgae

517 stack.fill_value = np.inf 1ofpdbcgae

518 # If there are no invalid values, the mask is False so let's ensure it's a bool array 

519 if stack.mask is np.bool_(0): 1ofpdbcgae

520 stack.mask = np.zeros(arr.shape, dtype=bool) 

521 assigned = np.full(events.shape, -1, dtype=int) # Initialize output array 1ofpdbcgae

522 min_tol = 0 if take == 'after' else -tol 1ofpdbcgae

523 for i, x in enumerate(events): 1ofpdbcgae

524 dx = stack.filled() - x 1ofpdbcgae

525 candidates = np.logical_and(min_tol < dx, dx < tol) 1ofpdbcgae

526 if any(candidates): # is any value within tolerance 1ofpdbcgae

527 idx = np.abs(dx).argmin() if take == 'nearest' else np.where(candidates)[0][0] 1ofpdbcgae

528 assigned[i] = idx 1ofpdbcgae

529 stack.mask[idx] = injective # If one-to-one, remove the assigned value 1ofpdbcgae

530 return assigned 1ofpdbcgae

531 

532 

533def groom_pin_state(gpio, ttl, ts, tolerance=2., display=False, take='first', min_diff=0.): 

534 """ 

535 Align the GPIO pin state to the DAQ sync TTLs. Any sync TTLs not reflected in the pin 

536 state are removed from the dict and the times of the detected fronts are converted to DAQ 

537 time. At the end of this the number of GPIO fronts should equal the number of TTLs. 

538 

539 Note: 

540 - This function is ultra safe: we probably don't need assign all the ups and down fronts. 

541 separately and could potentially even align the timestamps without removing the missed fronts 

542 - The input gpio and TTL dicts may be modified by this function. 

543 - For training sessions the frame rate is only 30Hz and the TTLs tend to be broken up by 

544 small gaps. Setting the min_diff to 5ms helps the timestamp assignment accuracy. 

545 

546 Parameters 

547 ---------- 

548 gpio : dict 

549 A dictionary containing GPIO pin state values, with keys {'indices', 'polarities'}. 

550 ttl : dict 

551 A dictionary of DAQ sync TTLs, with keys {'times', 'polarities'}. 

552 ts : numpy.array 

553 The camera frame times (the camera frame TTLs acquired by the main DAQ). 

554 tolerance : float 

555 Two pulses need to be within this many seconds to be considered related. 

556 display : bool 

557 If true, the resulting timestamps are plotted against the raw audio signal. 

558 take : {'first', 'nearest'} 

559 If 'first' the first value within tolerance is assigned; if 'nearest' the 

560 closest value is assigned. 

561 min_diff : float 

562 Sync TTL fronts less than min_diff seconds apart will be removed. 

563 

564 Returns 

565 ------- 

566 dict 

567 Dictionary of GPIO DAQ front indices, polarities and DAQ aligned times. 

568 dict 

569 Sync TTL times and polarities sans the TTLs not detected in the frame data. 

570 numpy.array 

571 Frame times in DAQ time. 

572 

573 See Also 

574 -------- 

575 ibllib.io.extractors.ephys_fpga._get_sync_fronts 

576 """ 

577 # Check that the dimensions match 

578 if np.any(gpio['indices'] >= ts.size): 1fdbkcgae

579 _logger.warning('GPIO events occurring beyond timestamps array length') 1g

580 keep = gpio['indices'] < ts.size 1g

581 gpio = {k: gpio[k][keep] for k, v in gpio.items()} 1g

582 assert ttl and ttl['times'].size > 0, 'no sync TTLs for session' 1fdbkcgae

583 assert ttl['times'].size == ttl['polarities'].size, 'sync TTL data dimension mismatch' 1fdbkcgae

584 # make sure that there are no 2 consecutive fall or consecutive rise events 

585 assert np.all(np.abs(np.diff(ttl['polarities'])) == 2), 'consecutive high/low sync TTL events' 1fdbkcgae

586 # make sure first TTL is high 

587 assert ttl['polarities'][0] == 1 1fdbkcgae

588 # make sure ttl times in order 

589 assert np.all(np.diff(ttl['times']) > 0) 1fdbkcgae

590 # make sure raw timestamps increase 

591 assert np.all(np.diff(ts) > 0), 'timestamps must strictly increase' 1fdbkcgae

592 # make sure there are state changes 

593 assert gpio['indices'].any(), 'no TTLs detected in GPIO' 1fdbkcgae

594 # # make sure first GPIO state is high 

595 assert gpio['polarities'][0] == 1 1fdbkcgae

596 """ 1fdbcgae

597 Some sync TTLs appear to be so short that they are not recorded by the camera. These can 

598 be as short as a few microseconds. Applying a cutoff based on framerate was unsuccessful. 

599 Assigning each sync TTL to each pin state change is not easy because some onsets occur very 

600 close together (sometimes < 70ms), on the order of the delay between TTL and frame time. 

601 Also, the two clocks have some degree of drift, so the delay between sync TTL and pin state 

602 change may be zero or even negative. 

603 

604 Here we split the events into sync TTL onsets (lo->hi) and TTL offsets (hi->lo). For each 

605 uptick in the GPIO pin state, we take the first TTL onset time that was within 100ms of it. 

606 We ensure that each sync TTL is assigned only once, so a TTL that is closer to frame 3 than 

607 frame 1 may still be assigned to frame 1. 

608 """ 

609 ifronts = gpio['indices'] # The pin state flips 1fdbcgae

610 sync_times = ttl['times'] 1fdbcgae

611 if ifronts.size != ttl['times'].size: 1fdbcgae

612 _logger.warning('more sync TTLs than GPIO state changes, assigning timestamps') 1fdbcgae

613 to_remove = np.zeros(ifronts.size, dtype=bool) # unassigned GPIO fronts to remove 1fdbcgae

614 low2high = ifronts[gpio['polarities'] == 1] 1fdbcgae

615 high2low = ifronts[gpio['polarities'] == -1] 1fdbcgae

616 assert low2high.size >= high2low.size 1fdbcgae

617 

618 # Remove and/or fuse short TTLs 

619 if min_diff > 0: 1fdbcgae

620 short, = np.where(np.diff(ttl['times']) < min_diff) 1fbae

621 sync_times = np.delete(ttl['times'], np.r_[short, short + 1]) 1fbae

622 _logger.debug(f'Removed {short.size * 2} fronts TLLs less than ' 1fbae

623 f'{min_diff * 1e3:.0f}ms apart') 

624 assert sync_times.size > 0, f'all sync TTLs less than {min_diff}s' 1fbae

625 

626 # Onsets 

627 ups = ts[low2high] - ts[low2high][0] # times relative to first GPIO high 1fdbcgae

628 onsets = sync_times[::2] - sync_times[0] # TTL times relative to first onset 1fdbcgae

629 # assign GPIO fronts to ttl onset 

630 assigned = attribute_times(onsets, ups, tol=tolerance, take=take) 1fdbcgae

631 unassigned = np.setdiff1d(np.arange(onsets.size), assigned[assigned > -1]) 1fdbcgae

632 if unassigned.size > 0: 1fdbcgae

633 _logger.debug(f'{unassigned.size} sync TTL rises were not detected by the camera') 1fdbcgae

634 # Check that all pin state upticks could be attributed to an onset TTL 

635 if np.any(missed := assigned == -1): 1fdbcgae

636 _logger.warning(f'{sum(missed)} pin state rises could not be attributed to a sync TTL') 1fbgae

637 if display: 1fbgae

638 ax = plt.subplot() 1a

639 vertical_lines(ups[assigned > -1], 1a

640 linestyle='-', color='g', ax=ax, 

641 label='assigned GPIO up state') 

642 vertical_lines(ups[missed], 1a

643 linestyle='-', color='r', ax=ax, 

644 label='unassigned GPIO up state') 

645 vertical_lines(onsets[unassigned], 1a

646 linestyle=':', color='k', ax=ax, 

647 alpha=0.3, label='sync TTL onset') 

648 vertical_lines(onsets[assigned], 1a

649 linestyle=':', color='b', ax=ax, label='assigned TTL onset') 

650 plt.legend() 1a

651 plt.show() 1a

652 # Remove the missed fronts 

653 to_remove = np.in1d(gpio['indices'], low2high[missed]) 1fbgae

654 assigned = assigned[~missed] 1fbgae

655 onsets_ = sync_times[::2][assigned] 1fdbcgae

656 

657 # Offsets 

658 downs = ts[high2low] - ts[high2low][0] 1fdbcgae

659 offsets = sync_times[1::2] - sync_times[1] 1fdbcgae

660 assigned = attribute_times(offsets, downs, tol=tolerance, take=take) 1fdbcgae

661 unassigned = np.setdiff1d(np.arange(offsets.size), assigned[assigned > -1]) 1fdbcgae

662 if unassigned.size > 0: 1fdbcgae

663 _logger.debug(f'{unassigned.size} sync TTL falls were not detected by the camera') 1dbcgae

664 # Check that all pin state downticks could be attributed to an offset TTL 

665 if np.any(missed := assigned == -1): 1fdbcgae

666 _logger.warning(f'{sum(missed)} pin state falls could not be attributed to a sync TTL') 1bgae

667 # Remove the missed fronts 

668 to_remove |= np.in1d(gpio['indices'], high2low[missed]) 1bgae

669 assigned = assigned[~missed] 1bgae

670 offsets_ = sync_times[1::2][assigned] 1fdbcgae

671 

672 # Sync TTLs groomed 

673 if np.any(to_remove): 1fdbcgae

674 # Check for any orphaned fronts (only one pin state edge was assigned) 

675 to_remove = np.pad(to_remove, (0, to_remove.size % 2), 'edge') # Ensure even size 1fbgae

676 # Perform xor to find GPIOs where only onset or offset is marked for removal 

677 orphaned = to_remove.reshape(-1, 2).sum(axis=1) == 1 1fbgae

678 if orphaned.any(): 1fbgae

679 """If there are orphaned GPIO fronts (i.e. only one edge was assigned to a sync 1fe

680 TTL front), remove the orphaned front its assigned sync TTL. In other words 

681 if both edges cannot be assigned to a sync TTL, we ignore the TTL entirely. 

682 This is a sign that the assignment was bad and extraction may fail.""" 

683 _logger.warning('Some onsets but not offsets (or vice versa) were not assigned; ' 1fe

684 'this may be a sign of faulty wiring or clock drift') 

685 # Find indices of GPIO upticks where only the downtick was marked for removal 

686 orphaned_onsets, = np.where(~to_remove.reshape(-1, 2)[:, 0] & orphaned) 1fe

687 # The onsets_ array already has the other TTLs removed (same size as to_remove == 

688 # False) so subtract the number of removed elements from index. 

689 for i, v in enumerate(orphaned_onsets): 1fe

690 orphaned_onsets[i] -= to_remove.reshape(-1, 2)[:v, 0].sum() 1e

691 # Same for offsets... 

692 orphaned_offsets, = np.where(~to_remove.reshape(-1, 2)[:, 1] & orphaned) 1fe

693 for i, v in enumerate(orphaned_offsets): 1fe

694 orphaned_offsets[i] -= to_remove.reshape(-1, 2)[:v, 1].sum() 1fe

695 # Remove orphaned ttl onsets and offsets 

696 onsets_ = np.delete(onsets_, orphaned_onsets[orphaned_onsets < onsets_.size]) 1fe

697 offsets_ = np.delete(offsets_, orphaned_offsets[orphaned_offsets < offsets_.size]) 1fe

698 _logger.debug(f'{orphaned.sum()} orphaned TTLs removed') 1fe

699 to_remove.reshape(-1, 2)[orphaned] = True 1fe

700 

701 # Remove those unassigned GPIOs 

702 gpio = {k: v[~to_remove[:v.size]] for k, v in gpio.items()} 1fbgae

703 ifronts = gpio['indices'] 1fbgae

704 

705 # Assert that we've removed discrete TTLs 

706 # A failure means e.g. an up-going front of one TTL was missed but not the down-going one. 

707 assert np.all(np.abs(np.diff(gpio['polarities'])) == 2) 1fbgae

708 assert gpio['polarities'][0] == 1 1fbgae

709 

710 ttl_ = {'times': np.empty(ifronts.size), 'polarities': gpio['polarities']} 1fdbcgae

711 ttl_['times'][::2] = onsets_ 1fdbcgae

712 ttl_['times'][1::2] = offsets_ 1fdbcgae

713 else: 

714 ttl_ = ttl.copy() 1fg

715 

716 # Align the frame times to DAQ 

717 fcn_a2b, drift_ppm = dsp.sync_timestamps(ts[ifronts], ttl_['times']) 1fdbcgae

718 _logger.debug(f'frame ttl alignment drift = {drift_ppm:.2f}ppm') 1fdbcgae

719 # Add times to GPIO dict 

720 gpio['times'] = fcn_a2b(ts[ifronts]) 1fdbcgae

721 

722 if display: 1fdbcgae

723 # Plot all the onsets and offsets 

724 ax = plt.subplot() 1ca

725 # All sync TTLs 

726 squares(ttl['times'], ttl['polarities'], 1ca

727 ax=ax, label='sync TTLs', linestyle=':', color='k', yrange=[0, 1], alpha=0.3) 

728 # GPIO 

729 x = np.insert(gpio['times'], 0, 0) 1ca

730 y = np.arange(x.size) % 2 1ca

731 squares(x, y, ax=ax, label='GPIO') 1ca

732 y = within_ranges(np.arange(ts.size), ifronts.reshape(-1, 2)) # 0 or 1 for each frame 1ca

733 ax.plot(fcn_a2b(ts), y, 'kx', label='cam times') 1ca

734 # Assigned ttl 

735 squares(ttl_['times'], ttl_['polarities'], 1ca

736 ax=ax, label='assigned sync TTL', linestyle=':', color='g', yrange=[0, 1]) 

737 ax.legend() 1ca

738 plt.xlabel('DAQ time (s)') 1ca

739 ax.set_yticks([0, 1]) 1ca

740 ax.set_title('GPIO - sync TTL alignment') 1ca

741 plt.show() 1ca

742 

743 return gpio, ttl_, fcn_a2b(ts) 1fdbcgae