Coverage for ibllib/plots/misc.py: 60%

172 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1#!/usr/bin/env python 

2# -*- coding:utf-8 -*- 

3from math import pi 

4 

5import numpy as np 

6import matplotlib.pyplot as plt 

7import scipy 

8 

9import ibldsp as dsp 

10 

11 

12def wiggle(w, fs=1, gain=0.71, color='k', ax=None, fill=True, linewidth=0.5, t0=0, clip=2, sf=None, 

13 **kwargs): 

14 """ 

15 Matplotlib display of wiggle traces 

16 

17 :param w: 2D array (numpy array dimension nsamples, ntraces) 

18 :param fs: sampling frequency 

19 :param gain: display gain ; Note that if sf is given, gain is not used 

20 :param color: ('k') color of traces 

21 :param ax: (None) matplotlib axes object 

22 :param fill: (True) fill variable area above 0 

23 :param t0: (0) timestamp of the first sample 

24 :param sf: scaling factor ; if None, uses the gain / SQRT of waveform RMS 

25 :return: None 

26 """ 

27 nech, ntr = w.shape 1a

28 tscale = np.arange(nech) / fs 1a

29 if sf is None: 1a

30 sf = gain / np.sqrt(dsp.utils.rms(w.flatten())) 1a

31 

32 def insert_zeros(trace): 1a

33 # Insert zero locations in data trace and tt vector based on linear fit 

34 # Find zeros 

35 zc_idx = np.where(np.diff(np.signbit(trace)))[0] 1a

36 x1 = tscale[zc_idx] 1a

37 x2 = tscale[zc_idx + 1] 1a

38 y1 = trace[zc_idx] 1a

39 y2 = trace[zc_idx + 1] 1a

40 a = (y2 - y1) / (x2 - x1) 1a

41 tt_zero = x1 - y1 / a 1a

42 # split tt and trace 

43 tt_split = np.split(tscale, zc_idx + 1) 1a

44 trace_split = np.split(trace, zc_idx + 1) 1a

45 tt_zi = tt_split[0] 1a

46 trace_zi = trace_split[0] 1a

47 # insert zeros in tt and trace 

48 for i in range(len(tt_zero)): 1a

49 tt_zi = np.hstack( 1a

50 (tt_zi, np.array([tt_zero[i]]), tt_split[i + 1])) 

51 trace_zi = np.hstack( 1a

52 (trace_zi, np.zeros(1), trace_split[i + 1])) 

53 return trace_zi, tt_zi 1a

54 

55 if not ax: 1a

56 ax = plt.gca() 1a

57 for ntr in range(ntr): 1a

58 if fill: 1a

59 trace, t_trace = insert_zeros(w[:, ntr] * sf) 1a

60 if clip: 1a

61 trace = np.maximum(np.minimum(trace, clip), -clip) 1a

62 ax.fill_betweenx(t_trace + t0, ntr, trace + ntr, 1a

63 where=trace >= 0, 

64 facecolor=color, 

65 linewidth=linewidth) 

66 wplot = np.minimum(np.maximum(w[:, ntr] * sf, -clip), clip) 1a

67 ax.plot(wplot + ntr, tscale + t0, color, linewidth=linewidth, **kwargs) 1a

68 

69 ax.set_xlim(-1, ntr + 1) 1a

70 ax.set_ylim(tscale[0] + t0, tscale[-1] + t0) 1a

71 ax.set_ylabel('Time (s)') 1a

72 ax.set_xlabel('Trace') 1a

73 ax.invert_yaxis() 1a

74 

75 return ax 1a

76 

77 

78class Density: 

79 def __init__(self, w, fs=30_000, cmap='Greys_r', ax=None, taxis=0, title=None, gain=None, t0=0, unit='ms', **kwargs): 

