Coverage for ibllib/io/extractors/camera.py: 96%
346 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1""" Camera extractor functions.
3This module handles extraction of camera timestamps for both Bpod and DAQ.
4"""
5import logging
7import cv2
8import numpy as np
9import matplotlib.pyplot as plt
10from iblutil.util import range_str
12import ibldsp.utils as dsp
13from ibllib.plots import squares, vertical_lines
14from ibllib.io.video import assert_valid_label, VideoStreamer
15from iblutil.numerical import within_ranges
16from ibllib.io.extractors.ephys_fpga import get_sync_fronts
17import ibllib.io.raw_data_loaders as raw
18import ibllib.io.extractors.video_motion as vmotion
19from ibllib.io.extractors.base import (
20 BaseBpodTrialsExtractor,
21 BaseExtractor,
22)
24_logger = logging.getLogger(__name__)
27def extract_camera_sync(sync, chmap=None):
28 """
29 Extract camera timestamps from the sync matrix.
31 :param sync: dictionary 'times', 'polarities' of fronts detected on sync trace
32 :param chmap: dictionary containing channel indices. Default to constant.
33 :return: dictionary containing camera timestamps
34 """
35 assert chmap 1lndqkc
36 times = {} 1lndqkc
37 for k in filter(lambda x: x.endswith('_camera'), chmap): 1lndqkc
38 label, _ = k.rsplit('_', 1) 1lndqkc
39 times[label] = get_sync_fronts(sync, chmap[k]).times[::2] 1lndqkc
40 return times 1lndqkc
43def get_video_length(video_path):
44 """
45 Returns video length.
47 :param video_path: A path to the video
48 :return:
49 """
50 is_url = isinstance(video_path, str) and video_path.startswith('http') 1ijldbkcrha
51 cap = VideoStreamer(video_path).cap if is_url else cv2.VideoCapture(str(video_path)) 1ijldbkcrha
52 assert cap.isOpened(), f'Failed to open video file {video_path}' 1ijldbkcrha
53 length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 1ijldbkcrha
54 cap.release() 1ijldbkcrha
55 return length 1ijldbkcrha
58class CameraTimestampsFPGA(BaseExtractor):
59 """Extractor for videos using DAQ sync and channel map."""
61 def __init__(self, label, session_path=None):
62 super().__init__(session_path) 1ldkc
63 self.label = assert_valid_label(label) 1ldkc
64 self.save_names = f'_ibl_{label}Camera.times.npy' 1ldkc
65 self.var_names = f'{label}_camera_timestamps' 1ldkc
66 self._log_level = _logger.level 1ldkc
67 _logger.setLevel(logging.DEBUG) 1ldkc
69 def __del__(self):
70 _logger.setLevel(self._log_level) 1mld
72 def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio',
73 display=False, extrapolate_missing=True, **kwargs):
74 """
75 The raw timestamps are taken from the DAQ. These are the times of the camera's frame TTLs.
76 If the pin state file exists, these timestamps are aligned to the video frames using
77 task TTLs (typically the audio TTLs). Frames missing from the embedded frame count are
78 removed from the timestamps array.
79 If the pin state file does not exist, the left and right camera timestamps may be aligned
80 using the wheel data.
82 Parameters
83 ----------
84 sync : dict
85 Dictionary 'times', 'polarities' of fronts detected on sync trace.
86 chmap : dict
87 Dictionary containing channel indices. Default to constant.
88 video_path : str, pathlib.Path, int
89 An optional path for fetching the number of frames. If None, the video is loaded from
90 the session path. If an int is provided this is taken to be the total number of frames.
91 sync_label : str
92 The sync label of the channel that's wired to the GPIO for synchronising the times.
93 display : bool
94 If true, the TTL and GPIO fronts are plotted.
95 extrapolate_missing : bool
96 If true, any missing timestamps at the beginning and end of the session are
97 extrapolated based on the median frame rate, otherwise they will be NaNs.
98 **kwargs
99 Extra keyword arguments (unused).
101 Returns
102 -------
103 numpy.array
104 The extracted camera timestamps.
105 """
106 fpga_times = extract_camera_sync(sync=sync, chmap=chmap) 1ldkc
107 count, (*_, gpio) = raw.load_embedded_frame_data(self.session_path, self.label) 1ldkc
108 raw_ts = fpga_times[self.label] 1ldkc
110 if video_path is None: 1ldkc
111 filename = f'_iblrig_{self.label}Camera.raw.mp4' 1lkc
112 video_path = self.session_path.joinpath('raw_video_data', filename) 1lkc
113 # Permit the video path to be the length for development and debugging purposes
114 length = (video_path if isinstance(video_path, int) else get_video_length(video_path)) 1ldkc
115 _logger.debug(f'Number of video frames = {length}') 1ldkc
117 if gpio is not None and gpio['indices'].size > 1 and sync_label is not None: 1ldkc
118 _logger.info(f'Aligning to {sync_label} TTLs') 1dkc
119 # Extract sync TTLs
120 ttl = get_sync_fronts(sync, chmap[sync_label]) 1dkc
121 _, ts = raw.load_camera_ssv_times(self.session_path, self.label) 1dkc
122 try: 1dkc
123 """ 1dkc
124 NB: Some of the sync TTLs occur very close together, and are therefore not
125 reflected in the pin state. This function removes those. Also converts frame
126 times to DAQ time.
127 """
128 gpio, ttl, ts = groom_pin_state(gpio, ttl, ts, display=display) 1dkc
129 """ 1dc
130 The length of the count and pin state are regularly longer than the length of
131 the video file. Here we assert that the video is either shorter or the same
132 length as the arrays, and we make an assumption that the missing frames are
133 right at the end of the video. We therefore simply shorten the arrays to match
134 the length of the video.
135 """
136 if count.size > length: 1dc
137 count = count[:length] 1dc
138 else:
139 assert length == count.size, 'fewer counts than frames'
140 assert raw_ts.shape[0] > 0, 'no timestamps found in channel indicated for ' \ 1dc
141 f'{self.label} camera'
142 return align_with_gpio(raw_ts, ttl, gpio, count, 1dc
143 display=display,
144 extrapolate_missing=extrapolate_missing)
145 except AssertionError as ex: 1k
146 _logger.critical('Failed to extract using %s: %s', sync_label, ex) 1k
148 # If you reach here extracting using sync TTLs was not possible, we attempt to align using wheel motion energy
149 _logger.warning('Attempting to align using wheel') 1lk
151 try: 1lk
152 if self.label not in ['left', 'right']: 1lk
153 # Can only use wheel alignment for left and right cameras
154 raise ValueError(f'Wheel alignment not supported for {self.label} camera') 1l
156 motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, sync='nidq', upload=True) 1lk
157 new_times = motion_class.process()
158 if not motion_class.qc_outcome:
159 raise ValueError(f'Wheel alignment for {self.label} camera failed to pass qc: {motion_class.qc}')
160 else:
161 _logger.warning(f'Wheel alignment for {self.label} camera successful, qc: {motion_class.qc}')
162 return new_times
164 except Exception as err: 1lk
165 _logger.critical(f'Failed to align with wheel for {self.label} camera: {err}') 1lk
167 if length < raw_ts.size: 1lk
168 df = raw_ts.size - length 1l
169 _logger.info(f'Discarding first {df} pulses') 1l
170 raw_ts = raw_ts[df:] 1l
172 return raw_ts 1lk
175class CameraTimestampsCamlog(BaseExtractor):
176 def __init__(self, label, session_path=None):
177 super().__init__(session_path) 1n
178 self.label = assert_valid_label(label) 1n
179 self.save_names = f'_ibl_{label}Camera.times.npy' 1n
180 self.var_names = f'{label}_camera_timestamps' 1n
181 self._log_level = _logger.level 1n
182 _logger.setLevel(logging.DEBUG) 1n
184 def __del__(self):
185 _logger.setLevel(self._log_level) 1n
187 def _extract(self, sync=None, chmap=None, video_path=None,
188 display=False, extrapolate_missing=True, **kwargs):
190 fpga_times = extract_camera_sync(sync=sync, chmap=chmap) 1n
191 video_frames = get_video_length(self.session_path.joinpath('raw_video_data', f'_iblrig_{self.label}Camera.raw.mp4')) 1n
192 raw_ts = fpga_times[self.label] 1n
194 # For left camera sometimes we have one extra pulse than video frame
195 if (raw_ts.size - video_frames) == 1: 1n
196 _logger.warning(f'One extra sync pulse detected for {self.label} camera')
197 raw_ts = raw_ts[:-1]
198 elif (raw_ts.size - video_frames) == -1: 1n
199 _logger.warning(f'One extra video frame detected for {self.label} camera')
200 med_time = np.median(np.diff(raw_ts))
201 raw_ts = np.r_[raw_ts, np.array([raw_ts[-1] + med_time])]
203 assert video_frames == raw_ts.size, f'dimension mismatch between video frames and TTL pulses for {self.label} camera' \ 1n
204 f' by {np.abs(video_frames - raw_ts.size)} frames'
206 return raw_ts 1n
209class CameraTimestampsBpod(BaseBpodTrialsExtractor):
210 """
211 Get the camera timestamps from the Bpod
213 The camera events are logged only during the events not in between, so the times need
214 to be interpolated
215 """
216 save_names = '_ibl_leftCamera.times.npy'
217 var_names = 'left_camera_timestamps'
219 def __init__(self, *args, **kwargs):
220 super().__init__(*args, **kwargs) 1ijbha
221 self._log_level = _logger.level 1ijbha
222 _logger.setLevel(logging.DEBUG) 1ijbha
224 def __del__(self):
225 _logger.setLevel(self._log_level) 1mib
227 def _extract(self, video_path=None, display=False, extrapolate_missing=True, **kwargs):
228 """
229 The raw timestamps are taken from the Bpod. These are the times of the camera's frame TTLs.
230 If the pin state file exists, these timestamps are aligned to the video frames using the
231 sync TTLs. Frames missing from the embedded frame count are removed from the timestamps
232 array.
233 If the pin state file does not exist, the left camera timestamps may be aligned using the
234 wheel data.
235 :param video_path: an optional path for fetching the number of frames. If None,
236 the video is loaded from the session path. If an int is provided this is taken to be
237 the total number of frames.
238 :param display: if True, the TTL and GPIO fronts are plotted.
239 :param extrapolate_missing: if True, any missing timestamps at the beginning and end of
240 the session are extrapolated based on the median frame rate, otherwise they will be NaNs.
241 :return: a numpy array of camera timestamps
242 """
243 raw_ts = self._times_from_bpod() 1ijbha
244 count, (*_, gpio) = raw.load_embedded_frame_data(self.session_path, 'left') 1ijbha
245 if video_path is None: 1ijbha
246 filename = '_iblrig_leftCamera.raw.mp4' 1ha
247 video_path = self.session_path.joinpath('raw_video_data', filename) 1ha
248 # Permit the video path to be the length for development and debugging purposes
249 length = video_path if isinstance(video_path, int) else get_video_length(video_path) 1ijbha
250 _logger.debug(f'Number of video frames = {length}') 1ijbha
252 # Check if the GPIO is usable for extraction. GPIO is None if the file does not exist,
253 # is empty, or contains only one value (i.e. doesn't change)
254 if gpio is not None and gpio['indices'].size > 1: 1ijbha
255 _logger.info('Aligning to sync TTLs') 1ba
256 # Extract audio TTLs
257 _, audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials, 1ba
258 task_collection=self.task_collection)
259 _, ts = raw.load_camera_ssv_times(self.session_path, 'left') 1ba
260 """ 1ba
261 There are many sync TTLs that are for some reason missed by the GPIO. Conversely
262 the last GPIO doesn't often correspond to any audio TTL. These will be removed.
263 The drift appears to be less severe than the DAQ, so when assigning TTLs we'll take
264 the nearest TTL within 500ms. The go cue TTLs comprise two short pulses ~3ms apart.
265 We will fuse any TTLs less than 5ms apart to make assignment more accurate.
266 """
267 try: 1ba
268 gpio, audio, ts = groom_pin_state(gpio, audio, ts, take='nearest', 1ba
269 tolerance=.5, min_diff=5e-3, display=display)
270 if count.size > length: 1ba
271 count = count[:length] 1a
272 else:
273 assert length == count.size, 'fewer counts than frames' 1ba
275 return align_with_gpio(raw_ts, audio, gpio, count, 1ba
276 extrapolate_missing, display=display)
277 except AssertionError as ex: 1a
278 _logger.critical('Failed to extract using audio: %s', ex) 1a
280 # If you reach here extracting using audio TTLs was not possible
281 _logger.warning('Alignment by wheel data not yet implemented') 1ijha
282 # Extrapolate at median frame rate
283 n_missing = length - raw_ts.size 1ijha
284 if n_missing > 0: 1ijha
285 _logger.warning(f'{n_missing} fewer Bpod timestamps than frames; ' 1ha
286 f'{"extrapolating" if extrapolate_missing else "appending nans"}')
287 frate = np.median(np.diff(raw_ts)) 1ha
288 to_app = ((np.arange(n_missing, ) + 1) / frate + raw_ts[-1] 1ha
289 if extrapolate_missing
290 else np.full(n_missing, np.nan))
291 raw_ts = np.r_[raw_ts, to_app] # Append the missing times 1ha
292 elif n_missing < 0: 1ijh
293 _logger.warning(f'{abs(n_missing)} fewer frames than Bpod timestamps') 1ijh
294 _logger.info(f'Discarding first {abs(n_missing)} pulses') 1ijh
295 raw_ts = raw_ts[abs(n_missing):] 1ijh
297 return raw_ts 1ijha
299 def _times_from_bpod(self):
300 ntrials = len(self.bpod_trials) 1ijbha
302 cam_times = [] 1ijbha
303 n_frames = 0 1ijbha
304 n_out_of_sync = 0 1ijbha
305 missed_trials = [] 1ijbha
306 for ind in range(ntrials): 1ijbha
307 # get upgoing and downgoing fronts
308 events = self.bpod_trials[ind]['behavior_data']['Events timestamps'] 1ijbha
309 pin = np.array(events.get('Port1In') or [np.nan]) 1ijbha
310 pout = np.array(events.get('Port1Out') or [np.nan]) 1ijbha
311 # some trials at startup may not have the camera working, discard
312 if np.isnan(pin).all(): 1ijbha
313 missed_trials.append(ind) 1a
314 continue 1a
315 # if the trial starts in the middle of a square, discard the first downgoing front
316 if pout[0] < pin[0]: 1ijbha
317 pout = pout[1:] 1ijbha
318 # same if the last sample is during an upgoing front,
319 # always put size as it happens last
320 pin = pin[:pout.size] 1ijbha
321 frate = np.median(np.diff(pin)) 1ijbha
322 if ind > 0: 1ijbha
323 """ 1ijbha
324 assert that the pulses have the same length and that we don't miss frames during
325 the trial, the refresh rate of bpod is 100us
326 """
327 test1 = np.all(np.abs(1 - (pin - pout) / np.median(pin - pout)) < 0.1) 1ijbha
328 test2 = np.all(np.abs(np.diff(pin) - frate) <= 0.00011) 1ijbha
329 if not all([test1, test2]): 1ijbha
330 n_out_of_sync += 1 1bha
331 # grow a list of cam times for ech trial
332 cam_times.append(pin) 1ijbha
333 n_frames += pin.size 1ijbha
335 if missed_trials: 1ijbha
336 _logger.debug('trial(s) %s missing TTL events', range_str(missed_trials)) 1a
337 if n_out_of_sync > 0: 1ijbha
338 _logger.warning(f"{n_out_of_sync} trials with bpod camera frame times not within" 1bha
339 f" 10% of the expected sampling rate")
341 t_first_frame = np.array([c[0] for c in cam_times]) 1ijbha
342 t_last_frame = np.array([c[-1] for c in cam_times]) 1ijbha
343 frate = 1 / np.nanmedian(np.array([np.median(np.diff(c)) for c in cam_times])) 1ijbha
344 intertrial_duration = t_first_frame[1:] - t_last_frame[:-1] 1ijbha
345 intertrial_missed_frames = np.int32(np.round(intertrial_duration * frate)) - 1 1ijbha
347 # initialize the full times array
348 frame_times = np.zeros(n_frames + int(np.sum(intertrial_missed_frames))) 1ijbha
349 ii = 0 1ijbha
350 for trial, cam_time in enumerate(cam_times): 1ijbha
351 if cam_time is not None: 1ijbha
352 # populate first the recovered times within the trials
353 frame_times[ii: ii + cam_time.size] = cam_time 1ijbha
354 ii += cam_time.size 1ijbha
355 if trial == (len(cam_times) - 1): 1ijbha
356 break 1ijbha
357 # then extrapolate in-between
358 nmiss = intertrial_missed_frames[trial] 1ijbha
359 frame_times[ii: ii + nmiss] = (cam_time[-1] + intertrial_duration[trial] / 1ijbha
360 (nmiss + 1) * (np.arange(nmiss) + 1))
361 ii += nmiss 1ijbha
362 assert all(np.diff(frame_times) > 0) # negative diffs implies a big problem 1ijbha
363 return frame_times 1ijbha
366def align_with_gpio(timestamps, ttl, pin_state, count, extrapolate_missing=True, display=False):
367 """
368 Groom the raw DAQ or Bpod camera timestamps using the frame embedded GPIO and frame counter.
370 Parameters
371 ----------
372 timestamps : numpy.array
373 An array of raw DAQ or Bpod camera timestamps.
374 ttl : dict
375 A dictionary of DAQ sync TTLs, with keys {'times', 'polarities'}.
376 pin_state : dict
377 A dictionary containing GPIO pin state values, with keys {'indices', 'polarities'}.
378 count : numpy.array
379 An array of frame numbers.
380 extrapolate_missing : bool
381 If true and the number of timestamps is fewer than the number of frame counts, the
382 remaining timestamps are extrapolated based on the frame rate, otherwise they are NaNs.
383 display : bool
384 Plot the resulting timestamps.
386 Returns
387 -------
388 numpy.array
389 The corrected frame timestamps.
390 """
391 # Some assertions made on the raw data
392 # assert count.size == pin_state.size, 'frame count and pin state size mismatch'
393 assert all(np.diff(count) > 0), 'frame count not strictly increasing' 1dbca
394 assert all(np.diff(timestamps) > 0), 'DAQ/Bpod camera times not strictly increasing' 1dbca
395 same_n_ttl = pin_state['times'].size == ttl['times'].size 1dbca
396 assert same_n_ttl, 'more ttl TTLs detected on camera than TTLs sent' 1dbca
398 """Here we will ensure that the DAQ camera times match the number of video frames in 1dbca
399 length. We will make the following assumptions:
401 1. The number of DAQ camera times is equal to or greater than the number of video frames.
402 2. No TTLs were missed between the camera and DAQ.
403 3. No pin states were missed by Bonsai.
404 4 No pixel count data was missed by Bonsai.
406 In other words the count and pin state arrays accurately reflect the number of frames
407 sent by the camera and should therefore be the same length, and the length of the frame
408 counter should match the number of saved video frames.
410 The missing frame timestamps are removed in three stages:
412 1. Remove any timestamps that occurred before video frame acquisition in Bonsai.
413 2. Remove any timestamps where the frame counter reported missing frames, i.e. remove the
414 dropped frames which occurred throughout the session.
415 3. Remove the trailing timestamps at the end of the session if the camera was turned off
416 in the wrong order.
417 """
418 # Align on first pin state change
419 first_uptick = pin_state['indices'][0] 1dbca
420 first_ttl = np.searchsorted(timestamps, ttl['times'][0]) 1dbca
421 """Here we find up to which index in the DAQ times we discard by taking the difference 1dbca
422 between the index of the first pin state change (when the sync TTL was reported by the
423 camera) and the index of the first sync TTL in DAQ time. We subtract the difference
424 between the frame count at the first pin state change and the index to account for any
425 video frames that were not saved during this period (we will remove those from the
426 camera DAQ times later).
427 """
428 # Minus any frames that were dropped between the start of frame acquisition and the first TTL
429 start = first_ttl - first_uptick - (count[first_uptick] - first_uptick) 1dbca
430 # Get approximate frame rate for extrapolating timestamps (if required)
431 frate = round(1 / np.nanmedian(np.diff(timestamps))) 1dbca
433 if start < 0: 1dbca
434 n_missing = abs(start) 1ba
435 _logger.warning(f'{n_missing} missing DAQ/Bpod timestamp(s) at start; ' 1ba
436 f'{"extrapolating" if extrapolate_missing else "prepending nans"}')
437 to_app = (timestamps[0] - (np.arange(n_missing, 0, -1) + 1) / frate 1ba
438 if extrapolate_missing
439 else np.full(n_missing, np.nan))
440 timestamps = np.r_[to_app, timestamps] # Prepend the missing times 1ba
441 start = 0 1ba
443 # Remove the extraneous timestamps from the beginning and end
444 end = count[-1] + 1 + start 1dbca
445 ts = timestamps[start:end] 1dbca
446 if (n_missing := count[-1] - ts.size + 1) > 0: 1dbca
447 """ 1dbca
448 For ephys sessions there may be fewer DAQ times than frame counts if DAQ acquisition is
449 turned off before the video acquisition workflow. For Bpod this always occurs because Bpod
450 finishes before the camera workflow. For Bpod the times are already extrapolated for
451 these late frames."""
452 _logger.warning(f'{n_missing} fewer DAQ/Bpod timestamps than frame counts; ' 1dbca
453 f'{"extrapolating" if extrapolate_missing else "appending nans"}')
454 to_app = ((np.arange(n_missing, ) + 1) / frate + ts[-1] 1dbca
455 if extrapolate_missing
456 else np.full(n_missing, np.nan))
457 ts = np.r_[ts, to_app] # Append the missing times 1dbca
458 assert ts.size >= count.size, 'fewer timestamps than frame counts' 1dbca
459 assert ts.size == count[-1] + 1, 'more frames recorded in frame count than timestamps ' 1dbca
461 # Remove the rest of the dropped frames
462 ts = ts[count] 1dbca
463 assert np.searchsorted(ts, ttl['times'][0]) == first_uptick, \ 1dbca
464 'time of first sync TTL doesn\'t match after alignment'
465 if ts.size != count.size: 1dbca
466 _logger.error('number of timestamps and frames don\'t match after alignment')
468 if display: 1dbca
469 # Plot to check
470 fig, axes = plt.subplots(1, 1) 1ca
471 y = within_ranges(np.arange(ts.size), pin_state['indices'].reshape(-1, 2)).astype(float) 1ca
472 y *= 1e-5 # For scale when zoomed in 1ca
473 axes.plot(ts, y, marker='d', color='blue', drawstyle='steps-pre', label='GPIO') 1ca
474 axes.plot(ts, np.zeros_like(ts), 'kx', label='DAQ timestamps') 1ca
475 vertical_lines(ttl['times'], ymin=0, ymax=1e-5, 1ca
476 color='r', linestyle=':', ax=axes, label='sync TTL')
477 plt.legend() 1ca
479 return ts 1dbca
482def attribute_times(arr, events, tol=.1, injective=True, take='first'):
483 """
484 Returns the values of the first array that correspond to those of the second.
486 Given two arrays of timestamps, the function will return the values of the first array
487 that most likely correspond to the values of the second. For each of the values in the
488 second array, the absolute difference is taken and the index of either the first sufficiently
489 close value, or simply the closest one, is assigned.
491 If injective is True, once a value has been assigned to an event it can't be assigned to
492 another. In other words there is a one-to-one mapping between the two arrays.
494 Parameters
495 ----------
496 arr : numpy.array
497 An array of event times to attribute to those in `events`.
498 events : numpy.array
499 An array of event times considered a subset of `arr`.
500 tol : float
501 The max absolute difference between values in order to be considered a match.
502 injective : bool
503 If true, once a value has been assigned it will not be assigned again.
504 take : {'first', 'nearest', 'after'}
505 If 'first' the first value within tolerance is assigned; if 'nearest' the
506 closest value is assigned; if 'after' assign the first event after.
508 Returns
509 -------
510 numpy.array
511 An array the same length as `events` containing indices of `arr` corresponding to each
512 event.
513 """
514 if (take := take.lower()) not in ('first', 'nearest', 'after'): 1ofpdbcgae
515 raise ValueError('Parameter `take` must be either "first", "nearest", or "after"') 1o
516 stack = np.ma.masked_invalid(arr, copy=False) 1ofpdbcgae
517 stack.fill_value = np.inf 1ofpdbcgae
518 # If there are no invalid values, the mask is False so let's ensure it's a bool array
519 if stack.mask is np.bool_(0): 1ofpdbcgae
520 stack.mask = np.zeros(arr.shape, dtype=bool)
521 assigned = np.full(events.shape, -1, dtype=int) # Initialize output array 1ofpdbcgae
522 min_tol = 0 if take == 'after' else -tol 1ofpdbcgae
523 for i, x in enumerate(events): 1ofpdbcgae
524 dx = stack.filled() - x 1ofpdbcgae
525 candidates = np.logical_and(min_tol < dx, dx < tol) 1ofpdbcgae
526 if any(candidates): # is any value within tolerance 1ofpdbcgae
527 idx = np.abs(dx).argmin() if take == 'nearest' else np.where(candidates)[0][0] 1ofpdbcgae
528 assigned[i] = idx 1ofpdbcgae
529 stack.mask[idx] = injective # If one-to-one, remove the assigned value 1ofpdbcgae
530 return assigned 1ofpdbcgae
533def groom_pin_state(gpio, ttl, ts, tolerance=2., display=False, take='first', min_diff=0.):
534 """
535 Align the GPIO pin state to the DAQ sync TTLs. Any sync TTLs not reflected in the pin
536 state are removed from the dict and the times of the detected fronts are converted to DAQ
537 time. At the end of this the number of GPIO fronts should equal the number of TTLs.
539 Note:
540 - This function is ultra safe: we probably don't need assign all the ups and down fronts.
541 separately and could potentially even align the timestamps without removing the missed fronts
542 - The input gpio and TTL dicts may be modified by this function.
543 - For training sessions the frame rate is only 30Hz and the TTLs tend to be broken up by
544 small gaps. Setting the min_diff to 5ms helps the timestamp assignment accuracy.
546 Parameters
547 ----------
548 gpio : dict
549 A dictionary containing GPIO pin state values, with keys {'indices', 'polarities'}.
550 ttl : dict
551 A dictionary of DAQ sync TTLs, with keys {'times', 'polarities'}.
552 ts : numpy.array
553 The camera frame times (the camera frame TTLs acquired by the main DAQ).
554 tolerance : float
555 Two pulses need to be within this many seconds to be considered related.
556 display : bool
557 If true, the resulting timestamps are plotted against the raw audio signal.
558 take : {'first', 'nearest'}
559 If 'first' the first value within tolerance is assigned; if 'nearest' the
560 closest value is assigned.
561 min_diff : float
562 Sync TTL fronts less than min_diff seconds apart will be removed.
564 Returns
565 -------
566 dict
567 Dictionary of GPIO DAQ front indices, polarities and DAQ aligned times.
568 dict
569 Sync TTL times and polarities sans the TTLs not detected in the frame data.
570 numpy.array
571 Frame times in DAQ time.
573 See Also
574 --------
575 ibllib.io.extractors.ephys_fpga._get_sync_fronts
576 """
577 # Check that the dimensions match
578 if np.any(gpio['indices'] >= ts.size): 1fdbkcgae
579 _logger.warning('GPIO events occurring beyond timestamps array length') 1g
580 keep = gpio['indices'] < ts.size 1g
581 gpio = {k: gpio[k][keep] for k, v in gpio.items()} 1g
582 assert ttl and ttl['times'].size > 0, 'no sync TTLs for session' 1fdbkcgae
583 assert ttl['times'].size == ttl['polarities'].size, 'sync TTL data dimension mismatch' 1fdbkcgae
584 # make sure that there are no 2 consecutive fall or consecutive rise events
585 assert np.all(np.abs(np.diff(ttl['polarities'])) == 2), 'consecutive high/low sync TTL events' 1fdbkcgae
586 # make sure first TTL is high
587 assert ttl['polarities'][0] == 1 1fdbkcgae
588 # make sure ttl times in order
589 assert np.all(np.diff(ttl['times']) > 0) 1fdbkcgae
590 # make sure raw timestamps increase
591 assert np.all(np.diff(ts) > 0), 'timestamps must strictly increase' 1fdbkcgae
592 # make sure there are state changes
593 assert gpio['indices'].any(), 'no TTLs detected in GPIO' 1fdbkcgae
594 # # make sure first GPIO state is high
595 assert gpio['polarities'][0] == 1 1fdbkcgae
596 """ 1fdbcgae
597 Some sync TTLs appear to be so short that they are not recorded by the camera. These can
598 be as short as a few microseconds. Applying a cutoff based on framerate was unsuccessful.
599 Assigning each sync TTL to each pin state change is not easy because some onsets occur very
600 close together (sometimes < 70ms), on the order of the delay between TTL and frame time.
601 Also, the two clocks have some degree of drift, so the delay between sync TTL and pin state
602 change may be zero or even negative.
604 Here we split the events into sync TTL onsets (lo->hi) and TTL offsets (hi->lo). For each
605 uptick in the GPIO pin state, we take the first TTL onset time that was within 100ms of it.
606 We ensure that each sync TTL is assigned only once, so a TTL that is closer to frame 3 than
607 frame 1 may still be assigned to frame 1.
608 """
609 ifronts = gpio['indices'] # The pin state flips 1fdbcgae
610 sync_times = ttl['times'] 1fdbcgae
611 if ifronts.size != ttl['times'].size: 1fdbcgae
612 _logger.warning('more sync TTLs than GPIO state changes, assigning timestamps') 1fdbcgae
613 to_remove = np.zeros(ifronts.size, dtype=bool) # unassigned GPIO fronts to remove 1fdbcgae
614 low2high = ifronts[gpio['polarities'] == 1] 1fdbcgae
615 high2low = ifronts[gpio['polarities'] == -1] 1fdbcgae
616 assert low2high.size >= high2low.size 1fdbcgae
618 # Remove and/or fuse short TTLs
619 if min_diff > 0: 1fdbcgae
620 short, = np.where(np.diff(ttl['times']) < min_diff) 1fbae
621 sync_times = np.delete(ttl['times'], np.r_[short, short + 1]) 1fbae
622 _logger.debug(f'Removed {short.size * 2} fronts TLLs less than ' 1fbae
623 f'{min_diff * 1e3:.0f}ms apart')
624 assert sync_times.size > 0, f'all sync TTLs less than {min_diff}s' 1fbae
626 # Onsets
627 ups = ts[low2high] - ts[low2high][0] # times relative to first GPIO high 1fdbcgae
628 onsets = sync_times[::2] - sync_times[0] # TTL times relative to first onset 1fdbcgae
629 # assign GPIO fronts to ttl onset
630 assigned = attribute_times(onsets, ups, tol=tolerance, take=take) 1fdbcgae
631 unassigned = np.setdiff1d(np.arange(onsets.size), assigned[assigned > -1]) 1fdbcgae
632 if unassigned.size > 0: 1fdbcgae
633 _logger.debug(f'{unassigned.size} sync TTL rises were not detected by the camera') 1fdbcgae
634 # Check that all pin state upticks could be attributed to an onset TTL
635 if np.any(missed := assigned == -1): 1fdbcgae
636 _logger.warning(f'{sum(missed)} pin state rises could not be attributed to a sync TTL') 1fbgae
637 if display: 1fbgae
638 ax = plt.subplot() 1a
639 vertical_lines(ups[assigned > -1], 1a
640 linestyle='-', color='g', ax=ax,
641 label='assigned GPIO up state')
642 vertical_lines(ups[missed], 1a
643 linestyle='-', color='r', ax=ax,
644 label='unassigned GPIO up state')
645 vertical_lines(onsets[unassigned], 1a
646 linestyle=':', color='k', ax=ax,
647 alpha=0.3, label='sync TTL onset')
648 vertical_lines(onsets[assigned], 1a
649 linestyle=':', color='b', ax=ax, label='assigned TTL onset')
650 plt.legend() 1a
651 plt.show() 1a
652 # Remove the missed fronts
653 to_remove = np.in1d(gpio['indices'], low2high[missed]) 1fbgae
654 assigned = assigned[~missed] 1fbgae
655 onsets_ = sync_times[::2][assigned] 1fdbcgae
657 # Offsets
658 downs = ts[high2low] - ts[high2low][0] 1fdbcgae
659 offsets = sync_times[1::2] - sync_times[1] 1fdbcgae
660 assigned = attribute_times(offsets, downs, tol=tolerance, take=take) 1fdbcgae
661 unassigned = np.setdiff1d(np.arange(offsets.size), assigned[assigned > -1]) 1fdbcgae
662 if unassigned.size > 0: 1fdbcgae
663 _logger.debug(f'{unassigned.size} sync TTL falls were not detected by the camera') 1dbcgae
664 # Check that all pin state downticks could be attributed to an offset TTL
665 if np.any(missed := assigned == -1): 1fdbcgae
666 _logger.warning(f'{sum(missed)} pin state falls could not be attributed to a sync TTL') 1bgae
667 # Remove the missed fronts
668 to_remove |= np.in1d(gpio['indices'], high2low[missed]) 1bgae
669 assigned = assigned[~missed] 1bgae
670 offsets_ = sync_times[1::2][assigned] 1fdbcgae
672 # Sync TTLs groomed
673 if np.any(to_remove): 1fdbcgae
674 # Check for any orphaned fronts (only one pin state edge was assigned)
675 to_remove = np.pad(to_remove, (0, to_remove.size % 2), 'edge') # Ensure even size 1fbgae
676 # Perform xor to find GPIOs where only onset or offset is marked for removal
677 orphaned = to_remove.reshape(-1, 2).sum(axis=1) == 1 1fbgae
678 if orphaned.any(): 1fbgae
679 """If there are orphaned GPIO fronts (i.e. only one edge was assigned to a sync 1fe
680 TTL front), remove the orphaned front its assigned sync TTL. In other words
681 if both edges cannot be assigned to a sync TTL, we ignore the TTL entirely.
682 This is a sign that the assignment was bad and extraction may fail."""
683 _logger.warning('Some onsets but not offsets (or vice versa) were not assigned; ' 1fe
684 'this may be a sign of faulty wiring or clock drift')
685 # Find indices of GPIO upticks where only the downtick was marked for removal
686 orphaned_onsets, = np.where(~to_remove.reshape(-1, 2)[:, 0] & orphaned) 1fe
687 # The onsets_ array already has the other TTLs removed (same size as to_remove ==
688 # False) so subtract the number of removed elements from index.
689 for i, v in enumerate(orphaned_onsets): 1fe
690 orphaned_onsets[i] -= to_remove.reshape(-1, 2)[:v, 0].sum() 1e
691 # Same for offsets...
692 orphaned_offsets, = np.where(~to_remove.reshape(-1, 2)[:, 1] & orphaned) 1fe
693 for i, v in enumerate(orphaned_offsets): 1fe
694 orphaned_offsets[i] -= to_remove.reshape(-1, 2)[:v, 1].sum() 1fe
695 # Remove orphaned ttl onsets and offsets
696 onsets_ = np.delete(onsets_, orphaned_onsets[orphaned_onsets < onsets_.size]) 1fe
697 offsets_ = np.delete(offsets_, orphaned_offsets[orphaned_offsets < offsets_.size]) 1fe
698 _logger.debug(f'{orphaned.sum()} orphaned TTLs removed') 1fe
699 to_remove.reshape(-1, 2)[orphaned] = True 1fe
701 # Remove those unassigned GPIOs
702 gpio = {k: v[~to_remove[:v.size]] for k, v in gpio.items()} 1fbgae
703 ifronts = gpio['indices'] 1fbgae
705 # Assert that we've removed discrete TTLs
706 # A failure means e.g. an up-going front of one TTL was missed but not the down-going one.
707 assert np.all(np.abs(np.diff(gpio['polarities'])) == 2) 1fbgae
708 assert gpio['polarities'][0] == 1 1fbgae
710 ttl_ = {'times': np.empty(ifronts.size), 'polarities': gpio['polarities']} 1fdbcgae
711 ttl_['times'][::2] = onsets_ 1fdbcgae
712 ttl_['times'][1::2] = offsets_ 1fdbcgae
713 else:
714 ttl_ = ttl.copy() 1fg
716 # Align the frame times to DAQ
717 fcn_a2b, drift_ppm = dsp.sync_timestamps(ts[ifronts], ttl_['times']) 1fdbcgae
718 _logger.debug(f'frame ttl alignment drift = {drift_ppm:.2f}ppm') 1fdbcgae
719 # Add times to GPIO dict
720 gpio['times'] = fcn_a2b(ts[ifronts]) 1fdbcgae
722 if display: 1fdbcgae
723 # Plot all the onsets and offsets
724 ax = plt.subplot() 1ca
725 # All sync TTLs
726 squares(ttl['times'], ttl['polarities'], 1ca
727 ax=ax, label='sync TTLs', linestyle=':', color='k', yrange=[0, 1], alpha=0.3)
728 # GPIO
729 x = np.insert(gpio['times'], 0, 0) 1ca
730 y = np.arange(x.size) % 2 1ca
731 squares(x, y, ax=ax, label='GPIO') 1ca
732 y = within_ranges(np.arange(ts.size), ifronts.reshape(-1, 2)) # 0 or 1 for each frame 1ca
733 ax.plot(fcn_a2b(ts), y, 'kx', label='cam times') 1ca
734 # Assigned ttl
735 squares(ttl_['times'], ttl_['polarities'], 1ca
736 ax=ax, label='assigned sync TTL', linestyle=':', color='g', yrange=[0, 1])
737 ax.legend() 1ca
738 plt.xlabel('DAQ time (s)') 1ca
739 ax.set_yticks([0, 1]) 1ca
740 ax.set_title('GPIO - sync TTL alignment') 1ca
741 plt.show() 1ca
743 return gpio, ttl_, fcn_a2b(ts) 1fdbcgae