Coverage for ibllib/plots/misc.py: 22%

153 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1#!/usr/bin/env python 

2# -*- coding:utf-8 -*- 

3import numpy as np 

4import matplotlib.pyplot as plt 

5import scipy 

6 

7import ibldsp as dsp 

8 

9 

10def wiggle(w, fs=1, gain=0.71, color='k', ax=None, fill=True, linewidth=0.5, t0=0, clip=2, sf=None, 

11 **kwargs): 

12 """ 

13 Matplotlib display of wiggle traces 

14 

15 :param w: 2D array (numpy array dimension nsamples, ntraces) 

16 :param fs: sampling frequency 

17 :param gain: display gain ; Note that if sf is given, gain is not used 

18 :param color: ('k') color of traces 

19 :param ax: (None) matplotlib axes object 

20 :param fill: (True) fill variable area above 0 

21 :param t0: (0) timestamp of the first sample 

22 :param sf: scaling factor ; if None, uses the gain / SQRT of waveform RMS 

23 :return: None 

24 """ 

25 nech, ntr = w.shape 

26 tscale = np.arange(nech) / fs 

27 if sf is None: 

28 sf = gain / np.sqrt(dsp.utils.rms(w.flatten())) 

29 

30 def insert_zeros(trace): 

31 # Insert zero locations in data trace and tt vector based on linear fit 

32 # Find zeros 

33 zc_idx = np.where(np.diff(np.signbit(trace)))[0] 

34 x1 = tscale[zc_idx] 

35 x2 = tscale[zc_idx + 1] 

36 y1 = trace[zc_idx] 

37 y2 = trace[zc_idx + 1] 

38 a = (y2 - y1) / (x2 - x1) 

39 tt_zero = x1 - y1 / a 

40 # split tt and trace 

41 tt_split = np.split(tscale, zc_idx + 1) 

42 trace_split = np.split(trace, zc_idx + 1) 

43 tt_zi = tt_split[0] 

44 trace_zi = trace_split[0] 

45 # insert zeros in tt and trace 

46 for i in range(len(tt_zero)): 

47 tt_zi = np.hstack( 

48 (tt_zi, np.array([tt_zero[i]]), tt_split[i + 1])) 

49 trace_zi = np.hstack( 

50 (trace_zi, np.zeros(1), trace_split[i + 1])) 

51 return trace_zi, tt_zi 

52 

53 if not ax: 

54 ax = plt.gca() 

55 for ntr in range(ntr): 

56 if fill: 

57 trace, t_trace = insert_zeros(w[:, ntr] * sf) 

58 if clip: 

59 trace = np.maximum(np.minimum(trace, clip), -clip) 

60 ax.fill_betweenx(t_trace + t0, ntr, trace + ntr, 

61 where=trace >= 0, 

62 facecolor=color, 

63 linewidth=linewidth) 

64 wplot = np.minimum(np.maximum(w[:, ntr] * sf, -clip), clip) 

65 ax.plot(wplot + ntr, tscale + t0, color, linewidth=linewidth, **kwargs) 

66 

67 ax.set_xlim(-1, ntr + 1) 

68 ax.set_ylim(tscale[0] + t0, tscale[-1] + t0) 

69 ax.set_ylabel('Time (s)') 

70 ax.set_xlabel('Trace') 

71 ax.invert_yaxis() 

72 

73 return ax 

74 

75 

76class Density: 

77 def __init__(self, w, fs=30_000, cmap='Greys_r', ax=None, taxis=0, title=None, gain=None, t0=0, unit='ms', **kwargs): 

78 """ 

79 Matplotlib display of traces as a density display using `imshow()`. 

80 

81 :param w: 2D array (numpy array dimension nsamples, ntraces) 

82 :param fs: sampling frequency (Hz). [default: 30000] 

83 :param cmap: Name of MPL colormap to use in `imshow()`. [default: 'Greys_r'] 

84 :param ax: Axis to plot in. If `None`, a new one is created. [default: `None`] 

85 :param taxis: Time axis of input array (w). [default: 0] 

86 :param title: Title to display on plot. [default: `None`] 

87 :param gain: Gain in dB to display. Note: overrides `vmin` and `vmax` kwargs to `imshow()`. 

88 Default: [`None` (auto)] 

89 :param t0: Time offset to display in seconds. [default: 0] 

90 :param kwargs: Key word arguments passed to `imshow()` 

91 :param t_scalar: 1e3 for ms (default), 1 for s 

92 :return: None 

93 """ 

