Coverage for brainbox/metrics/electrode_drift.py: 12%
32 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 09:55 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 09:55 +0000
1import numpy as np
3from ibldsp import smooth, utils, fourier
4from iblutil.numerical import bincount2D
7def estimate_drift(spike_times, spike_amps, spike_depths, display=False):
8 """
9 Electrode drift for spike sorted data.
10 :param spike_times:
11 :param spike_amps:
12 :param spike_depths:
13 :param display:
14 :return: drift (ntimes vector) in input units (usually um)
15 :return: ts (ntimes vector) time scale in seconds
17 """
18 # binning parameters
19 DT_SECS = 1 # output sampling rate of the depth estimation (seconds)
20 DEPTH_BIN_UM = 2 # binning parameter for depth
21 AMP_BIN_LOG10 = [1.25, 3.25] # binning parameter for amplitudes (log10 in uV)
22 N_AMP = 1 # number of amplitude bins
24 NXCORR = 50 # positive and negative lag in depth samples to look for depth
25 NT_SMOOTH = 9 # length of the Gaussian smoothing window in samples (DT_SECS rate)
27 # experimental: try the amp with a log scale
28 nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM))
29 tmin, tmax = (np.min(spike_times), np.max(spike_times))
30 nt = int((np.ceil(tmax) - np.floor(tmin)) / DT_SECS)
32 # 3d histogram of spikes along amplitude, depths and time
33 atd_hist = np.zeros((N_AMP, nt, nd), dtype=np.single)
34 abins = (np.log10(spike_amps * 1e6) - AMP_BIN_LOG10[0]) / np.diff(AMP_BIN_LOG10) * N_AMP
35 abins = np.minimum(np.maximum(0, np.floor(abins)), N_AMP - 1)
37 for i, abin in enumerate(np.unique(abins)):
38 inds = np.where(np.logical_and(abins == abin, ~np.isnan(spike_depths)))[0]
39 a, _, _ = bincount2D(spike_depths[inds], spike_times[inds], DEPTH_BIN_UM, DT_SECS,
40 [0, nd * DEPTH_BIN_UM], [np.floor(tmin), np.ceil(tmax)])
41 atd_hist[i] = a[:-1, :-1]
43 fdscale = np.abs(np.fft.fftfreq(nd, d=DEPTH_BIN_UM))
44 # k-filter along the depth direction
45 lp = fourier._freq_vector(fdscale, np.array([1 / 16, 1 / 8]), typ='lp')
46 # compute the depth lag by xcorr
47 # to experiment: LP the fft for a better tracking ?
48 atd_ = np.fft.fft(atd_hist, axis=-1)
49 # xcorrelation against reference
50 xcorr = np.real(np.fft.ifft(lp * atd_ * np.conj(np.median(atd_, axis=1))[:, np.newaxis, :]))
51 xcorr = np.sum(xcorr, axis=0)
52 xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]]
53 xcorr = xcorr - np.mean(xcorr, 1)[:, np.newaxis]
54 # import easyqc
55 # easyqc.viewdata(xcorr - np.mean(xcorr, 1)[:, np.newaxis], DEPTH_BIN_UM, title='xcor')
57 # to experiment: parabolic fit to get max values
58 raw_drift = (utils.parabolic_max(xcorr)[0] - NXCORR) * DEPTH_BIN_UM
59 drift = smooth.rolling_window(raw_drift, window_len=NT_SMOOTH, window='hanning')
60 drift = drift - np.mean(drift)
61 ts = DT_SECS * np.arange(drift.size)
62 if display: # pragma: no cover
63 import matplotlib.pyplot as plt
64 from brainbox.plot import driftmap
65 fig1, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]},
66 sharex=True, figsize=(20, 10))
67 axs[0].plot(ts, drift)
68 driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1])
69 axs[1].set_ylim([- NXCORR * 2, 3840 + NXCORR * 2])
70 fig2, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]},
71 sharex=True, figsize=(20, 10))
72 axs[0].plot(ts, drift)
73 dd = np.interp(spike_times, ts, drift)
74 driftmap(spike_times, spike_depths - dd, t_bin=0.1, d_bin=5, ax=axs[1])
75 axs[1].set_ylim([- NXCORR * 2, 3840 + NXCORR * 2])
76 return drift, ts, [fig1, fig2]
78 return drift, ts