Coverage for brainbox/ephys_plots.py: 21%
256 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 13:06 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 13:06 +0000
1import numpy as np
2from matplotlib import cm
3import matplotlib.pyplot as plt
4from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line,
5 plot_image, plot_probe, plot_scatter, arrange_channels2banks)
6from brainbox.processing import compute_cluster_average
7from iblutil.numerical import bincount2D
8from iblatlas.regions import BrainRegions
11def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords=None, chn_inds=None, freq_range=(0, 300),
12 avg_across_depth=False, clim=None, cmap='viridis', display=False, title=None, **kwargs):
13 """
14 Prepare data for 2D image plot of LFP power spectrum along depth of probe
16 :param lfp_power:
17 :param lfp_freq:
18 :param chn_depths:
19 :param chn_inds:
20 :param freq_range:
21 :param avg_across_depth: Whether to average across channels at same depth
22 :param cmap:
23 :param display: generate figure
24 :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
25 """
27 ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)'
28 title = title or 'LFP Power Spectrum'
30 y = np.arange(lfp_power.shape[1]) if chn_coords is None else chn_coords[:, 1]
31 chn_inds = np.arange(lfp_power.shape[1]) if chn_inds is None else chn_inds
33 freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0]
34 freqs = lfp_freq[freq_idx]
35 lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1)
36 lfp_db = 10 * np.log10(lfp)
37 lfp_db[np.isinf(lfp_db)] = np.nan
38 x = freqs
40 # Average across channels that are at the same depth
41 if avg_across_depth:
42 chn_depth, chn_idx, chn_count = np.unique(y, return_index=True,
43 return_counts=True)
44 chn_idx_eq = np.copy(chn_idx)
45 chn_idx_eq[np.where(chn_count == 2)] += 1
47 lfp_db = np.apply_along_axis(lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1,
48 lfp_db)
50 x = freqs
51 y = chn_depth
53 data = ImagePlot(lfp_db, x=x, y=y, cmap=cmap)
54 data.set_labels(title=title, xlabel='Frequency (Hz)',
55 ylabel=ylabel, clabel='LFP Power (dB)')
56 clim = clim or np.quantile(lfp_db, [0.1, 0.9])
57 data.set_clim(clim=clim)
59 if display:
60 ax, fig = plot_image(data.convert2dict(), **kwargs)
61 return data.convert2dict(), fig, ax
63 return data
66def image_rms_plot(rms_amps, rms_times, chn_coords=None, chn_inds=None, avg_across_depth=False,
67 median_subtract=True, clim=None, cmap='plasma', band='AP', display=False, title=None, **kwargs):
68 """
69 Prepare data for 2D image plot of RMS data along depth of probe
71 :param rms_amps:
72 :param rms_times:
73 :param chn_coords:
74 :param chn_inds:
75 :param avg_across_depth: Whether to average across channels at same depth
76 :param median_subtract: Whether to apply median subtraction correction
77 :param cmap:
78 :param band: Frequency band of rms data, can be either 'LF' or 'AP'
79 :param display: generate figure
80 :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
81 """
83 ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)'
84 title = title or f'{band} RMS'
85 chn_inds = np.arange(rms_amps.shape[1]) if chn_inds is None else chn_inds
86 y = np.arange(rms_amps.shape[1]) if chn_coords is None else chn_coords[:, 1]
88 rms = rms_amps[:, chn_inds]
89 rms = 10 * np.log10(rms)
90 x = rms_times
92 if avg_across_depth:
93 chn_depth, chn_idx, chn_count = np.unique(y, return_index=True, return_counts=True)
94 chn_idx_eq = np.copy(chn_idx)
95 chn_idx_eq[np.where(chn_count == 2)] += 1
96 rms = np.apply_along_axis(lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, rms)
97 y = chn_depth
99 if median_subtract:
100 median = np.mean(np.apply_along_axis(lambda a: np.median(a), 1, rms))
101 rms = np.apply_along_axis(lambda a: a - np.median(a), 1, rms) + median
103 data = ImagePlot(rms, x=x, y=y, cmap=cmap)
104 data.set_labels(title=title, xlabel='Time (s)', ylabel=ylabel, clabel=f'{band} RMS (dB)')
105 clim = clim or np.quantile(rms, [0.1, 0.9])
106 data.set_clim(clim=clim)
108 if display:
109 ax, fig = plot_image(data.convert2dict(), **kwargs)
110 return data.convert2dict(), fig, ax
112 return data
115def scatter_raster_plot(spike_amps, spike_depths, spike_times, n_amp_bins=10, cmap='BuPu',
116 subsample_factor=100, display=False, title=None, **kwargs):
117 """
118 Prepare data for 2D raster plot of spikes with colour and size indicative of spike amplitude
120 :param spike_amps:
121 :param spike_depths:
122 :param spike_times:
123 :param n_amp_bins: no. of colour and size bins into which to split amplitude data
124 :param cmap:
125 :param subsample_factor: factor by which to subsample data when too many points for efficient
126 display
127 :param display: generate figure
128 :return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects
129 """
131 title = title or 'Spike times vs Spike depths'
132 amp_range = np.quantile(spike_amps, [0, 0.9])
133 amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
134 color_bin = np.linspace(0.0, 1.0, n_amp_bins + 1)
135 colors = (cm.get_cmap(cmap)(color_bin)[np.newaxis, :, :3][0])
137 spike_amps = spike_amps[0:-1:subsample_factor]
138 spike_colors = np.zeros((spike_amps.size, 3))
139 spike_size = np.zeros(spike_amps.size)
140 for iA in range(amp_bins.size):
141 if iA == (amp_bins.size - 1):
142 idx = np.where(spike_amps > amp_bins[iA])[0]
143 # Make saturated spikes the darkest colour
144 spike_colors[idx] = colors[-1]
145 else:
146 idx = np.where((spike_amps > amp_bins[iA]) & (spike_amps <= amp_bins[iA + 1]))[0]
147 spike_colors[idx] = [*colors[iA]]
149 spike_size[idx] = iA / (n_amp_bins / 8)
151 data = ScatterPlot(x=spike_times[0:-1:subsample_factor], y=spike_depths[0:-1:subsample_factor],
152 c=spike_amps * 1e6, cmap='BuPu')
153 data.set_ylim((0, 3840))
154 data.set_color(color=spike_colors)
155 data.set_clim(clim=amp_range * 1e6)
156 data.set_marker_size(marker_size=spike_size)
157 data.set_labels(title=title, xlabel='Time (s)',
158 ylabel='Distance from probe tip (um)', clabel='Spike amplitude (uV)')
160 if display:
161 ax, fig = plot_scatter(data.convert2dict(), **kwargs)
162 return data.convert2dict(), fig, ax
164 return data
167def image_fr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=5, cmap='binary',
168 display=False, title=None, **kwargs):
169 """
170 Prepare data 2D raster plot of firing rate across recording
172 :param spike_depths:
173 :param spike_times:
174 :param chn_coords:
175 :param t_bin: time bin to average across (see also brainbox.processing.bincount2D)
176 :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
177 :param cmap:
178 :param display: generate figure
179 :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
180 """
182 title = title or 'Firing Rate'
183 n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin,
184 ylim=[0, np.max(chn_coords[:, 1])])
185 fr = n.T / t_bin
187 data = ImagePlot(fr, x=x, y=y, cmap=cmap)
188 data.set_labels(title=title, xlabel='Time (s)',
189 ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)')
190 data.set_clim(clim=(np.min(np.mean(fr, axis=0)), np.max(np.mean(fr, axis=0))))
191 if display:
192 ax, fig = plot_image(data.convert2dict(), **kwargs)
193 return data.convert2dict(), fig, ax
195 return data
198def image_crosscorr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=40,
199 cmap='viridis', display=False, title=None, **kwargs):
200 """
201 Prepare data for 2D cross correlation plot of data across depth
203 :param spike_depths:
204 :param spike_times:
205 :param chn_coords:
206 :param t_bin: t_bin: time bin to average across (see also brainbox.processing.bincount2D)
207 :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
208 :param cmap:
209 :param display: generate figure
210 :return: ImagePlot object, if display=True also returns matploltlib fig and ax objects
211 """
213 title = title or 'Correlation'
214 n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin,
215 ylim=[0, np.max(chn_coords[:, 1])])
216 corr = np.corrcoef(n)
217 corr[np.isnan(corr)] = 0
219 data = ImagePlot(corr, x=y, y=y, cmap=cmap)
220 data.set_labels(title=title, xlabel='Distance from probe tip (um)',
221 ylabel='Distance from probe tip (um)', clabel='Correlation')
223 if display:
224 ax, fig = plot_image(data.convert2dict(), **kwargs)
225 return data.convert2dict(), fig, ax
227 return data
230def scatter_amp_depth_fr_plot(spike_amps, spike_clusters, spike_depths, spike_times, cmap='hot',
231 display=False, title=None, **kwargs):
232 """
233 Prepare data for 2D scatter plot of cluster depth vs cluster amp with colour indicating cluster
234 firing rate
236 :param spike_amps:
237 :param spike_clusters:
238 :param spike_depths:
239 :param spike_times:
240 :param cmap:
241 :param display: generate figure
242 :return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects
243 """
245 title = title or 'Cluster depth vs amp vs firing rate'
247 # TODO use pandas here instead, much quicker
248 cluster, cluster_depth, n_cluster = compute_cluster_average(spike_clusters, spike_depths)
249 _, cluster_amp, _ = compute_cluster_average(spike_clusters, spike_amps)
250 cluster_amp = cluster_amp * 1e6
251 cluster_fr = n_cluster / np.max(spike_times)
253 data = ScatterPlot(x=cluster_amp, y=cluster_depth, c=cluster_fr, cmap=cmap)
254 data.set_xlim((0.9 * np.min(cluster_amp), 1.1 * np.max(cluster_amp)))
255 data.set_labels(title=title, xlabel='Cluster Amplitude (uV)', ylabel='Distance from probe tip (um)',
256 clabel='Firing rate (Hz)')
257 if display:
258 ax, fig = plot_scatter(data.convert2dict(), **kwargs)
259 return data.convert2dict(), fig, ax
261 return data
264def probe_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 4),
265 display=False, pad=True, x_offset=1, **kwargs):
266 """
267 Prepare data for 2D probe plot of LFP power spectrum along depth of probe
269 :param lfp_power:
270 :param lfp_freq:
271 :param chn_coords:
272 :param chn_inds:
273 :param freq_range:
274 :param display:
275 :param pad: whether to add nans around the individual image plots. For matplotlib use pad=True,
276 for pyqtgraph use pad=False
277 :param x_offset: Distance between the channel banks in x direction
278 :return: ProbePlot object, if display=True also returns matplotlib fig and ax objects
279 """
281 freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0]
282 lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1)
283 lfp_db = 10 * np.log10(lfp)
284 lfp_db[np.isinf(lfp_db)] = np.nan
285 lfp_db = np.mean(lfp_db, axis=0)
287 data_bank, x_bank, y_bank = arrange_channels2banks(lfp_db, chn_coords, depth=None,
288 pad=pad, x_offset=x_offset)
289 data = ProbePlot(data_bank, x=x_bank, y=y_bank)
290 data.set_labels(ylabel='Distance from probe tip (um)', clabel='PSD 0-4 Hz (dB)')
291 clim = np.nanquantile(np.concatenate([np.squeeze(np.ravel(d)) for d in data_bank]).ravel(),
292 [0.1, 0.9])
293 data.set_clim(clim)
295 if display:
296 ax, fig = plot_probe(data.convert2dict(), **kwargs)
297 return data.convert2dict(), fig, ax
299 return data
302def probe_rms_plot(rms_amps, chn_coords, chn_inds, cmap='plasma', band='AP',
303 display=False, pad=True, x_offset=1, **kwargs):
304 """
305 Prepare data for 2D probe plot of RMS along depth of probe
307 :param rms_amps:
308 :param chn_coords:
309 :param chn_inds:
310 :param cmap:
311 :param band:
312 :param display:
313 :param pad: whether to add nans around the individual image plots. For matplotlib use pad=True,
314 for pyqtgraph use pad=False
315 :param x_offset: Distance between the channel banks in x direction
316 :return: ProbePlot object, if display=True also returns matplotlib fig and ax objects
317 """
319 rms = (np.mean(rms_amps, axis=0)[chn_inds]) * 1e6
321 data_bank, x_bank, y_bank = arrange_channels2banks(rms, chn_coords, depth=None,
322 pad=pad, x_offset=x_offset)
323 data = ProbePlot(data_bank, x=x_bank, y=y_bank, cmap=cmap)
324 data.set_labels(ylabel='Distance from probe tip (um)', clabel=f'{band} RMS (uV)')
325 clim = np.nanquantile(np.concatenate([np.squeeze(np.ravel(d)) for d in data_bank]).ravel(),
326 [0.1, 0.9])
327 data.set_clim(clim)
329 if display:
330 ax, fig = plot_probe(data.convert2dict(), **kwargs)
331 return data.convert2dict(), fig, ax
333 return data
336def line_fr_plot(spike_depths, spike_times, chn_coords, d_bin=10, display=False, title=None, **kwargs):
337 """
338 Prepare data for 1D line plot of average firing rate across depth
340 :param spike_depths:
341 :param spike_times:
342 :param chn_coords:
343 :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
344 :param display:
345 :return:
346 """
348 title = title or 'Avg Firing Rate'
349 t_bin = np.max(spike_times)
350 n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin,
351 ylim=[0, np.max(chn_coords[:, 1])])
352 mean_fr = n[:, 0] / t_bin
354 data = LinePlot(x=mean_fr, y=y)
355 data.set_xlim((0, np.max(mean_fr)))
356 data.set_labels(title=title, xlabel='Firing Rate (Hz)',
357 ylabel='Distance from probe tip (um)')
359 if display:
360 ax, fig = plot_line(data.convert2dict(), **kwargs)
361 return data.convert2dict(), fig, ax
363 return data
366def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, display=False, title=None, **kwargs):
367 """
368 Prepare data for 1D line plot of average firing rate across depth
369 :param spike_amps:
370 :param spike_depths:
371 :param spike_times:
372 :param chn_coords:
373 :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
374 :param display:
375 :return:
376 """
377 title = title or 'Avg Amplitude'
378 t_bin = np.max(spike_times)
379 n, _, _ = bincount2D(spike_times, spike_depths, t_bin, d_bin,
380 ylim=[0, np.max(chn_coords[:, 1])])
381 amp, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin,
382 ylim=[0, np.max(chn_coords[:, 1])], weights=spike_amps)
384 mean_amp = np.divide(amp[:, 0], n[:, 0]) * 1e6
385 mean_amp[np.isnan(mean_amp)] = 0
386 remove_bins = np.where(n[:, 0] < 50)[0]
387 mean_amp[remove_bins] = 0
389 data = LinePlot(x=mean_amp, y=y)
390 data.set_xlim((0, np.max(mean_amp)))
391 data.set_labels(title=title, xlabel='Amplitude (uV)',
392 ylabel='Distance from probe tip (um)')
393 if display:
394 ax, fig = plot_line(data.convert2dict(), **kwargs)
395 return data.convert2dict(), fig, ax
396 return data
399def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None,
400 title=None, label='left', **kwargs):
401 """
402 Plot brain regions along probe, if channel depths is provided will plot along depth otherwise along channel idx
403 :param channel_ids: atlas ids for each channel
404 :param channel_depths: depth along probe for each channel
405 :param brain_regions: BrainRegions object
406 :param display: whether to output plot
407 :param ax: axis to plot on
408 :param title: title for plot
409 :param kwargs: additional keyword arguments for bar plot
410 :return:
411 """
413 if channel_depths is not None: 1a
414 assert channel_ids.shape[0] == channel_depths.shape[0] 1a
415 else:
416 channel_depths = np.arange(channel_ids.shape[0])
418 br = brain_regions or BrainRegions() 1a
420 region_info = br.get(channel_ids) 1a
421 boundaries = np.where(np.diff(region_info.id) != 0)[0] 1a
422 boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1] 1a
424 regions = np.c_[boundaries[0:-1], boundaries[1:]] 1a
425 if channel_depths is not None: 1a
426 regions = channel_depths[regions] 1a
427 region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]] 1a
428 region_colours = region_info.rgb[boundaries[1:]] 1a
430 if display: 1a
431 if ax is None: 1a
432 fig, ax = plt.subplots()
433 else:
434 fig = ax.get_figure() 1a
436 for reg, col in zip(regions, region_colours): 1a
437 height = np.abs(reg[1] - reg[0]) 1a
438 bar_kwargs = dict(edgecolor='w', width=1) 1a
439 bar_kwargs.update(**kwargs) 1a
440 color = col / 255 1a
441 ax.bar(x=0.5, height=height, color=color, bottom=reg[0], **kwargs) 1a
442 ax.spines['top'].set_visible(False) 1a
443 ax.spines['bottom'].set_visible(False) 1a
444 if label is not None: 1a
445 if label == 'right': 1a
446 ax.yaxis.tick_right()
447 ax.set_yticks(region_labels[:, 0].astype(int)) 1a
448 ax.yaxis.set_tick_params(labelsize=8) 1a
449 ax.set_ylim(np.nanmin(channel_depths), np.nanmax(channel_depths)) 1a
450 ax.get_xaxis().set_visible(False) 1a
451 ax.set_yticklabels(region_labels[:, 1]) 1a
452 if label == 'right': 1a
453 ax.yaxis.tick_right()
454 ax.spines['left'].set_visible(False)
455 else:
456 ax.spines['right'].set_visible(False) 1a
458 if title: 1a
459 ax.set_title(title) 1a
461 return fig, ax 1a
462 else:
463 return regions, region_labels, region_colours
466def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None,
467 display=False, cmap='hot', ax=None):
468 """
469 Plot cumulative amplitude of spikes across depth
470 :param spike_amps:
471 :param spike_depths:
472 :param spike_times:
473 :param n_amp_bins: number of amplitude bins to use
474 :param d_bin: the value of the depth bins in um (default is 40 um)
475 :param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps
476 :param d_range: depth range to use, by default [0, 3840]
477 :param display: whether or not to display plot
478 :param cmap:
479 :return:
480 """
482 amp_range = amp_range or np.quantile(spike_amps, (0, 0.9))
483 amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
484 d_range = d_range or [0, 3840]
485 depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin)
486 t_bin = np.max(spike_times)
488 def histc(x, bins):
489 map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs.
490 res = np.zeros(bins.shape)
492 for el in map_to_bins:
493 res[el - 1] += 1 # Increment appropriate bin.
494 return res
496 cdfs = np.empty((len(depth_bins) - 1, n_amp_bins))
497 for d in range(len(depth_bins) - 1):
498 spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1])
499 h = histc(spike_amps[spikes], amp_bins) / t_bin
500 hcsum = np.cumsum(h[::-1])
501 cdfs[d, :] = hcsum[::-1]
503 cdfs[cdfs == 0] = np.nan
505 data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap)
506 data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)',
507 ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)')
509 if display:
510 ax, fig = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]}, ax=ax)
511 return data.convert2dict(), fig, ax
513 return data
516def image_raw_data(raw, fs, chn_coords=None, cmap='bone', title=None, display=False, gain=-90, **kwargs):
518 def gain2level(gain):
519 return 10 ** (gain / 20) * 4 * np.array([-1, 1])
521 ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)'
522 title = title or 'Raw data'
524 y = np.arange(raw.shape[1]) if chn_coords is None else chn_coords[:, 1]
526 x = np.array([0, raw.shape[0] - 1]) / fs * 1e3
528 data = ImagePlot(raw, y=y, cmap=cmap)
529 data.set_labels(title=title, xlabel='Time (ms)',
530 ylabel=ylabel, clabel='Power (uV)')
531 clim = gain2level(gain)
532 data.set_clim(clim=clim)
533 data.set_xlim(xlim=x)
534 data.set_ylim()
536 if display:
537 ax, fig = plot_image(data.convert2dict(), **kwargs)
538 return data.convert2dict(), fig, ax
540 return data