Coverage for ibllib/ephys/sync_probes.py: 88%
172 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1import logging
3import matplotlib.axes
4import matplotlib.pyplot as plt
5import numpy as np
6from scipy.interpolate import interp1d
7import one.alf.io as alfio
8import one.alf.exceptions
9from iblutil.util import Bunch
10import spikeglx
12from ibllib.exceptions import Neuropixel3BSyncFrontsNonMatching
13from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_ibl_sync_map
15_logger = logging.getLogger(__name__)
18def apply_sync(sync_file, times, forward=True):
19 """
20 :param sync_file: probe sync file (usually of the form _iblrig_ephysData.raw.imec1.sync.npy)
21 :param times: times in seconds to interpolate
22 :param forward: if True goes from probe time to session time, from session time to probe time
23 otherwise
24 :return: interpolated times
25 """
26 sync_points = np.load(sync_file) 1ab
27 if forward: 1ab
28 fcn = interp1d(sync_points[:, 0], 1ab
29 sync_points[:, 1], fill_value='extrapolate')
30 else:
31 fcn = interp1d(sync_points[:, 1],
32 sync_points[:, 0], fill_value='extrapolate')
33 return fcn(times) 1ab
36def sync(ses_path, **kwargs):
37 """
38 Wrapper for sync_probes.version3A and sync_probes.version3B that automatically determines
39 the version
40 :param ses_path:
41 :return: bool True on a a successful sync
42 """
43 version = spikeglx.get_neuropixel_version_from_folder(ses_path) 1cde
44 if version == '3A': 1cde
45 return version3A(ses_path, **kwargs)
46 elif version == '3B': 1cde
47 return version3B(ses_path, **kwargs) 1cde
50def version3A(ses_path, display=True, type='smooth', tol=2.1, probe_names=None):
51 """
52 From a session path with _spikeglx_sync arrays extracted, locate ephys files for 3A and
53 outputs one sync.timestamps.probeN.npy file per acquired probe. By convention the reference
54 probe is the one with the most synchronisation pulses.
55 Assumes the _spikeglx_sync datasets are already extracted from binary data
56 :param ses_path:
57 :param type: linear, exact or smooth
58 :return: bool True on a a successful sync
59 """
60 ephys_files = spikeglx.glob_ephys_files(ses_path, ext='meta', bin_exists=False) 1akhil
61 nprobes = len(ephys_files) 1akhil
62 if nprobes == 1: 1akhil
63 timestamps = np.array([[0., 0.], [1., 1.]]) 1kl
64 sr = _get_sr(ephys_files[0]) 1kl
65 out_files = _save_timestamps_npy(ephys_files[0], timestamps, sr) 1kl
66 return True, out_files 1kl
68 def get_sync_fronts(auxiliary_name): 1ahi
69 d = Bunch({'times': [], 'nsync': np.zeros(nprobes, )}) 1ahi
70 # auxiliary_name: frame2ttl or right_camera
71 for ind, ephys_file in enumerate(ephys_files): 1ahi
72 sync = alfio.load_object( 1ahi
73 ephys_file.ap.parent, 'sync', namespace='spikeglx', short_keys=True)
74 sync_map = get_ibl_sync_map(ephys_file, '3A') 1ahi
75 # exits if sync label not found for current probe
76 if auxiliary_name not in sync_map: 1ahi
77 return
78 isync = np.in1d(sync['channels'], np.array([sync_map[auxiliary_name]])) 1ahi
79 # only returns syncs if we get fronts for all probes
80 if np.all(~isync): 1ahi
81 return 1ahi
82 d.nsync[ind] = len(sync.channels) 1ahi
83 d['times'].append(sync['times'][isync]) 1ahi
84 return d 1ahi
86 d = get_sync_fronts('frame2ttl') 1ahi
87 if not d: 1ahi
88 _logger.warning('Ephys sync: frame2ttl not detected on both probes, using camera sync') 1ahi
89 d = get_sync_fronts('right_camera') 1ahi
90 if not min([t[0] for t in d['times']]) > 0.2: 1ahi
91 raise ValueError('Cameras started before ephys, no sync possible')
92 # chop off to the lowest number of sync points
93 nsyncs = [t.size for t in d['times']] 1ahi
94 if len(set(nsyncs)) > 1: 1ahi
95 _logger.warning("Probes don't have the same number of synchronizations pulses")
96 d['times'] = np.r_[[t[:min(nsyncs)] for t in d['times']]].transpose() 1ahi
98 # the reference probe is the one with the most sync pulses detected
99 iref = np.argmax(d.nsync) 1ahi
100 # islave = np.setdiff1d(np.arange(nprobes), iref)
101 # get the sampling rate from the reference probe using metadata file
102 sr = _get_sr(ephys_files[iref]) 1ahi
103 qc_all = True 1ahi
104 # output timestamps files as per ALF convention
105 for ind, ephys_file in enumerate(ephys_files): 1ahi
106 if ind == iref: 1ahi
107 timestamps = np.array([[0., 0.], [1., 1.]]) 1ahi
108 else:
109 timestamps, qc = sync_probe_front_times(d.times[:, ind], d.times[:, iref], sr, 1ahi
110 display=display, type=type, tol=tol)
111 qc_all &= qc 1ahi
112 out_files = _save_timestamps_npy(ephys_file, timestamps, sr) 1ahi
113 return qc_all, out_files 1ahi
116def version3B(ses_path, display=True, type=None, tol=2.5, probe_names=None):
117 """
118 From a session path with _spikeglx_sync arrays extraccted, locate ephys files for 3A and
119 outputs one sync.timestamps.probeN.npy file per acquired probe. By convention the reference
120 probe is the one with the most synchronisation pulses.
121 Assumes the _spikeglx_sync datasets are already extracted from binary data
122 :param ses_path:
123 :param type: linear, exact or smooth
124 :param probe_names: by default will rglob all probes in the directory. If specified, this will filter
125 the probes on which to perform the synchronisation, defaults to None, optional
126 :return: None
127 """
128 DEFAULT_TYPE = 'smooth' 1cdebfg
129 ephys_files = spikeglx.glob_ephys_files(ses_path, ext='meta', bin_exists=False) 1cdebfg
130 for ef in ephys_files: 1cdebfg
131 try: 1cdebfg
132 ef['sync'] = alfio.load_object(ef.path, 'sync', namespace='spikeglx', short_keys=True) 1cdebfg
133 ef['sync_map'] = get_ibl_sync_map(ef, '3B') 1cdebfg
134 except one.alf.exceptions.ALFObjectNotFound as e:
135 if probe_names is None or ef.path.parts[-1] in probe_names:
136 raise e
137 nidq_file = [ef for ef in ephys_files if ef.get('nidq')] 1cdebfg
138 ephys_files = [ef for ef in ephys_files if not ef.get('nidq')] 1cdebfg
139 if probe_names is not None: 1cdebfg
140 ephys_files = [ef for ef in ephys_files if ef.path.parts[-1] in probe_names] 1cde
141 # should have at least 2 probes and only one nidq
142 assert len(nidq_file) == 1 1cdebfg
143 nidq_file = nidq_file[0] 1cdebfg
144 sync_nidq = get_sync_fronts(nidq_file.sync, nidq_file.sync_map['imec_sync']) 1cdebfg
146 qc_all = True 1cdebfg
147 out_files = [] 1cdebfg
148 for ef in ephys_files: 1cdebfg
149 sync_probe = get_sync_fronts(ef.sync, ef.sync_map['imec_sync']) 1cdebfg
150 sr = _get_sr(ef) 1cdebfg
151 try: 1cdebfg
152 # we say that the number of pulses should be within 10 %
153 assert np.isclose(sync_nidq.times.size, sync_probe.times.size, rtol=0.1) 1cdebfg
154 except AssertionError:
155 raise Neuropixel3BSyncFrontsNonMatching(f"{ses_path}")
157 # Find the indexes in case the sizes don't match
158 if sync_nidq.times.size != sync_probe.times.size: 1cdebfg
159 _logger.warning(f'Sync mismatch by {np.abs(sync_nidq.times.size - sync_probe.times.size)} '
160 f'NIDQ sync times: {sync_nidq.times.size}, Probe sync times {sync_probe.times.size}')
161 sync_idx = np.min([sync_nidq.times.size, sync_probe.times.size]) 1cdebfg
163 # if the qc of the diff finds anomalies, do not attempt to smooth the interp function
164 qcdiff = _check_diff_3b(sync_probe) 1cdebfg
165 if not qcdiff: 1cdebfg
166 qc_all = False 1b
167 type_probe = type or 'exact' 1b
168 else:
169 type_probe = type or DEFAULT_TYPE 1cdebfg
170 timestamps, qc = sync_probe_front_times(sync_probe.times[:sync_idx], sync_nidq.times[:sync_idx], sr, 1cdebfg
171 display=display, type=type_probe, tol=tol)
172 qc_all &= qc 1cdebfg
173 out_files.extend(_save_timestamps_npy(ef, timestamps, sr)) 1cdebfg
174 return qc_all, out_files 1cdebfg
177def sync_probe_front_times(t, tref, sr, display=False, type='smooth', tol=2.0):
178 """
179 From 2 timestamps vectors of equivalent length, output timestamps array to be used for
180 linear interpolation
181 :param t: time-serie to be synchronized
182 :param tref: time-serie of the reference
183 :param sr: sampling rate of the slave probe
184 :return: a 2 columns by n-sync points array where each row corresponds
185 to a sync point: sample_index (0 based), tref
186 :return: quality Bool. False if tolerance is exceeded
187 """
188 qc = True 1cdeabhifg
189 """ 1cdeabhifg
190 the main drift is computed through linear regression. A further step compute a smoothed
191 version of the residual to add to the linear drift. The precision is enforced
192 by ensuring that each point lies less than one sampling rate away from the predicted.
193 """
194 pol = np.polyfit(t, tref, 1) # higher order terms first: slope / int for linear 1cdeabhifg
195 residual = tref - np.polyval(pol, t) 1cdeabhifg
196 if type == 'smooth': 1cdeabhifg
197 """ 1cdeabhifg
198 the interp function from camera fronts is not smooth due to the locking of detections
199 to the sampling rate of digital channels. The residual is fit using frequency domain
200 smoothing
201 """
202 import ibldsp.fourier 1cdeabhifg
203 CAMERA_UPSAMPLING_RATE_HZ = 300 1cdeabhifg
204 PAD_LENGTH_SECS = 60 1cdeabhifg
205 STAT_LENGTH_SECS = 30 # median length to compute padding value 1cdeabhifg
206 SYNC_SAMPLING_RATE_SECS = 20 1cdeabhifg
207 t_upsamp = np.arange(tref[0], tref[-1], 1 / CAMERA_UPSAMPLING_RATE_HZ) 1cdeabhifg
208 res_upsamp = np.interp(t_upsamp, tref, residual) 1cdeabhifg
209 # padding needs extra care as the function oscillates and numpy fft performance is
210 # abysmal for non prime sample sizes
211 nech = res_upsamp.size + (CAMERA_UPSAMPLING_RATE_HZ * PAD_LENGTH_SECS) 1cdeabhifg
212 lpad = 2 ** np.ceil(np.log2(nech)) - res_upsamp.size 1cdeabhifg
213 lpad = [int(np.floor(lpad / 2) + lpad % 2), int(np.floor(lpad / 2))] 1cdeabhifg
214 res_filt = np.pad(res_upsamp, lpad, mode='median', 1cdeabhifg
215 stat_length=CAMERA_UPSAMPLING_RATE_HZ * STAT_LENGTH_SECS)
216 fbounds = [0.001, 0.002] 1cdeabhifg
217 res_filt = ibldsp.fourier.lp(res_filt, 1 / CAMERA_UPSAMPLING_RATE_HZ, fbounds)[lpad[0]:-lpad[1]] 1cdeabhifg
218 tout = np.arange(0, np.max(tref) + SYNC_SAMPLING_RATE_SECS, 20) 1cdeabhifg
219 sync_points = np.c_[tout, np.polyval(pol, tout) + np.interp(tout, t_upsamp, res_filt)] 1cdeabhifg
220 if display: 1cdeabhifg
221 if isinstance(display, matplotlib.axes.Axes): 1cdea
222 ax = display
223 else:
224 ax = plt.axes() 1cdea
225 ax.plot(tref, residual * sr, label='residual') 1cdea
226 ax.plot(t_upsamp, res_filt * sr, label='smoothed residual') 1cdea
227 ax.plot(tout, np.interp(tout, t_upsamp, res_filt) * sr, '*', label='interp timestamps') 1cdea
228 ax.legend() 1cdea
229 ax.set_xlabel('time (sec)') 1cdea
230 ax.set_ylabel('Residual drift (samples @ 30kHz)') 1cdea
231 elif type == 'exact': 1ab
232 sync_points = np.c_[t, tref] 1b
233 if display: 1b
234 plt.plot(tref, residual * sr, label='residual')
235 plt.ylabel('Residual drift (samples @ 30kHz)')
236 plt.xlabel('time (sec)')
237 pass
238 elif type == 'linear': 1ab
239 sync_points = np.c_[np.array([0., 1.]), np.polyval(pol, np.array([0., 1.]))] 1ab
240 if display: 1ab
241 plt.plot(tref, residual * sr)
242 plt.ylabel('Residual drift (samples @ 30kHz)')
243 plt.xlabel('time (sec)')
244 # test that the interp is within tol sample
245 fcn = interp1d(sync_points[:, 0], sync_points[:, 1], fill_value='extrapolate') 1cdeabhifg
246 if np.any(np.abs((tref - fcn(t)) * sr) > (tol)): 1cdeabhifg
247 _logger.error(f'Synchronization check exceeds tolerance of {tol} samples. Check !!')
248 qc = False
249 # plt.plot((tref - fcn(t)) * sr)
250 # plt.plot( (sync_points[:, 0] - fcn(sync_points[:, 1])) * sr)
251 return sync_points, qc 1cdeabhifg
254def _get_sr(ephys_file):
255 meta = spikeglx.read_meta_data(ephys_file.ap.with_suffix('.meta')) 1cdeakbhilfg
256 return spikeglx._get_fs_from_meta(meta) 1cdeakbhilfg
259def _save_timestamps_npy(ephys_file, tself_tref, sr):
260 # this is the file with self_time_secs, ref_time_secs output
261 file_sync = ephys_file.ap.parent.joinpath(ephys_file.ap.name.replace('.ap.', '.sync.') 1cdeakbhilfg
262 ).with_suffix('.npy')
263 np.save(file_sync, tself_tref) 1cdeakbhilfg
264 # this is the timestamps file
265 file_ts = ephys_file.ap.parent.joinpath(ephys_file.ap.name.replace('.ap.', '.timestamps.') 1cdeakbhilfg
266 ).with_suffix('.npy')
267 timestamps = np.copy(tself_tref) 1cdeakbhilfg
268 timestamps[:, 0] *= np.float64(sr) 1cdeakbhilfg
269 np.save(file_ts, timestamps) 1cdeakbhilfg
270 return [file_sync, file_ts] 1cdeakbhilfg
273def _check_diff_3b(sync):
274 """
275 Checks that the diff between consecutive sync pulses is below 150 PPM
276 Returns True on a pass result (all values below threshold)
277 """
278 THRESH_PPM = 150 1cdebfg
279 d = np.diff(sync.times[sync.polarities == 1]) 1cdebfg
280 dt = np.median(d) 1cdebfg
281 qc_pass = np.all(np.abs((d - dt) / dt * 1e6) < THRESH_PPM) 1cdebfg
282 if not qc_pass: 1cdebfg
283 _logger.error(f'Synchronizations bursts over {THRESH_PPM} ppm between sync pulses. ' 1b
284 'Sync using "exact" match between pulses.')
285 return qc_pass 1cdebfg