94 w = w.reshape(w.shape[0], -1) 

95 t_scalar = 1e3 if unit == 'ms' else 1 

96 if taxis == 0: 

97 nech, ntr = w.shape 

98 tscale = np.array([0, nech - 1]) / fs * t_scalar 

99 extent = [-0.5, ntr - 0.5, tscale[1] + t0 * t_scalar, tscale[0] + t0 * t_scalar] 

100 xlabel, ylabel, origin = ('Trace', f'Time ({unit})', 'upper') 

101 elif taxis == 1: 

102 ntr, nech = w.shape 

103 tscale = np.array([0, nech - 1]) / fs * t_scalar 

104 extent = [tscale[0] + t0 * t_scalar, tscale[1] + t0 * t_scalar, -0.5, ntr - 0.5] 

105 ylabel, xlabel, origin = ('Trace', f'Time ({unit})', 'lower') 

106 if ax is None: 

107 self.figure, ax = plt.subplots() 

108 else: 

109 self.figure = ax.get_figure() 

110 if gain: 

111 kwargs["vmin"] = - 4 * (10 ** (gain / 20)) 

112 kwargs["vmax"] = -kwargs["vmin"] 

113 self.im = ax.imshow(w, aspect='auto', cmap=cmap, extent=extent, origin=origin, **kwargs) 

114 ax.set_ylabel(ylabel) 

115 ax.set_xlabel(xlabel) 

116 self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press) 

117 self.ax = ax 

118 self.title = title or None 

119 

120 def on_key_press(self, event): 

121 if event.key == 'ctrl+a': 

122 self.im.set_data(self.im.get_array() * np.sqrt(2)) 

123 elif event.key == 'ctrl+z': 

124 self.im.set_data(self.im.get_array() / np.sqrt(2)) 

125 else: 

126 return 

127 self.figure.canvas.draw() 

128 

129 

130class Traces: 

131 def __init__(self, w, fs=1, gain=0.71, color='k', ax=None, linewidth=0.5, t0=0, **kwargs): 

132 """ 

133 Matplotlib display of traces as a density display 

134 

135 :param w: 2D array (numpy array dimension nsamples, ntraces) 

136 :param fs: sampling frequency (Hz) 

137 :param ax: axis to plot in 

138 :return: None 

139 """ 

140 w = w.reshape(w.shape[0], -1) 

141 nech, ntr = w.shape 

142 tscale = np.arange(nech) / fs * 1e3 

143 sf = gain / dsp.utils.rms(w.flatten()) / 2 

144 if ax is None: 

145 self.figure, ax = plt.subplots() 

146 else: 

147 self.figure = ax.get_figure() 

148 self.plot = ax.plot(w * sf + np.arange(ntr), tscale + t0, color, 

149 linewidth=linewidth, **kwargs) 

150 ax.set_xlim(-1, ntr + 1) 

151 ax.set_ylim(tscale[0] + t0, tscale[-1] + t0) 

152 ax.set_ylabel('Time (ms)') 

153 ax.set_xlabel('Trace') 

154 ax.invert_yaxis() 

155 self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press) 

156 self.ax = ax 

157 

158 def on_key_press(self, event): 

159 if event.key == 'ctrl+a': 

160 for i, l in enumerate(self.plot): 

161 l.set_xdata((l.get_xdata() - i) * np.sqrt(2) + i) 

162 elif event.key == 'ctrl+z': 

163 for i, l in enumerate(self.plot): 

164 l.set_xdata((l.get_xdata() - i) / np.sqrt(2) + i) 

165 else: 

166 return 

167 self.figure.canvas.draw() 

168 

169 

170def squares(tscale, polarity, ax=None, yrange=[-1, 1], **kwargs): 

171 """ 

172 Matplotlib display of rising and falling fronts in a square-wave pattern 

173 

174 :param tscale: time of indices of fronts 

175 :param polarity: polarity of front (1: rising, -1:falling) 

176 :param ax: matplotlib axes object 

177 :return: None 

178 """ 

179 if not ax: 1deab

180 ax = plt.gca() 

181 isort = np.argsort(tscale) 1deab

182 tscale = tscale[isort] 1deab

183 polarity = polarity[isort] 1deab

184 f = np.tile(polarity, (2, 1)) 1deab

185 t = np.concatenate((tscale, np.r_[tscale[1:], tscale[-1]])).reshape(2, f.shape[1]) 1deab

