Coverage for ibllib/io/extractors/camera.py: 96%
366 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1""" Camera extractor functions.
3This module handles extraction of camera timestamps for both Bpod and DAQ.
4"""
5import logging
6from functools import partial
8import cv2
9import numpy as np
10import matplotlib.pyplot as plt
11from iblutil.util import range_str
13import ibldsp.utils as dsp
14from ibllib.plots import squares, vertical_lines
15from ibllib.io.video import assert_valid_label, VideoStreamer
16from iblutil.numerical import within_ranges
17from ibllib.io.extractors.base import get_session_extractor_type
18from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map
19import ibllib.io.raw_data_loaders as raw
20import ibllib.io.extractors.video_motion as vmotion
21from ibllib.io.extractors.base import (
22 BaseBpodTrialsExtractor,
23 BaseExtractor,
24 run_extractor_classes,
25 _get_task_types_json_config
26)
28_logger = logging.getLogger(__name__)
31def extract_camera_sync(sync, chmap=None):
32 """
33 Extract camera timestamps from the sync matrix.
35 :param sync: dictionary 'times', 'polarities' of fronts detected on sync trace
36 :param chmap: dictionary containing channel indices. Default to constant.
37 :return: dictionary containing camera timestamps
38 """
39 assert chmap 1qtpgwrefd
40 times = {} 1qtpgwrefd
41 for k in filter(lambda x: x.endswith('_camera'), chmap): 1qtpgwrefd
42 label, _ = k.rsplit('_', 1) 1qtpgwrefd
43 times[label] = get_sync_fronts(sync, chmap[k]).times[::2] 1qtpgwrefd
44 return times 1qtpgwrefd
47def get_video_length(video_path):
48 """
49 Returns video length.
51 :param video_path: A path to the video
52 :return:
53 """
54 is_url = isinstance(video_path, str) and video_path.startswith('http') 1lmqpkgcrefdxoban
55 cap = VideoStreamer(video_path).cap if is_url else cv2.VideoCapture(str(video_path)) 1lmqpkgcrefdxoban
56 assert cap.isOpened(), f'Failed to open video file {video_path}' 1lmqpkgcrefdxoban
57 length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 1lmqpkgcrefdxoban
58 cap.release() 1lmqpkgcrefdxoban
59 return length 1lmqpkgcrefdxoban
62class CameraTimestampsFPGA(BaseExtractor):
63 """Extractor for videos using DAQ sync and channel map."""
65 def __init__(self, label, session_path=None):
66 super().__init__(session_path) 1qpgrefd
67 self.label = assert_valid_label(label) 1qpgrefd
68 self.save_names = f'_ibl_{label}Camera.times.npy' 1qpgrefd
69 self.var_names = f'{label}_camera_timestamps' 1qpgrefd
70 self._log_level = _logger.level 1qpgrefd
71 _logger.setLevel(logging.DEBUG) 1qpgrefd
73 def __del__(self):
74 _logger.setLevel(self._log_level) 1sqpgef
76 def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio',
77 display=False, extrapolate_missing=True, **kwargs):
78 """
79 The raw timestamps are taken from the DAQ. These are the times of the camera's frame TTLs.
80 If the pin state file exists, these timestamps are aligned to the video frames using
81 task TTLs (typically the audio TTLs). Frames missing from the embedded frame count are
82 removed from the timestamps array.
83 If the pin state file does not exist, the left and right camera timestamps may be aligned
84 using the wheel data.
86 Parameters
87 ----------
88 sync : dict
89 Dictionary 'times', 'polarities' of fronts detected on sync trace.
90 chmap : dict
91 Dictionary containing channel indices. Default to constant.
92 video_path : str, pathlib.Path, int
93 An optional path for fetching the number of frames. If None, the video is loaded from
94 the session path. If an int is provided this is taken to be the total number of frames.
95 sync_label : str
96 The sync label of the channel that's wired to the GPIO for synchronising the times.
97 display : bool
98 If true, the TTL and GPIO fronts are plotted.
99 extrapolate_missing : bool
100 If true, any missing timestamps at the beginning and end of the session are
101 extrapolated based on the median frame rate, otherwise they will be NaNs.
102 **kwargs
103 Extra keyword arguments (unused).
105 Returns
106 -------
107 numpy.array
108 The extracted camera timestamps.
109 """
110 fpga_times = extract_camera_sync(sync=sync, chmap=chmap) 1qpgrefd
111 count, (*_, gpio) = raw.load_embedded_frame_data(self.session_path, self.label) 1qpgrefd
112 raw_ts = fpga_times[self.label] 1qpgrefd
114 if video_path is None: 1qpgrefd
115 filename = f'_iblrig_{self.label}Camera.raw.mp4' 1qprefd
116 video_path = self.session_path.joinpath('raw_video_data', filename) 1qprefd
117 # Permit the video path to be the length for development and debugging purposes
118 length = (video_path if isinstance(video_path, int) else get_video_length(video_path)) 1qpgrefd
119 _logger.debug(f'Number of video frames = {length}') 1qpgrefd
121 if gpio is not None and gpio['indices'].size > 1 and sync_label is not None: 1qpgrefd
122 _logger.info(f'Aligning to {sync_label} TTLs') 1grefd
123 # Extract sync TTLs
124 ttl = get_sync_fronts(sync, chmap[sync_label]) 1grefd
125 _, ts = raw.load_camera_ssv_times(self.session_path, self.label) 1grefd
126 try: 1grefd
127 """ 1grefd
128 NB: Some of the sync TTLs occur very close together, and are therefore not
129 reflected in the pin state. This function removes those. Also converts frame
130 times to DAQ time.
131 """
132 gpio, ttl, ts = groom_pin_state(gpio, ttl, ts, display=display) 1grefd
133 """ 1gefd
134 The length of the count and pin state are regularly longer than the length of
135 the video file. Here we assert that the video is either shorter or the same
136 length as the arrays, and we make an assumption that the missing frames are
137 right at the end of the video. We therefore simply shorten the arrays to match
138 the length of the video.
139 """
140 if count.size > length: 1gefd
141 count = count[:length] 1gefd
142 else:
143 assert length == count.size, 'fewer counts than frames' 1ef
144 assert raw_ts.shape[0] > 0, 'no timestamps found in channel indicated for ' \ 1gefd
145 f'{self.label} camera'
146 return align_with_gpio(raw_ts, ttl, gpio, count, 1gefd
147 display=display,
148 extrapolate_missing=extrapolate_missing)
149 except AssertionError as ex: 1r
150 _logger.critical('Failed to extract using %s: %s', sync_label, ex) 1r
152 # If you reach here extracting using sync TTLs was not possible, we attempt to align using wheel motion energy
153 _logger.warning('Attempting to align using wheel') 1qpr
155 try: 1qpr
156 if self.label not in ['left', 'right']: 1qpr
157 # Can only use wheel alignment for left and right cameras
158 raise ValueError(f'Wheel alignment not supported for {self.label} camera') 1qp
160 motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, sync='nidq', upload=True) 1qpr
161 new_times = motion_class.process()
162 if not motion_class.qc_outcome:
163 raise ValueError(f'Wheel alignment for {self.label} camera failed to pass qc: {motion_class.qc}')
164 else:
165 _logger.warning(f'Wheel alignment for {self.label} camera successful, qc: {motion_class.qc}')
166 return new_times
168 except Exception as err: 1qpr
169 _logger.critical(f'Failed to align with wheel for {self.label} camera: {err}') 1qpr
171 if length < raw_ts.size: 1qpr
172 df = raw_ts.size - length 1qp
173 _logger.info(f'Discarding first {df} pulses') 1qp
174 raw_ts = raw_ts[df:] 1qp
176 return raw_ts 1qpr
179class CameraTimestampsCamlog(BaseExtractor):
180 def __init__(self, label, session_path=None):
181 super().__init__(session_path) 1t
182 self.label = assert_valid_label(label) 1t
183 self.save_names = f'_ibl_{label}Camera.times.npy' 1t
184 self.var_names = f'{label}_camera_timestamps' 1t
185 self._log_level = _logger.level 1t
186 _logger.setLevel(logging.DEBUG) 1t
188 def __del__(self):
189 _logger.setLevel(self._log_level) 1t
191 def _extract(self, sync=None, chmap=None, video_path=None,
192 display=False, extrapolate_missing=True, **kwargs):
194 fpga_times = extract_camera_sync(sync=sync, chmap=chmap) 1t
195 video_frames = get_video_length(self.session_path.joinpath('raw_video_data', f'_iblrig_{self.label}Camera.raw.mp4')) 1t
196 raw_ts = fpga_times[self.label] 1t
198 # For left camera sometimes we have one extra pulse than video frame
199 if (raw_ts.size - video_frames) == 1: 1t
200 _logger.warning(f'One extra sync pulse detected for {self.label} camera')
201 raw_ts = raw_ts[:-1]
202 elif (raw_ts.size - video_frames) == -1: 1t
203 _logger.warning(f'One extra video frame detected for {self.label} camera')
204 med_time = np.median(np.diff(raw_ts))
205 raw_ts = np.r_[raw_ts, np.array([raw_ts[-1] + med_time])]
207 assert video_frames == raw_ts.size, f'dimension mismatch between video frames and TTL pulses for {self.label} camera' \ 1t
208 f'by {np.abs(video_frames - raw_ts.size)} frames'
210 return raw_ts 1t
213class CameraTimestampsBpod(BaseBpodTrialsExtractor):
214 """
215 Get the camera timestamps from the Bpod
217 The camera events are logged only during the events not in between, so the times need
218 to be interpolated
219 """
220 save_names = '_ibl_leftCamera.times.npy'
221 var_names = 'left_camera_timestamps'
223 def __init__(self, *args, **kwargs):
224 super().__init__(*args, **kwargs) 1lmkcoban
225 self._log_level = _logger.level 1lmkcoban
226 _logger.setLevel(logging.DEBUG) 1lmkcoban
228 def __del__(self):
229 _logger.setLevel(self._log_level) 1slmkcbn
231 def _extract(self, video_path=None, display=False, extrapolate_missing=True, **kwargs):
232 """
233 The raw timestamps are taken from the Bpod. These are the times of the camera's frame TTLs.
234 If the pin state file exists, these timestamps are aligned to the video frames using the
235 sync TTLs. Frames missing from the embedded frame count are removed from the timestamps
236 array.
237 If the pin state file does not exist, the left camera timestamps may be aligned using the
238 wheel data.
239 :param video_path: an optional path for fetching the number of frames. If None,
240 the video is loaded from the session path. If an int is provided this is taken to be
241 the total number of frames.
242 :param display: if True, the TTL and GPIO fronts are plotted.
243 :param extrapolate_missing: if True, any missing timestamps at the beginning and end of
244 the session are extrapolated based on the median frame rate, otherwise they will be NaNs.
245 :return: a numpy array of camera timestamps
246 """
247 raw_ts = self._times_from_bpod() 1lmkcoban
248 count, (*_, gpio) = raw.load_embedded_frame_data(self.session_path, 'left') 1lmkcoban
249 if video_path is None: 1lmkcoban
250 filename = '_iblrig_leftCamera.raw.mp4' 1lmoban
251 video_path = self.session_path.joinpath('raw_video_data', filename) 1lmoban
252 # Permit the video path to be the length for development and debugging purposes
253 length = video_path if isinstance(video_path, int) else get_video_length(video_path) 1lmkcoban
254 _logger.debug(f'Number of video frames = {length}') 1lmkcoban
256 # Check if the GPIO is usable for extraction. GPIO is None if the file does not exist,
257 # is empty, or contains only one value (i.e. doesn't change)
258 if gpio is not None and gpio['indices'].size > 1: 1lmkcoban
259 _logger.info('Aligning to sync TTLs') 1cba
260 # Extract audio TTLs
261 _, audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials, 1cba
262 task_collection=self.task_collection)
263 _, ts = raw.load_camera_ssv_times(self.session_path, 'left') 1cba
264 """ 1cba
265 There are many sync TTLs that are for some reason missed by the GPIO. Conversely
266 the last GPIO doesn't often correspond to any audio TTL. These will be removed.
267 The drift appears to be less severe than the DAQ, so when assigning TTLs we'll take
268 the nearest TTL within 500ms. The go cue TTLs comprise two short pulses ~3ms apart.
269 We will fuse any TTLs less than 5ms apart to make assignment more accurate.
270 """
271 try: 1cba
272 gpio, audio, ts = groom_pin_state(gpio, audio, ts, take='nearest', 1cba
273 tolerance=.5, min_diff=5e-3, display=display)
274 if count.size > length: 1cba
275 count = count[:length] 1a
276 else:
277 assert length == count.size, 'fewer counts than frames' 1cba
279 return align_with_gpio(raw_ts, audio, gpio, count, 1cba
280 extrapolate_missing, display=display)
281 except AssertionError as ex: 1a
282 _logger.critical('Failed to extract using audio: %s', ex) 1a
284 # If you reach here extracting using audio TTLs was not possible
285 _logger.warning('Alignment by wheel data not yet implemented') 1lmkoan
286 # Extrapolate at median frame rate
287 n_missing = length - raw_ts.size 1lmkoan
288 if n_missing > 0: 1lmkoan
289 _logger.warning(f'{n_missing} fewer Bpod timestamps than frames; ' 1oa
290 f'{"extrapolating" if extrapolate_missing else "appending nans"}')
291 frate = np.median(np.diff(raw_ts)) 1oa
292 to_app = ((np.arange(n_missing, ) + 1) / frate + raw_ts[-1] 1oa
293 if extrapolate_missing
294 else np.full(n_missing, np.nan))
295 raw_ts = np.r_[raw_ts, to_app] # Append the missing times 1oa
296 elif n_missing < 0: 1lmkon
297 _logger.warning(f'{abs(n_missing)} fewer frames than Bpod timestamps') 1lmon
298 _logger.info(f'Discarding first {abs(n_missing)} pulses') 1lmon
299 raw_ts = raw_ts[abs(n_missing):] 1lmon
301 return raw_ts 1lmkoan
303 def _times_from_bpod(self):
304 ntrials = len(self.bpod_trials) 1lmkcoban
306 cam_times = [] 1lmkcoban
307 n_frames = 0 1lmkcoban
308 n_out_of_sync = 0 1lmkcoban
309 missed_trials = [] 1lmkcoban
310 for ind in np.arange(ntrials): 1lmkcoban
311 # get upgoing and downgoing fronts
312 pin = np.array(self.bpod_trials[ind]['behavior_data'] 1lmkcoban
313 ['Events timestamps'].get('Port1In'))
314 pout = np.array(self.bpod_trials[ind]['behavior_data'] 1lmkcoban
315 ['Events timestamps'].get('Port1Out'))
316 # some trials at startup may not have the camera working, discard
317 if np.all(pin) is None: 1lmkcoban
318 missed_trials.append(ind) 1ka
319 continue 1ka
320 # if the trial starts in the middle of a square, discard the first downgoing front
321 if pout[0] < pin[0]: 1lmkcoban
322 pout = pout[1:] 1lmkcoban
323 # same if the last sample is during an upgoing front,
324 # always put size as it happens last
325 pin = pin[:pout.size] 1lmkcoban
326 frate = np.median(np.diff(pin)) 1lmkcoban
327 if ind > 0: 1lmkcoban
328 """ 1lmkcoban
329 assert that the pulses have the same length and that we don't miss frames during
330 the trial, the refresh rate of bpod is 100us
331 """
332 test1 = np.all(np.abs(1 - (pin - pout) / np.median(pin - pout)) < 0.1) 1lmkcoban
333 test2 = np.all(np.abs(np.diff(pin) - frate) <= 0.00011) 1lmkcoban
334 if not all([test1, test2]): 1lmkcoban
335 n_out_of_sync += 1 1coba
336 # grow a list of cam times for ech trial
337 cam_times.append(pin) 1lmkcoban
338 n_frames += pin.size 1lmkcoban
340 if missed_trials: 1lmkcoban
341 _logger.debug('trial(s) %s missing TTL events', range_str(missed_trials)) 1ka
342 if n_out_of_sync > 0: 1lmkcoban
343 _logger.warning(f"{n_out_of_sync} trials with bpod camera frame times not within" 1coba
344 f" 10% of the expected sampling rate")
346 t_first_frame = np.array([c[0] for c in cam_times]) 1lmkcoban
347 t_last_frame = np.array([c[-1] for c in cam_times]) 1lmkcoban
348 frate = 1 / np.nanmedian(np.array([np.median(np.diff(c)) for c in cam_times])) 1lmkcoban
349 intertrial_duration = t_first_frame[1:] - t_last_frame[:-1] 1lmkcoban
350 intertrial_missed_frames = np.int32(np.round(intertrial_duration * frate)) - 1 1lmkcoban
352 # initialize the full times array
353 frame_times = np.zeros(n_frames + int(np.sum(intertrial_missed_frames))) 1lmkcoban
354 ii = 0 1lmkcoban
355 for trial, cam_time in enumerate(cam_times): 1lmkcoban
356 if cam_time is not None: 1lmkcoban
357 # populate first the recovered times within the trials
358 frame_times[ii: ii + cam_time.size] = cam_time 1lmkcoban
359 ii += cam_time.size 1lmkcoban
360 if trial == (len(cam_times) - 1): 1lmkcoban
361 break 1lmkcoban
362 # then extrapolate in-between
363 nmiss = intertrial_missed_frames[trial] 1lmkcoban
364 frame_times[ii: ii + nmiss] = (cam_time[-1] + intertrial_duration[trial] / 1lmkcoban
365 (nmiss + 1) * (np.arange(nmiss) + 1))
366 ii += nmiss 1lmkcoban
367 assert all(np.diff(frame_times) > 0) # negative diffs implies a big problem 1lmkcoban
368 return frame_times 1lmkcoban
371def align_with_gpio(timestamps, ttl, pin_state, count, extrapolate_missing=True, display=False):
372 """
373 Groom the raw DAQ or Bpod camera timestamps using the frame embedded GPIO and frame counter.
375 Parameters
376 ----------
377 timestamps : numpy.array
378 An array of raw DAQ or Bpod camera timestamps.
379 ttl : dict
380 A dictionary of DAQ sync TTLs, with keys {'times', 'polarities'}.
381 pin_state : dict
382 A dictionary containing GPIO pin state values, with keys {'indices', 'polarities'}.
383 count : numpy.array
384 An array of frame numbers.
385 extrapolate_missing : bool
386 If true and the number of timestamps is fewer than the number of frame counts, the
387 remaining timestamps are extrapolated based on the frame rate, otherwise they are NaNs.
388 display : bool
389 Plot the resulting timestamps.
391 Returns
392 -------
393 numpy.array
394 The corrected frame timestamps.
395 """
396 # Some assertions made on the raw data
397 # assert count.size == pin_state.size, 'frame count and pin state size mismatch'
398 assert all(np.diff(count) > 0), 'frame count not strictly increasing' 1gcefdba
399 assert all(np.diff(timestamps) > 0), 'DAQ/Bpod camera times not strictly increasing' 1gcefdba
400 same_n_ttl = pin_state['times'].size == ttl['times'].size 1gcefdba
401 assert same_n_ttl, 'more ttl TTLs detected on camera than TTLs sent' 1gcefdba
403 """Here we will ensure that the DAQ camera times match the number of video frames in 1gcefdba
404 length. We will make the following assumptions:
406 1. The number of DAQ camera times is equal to or greater than the number of video frames.
407 2. No TTLs were missed between the camera and DAQ.
408 3. No pin states were missed by Bonsai.
409 4 No pixel count data was missed by Bonsai.
411 In other words the count and pin state arrays accurately reflect the number of frames
412 sent by the camera and should therefore be the same length, and the length of the frame
413 counter should match the number of saved video frames.
415 The missing frame timestamps are removed in three stages:
417 1. Remove any timestamps that occurred before video frame acquisition in Bonsai.
418 2. Remove any timestamps where the frame counter reported missing frames, i.e. remove the
419 dropped frames which occurred throughout the session.
420 3. Remove the trailing timestamps at the end of the session if the camera was turned off
421 in the wrong order.
422 """
423 # Align on first pin state change
424 first_uptick = pin_state['indices'][0] 1gcefdba
425 first_ttl = np.searchsorted(timestamps, ttl['times'][0]) 1gcefdba
426 """Here we find up to which index in the DAQ times we discard by taking the difference 1gcefdba
427 between the index of the first pin state change (when the sync TTL was reported by the
428 camera) and the index of the first sync TTL in DAQ time. We subtract the difference
429 between the frame count at the first pin state change and the index to account for any
430 video frames that were not saved during this period (we will remove those from the
431 camera DAQ times later).
432 """
433 # Minus any frames that were dropped between the start of frame acquisition and the first TTL
434 start = first_ttl - first_uptick - (count[first_uptick] - first_uptick) 1gcefdba
435 # Get approximate frame rate for extrapolating timestamps (if required)
436 frate = round(1 / np.nanmedian(np.diff(timestamps))) 1gcefdba
438 if start < 0: 1gcefdba
439 n_missing = abs(start) 1cba
440 _logger.warning(f'{n_missing} missing DAQ/Bpod timestamp(s) at start; ' 1cba
441 f'{"extrapolating" if extrapolate_missing else "prepending nans"}')
442 to_app = (timestamps[0] - (np.arange(n_missing, 0, -1) + 1) / frate 1cba
443 if extrapolate_missing
444 else np.full(n_missing, np.nan))
445 timestamps = np.r_[to_app, timestamps] # Prepend the missing times 1cba
446 start = 0 1cba
448 # Remove the extraneous timestamps from the beginning and end
449 end = count[-1] + 1 + start 1gcefdba
450 ts = timestamps[start:end] 1gcefdba
451 if (n_missing := count[-1] - ts.size + 1) > 0: 1gcefdba
452 """ 1gcefdba
453 For ephys sessions there may be fewer DAQ times than frame counts if DAQ acquisition is
454 turned off before the video acquisition workflow. For Bpod this always occurs because Bpod
455 finishes before the camera workflow. For Bpod the times are already extrapolated for
456 these late frames."""
457 _logger.warning(f'{n_missing} fewer DAQ/Bpod timestamps than frame counts; ' 1gcefdba
458 f'{"extrapolating" if extrapolate_missing else "appending nans"}')
459 to_app = ((np.arange(n_missing, ) + 1) / frate + ts[-1] 1gcefdba
460 if extrapolate_missing
461 else np.full(n_missing, np.nan))
462 ts = np.r_[ts, to_app] # Append the missing times 1gcefdba
463 assert ts.size >= count.size, 'fewer timestamps than frame counts' 1gcefdba
464 assert ts.size == count[-1] + 1, 'more frames recorded in frame count than timestamps ' 1gcefdba
466 # Remove the rest of the dropped frames
467 ts = ts[count] 1gcefdba
468 assert np.searchsorted(ts, ttl['times'][0]) == first_uptick, \ 1gcefdba
469 'time of first sync TTL doesn\'t match after alignment'
470 if ts.size != count.size: 1gcefdba
471 _logger.error('number of timestamps and frames don\'t match after alignment')
473 if display: 1gcefdba
474 # Plot to check
475 fig, axes = plt.subplots(1, 1) 1da
476 y = within_ranges(np.arange(ts.size), pin_state['indices'].reshape(-1, 2)).astype(float) 1da
477 y *= 1e-5 # For scale when zoomed in 1da
478 axes.plot(ts, y, marker='d', color='blue', drawstyle='steps-pre', label='GPIO') 1da
479 axes.plot(ts, np.zeros_like(ts), 'kx', label='DAQ timestamps') 1da
480 vertical_lines(ttl['times'], ymin=0, ymax=1e-5, 1da
481 color='r', linestyle=':', ax=axes, label='sync TTL')
482 plt.legend() 1da
484 return ts 1gcefdba
487def attribute_times(arr, events, tol=.1, injective=True, take='first'):
488 """
489 Returns the values of the first array that correspond to those of the second.
491 Given two arrays of timestamps, the function will return the values of the first array
492 that most likely correspond to the values of the second. For each of the values in the
493 second array, the absolute difference is taken and the index of either the first sufficiently
494 close value, or simply the closest one, is assigned.
496 If injective is True, once a value has been assigned to an event it can't be assigned to
497 another. In other words there is a one-to-one mapping between the two arrays.
499 Parameters
500 ----------
501 arr : numpy.array
502 An array of event times to attribute to those in `events`.
503 events : numpy.array
504 An array of event times considered a subset of `arr`.
505 tol : float
506 The max absolute difference between values in order to be considered a match.
507 injective : bool
508 If true, once a value has been assigned it will not be assigned again.
509 take : {'first', 'nearest', 'after'}
510 If 'first' the first value within tolerance is assigned; if 'nearest' the
511 closest value is assigned; if 'after' assign the first event after.
513 Returns
514 -------
515 numpy.array
516 An array the same length as `events` containing indices of `arr` corresponding to each
517 event.
518 """
519 if (take := take.lower()) not in ('first', 'nearest', 'after'): 1uivgcefdjbah
520 raise ValueError('Parameter `take` must be either "first", "nearest", or "after"') 1u
521 stack = np.ma.masked_invalid(arr, copy=False) 1uivgcefdjbah
522 stack.fill_value = np.inf 1uivgcefdjbah
523 # If there are no invalid values, the mask is False so let's ensure it's a bool array
524 if stack.mask is np.bool_(0): 1uivgcefdjbah
525 stack.mask = np.zeros(arr.shape, dtype=bool)
526 assigned = np.full(events.shape, -1, dtype=int) # Initialize output array 1uivgcefdjbah
527 min_tol = 0 if take == 'after' else -tol 1uivgcefdjbah
528 for i, x in enumerate(events): 1uivgcefdjbah
529 dx = stack.filled() - x 1uivgcefdjbah
530 candidates = np.logical_and(min_tol < dx, dx < tol) 1uivgcefdjbah
531 if any(candidates): # is any value within tolerance 1uivgcefdjbah
532 idx = np.abs(dx).argmin() if take == 'nearest' else np.where(candidates)[0][0] 1uivgcefdjbah
533 assigned[i] = idx 1uivgcefdjbah
534 stack.mask[idx] = injective # If one-to-one, remove the assigned value 1uivgcefdjbah
535 return assigned 1uivgcefdjbah
538def groom_pin_state(gpio, ttl, ts, tolerance=2., display=False, take='first', min_diff=0.):
539 """
540 Align the GPIO pin state to the DAQ sync TTLs. Any sync TTLs not reflected in the pin
541 state are removed from the dict and the times of the detected fronts are converted to DAQ
542 time. At the end of this the number of GPIO fronts should equal the number of TTLs.
544 Note:
545 - This function is ultra safe: we probably don't need assign all the ups and down fronts.
546 separately and could potentially even align the timestamps without removing the missed fronts
547 - The input gpio and TTL dicts may be modified by this function.
548 - For training sessions the frame rate is only 30Hz and the TTLs tend to be broken up by
549 small gaps. Setting the min_diff to 5ms helps the timestamp assignment accuracy.
551 Parameters
552 ----------
553 gpio : dict
554 A dictionary containing GPIO pin state values, with keys {'indices', 'polarities'}.
555 ttl : dict
556 A dictionary of DAQ sync TTLs, with keys {'times', 'polarities'}.
557 ts : numpy.array
558 The camera frame times (the camera frame TTLs acquired by the main DAQ).
559 tolerance : float
560 Two pulses need to be within this many seconds to be considered related.
561 display : bool
562 If true, the resulting timestamps are plotted against the raw audio signal.
563 take : {'first', 'nearest'}
564 If 'first' the first value within tolerance is assigned; if 'nearest' the
565 closest value is assigned.
566 min_diff : float
567 Sync TTL fronts less than min_diff seconds apart will be removed.
569 Returns
570 -------
571 dict
572 Dictionary of GPIO DAQ front indices, polarities and DAQ aligned times.
573 dict
574 Sync TTL times and polarities sans the TTLs not detected in the frame data.
575 numpy.array
576 Frame times in DAQ time.
578 See Also
579 --------
580 ibllib.io.extractors.ephys_fpga._get_sync_fronts
581 """
582 # Check that the dimensions match
583 if np.any(gpio['indices'] >= ts.size): 1igcrefdjbah
584 _logger.warning('GPIO events occurring beyond timestamps array length') 1j
585 keep = gpio['indices'] < ts.size 1j
586 gpio = {k: gpio[k][keep] for k, v in gpio.items()} 1j
587 assert ttl and ttl['times'].size > 0, 'no sync TTLs for session' 1igcrefdjbah
588 assert ttl['times'].size == ttl['polarities'].size, 'sync TTL data dimension mismatch' 1igcrefdjbah
589 # make sure that there are no 2 consecutive fall or consecutive rise events
590 assert np.all(np.abs(np.diff(ttl['polarities'])) == 2), 'consecutive high/low sync TTL events' 1igcrefdjbah
591 # make sure first TTL is high
592 assert ttl['polarities'][0] == 1 1igcrefdjbah
593 # make sure ttl times in order
594 assert np.all(np.diff(ttl['times']) > 0) 1igcrefdjbah
595 # make sure raw timestamps increase
596 assert np.all(np.diff(ts) > 0), 'timestamps must strictly increase' 1igcrefdjbah
597 # make sure there are state changes
598 assert gpio['indices'].any(), 'no TTLs detected in GPIO' 1igcrefdjbah
599 # # make sure first GPIO state is high
600 assert gpio['polarities'][0] == 1 1igcrefdjbah
601 """ 1igcefdjbah
602 Some sync TTLs appear to be so short that they are not recorded by the camera. These can
603 be as short as a few microseconds. Applying a cutoff based on framerate was unsuccessful.
604 Assigning each sync TTL to each pin state change is not easy because some onsets occur very
605 close together (sometimes < 70ms), on the order of the delay between TTL and frame time.
606 Also, the two clocks have some degree of drift, so the delay between sync TTL and pin state
607 change may be zero or even negative.
609 Here we split the events into sync TTL onsets (lo->hi) and TTL offsets (hi->lo). For each
610 uptick in the GPIO pin state, we take the first TTL onset time that was within 100ms of it.
611 We ensure that each sync TTL is assigned only once, so a TTL that is closer to frame 3 than
612 frame 1 may still be assigned to frame 1.
613 """
614 ifronts = gpio['indices'] # The pin state flips 1igcefdjbah
615 sync_times = ttl['times'] 1igcefdjbah
616 if ifronts.size != ttl['times'].size: 1igcefdjbah
617 _logger.warning('more sync TTLs than GPIO state changes, assigning timestamps') 1igcefdjbah
618 to_remove = np.zeros(ifronts.size, dtype=bool) # unassigned GPIO fronts to remove 1igcefdjbah
619 low2high = ifronts[gpio['polarities'] == 1] 1igcefdjbah
620 high2low = ifronts[gpio['polarities'] == -1] 1igcefdjbah
621 assert low2high.size >= high2low.size 1igcefdjbah
623 # Remove and/or fuse short TTLs
624 if min_diff > 0: 1igcefdjbah
625 short, = np.where(np.diff(ttl['times']) < min_diff) 1icbah
626 sync_times = np.delete(ttl['times'], np.r_[short, short + 1]) 1icbah
627 _logger.debug(f'Removed {short.size * 2} fronts TLLs less than ' 1icbah
628 f'{min_diff * 1e3:.0f}ms apart')
629 assert sync_times.size > 0, f'all sync TTLs less than {min_diff}s' 1icbah
631 # Onsets
632 ups = ts[low2high] - ts[low2high][0] # times relative to first GPIO high 1igcefdjbah
633 onsets = sync_times[::2] - sync_times[0] # TTL times relative to first onset 1igcefdjbah
634 # assign GPIO fronts to ttl onset
635 assigned = attribute_times(onsets, ups, tol=tolerance, take=take) 1igcefdjbah
636 unassigned = np.setdiff1d(np.arange(onsets.size), assigned[assigned > -1]) 1igcefdjbah
637 if unassigned.size > 0: 1igcefdjbah
638 _logger.debug(f'{unassigned.size} sync TTL rises were not detected by the camera') 1igcefdjbah
639 # Check that all pin state upticks could be attributed to an onset TTL
640 if np.any(missed := assigned == -1): 1igcefdjbah
641 _logger.warning(f'{sum(missed)} pin state rises could not be attributed to a sync TTL') 1icjbah
642 if display: 1icjbah
643 ax = plt.subplot() 1a
644 vertical_lines(ups[assigned > -1], 1a
645 linestyle='-', color='g', ax=ax,
646 label='assigned GPIO up state')
647 vertical_lines(ups[missed], 1a
648 linestyle='-', color='r', ax=ax,
649 label='unassigned GPIO up state')
650 vertical_lines(onsets[unassigned], 1a
651 linestyle=':', color='k', ax=ax,
652 alpha=0.3, label='sync TTL onset')
653 vertical_lines(onsets[assigned], 1a
654 linestyle=':', color='b', ax=ax, label='assigned TTL onset')
655 plt.legend() 1a
656 plt.show() 1a
657 # Remove the missed fronts
658 to_remove = np.in1d(gpio['indices'], low2high[missed]) 1icjbah
659 assigned = assigned[~missed] 1icjbah
660 onsets_ = sync_times[::2][assigned] 1igcefdjbah
662 # Offsets
663 downs = ts[high2low] - ts[high2low][0] 1igcefdjbah
664 offsets = sync_times[1::2] - sync_times[1] 1igcefdjbah
665 assigned = attribute_times(offsets, downs, tol=tolerance, take=take) 1igcefdjbah
666 unassigned = np.setdiff1d(np.arange(offsets.size), assigned[assigned > -1]) 1igcefdjbah
667 if unassigned.size > 0: 1igcefdjbah
668 _logger.debug(f'{unassigned.size} sync TTL falls were not detected by the camera') 1gcefdjbah
669 # Check that all pin state downticks could be attributed to an offset TTL
670 if np.any(missed := assigned == -1): 1igcefdjbah
671 _logger.warning(f'{sum(missed)} pin state falls could not be attributed to a sync TTL') 1cjbah
672 # Remove the missed fronts
673 to_remove |= np.in1d(gpio['indices'], high2low[missed]) 1cjbah
674 assigned = assigned[~missed] 1cjbah
675 offsets_ = sync_times[1::2][assigned] 1igcefdjbah
677 # Sync TTLs groomed
678 if np.any(to_remove): 1igcefdjbah
679 # Check for any orphaned fronts (only one pin state edge was assigned)
680 to_remove = np.pad(to_remove, (0, to_remove.size % 2), 'edge') # Ensure even size 1icjbah
681 # Perform xor to find GPIOs where only onset or offset is marked for removal
682 orphaned = to_remove.reshape(-1, 2).sum(axis=1) == 1 1icjbah
683 if orphaned.any(): 1icjbah
684 """If there are orphaned GPIO fronts (i.e. only one edge was assigned to a sync 1ih
685 TTL front), remove the orphaned front its assigned sync TTL. In other words
686 if both edges cannot be assigned to a sync TTL, we ignore the TTL entirely.
687 This is a sign that the assignment was bad and extraction may fail."""
688 _logger.warning('Some onsets but not offsets (or vice versa) were not assigned; ' 1ih
689 'this may be a sign of faulty wiring or clock drift')
690 # Find indices of GPIO upticks where only the downtick was marked for removal
691 orphaned_onsets, = np.where(~to_remove.reshape(-1, 2)[:, 0] & orphaned) 1ih
692 # The onsets_ array already has the other TTLs removed (same size as to_remove ==
693 # False) so subtract the number of removed elements from index.
694 for i, v in enumerate(orphaned_onsets): 1ih
695 orphaned_onsets[i] -= to_remove.reshape(-1, 2)[:v, 0].sum() 1h
696 # Same for offsets...
697 orphaned_offsets, = np.where(~to_remove.reshape(-1, 2)[:, 1] & orphaned) 1ih
698 for i, v in enumerate(orphaned_offsets): 1ih
699 orphaned_offsets[i] -= to_remove.reshape(-1, 2)[:v, 1].sum() 1ih
700 # Remove orphaned ttl onsets and offsets
701 onsets_ = np.delete(onsets_, orphaned_onsets[orphaned_onsets < onsets_.size]) 1ih
702 offsets_ = np.delete(offsets_, orphaned_offsets[orphaned_offsets < offsets_.size]) 1ih
703 _logger.debug(f'{orphaned.sum()} orphaned TTLs removed') 1ih
704 to_remove.reshape(-1, 2)[orphaned] = True 1ih
706 # Remove those unassigned GPIOs
707 gpio = {k: v[~to_remove[:v.size]] for k, v in gpio.items()} 1icjbah
708 ifronts = gpio['indices'] 1icjbah
710 # Assert that we've removed discrete TTLs
711 # A failure means e.g. an up-going front of one TTL was missed but not the down-going one.
712 assert np.all(np.abs(np.diff(gpio['polarities'])) == 2) 1icjbah
713 assert gpio['polarities'][0] == 1 1icjbah
715 ttl_ = {'times': np.empty(ifronts.size), 'polarities': gpio['polarities']} 1igcefdjbah
716 ttl_['times'][::2] = onsets_ 1igcefdjbah
717 ttl_['times'][1::2] = offsets_ 1igcefdjbah
718 else:
719 ttl_ = ttl.copy() 1ij
721 # Align the frame times to DAQ
722 fcn_a2b, drift_ppm = dsp.sync_timestamps(ts[ifronts], ttl_['times']) 1igcefdjbah
723 _logger.debug(f'frame ttl alignment drift = {drift_ppm:.2f}ppm') 1igcefdjbah
724 # Add times to GPIO dict
725 gpio['times'] = fcn_a2b(ts[ifronts]) 1igcefdjbah
727 if display: 1igcefdjbah
728 # Plot all the onsets and offsets
729 ax = plt.subplot() 1da
730 # All sync TTLs
731 squares(ttl['times'], ttl['polarities'], 1da
732 ax=ax, label='sync TTLs', linestyle=':', color='k', yrange=[0, 1], alpha=0.3)
733 # GPIO
734 x = np.insert(gpio['times'], 0, 0) 1da
735 y = np.arange(x.size) % 2 1da
736 squares(x, y, ax=ax, label='GPIO') 1da
737 y = within_ranges(np.arange(ts.size), ifronts.reshape(-1, 2)) # 0 or 1 for each frame 1da
738 ax.plot(fcn_a2b(ts), y, 'kx', label='cam times') 1da
739 # Assigned ttl
740 squares(ttl_['times'], ttl_['polarities'], 1da
741 ax=ax, label='assigned sync TTL', linestyle=':', color='g', yrange=[0, 1])
742 ax.legend() 1da
743 plt.xlabel('DAQ time (s)') 1da
744 ax.set_yticks([0, 1]) 1da
745 ax.set_title('GPIO - sync TTL alignment') 1da
746 plt.show() 1da
748 return gpio, ttl_, fcn_a2b(ts) 1igcefdjbah
751def extract_all(session_path, sync_type=None, save=True, **kwargs):
752 """
753 For the IBL ephys task, reads ephys binary file and extract:
754 - video time stamps
755 :param session_type: the session type to extract, i.e. 'ephys', 'training' or 'biased'. If
756 None the session type is inferred from the settings file.
757 :param save: Bool, defaults to False
758 :param kwargs: parameters to pass to the extractor
759 :return: outputs, files
761 Parameters
762 ----------
763 session_path : str, pathlib.Path
764 The session path, e.g. '/path/to/subject/yyyy-mm-dd/001'.
765 sync_type : str
766 The sync label from the experiment description file.
767 sync_collection : str
768 The subdirectory containing the sync files.
769 save : bool
770 If True, save the camera timestamp files to disk.
771 session_type : str
772 (DEPRECATED) The session type, e.g. 'ephys'.
773 **kwargs
774 Extra keyword args to pass to the camera extractor classes.
776 Returns
777 -------
778 list of numpy.array
779 List of extracted output data, i.e. the camera times.
780 list of pathlib.Path
781 The paths of the extracted data, if save = True
782 """
784 sync_collection = kwargs.get('sync_collection', 'raw_ephys_data') 1lmqtpkgcefbn
785 camlog = kwargs.get('camlog', False) 1lmqtpkgcefbn
787 if not sync_type: # infer from session type 1lmqtpkgcefbn
788 session_type = kwargs.get('session_type') or get_session_extractor_type(session_path) 1pkefb
789 if not session_type or session_type not in _get_task_types_json_config().values(): 1pkefb
790 raise ValueError(f"Session type {session_type} has no matching extractor")
791 else:
792 sync_type = 'nidq' if session_type == 'ephys' else 'bpod' 1pkefb
794 if sync_type == 'nidq': 1lmqtpkgcefbn
795 labels = assert_valid_label(kwargs.pop('labels', ('left', 'right', 'body'))) 1qtpgef
796 labels = (labels,) if isinstance(labels, str) else labels # Ensure list/tuple 1qtpgef
797 CamExtractor = CameraTimestampsCamlog if camlog else CameraTimestampsFPGA 1qtpgef
798 extractor = [partial(CamExtractor, label) for label in labels] 1qtpgef
799 if 'sync' not in kwargs: 1qtpgef
800 kwargs['sync'], kwargs['chmap'] = get_sync_and_chn_map(session_path, sync_collection) 1qtpef
801 else: # assume Bpod otherwise
802 assert kwargs.pop('labels', 'left'), 'only left camera is currently supported' 1lmkcbn
803 extractor = CameraTimestampsBpod 1lmkcbn
805 outputs, files = run_extractor_classes( 1lmqtpkgcefbn
806 extractor, session_path=session_path, save=save, **kwargs)
807 return outputs, files 1lmqtpkgcefbn