Coverage for ibllib/ephys/sync_probes.py: 88%
171 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1import logging
3import matplotlib.axes
4import matplotlib.pyplot as plt
5import numpy as np
6from scipy.interpolate import interp1d
7import one.alf.io as alfio
8import one.alf.exceptions
9from iblutil.util import Bunch
10import spikeglx
12from ibllib.exceptions import Neuropixel3BSyncFrontsNonMatching
13from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_ibl_sync_map
15_logger = logging.getLogger(__name__)
18def apply_sync(sync_file, times, forward=True):
19 """
20 :param sync_file: probe sync file (usually of the form _iblrig_ephysData.raw.imec1.sync.npy)
21 :param times: times in seconds to interpolate
22 :param forward: if True goes from probe time to session time, from session time to probe time
23 otherwise
24 :return: interpolated times
25 """
26 sync_points = np.load(sync_file) 1abc
27 if forward: 1abc
28 fcn = interp1d(sync_points[:, 0], 1abc
29 sync_points[:, 1], fill_value='extrapolate')
30 else:
31 fcn = interp1d(sync_points[:, 1],
32 sync_points[:, 0], fill_value='extrapolate')
33 return fcn(times) 1abc
36def sync(ses_path, **kwargs):
37 """
38 Wrapper for sync_probes.version3A and sync_probes.version3B that automatically determines
39 the version
40 :param ses_path:
41 :return: bool True on a a successful sync
42 """
43 version = spikeglx.get_neuropixel_version_from_folder(ses_path) 1defa
44 if version == '3A': 1defa
45 return version3A(ses_path, **kwargs)
46 elif version == '3B': 1defa
47 return version3B(ses_path, **kwargs) 1defa
50def version3A(ses_path, display=True, type='smooth', tol=2.1):
51 """
52 From a session path with _spikeglx_sync arrays extracted, locate ephys files for 3A and
53 outputs one sync.timestamps.probeN.npy file per acquired probe. By convention the reference
54 probe is the one with the most synchronisation pulses.
55 Assumes the _spikeglx_sync datasets are already extracted from binary data
56 :param ses_path:
57 :param type: linear, exact or smooth
58 :return: bool True on a a successful sync
59 """
60 ephys_files = spikeglx.glob_ephys_files(ses_path, ext='meta', bin_exists=False) 1blijm
61 nprobes = len(ephys_files) 1blijm
62 if nprobes == 1: 1blijm
63 timestamps = np.array([[0., 0.], [1., 1.]]) 1lm
64 sr = _get_sr(ephys_files[0]) 1lm
65 out_files = _save_timestamps_npy(ephys_files[0], timestamps, sr) 1lm
66 return True, out_files 1lm
68 def get_sync_fronts(auxiliary_name): 1bij
69 d = Bunch({'times': [], 'nsync': np.zeros(nprobes, )}) 1bij
70 # auxiliary_name: frame2ttl or right_camera
71 for ind, ephys_file in enumerate(ephys_files): 1bij
72 sync = alfio.load_object( 1bij
73 ephys_file.ap.parent, 'sync', namespace='spikeglx', short_keys=True)
74 sync_map = get_ibl_sync_map(ephys_file, '3A') 1bij
75 # exits if sync label not found for current probe
76 if auxiliary_name not in sync_map: 1bij
77 return
78 isync = np.in1d(sync['channels'], np.array([sync_map[auxiliary_name]])) 1bij
79 # only returns syncs if we get fronts for all probes
80 if np.all(~isync): 1bij
81 return 1bij
82 d.nsync[ind] = len(sync.channels) 1bij
83 d['times'].append(sync['times'][isync]) 1bij
84 return d 1bij
86 d = get_sync_fronts('frame2ttl') 1bij
87 if not d: 1bij
88 _logger.warning('Ephys sync: frame2ttl not detected on both probes, using camera sync') 1bij
89 d = get_sync_fronts('right_camera') 1bij
90 if not min([t[0] for t in d['times']]) > 0.2: 1bij
91 raise ValueError('Cameras started before ephys, no sync possible')
92 # chop off to the lowest number of sync points
93 nsyncs = [t.size for t in d['times']] 1bij
94 if len(set(nsyncs)) > 1: 1bij
95 _logger.warning("Probes don't have the same number of synchronizations pulses")
96 d['times'] = np.r_[[t[:min(nsyncs)] for t in d['times']]].transpose() 1bij
98 # the reference probe is the one with the most sync pulses detected
99 iref = np.argmax(d.nsync) 1bij
100 # islave = np.setdiff1d(np.arange(nprobes), iref)
101 # get the sampling rate from the reference probe using metadata file
102 sr = _get_sr(ephys_files[iref]) 1bij
103 qc_all = True 1bij
104 # output timestamps files as per ALF convention
105 for ind, ephys_file in enumerate(ephys_files): 1bij
106 if ind == iref: 1bij
107 timestamps = np.array([[0., 0.], [1., 1.]]) 1bij
108 else:
109 timestamps, qc = sync_probe_front_times(d.times[:, ind], d.times[:, iref], sr, 1bij
110 display=display, type=type, tol=tol)
111 qc_all &= qc 1bij
112 out_files = _save_timestamps_npy(ephys_file, timestamps, sr) 1bij
113 return qc_all, out_files 1bij
116def version3B(ses_path, display=True, type=None, tol=2.5, probe_names=None):
117 """
118 From a session path with _spikeglx_sync arrays extraccted, locate ephys files for 3A and
119 outputs one sync.timestamps.probeN.npy file per acquired probe. By convention the reference
120 probe is the one with the most synchronisation pulses.
121 Assumes the _spikeglx_sync datasets are already extracted from binary data
122 :param ses_path:
123 :param type: linear, exact or smooth
124 :param probe_names: by default will rglob all probes in the directory. If specified, this will filter
125 the probes on which to perform the synchronisation, defaults to None, optional
126 :return: None
127 """
128 DEFAULT_TYPE = 'smooth' 1defacgh
129 ephys_files = spikeglx.glob_ephys_files(ses_path, ext='meta', bin_exists=False) 1defacgh
130 for ef in ephys_files: 1defacgh
131 try: 1defacgh
132 ef['sync'] = alfio.load_object(ef.path, 'sync', namespace='spikeglx', short_keys=True) 1defacgh
133 ef['sync_map'] = get_ibl_sync_map(ef, '3B') 1defacgh
134 except one.alf.exceptions.ALFObjectNotFound as e:
135 if probe_names is None or ef.path.parts[-1] in probe_names:
136 raise e
137 nidq_file = [ef for ef in ephys_files if ef.get('nidq')] 1defacgh
138 ephys_files = [ef for ef in ephys_files if not ef.get('nidq')] 1defacgh
139 if probe_names is not None: 1defacgh
140 ephys_files = [ef for ef in ephys_files if ef.path.parts[-1] in probe_names] 1def
141 # should have at least 2 probes and only one nidq
142 assert len(nidq_file) == 1 1defacgh
143 nidq_file = nidq_file[0] 1defacgh
144 sync_nidq = get_sync_fronts(nidq_file.sync, nidq_file.sync_map['imec_sync']) 1defacgh
146 qc_all = True 1defacgh
147 out_files = [] 1defacgh
148 for ef in ephys_files: 1defacgh
149 sync_probe = get_sync_fronts(ef.sync, ef.sync_map['imec_sync']) 1defacgh
150 sr = _get_sr(ef) 1defacgh
151 try: 1defacgh
152 # we say that the number of pulses should be within 10 %
153 assert np.isclose(sync_nidq.times.size, sync_probe.times.size, rtol=0.1) 1defacgh
154 except AssertionError:
155 raise Neuropixel3BSyncFrontsNonMatching(f"{ses_path}")
157 # Find the indexes in case the sizes don't match
158 if sync_nidq.times.size != sync_probe.times.size: 1defacgh
159 _logger.warning(f'Sync mismatch by {np.abs(sync_nidq.times.size - sync_probe.times.size)} '
160 f'NIDQ sync times: {sync_nidq.times.size}, Probe sync times {sync_probe.times.size}')
161 sync_idx = np.min([sync_nidq.times.size, sync_probe.times.size]) 1defacgh
163 # if the qc of the diff finds anomalies, do not attempt to smooth the interp function
164 qcdiff = _check_diff_3b(sync_probe) 1defacgh
165 if not qcdiff: 1defacgh
166 qc_all = False 1c
167 type_probe = type or 'exact' 1c
168 else:
169 type_probe = type or DEFAULT_TYPE 1defacgh
170 timestamps, qc = sync_probe_front_times(sync_probe.times[:sync_idx], sync_nidq.times[:sync_idx], sr, 1defacgh
171 display=display, type=type_probe, tol=tol)
172 qc_all &= qc 1defacgh
173 out_files.extend(_save_timestamps_npy(ef, timestamps, sr)) 1defacgh
174 return qc_all, out_files 1defacgh
177def sync_probe_front_times(t, tref, sr, display=False, type='smooth', tol=2.0):
178 """
179 From 2 timestamps vectors of equivalent length, output timestamps array to be used for
180 linear interpolation
181 :param t: time-serie to be synchronized
182 :param tref: time-serie of the reference
183 :param sr: sampling rate of the slave probe
184 :return: a 2 columns by n-sync points array where each row corresponds
185 to a sync point: sample_index (0 based), tref
186 :return: quality Bool. False if tolerance is exceeded
187 """
188 qc = True 1defabcijgh
189 """ 1defabcijgh
190 the main drift is computed through linear regression. A further step compute a smoothed
191 version of the residual to add to the linear drift. The precision is enforced
192 by ensuring that each point lies less than one sampling rate away from the predicted.
193 """
194 pol = np.polyfit(t, tref, 1) # higher order terms first: slope / int for linear 1defabcijgh
195 residual = tref - np.polyval(pol, t) 1defabcijgh
196 if type == 'smooth': 1defabcijgh
197 """
198 the interp function from camera fronts is not smooth due to the locking of detections
199 to the sampling rate of digital channels. The residual is fit using frequency domain
200 smoothing
201 """
202 import neurodsp.fourier 1defabcijgh
203 CAMERA_UPSAMPLING_RATE_HZ = 300 1defabcijgh
204 PAD_LENGTH_SECS = 60 1defabcijgh
205 STAT_LENGTH_SECS = 30 # median length to compute padding value 1defabcijgh
206 SYNC_SAMPLING_RATE_SECS = 20 1defabcijgh
207 t_upsamp = np.arange(tref[0], tref[-1], 1 / CAMERA_UPSAMPLING_RATE_HZ) 1defabcijgh
208 res_upsamp = np.interp(t_upsamp, tref, residual) 1defabcijgh
209 # padding needs extra care as the function oscillates and numpy fft performance is
210 # abysmal for non prime sample sizes
211 nech = res_upsamp.size + (CAMERA_UPSAMPLING_RATE_HZ * PAD_LENGTH_SECS) 1defabcijgh
212 lpad = 2 ** np.ceil(np.log2(nech)) - res_upsamp.size 1defabcijgh
213 lpad = [int(np.floor(lpad / 2) + lpad % 2), int(np.floor(lpad / 2))] 1defabcijgh
214 res_filt = np.pad(res_upsamp, lpad, mode='median', 1defabcijgh
215 stat_length=CAMERA_UPSAMPLING_RATE_HZ * STAT_LENGTH_SECS)
216 fbounds = [0.001, 0.002] 1defabcijgh
217 res_filt = neurodsp.fourier.lp(res_filt, 1 / CAMERA_UPSAMPLING_RATE_HZ, fbounds)[lpad[0]:-lpad[1]] 1defabcijgh
218 tout = np.arange(0, np.max(tref) + SYNC_SAMPLING_RATE_SECS, 20) 1defabcijgh
219 sync_points = np.c_[tout, np.polyval(pol, tout) + np.interp(tout, t_upsamp, res_filt)] 1defabcijgh
220 if display: 1defabcijgh
221 if isinstance(display, matplotlib.axes.Axes): 1defab
222 ax = display
223 else:
224 ax = plt.axes() 1defab
225 ax.plot(tref, residual * sr, label='residual') 1defab
226 ax.plot(t_upsamp, res_filt * sr, label='smoothed residual') 1defab
227 ax.plot(tout, np.interp(tout, t_upsamp, res_filt) * sr, '*', label='interp timestamps') 1defab
228 ax.legend() 1defab
229 ax.set_xlabel('time (sec)') 1defab
230 ax.set_ylabel('Residual drift (samples @ 30kHz)') 1defab
231 elif type == 'exact': 1bc
232 sync_points = np.c_[t, tref] 1c
233 if display: 1c
234 plt.plot(tref, residual * sr, label='residual')
235 plt.ylabel('Residual drift (samples @ 30kHz)')
236 plt.xlabel('time (sec)')
237 pass
238 elif type == 'linear': 1bc
239 sync_points = np.c_[np.array([0., 1.]), np.polyval(pol, np.array([0., 1.]))] 1bc
240 if display: 1bc
241 plt.plot(tref, residual * sr)
242 plt.ylabel('Residual drift (samples @ 30kHz)')
243 plt.xlabel('time (sec)')
244 # test that the interp is within tol sample
245 fcn = interp1d(sync_points[:, 0], sync_points[:, 1], fill_value='extrapolate') 1defabcijgh
246 if np.any(np.abs((tref - fcn(t)) * sr) > (tol)): 1defabcijgh
247 _logger.error(f'Synchronization check exceeds tolerance of {tol} samples. Check !!')
248 qc = False
249 # plt.plot((tref - fcn(t)) * sr)
250 # plt.plot( (sync_points[:, 0] - fcn(sync_points[:, 1])) * sr)
251 return sync_points, qc 1defabcijgh
254def _get_sr(ephys_file):
255 meta = spikeglx.read_meta_data(ephys_file.ap.with_suffix('.meta')) 1defablcijmgh
256 return spikeglx._get_fs_from_meta(meta) 1defablcijmgh
259def _save_timestamps_npy(ephys_file, tself_tref, sr):
260 # this is the file with self_time_secs, ref_time_secs output
261 file_sync = ephys_file.ap.parent.joinpath(ephys_file.ap.name.replace('.ap.', '.sync.') 1defablcijmgh
262 ).with_suffix('.npy')
263 np.save(file_sync, tself_tref) 1defablcijmgh
264 # this is the timestamps file
265 file_ts = ephys_file.ap.parent.joinpath(ephys_file.ap.name.replace('.ap.', '.timestamps.') 1defablcijmgh
266 ).with_suffix('.npy')
267 timestamps = np.copy(tself_tref) 1defablcijmgh
268 timestamps[:, 0] *= np.float64(sr) 1defablcijmgh
269 np.save(file_ts, timestamps) 1defablcijmgh
270 return [file_sync, file_ts] 1defablcijmgh
273def _check_diff_3b(sync):
274 """
275 Checks that the diff between consecutive sync pulses is below 150 PPM
276 Returns True on a pass result (all values below threshold)
277 """
278 THRESH_PPM = 150 1defacgh
279 d = np.diff(sync.times[sync.polarities == 1]) 1defacgh
280 dt = np.median(d) 1defacgh
281 qc_pass = np.all(np.abs((d - dt) / dt * 1e6) < THRESH_PPM) 1defacgh
282 if not qc_pass: 1defacgh
283 _logger.error(f'Synchronizations bursts over {THRESH_PPM} ppm between sync pulses. ' 1c
284 'Sync using "exact" match between pulses.')
285 return qc_pass 1defacgh