186 ydata = f.transpose().ravel() 1deab

187 ydata = (ydata + 1) / 2 * (yrange[1] - yrange[0]) + yrange[0] 1deab

188 ax.plot(t.transpose().ravel(), ydata, **kwargs) 1deab

189 

190 

191def vertical_lines(x, ymin=0, ymax=1, ax=None, **kwargs): 

192 """ 

193 From an x vector, draw separate vertical lines at each x location ranging from ymin to ymax 

194 

195 :param x: numpy array vector of x values where to display lines 

196 :param ymin: lower end of the lines (scalar) 

197 :param ymax: higher end of the lines (scalar) 

198 :param ax: (optional) matplotlib axis instance 

199 :return: None 

200 """ 

201 x = np.tile(x, (3, 1)) 1fab

202 x[2, :] = np.nan 1fab

203 y = np.zeros_like(x) 1fab

204 y[0, :] = ymin 1fab

205 y[1, :] = ymax 1fab

206 y[2, :] = np.nan 1fab

207 if not ax: 1fab

208 ax = plt.gca() 

209 ax.plot(x.T.flatten(), y.T.flatten(), **kwargs) 1fab

210 

211 

212def spectrum(w, fs, smooth=None, unwrap=True, axis=0, **kwargs): 

213 """ 

214 Display spectral density of a signal along a given dimension 

215 spectrum(w, fs) 

216 :param w: signal 

217 :param fs: sampling frequency (Hz) 

218 :param smooth: (None) frequency samples to smooth over 

219 :param unwrap: (True) unwraps the phase specrum 

220 :param axis: axis on which to compute the FFT 

221 :param kwargs: plot arguments to be passed to matplotlib 

222 :return: matplotlib axes 

223 """ 

224 axis = 0 

225 smooth = None 

226 unwrap = True 

227 

228 ns = w.shape[axis] 

229 fscale = dsp.fourier.fscale(ns, 1 / fs, one_sided=True) 

230 W = scipy.fft.rfft(w, axis=axis) 

231 amp = 20 * np.log10(np.abs(W)) 

232 phi = np.angle(W) 

233 

234 if unwrap: 

235 phi = np.unwrap(phi) 

236 

237 if smooth: 

238 nf = np.round(smooth / fscale[1] / 2) * 2 + 1 

239 amp = scipy.signal.medfilt(amp, nf) 

240 phi = scipy.signal.medfilt(phi, nf) 

241 

242 fig, ax = plt.subplots(2, 1, sharex=True) 

243 ax[0].plot(fscale, amp, **kwargs) 

244 ax[1].plot(fscale, phi, **kwargs) 

245 

246 ax[0].set_title('Spectral Density (dB rel to amplitude.Hz^-0.5)') 

247 ax[0].set_ylabel('Amp (dB)') 

248 ax[1].set_ylabel('Phase (rad)') 

249 ax[1].set_xlabel('Frequency (Hz)') 

250 return ax 

251 

252 

253def color_cycle(ind=None): 

254 """ 

255 Gets the matplotlib color-cycle as RGB numpy array of floats between 0 and 1 

256 :return: 

257 """ 

258 # import matplotlib as mpl 

259 # c = np.uint32(np.array([int(c['color'][1:], 16) for c in mpl.rcParams['axes.prop_cycle']])) 

260 # c = np.double(np.flip(np.reshape(c.view(np.uint8), (c.size, 4))[:, :3], 1)) / 255 

261 c = np.array([[0.12156863, 0.46666667, 0.70588235], 

262 [1., 0.49803922, 0.05490196], 

263 [0.17254902, 0.62745098, 0.17254902], 

264 [0.83921569, 0.15294118, 0.15686275], 

265 [0.58039216, 0.40392157, 0.74117647], 

266 [0.54901961, 0.3372549, 0.29411765], 

267 [0.89019608, 0.46666667, 0.76078431], 

268 [0.49803922, 0.49803922, 0.49803922], 

269 [0.7372549, 0.74117647, 0.13333333], 

270 [0.09019608, 0.74509804, 0.81176471]]) 

271 if ind is None: 

272 return c 

273 else: 

274 return tuple(c[ind % c.shape[0], :]) 

275 

276 

277if __name__ == "__main__": 

278 w = np.random.rand(500, 40) - 0.5 

279 wiggle(w, fs=30000) 

280 Traces(w, fs=30000, color='r')