80 """ 

81 Matplotlib display of traces as a density display using `imshow()`. 

82 

83 :param w: 2D array (numpy array dimension nsamples, ntraces) 

84 :param fs: sampling frequency (Hz). [default: 30000] 

85 :param cmap: Name of MPL colormap to use in `imshow()`. [default: 'Greys_r'] 

86 :param ax: Axis to plot in. If `None`, a new one is created. [default: `None`] 

87 :param taxis: Time axis of input array (w). [default: 0] 

88 :param title: Title to display on plot. [default: `None`] 

89 :param gain: Gain in dB to display. Note: overrides `vmin` and `vmax` kwargs to `imshow()`. 

90 Default: [`None` (auto)] 

91 :param t0: Time offset to display in seconds. [default: 0] 

92 :param kwargs: Key word arguments passed to `imshow()` 

93 :param t_scalar: 1e3 for ms (default), 1 for s 

94 :return: None 

95 """ 

96 w = w.reshape(w.shape[0], -1) 

97 t_scalar = 1e3 if unit == 'ms' else 1 

98 if taxis == 0: 

99 nech, ntr = w.shape 

100 tscale = np.array([0, nech - 1]) / fs * t_scalar 

101 extent = [-0.5, ntr - 0.5, tscale[1] + t0 * t_scalar, tscale[0] + t0 * t_scalar] 

102 xlabel, ylabel, origin = ('Trace', f'Time ({unit})', 'upper') 

103 elif taxis == 1: 

104 ntr, nech = w.shape 

105 tscale = np.array([0, nech - 1]) / fs * t_scalar 

106 extent = [tscale[0] + t0 * t_scalar, tscale[1] + t0 * t_scalar, -0.5, ntr - 0.5] 

107 ylabel, xlabel, origin = ('Trace', f'Time ({unit})', 'lower') 

108 if ax is None: 

109 self.figure, ax = plt.subplots() 

110 else: 

111 self.figure = ax.get_figure() 

112 if gain: 

113 kwargs["vmin"] = - 4 * (10 ** (gain / 20)) 

114 kwargs["vmax"] = -kwargs["vmin"] 

115 self.im = ax.imshow(w, aspect='auto', cmap=cmap, extent=extent, origin=origin, **kwargs) 

116 ax.set_ylabel(ylabel) 

117 ax.set_xlabel(xlabel) 

118 self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press) 

119 self.ax = ax 

120 self.title = title or None 

121 

122 def on_key_press(self, event): 

123 if event.key == 'ctrl+a': 

124 self.im.set_data(self.im.get_array() * np.sqrt(2)) 

125 elif event.key == 'ctrl+z': 

126 self.im.set_data(self.im.get_array() / np.sqrt(2)) 

127 else: 

128 return 

129 self.figure.canvas.draw() 

130 

131 

132class Traces: 

133 def __init__(self, w, fs=1, gain=0.71, color='k', ax=None, linewidth=0.5, t0=0, **kwargs): 

134 """ 

135 Matplotlib display of traces as a density display 

136 

137 :param w: 2D array (numpy array dimension nsamples, ntraces) 

138 :param fs: sampling frequency (Hz) 

139 :param ax: axis to plot in 

140 :return: None 

141 """ 

142 w = w.reshape(w.shape[0], -1) 1a

143 nech, ntr = w.shape 1a

144 tscale = np.arange(nech) / fs * 1e3 1a

145 sf = gain / dsp.utils.rms(w.flatten()) / 2 1a

146 if ax is None: 1a

147 self.figure, ax = plt.subplots() 1a

148 else: 

149 self.figure = ax.get_figure() 

150 self.plot = ax.plot(w * sf + np.arange(ntr), tscale + t0, color, 1a

151 linewidth=linewidth, **kwargs) 

152 ax.set_xlim(-1, ntr + 1) 1a

153 ax.set_ylim(tscale[0] + t0, tscale[-1] + t0) 1a

154 ax.set_ylabel('Time (ms)') 1a

155 ax.set_xlabel('Trace') 1a

156 ax.invert_yaxis() 1a

157 self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press) 1a

158 self.ax = ax 1a

159 

160 def on_key_press(self, event): 

161 if event.key == 'ctrl+a': 

162 for i, l in enumerate(self.plot): 

