Coverage for ibllib/plots/misc.py: 32%
149 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1#!/usr/bin/env python
2# -*- coding:utf-8 -*-
3import numpy as np
4import matplotlib.pyplot as plt
5import scipy
7import neurodsp as dsp
10def wiggle(w, fs=1, gain=0.71, color='k', ax=None, fill=True, linewidth=0.5, t0=0, clip=2, sf=None,
11 **kwargs):
12 """
13 Matplotlib display of wiggle traces
15 :param w: 2D array (numpy array dimension nsamples, ntraces)
16 :param fs: sampling frequency
17 :param gain: display gain ; Note that if sf is given, gain is not used
18 :param color: ('k') color of traces
19 :param ax: (None) matplotlib axes object
20 :param fill: (True) fill variable area above 0
21 :param t0: (0) timestamp of the first sample
22 :param sf: scaling factor ; if None, uses the gain / SQRT of waveform RMS
23 :return: None
24 """
25 nech, ntr = w.shape
26 tscale = np.arange(nech) / fs
27 if sf is None:
28 sf = gain / np.sqrt(dsp.utils.rms(w.flatten()))
30 def insert_zeros(trace):
31 # Insert zero locations in data trace and tt vector based on linear fit
32 # Find zeros
33 zc_idx = np.where(np.diff(np.signbit(trace)))[0]
34 x1 = tscale[zc_idx]
35 x2 = tscale[zc_idx + 1]
36 y1 = trace[zc_idx]
37 y2 = trace[zc_idx + 1]
38 a = (y2 - y1) / (x2 - x1)
39 tt_zero = x1 - y1 / a
40 # split tt and trace
41 tt_split = np.split(tscale, zc_idx + 1)
42 trace_split = np.split(trace, zc_idx + 1)
43 tt_zi = tt_split[0]
44 trace_zi = trace_split[0]
45 # insert zeros in tt and trace
46 for i in range(len(tt_zero)):
47 tt_zi = np.hstack(
48 (tt_zi, np.array([tt_zero[i]]), tt_split[i + 1]))
49 trace_zi = np.hstack(
50 (trace_zi, np.zeros(1), trace_split[i + 1]))
51 return trace_zi, tt_zi
53 if not ax:
54 ax = plt.gca()
55 for ntr in range(ntr):
56 if fill:
57 trace, t_trace = insert_zeros(w[:, ntr] * sf)
58 if clip:
59 trace = np.maximum(np.minimum(trace, clip), -clip)
60 ax.fill_betweenx(t_trace + t0, ntr, trace + ntr,
61 where=trace >= 0,
62 facecolor=color,
63 linewidth=linewidth)
64 wplot = np.minimum(np.maximum(w[:, ntr] * sf, -clip), clip)
65 ax.plot(wplot + ntr, tscale + t0, color, linewidth=linewidth, **kwargs)
67 ax.set_xlim(-1, ntr + 1)
68 ax.set_ylim(tscale[0] + t0, tscale[-1] + t0)
69 ax.set_ylabel('Time (s)')
70 ax.set_xlabel('Trace')
71 ax.invert_yaxis()
73 return ax
76class Density:
77 def __init__(self, w, fs=1, cmap='Greys_r', ax=None, taxis=0, title=None, **kwargs):
78 """
79 Matplotlib display of traces as a density display
81 :param w: 2D array (numpy array dimension nsamples, ntraces)
82 :param fs: sampling frequency (Hz)
83 :param ax: axis to plot in
84 :return: None
85 """
86 w = w.reshape(w.shape[0], -1) 1d
87 if taxis == 0: 1d
88 nech, ntr = w.shape
89 tscale = np.array([0, nech - 1]) / fs * 1e3
90 extent = [-0.5, ntr - 0.5, tscale[1], tscale[0]]
91 xlabel, ylabel, origin = ('Trace', 'Time (ms)', 'upper')
92 elif taxis == 1: 1d
93 ntr, nech = w.shape 1d
94 tscale = np.array([0, nech - 1]) / fs * 1e3 1d
95 extent = [tscale[0], tscale[1], -0.5, ntr - 0.5] 1d
96 ylabel, xlabel, origin = ('Trace', 'Time (ms)', 'lower') 1d
97 if ax is None: 1d
98 self.figure, ax = plt.subplots()
99 else:
100 self.figure = ax.get_figure() 1d
101 self.im = ax.imshow(w, aspect='auto', cmap=cmap, extent=extent, origin=origin, **kwargs) 1d
102 ax.set_ylabel(ylabel) 1d
103 ax.set_xlabel(xlabel) 1d
104 self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press) 1d
105 self.ax = ax 1d
106 self.title = title or None 1d
108 def on_key_press(self, event):
109 if event.key == 'ctrl+a':
110 self.im.set_data(self.im.get_array() * np.sqrt(2))
111 elif event.key == 'ctrl+z':
112 self.im.set_data(self.im.get_array() / np.sqrt(2))
113 else:
114 return
115 self.figure.canvas.draw()
118class Traces:
119 def __init__(self, w, fs=1, gain=0.71, color='k', ax=None, linewidth=0.5, t0=0, **kwargs):
120 """
121 Matplotlib display of traces as a density display
123 :param w: 2D array (numpy array dimension nsamples, ntraces)
124 :param fs: sampling frequency (Hz)
125 :param ax: axis to plot in
126 :return: None
127 """
128 w = w.reshape(w.shape[0], -1)
129 nech, ntr = w.shape
130 tscale = np.arange(nech) / fs * 1e3
131 sf = gain / dsp.utils.rms(w.flatten()) / 2
132 if ax is None:
133 self.figure, ax = plt.subplots()
134 else:
135 self.figure = ax.get_figure()
136 self.plot = ax.plot(w * sf + np.arange(ntr), tscale + t0, color,
137 linewidth=linewidth, **kwargs)
138 ax.set_xlim(-1, ntr + 1)
139 ax.set_ylim(tscale[0] + t0, tscale[-1] + t0)
140 ax.set_ylabel('Time (ms)')
141 ax.set_xlabel('Trace')
142 ax.invert_yaxis()
143 self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press)
144 self.ax = ax
146 def on_key_press(self, event):
147 if event.key == 'ctrl+a':
148 for i, l in enumerate(self.plot):
149 l.set_xdata((l.get_xdata() - i) * np.sqrt(2) + i)
150 elif event.key == 'ctrl+z':
151 for i, l in enumerate(self.plot):
152 l.set_xdata((l.get_xdata() - i) / np.sqrt(2) + i)
153 else:
154 return
155 self.figure.canvas.draw()
158def squares(tscale, polarity, ax=None, yrange=[-1, 1], **kwargs):
159 """
160 Matplotlib display of rising and falling fronts in a square-wave pattern
162 :param tscale: time of indices of fronts
163 :param polarity: polarity of front (1: rising, -1:falling)
164 :param ax: matplotlib axes object
165 :return: None
166 """
167 if not ax: 1eab
168 ax = plt.gca()
169 isort = np.argsort(tscale) 1eab
170 tscale = tscale[isort] 1eab
171 polarity = polarity[isort] 1eab
172 f = np.tile(polarity, (2, 1)) 1eab
173 t = np.concatenate((tscale, np.r_[tscale[1:], tscale[-1]])).reshape(2, f.shape[1]) 1eab
174 ydata = f.transpose().ravel() 1eab
175 ydata = (ydata + 1) / 2 * (yrange[1] - yrange[0]) + yrange[0] 1eab
176 ax.plot(t.transpose().ravel(), ydata, **kwargs) 1eab
179def vertical_lines(x, ymin=0, ymax=1, ax=None, **kwargs):
180 """
181 From a x vector, draw separate vertical lines at each x location ranging from ymin to ymax
183 :param x: numpy array vector of x values where to display lnes
184 :param ymin: lower end of the lines (scalar)
185 :param ymax: higher end of the lines (scalar)
186 :param ax: (optional) matplotlib axis instance
187 :return: None
188 """
189 x = np.tile(x, (3, 1)) 1fab
190 x[2, :] = np.nan 1fab
191 y = np.zeros_like(x) 1fab
192 y[0, :] = ymin 1fab
193 y[1, :] = ymax 1fab
194 y[2, :] = np.nan 1fab
195 if not ax: 1fab
196 ax = plt.gca()
197 ax.plot(x.T.flatten(), y.T.flatten(), **kwargs) 1fab
200def spectrum(w, fs, smooth=None, unwrap=True, axis=0, **kwargs):
201 """
202 Display spectral density of a signal along a given dimension
203 spectrum(w, fs)
204 :param w: signal
205 :param fs: sampling frequency (Hz)
206 :param smooth: (None) frequency samples to smooth over
207 :param unwrap: (True) unwraps the phase specrum
208 :param axis: axis on which to compute the FFT
209 :param kwargs: plot arguments to be passed to matplotlib
210 :return: matplotlib axes
211 """
212 axis = 0
213 smooth = None
214 unwrap = True
216 ns = w.shape[axis]
217 fscale = dsp.fourier.fscale(ns, 1 / fs, one_sided=True)
218 W = scipy.fft.rfft(w, axis=axis)
219 amp = 20 * np.log10(np.abs(W))
220 phi = np.angle(W)
222 if unwrap:
223 phi = np.unwrap(phi)
225 if smooth:
226 nf = np.round(smooth / fscale[1] / 2) * 2 + 1
227 amp = scipy.signal.medfilt(amp, nf)
228 phi = scipy.signal.medfilt(phi, nf)
230 fig, ax = plt.subplots(2, 1, sharex=True)
231 ax[0].plot(fscale, amp, **kwargs)
232 ax[1].plot(fscale, phi, **kwargs)
234 ax[0].set_title('Spectral Density (dB rel to amplitude.Hz^-0.5)')
235 ax[0].set_ylabel('Amp (dB)')
236 ax[1].set_ylabel('Phase (rad)')
237 ax[1].set_xlabel('Frequency (Hz)')
238 return ax
241def color_cycle(ind=None):
242 """
243 Gets the matplotlib color-cycle as RGB numpy array of floats between 0 and 1
244 :return:
245 """
246 # import matplotlib as mpl
247 # c = np.uint32(np.array([int(c['color'][1:], 16) for c in mpl.rcParams['axes.prop_cycle']]))
248 # c = np.double(np.flip(np.reshape(c.view(np.uint8), (c.size, 4))[:, :3], 1)) / 255
249 c = np.array([[0.12156863, 0.46666667, 0.70588235],
250 [1., 0.49803922, 0.05490196],
251 [0.17254902, 0.62745098, 0.17254902],
252 [0.83921569, 0.15294118, 0.15686275],
253 [0.58039216, 0.40392157, 0.74117647],
254 [0.54901961, 0.3372549, 0.29411765],
255 [0.89019608, 0.46666667, 0.76078431],
256 [0.49803922, 0.49803922, 0.49803922],
257 [0.7372549, 0.74117647, 0.13333333],
258 [0.09019608, 0.74509804, 0.81176471]])
259 if ind is None:
260 return c
261 else:
262 return tuple(c[ind % c.shape[0], :])
265if __name__ == "__main__":
266 w = np.random.rand(500, 40) - 0.5
267 wiggle(w, fs=30000)
268 Traces(w, fs=30000, color='r')