Coverage for brainbox/metrics/electrode_drift.py: 12%

32 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1import numpy as np 

2 

3from ibldsp import smooth, utils, fourier 

4from iblutil.numerical import bincount2D 

5 

6 

7def estimate_drift(spike_times, spike_amps, spike_depths, display=False): 

8 """ 

9 Electrode drift for spike sorted data. 

10 :param spike_times: 

11 :param spike_amps: 

12 :param spike_depths: 

13 :param display: 

14 :return: drift (ntimes vector) in input units (usually um) 

15 :return: ts (ntimes vector) time scale in seconds 

16 

17 """ 

18 # binning parameters 

19 DT_SECS = 1 # output sampling rate of the depth estimation (seconds) 

20 DEPTH_BIN_UM = 2 # binning parameter for depth 

21 AMP_BIN_LOG10 = [1.25, 3.25] # binning parameter for amplitudes (log10 in uV) 

22 N_AMP = 1 # number of amplitude bins 

23 

24 NXCORR = 50 # positive and negative lag in depth samples to look for depth 

25 NT_SMOOTH = 9 # length of the Gaussian smoothing window in samples (DT_SECS rate) 

26 

27 # experimental: try the amp with a log scale 

28 nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM)) 

29 tmin, tmax = (np.min(spike_times), np.max(spike_times)) 

30 nt = int((np.ceil(tmax) - np.floor(tmin)) / DT_SECS) 

31 

32 # 3d histogram of spikes along amplitude, depths and time 

33 atd_hist = np.zeros((N_AMP, nt, nd), dtype=np.single) 

34 abins = (np.log10(spike_amps * 1e6) - AMP_BIN_LOG10[0]) / np.diff(AMP_BIN_LOG10) * N_AMP 

35 abins = np.minimum(np.maximum(0, np.floor(abins)), N_AMP - 1) 

36 

37 for i, abin in enumerate(np.unique(abins)): 

38 inds = np.where(np.logical_and(abins == abin, ~np.isnan(spike_depths)))[0] 

39 a, _, _ = bincount2D(spike_depths[inds], spike_times[inds], DEPTH_BIN_UM, DT_SECS, 

40 [0, nd * DEPTH_BIN_UM], [np.floor(tmin), np.ceil(tmax)]) 

41 atd_hist[i] = a[:-1, :-1] 

42 

43 fdscale = np.abs(np.fft.fftfreq(nd, d=DEPTH_BIN_UM)) 

44 # k-filter along the depth direction 

45 lp = fourier._freq_vector(fdscale, np.array([1 / 16, 1 / 8]), typ='lp') 

46 # compute the depth lag by xcorr 

47 # to experiment: LP the fft for a better tracking ? 

48 atd_ = np.fft.fft(atd_hist, axis=-1) 

49 # xcorrelation against reference 

50 xcorr = np.real(np.fft.ifft(lp * atd_ * np.conj(np.median(atd_, axis=1))[:, np.newaxis, :])) 

51 xcorr = np.sum(xcorr, axis=0) 

52 xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]] 

53 xcorr = xcorr - np.mean(xcorr, 1)[:, np.newaxis] 

54 # import easyqc 

55 # easyqc.viewdata(xcorr - np.mean(xcorr, 1)[:, np.newaxis], DEPTH_BIN_UM, title='xcor') 

56 

57 # to experiment: parabolic fit to get max values 

58 raw_drift = (utils.parabolic_max(xcorr)[0] - NXCORR) * DEPTH_BIN_UM 

59 drift = smooth.rolling_window(raw_drift, window_len=NT_SMOOTH, window='hanning') 

60 drift = drift - np.mean(drift) 

61 ts = DT_SECS * np.arange(drift.size) 

62 if display: # pragma: no cover 

63 import matplotlib.pyplot as plt 

64 from brainbox.plot import driftmap 

65 fig1, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]}, 

66 sharex=True, figsize=(20, 10)) 

67 axs[0].plot(ts, drift) 

68 driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1]) 

69 axs[1].set_ylim([- NXCORR * 2, 3840 + NXCORR * 2]) 

70 fig2, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]}, 

71 sharex=True, figsize=(20, 10)) 

72 axs[0].plot(ts, drift) 

73 dd = np.interp(spike_times, ts, drift) 

74 driftmap(spike_times, spike_depths - dd, t_bin=0.1, d_bin=5, ax=axs[1]) 

75 axs[1].set_ylim([- NXCORR * 2, 3840 + NXCORR * 2]) 

76 return drift, ts, [fig1, fig2] 

77 

78 return drift, ts