163 l.set_xdata((l.get_xdata() - i) * np.sqrt(2) + i) 

164 elif event.key == 'ctrl+z': 

165 for i, l in enumerate(self.plot): 

166 l.set_xdata((l.get_xdata() - i) / np.sqrt(2) + i) 

167 else: 

168 return 

169 self.figure.canvas.draw() 

170 

171 

172def squares(tscale, polarity, ax=None, yrange=[-1, 1], **kwargs): 

173 """ 

174 Matplotlib display of rising and falling fronts in a square-wave pattern 

175 

176 :param tscale: time of indices of fronts 

177 :param polarity: polarity of front (1: rising, -1:falling) 

178 :param ax: matplotlib axes object 

179 :return: None 

180 """ 

181 if not ax: 1fgde

182 ax = plt.gca() 

183 isort = np.argsort(tscale) 1fgde

184 tscale = tscale[isort] 1fgde

185 polarity = polarity[isort] 1fgde

186 f = np.tile(polarity, (2, 1)) 1fgde

187 t = np.concatenate((tscale, np.r_[tscale[1:], tscale[-1]])).reshape(2, f.shape[1]) 1fgde

188 ydata = f.transpose().ravel() 1fgde

189 ydata = (ydata + 1) / 2 * (yrange[1] - yrange[0]) + yrange[0] 1fgde

190 ax.plot(t.transpose().ravel(), ydata, **kwargs) 1fgde

191 

192 

193def vertical_lines(x, ymin=0, ymax=1, ax=None, **kwargs): 

194 """ 

195 From an x vector, draw separate vertical lines at each x location ranging from ymin to ymax 

196 

197 :param x: numpy array vector of x values where to display lines 

198 :param ymin: lower end of the lines (scalar) 

199 :param ymax: higher end of the lines (scalar) 

200 :param ax: (optional) matplotlib axis instance 

201 :return: None 

202 """ 

203 x = np.tile(x, (3, 1)) 1hde

204 x[2, :] = np.nan 1hde

205 y = np.zeros_like(x) 1hde

206 y[0, :] = ymin 1hde

207 y[1, :] = ymax 1hde

208 y[2, :] = np.nan 1hde

209 if not ax: 1hde

210 ax = plt.gca() 

211 ax.plot(x.T.flatten(), y.T.flatten(), **kwargs) 1hde

212 

213 

214def spectrum(w, fs, smooth=None, unwrap=True, axis=0, **kwargs): 

215 """ 

216 Display spectral density of a signal along a given dimension 

217 spectrum(w, fs) 

218 :param w: signal 

219 :param fs: sampling frequency (Hz) 

220 :param smooth: (None) frequency samples to smooth over 

221 :param unwrap: (True) unwraps the phase specrum 

222 :param axis: axis on which to compute the FFT 

223 :param kwargs: plot arguments to be passed to matplotlib 

224 :return: matplotlib axes 

225 """ 

226 axis = 0 

227 smooth = None 

228 unwrap = True 

229 

230 ns = w.shape[axis] 

231 fscale = dsp.fourier.fscale(ns, 1 / fs, one_sided=True) 

232 W = scipy.fft.rfft(w, axis=axis) 

233 amp = 20 * np.log10(np.abs(W)) 

234 phi = np.angle(W) 

235 

236 if unwrap: 

237 phi = np.unwrap(phi) 

238 

239 if smooth: 

240 nf = np.round(smooth / fscale[1] / 2) * 2 + 1 

241 amp = scipy.signal.medfilt(amp, nf) 

242 phi = scipy.signal.medfilt(phi, nf) 

243 

244 fig, ax = plt.subplots(2, 1, sharex=True) 

245 ax[0].plot(fscale, amp, **kwargs) 

246 ax[1].plot(fscale, phi, **kwargs) 

247 

248 ax[0].set_title('Spectral Density (dB rel to amplitude.Hz^-0.5)') 

249 ax[0].set_ylabel('Amp (dB)') 

250 ax[1].set_ylabel('Phase (rad)') 

251 ax[1].set_xlabel('Frequency (Hz)') 

252 return ax 

253 

254 

255def color_cycle(ind=None): 

256 """ 

257 Gets the matplotlib color-cycle as RGB numpy array of floats between 0 and 1 

258 :return: 

259 """ 

260 # import matplotlib as mpl 

261 # c = np.uint32(np.array([int(c['color'][1:], 16) for c in mpl.rcParams['axes.prop_cycle']])) 

262 # c = np.double(np.flip(np.reshape(c.view(np.uint8), (c.size, 4))[:, :3], 1)) / 255 

263 c = np.array([[0.12156863, 0.46666667, 0.70588235], 

264 [1., 0.49803922, 0.05490196], 

265 [0.17254902, 0.62745098, 0.17254902], 

266 [0.83921569, 0.15294118, 0.15686275], 

267 [0.58039216, 0.40392157, 0.74117647], 

268 [0.54901961, 0.3372549, 0.29411765], 

269 [0.89019608, 0.46666667, 0.76078431], 

270 [0.49803922, 0.49803922, 0.49803922], 

271 [0.7372549, 0.74117647, 0.13333333], 

272 [0.09019608, 0.74509804, 0.81176471]]) 

273 if ind is None: 

274 return c 

275 else: 

276 return tuple(c[ind % c.shape[0], :]) 

277 

278 

279def starplot(labels, radii, ticks=None, ax=None, ylim=None, color=None, title=None): 

280 """ 

281 Function to create a star plot (also known as a spider plot, polar plot, or radar chart). 

282 

283 Parameters: 

284 labels (list): A list of labels for the variables to be plotted along the axes. 

285 radii (numpy array): The values to be plotted for each variable. 

286 ticks (numpy array, optional): A list of values to be used for the radial ticks. 

287 If None, 5 ticks will be created between the minimum and maximum values of radii. 

288 ax (matplotlib.axes._subplots.PolarAxesSubplot, optional): A polar axis object to plot on. 

289 If None, a new figure and axis will be created. 

290 ylim (tuple, optional): A tuple specifying the upper and lower limits of the y-axis. 

291 If None, the limits will be set to the minimum and maximum values of radii. 

292 color (str, optional): A string specifying the color of the plot. 

293 If None, the color will be determined by the current matplotlib color cycle. 

294 title (str, optional): A string specifying the title of the plot. 

295 If None, no title will be displayed. 

296 

297 Returns: 

298 ax (matplotlib.axes._subplots.PolarAxesSubplot): The polar axis object containing the plot. 

299 """ 

300 

301 # What will be the angle of each axis in the plot? (we divide the plot / number of variable) 

302 angles = [n / float(radii.size) * 2 * pi for n in range(radii.size)] 1b

303 angles += angles[:1] 1b

304 

305 if ax is None: 1b

306 # Initialise the spider plot 

307 fig = plt.figure(figsize=(8, 8)) 1b

308 ax = fig.add_subplot(111, polar=True) 1b

309 # If you want the first axis to be on top: 

310 ax.set_theta_offset(pi / 2) 1b

311 ax.set_theta_direction(-1) 1b

312 # Draw one axe per variable + add labels 

313 plt.xticks(angles[:-1], labels) 1b

314 # Draw ylabels 

315 ax.set_rlabel_position(0) 1b

316 if ylim is None: 1b

317 ylim = (0, np.max(radii)) 

318 if ticks is None: 1b

319 ticks = np.linspace(ylim[0], ylim[1], 5) 1b

320 plt.yticks(ticks, [f'{t:2.2f}' for t in ticks], color="grey", size=7) 1b

321 plt.ylim(ylim) 1b

322 

323 r = np.r_[radii, radii[0]] 1b

324 p = ax.plot(angles, r, linewidth=1, linestyle='solid', label="group A", color=color) 1b

325 ax.fill(angles, r, alpha=0.1, color=p[0].get_color()) 1b

326 if title is not None: 1b

327 ax.set_title(title) 

328 return ax 1b