Coverage for brainbox/plot.py: 11%
244 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 13:06 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 13:06 +0000
1"""
2Plots metrics that assess quality of single units. Some functions here generate plots for the
3output of functions in the brainbox `single_units.py` module.
5Run the following to set-up the workspace to run the docstring examples:
6>>> from brainbox import processing
7>>> import one.alf.io as alfio
8>>> import numpy as np
9>>> import matplotlib.pyplot as plt
10>>> import ibllib.ephys.spikes as e_spks
11# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
12>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
13# Load the alf spikes bunch and clusters bunch, and get a units bunch.
14>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
15>>> clstrs_b = alfio.load_object(path_to_alf_out, 'clusters')
16>>> units_b = processing.get_units_bunch(spks_b) # may take a few mins to compute
17"""
19import time
20from warnings import warn
22import matplotlib.pyplot as plt
23import seaborn as sns
24import numpy as np
26# from matplotlib.ticker import StrMethodFormatter
27from brainbox import singlecell
28from brainbox.metrics import single_units
29from brainbox.io.spikeglx import extract_waveforms
30from iblutil.numerical import bincount2D
31import spikeglx
34def feat_vars(units_b, units=None, feat_name='amps', dist='norm', test='ks', cmap_name='coolwarm',
35 ax=None):
36 '''
37 Plots the coefficients of variation of a particular spike feature for all units as a bar plot,
38 where each bar is color-coded corresponding to the depth of the max amplitude channel of the
39 respective unit.
41 Parameters
42 ----------
43 units_b : bunch
44 A units bunch containing fields with spike information (e.g. cluster IDs, times, features,
45 etc.) for all units.
46 units : array-like (optional)
47 A subset of all units for which to create the bar plot. (If `None`, all units are used)
48 feat_name : string (optional)
49 The spike feature to plot.
50 dist : string (optional)
51 The type of hypothetical null distribution from which the empirical spike feature
52 distributions are presumed to belong to.
53 test : string (optional)
54 The statistical test used to calculate the probability that the empirical spike feature
55 distributions come from `dist`.
56 cmap_name : string (optional)
57 The name of the colormap associated with the plot.
58 ax : axessubplot (optional)
59 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
61 Returns
62 -------
63 cv_vals : ndarray
64 The coefficients of variation of `feat_name` for each unit.
65 p_vals : ndarray
66 The probabilites that the distribution for `feat_name` for each unit comes from a
67 `dist` distribution based on the `test` statistical test.
69 See Also
70 --------
71 metrics.unit_stability
73 Examples
74 --------
75 1) Create a bar plot of the coefficients of variation of the spike amplitudes for all units.
76 >>> fig, var_vals, p_vals = bb.plot.feat_vars(units_b)
77 '''
79 # Get units.
80 if not (units is None): # we're using a subset of all units
81 unit_list = list(units_b['depths'].keys())
82 # For each unit in `unit_list`, remove unit from `units_b` if not in `units`.
83 [units_b['depths'].pop(unit) for unit in unit_list if not (int(unit) in units)]
84 unit_list = list(units_b['depths'].keys()) # get new `unit_list` after removing unit
86 # Calculate coefficients of variation for all units
87 p_vals_b, cv_b = single_units.unit_stability(
88 units_b, units=units, feat_names=[feat_name], dist=dist, test=test)
89 cv_vals = np.array(tuple(cv_b[feat_name].values()))
90 cv_vals = cv_vals * 1e6 if feat_name == 'amps' else cv_vals # convert to uV if amps
91 p_vals = np.array(tuple(p_vals_b[feat_name].values()))
93 # Remove any empty units. This must be done AFTER the above calculations for ALL units so that
94 # we can keep direct indexing.
95 empty_unit_idxs = np.where([len(units_b['times'][unit]) == 0 for unit in unit_list])[0]
96 good_units = [unit for unit in unit_list if unit not in empty_unit_idxs.astype(str)]
98 # Get mean depths of spikes for good units
99 depths = np.asarray([np.mean(units_b['depths'][str(unit)]) for unit in good_units])
101 # Create unit normalized colormap based on `depths`, sorted by depth.
102 cmap = plt.cm.get_cmap(cmap_name)
103 depths_norm = depths / np.max(depths)
104 rgba = np.asarray([cmap(depth) for depth in np.sort(np.flip(depths_norm))])
106 # Plot depth-color-coded h bar plot of CVs for `feature` for each unit, where units are
107 # sorted descendingly by depth along y-axis.
108 if ax is None:
109 fig, ax = plt.subplots()
110 ax.barh(y=[int(unit) for unit in good_units], width=cv_vals[np.argsort(depths)], color=rgba)
111 fig = ax.figure
112 cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap), ax=ax)
113 max_d = np.max(depths)
114 tick_labels = [int(max_d * tick) for tick in (0, 0.2, 0.4, 0.6, 0.8, 1.0)]
115 cbar.set_ticks(cbar.get_ticks()) # must call `set_ticks` to call `set_ticklabels`
116 cbar.set_ticklabels(tick_labels)
117 ax.set_title('CV of {feat}'.format(feat=feat_name))
118 ax.set_ylabel('Unit Number (sorted by depth)')
119 ax.set_xlabel('CV')
120 cbar.set_label('Depth', rotation=-90)
122 return cv_vals, p_vals
125def missed_spikes_est(feat, feat_name, spks_per_bin=20, sigma=5, min_num_bins=50, ax=None):
126 '''
127 Plots the pdf of an estimated symmetric spike feature distribution, with a vertical cutoff line
128 that indicates the approximate fraction of spikes missing from the distribution, assuming the
129 true distribution is symmetric.
131 Parameters
132 ----------
133 feat : ndarray
134 The spikes' feature values.
135 feat_name : string
136 The spike feature to plot.
137 spks_per_bin : int (optional)
138 The number of spikes per bin from which to compute the spike feature histogram.
139 sigma : int (optional)
140 The standard deviation for the gaussian kernel used to compute the pdf from the spike
141 feature histogram.
142 min_num_bins : int (optional)
143 The minimum number of bins used to compute the spike feature histogram.
144 ax : axessubplot (optional)
145 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
147 Returns
148 -------
149 fraction_missing : float
150 The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an
151 accurate estimate isn't possible.
153 See Also
154 --------
155 single_units.feature_cutoff
157 Examples
158 --------
159 1) Plot cutoff line indicating the fraction of spikes missing from a unit based on the recorded
160 unit's spike amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric.
161 >>> feat = units_b['amps']['1']
162 >>> fraction_missing = bb.plot.missed_spikes_est(feat, feat_name='amps', unit=1)
163 '''
165 # Calculate the feature distribution histogram and fraction of spikes missing.
166 fraction_missing, pdf, cutoff_idx = \
167 single_units.missed_spikes_est(feat, spks_per_bin, sigma, min_num_bins)
169 # Plot.
170 if ax is None: # create two axes
171 fig, ax = plt.subplots(nrows=1, ncols=2)
172 if ax is None or len(ax) == 2: # plot histogram and pdf on two separate axes
173 num_bins = int(feat.size / spks_per_bin)
174 ax[0].hist(feat, bins=num_bins)
175 ax[0].set_xlabel('{0}'.format(feat_name))
176 ax[0].set_ylabel('Count')
177 ax[0].set_title('Histogram of {0}'.format(feat_name))
178 ax[1].plot(pdf)
179 ax[1].vlines(cutoff_idx, 0, np.max(pdf), colors='r')
180 ax[1].set_xlabel('Bin Number')
181 ax[1].set_ylabel('Density')
182 ax[1].set_title('PDF Symmetry Cutoff\n'
183 '(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100))
184 else: # just plot pdf
185 ax = ax[0]
186 ax.plot(pdf)
187 ax.vlines(cutoff_idx, 0, np.max(pdf), colors='r')
188 ax.set_xlabel('Bin Number')
189 ax.set_ylabel('Density')
190 ax.set_title('PDF Symmetry Cutoff\n'
191 '(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100))
193 return fraction_missing
196def wf_comp(ephys_file, ts1, ts2, ch, sr=30000, n_ch_probe=385, dtype='int16', car=True,
197 col=['b', 'r'], ax=None):
198 '''
199 Plots two different sets of waveforms across specified channels after (optionally)
200 common-average-referencing. In this way, waveforms can be compared to see if there is,
201 e.g. drift during the recording, or if two units should be merged, or one unit should be split.
203 Parameters
204 ----------
205 ephys_file : string
206 The file path to the binary ephys data.
207 ts1 : array_like
208 A set of timestamps for which to compare waveforms with `ts2`.
209 ts2: array_like
210 A set of timestamps for which to compare waveforms with `ts1`.
211 ch : array-like
212 The channels to use for extracting and plotting the waveforms.
213 sr : int (optional)
214 The sampling rate (in hz) that the ephys data was acquired at.
215 n_ch_probe : int (optional)
216 The number of channels of the recording.
217 dtype: str (optional)
218 The datatype represented by the bytes in `ephys_file`.
219 car: bool (optional)
220 A flag for whether or not to perform common-average-referencing before extracting waveforms
221 col: list of strings or float arrays (optional)
222 Two elements in the list, where each specifies the color the `ts1` and `ts2` waveforms
223 will be plotted in, respectively.
224 ax : axessubplot (optional)
225 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
227 Returns
228 -------
229 wf1 : ndarray
230 The waveforms for the spikes in `ts1`: an array of shape (#spikes, #samples, #channels).
231 wf2 : ndarray
232 The waveforms for the spikes in `ts2`: an array of shape (#spikes, #samples, #channels).
233 s : float
234 The similarity score between the two sets of waveforms, calculated by
235 `single_units.wf_similarity`
237 See Also
238 --------
239 io.extract_waveforms
240 single_units.wf_similarity
242 Examples
243 --------
244 1) Compare first and last 100 spike waveforms for unit1, across 20 channels around the channel
245 of max amplitude, and compare the waveforms in the first minute to the waveforms in the fourth
246 minutes for unit2, across 10 channels around the mean.
247 # Get first and last 100 spikes, and 20 channels around channel of max amp for unit 1:
248 >>> ts1 = units_b['times']['1'][:100]
249 >>> ts2 = units_b['times']['1'][-100:]
250 >>> max_ch = clstrs_b['channels'][1]
251 >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`.
252 >>> ch = np.arange(max_ch, max_ch + 20)
253 >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`.
254 >>> ch = np.arange(max_ch - 20, max_ch)
255 >>> else: # take `n_c_ch` around `max_ch`.
256 >>> ch = np.arange(max_ch - 10, max_ch + 10)
257 >>> wf1, wf2, s = bb.plot.wf_comp(path_to_ephys_file, ts1, ts2, ch)
258 # Plot waveforms for unit2 from the first and fourth minutes across 10 channels.
259 >>> ts = units_b['times']['2']
260 >>> ts1_2 = ts[np.where(ts<60)[0]]
261 >>> ts2_2 = ts[np.where(ts>180)[0][:len(ts1)]]
262 >>> max_ch = clstrs_b['channels'][2]
263 >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`.
264 >>> ch = np.arange(max_ch, max_ch + 10)
265 >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`.
266 >>> ch = np.arange(max_ch - 10, max_ch)
267 >>> else: # take `n_c_ch` around `max_ch`.
268 >>> ch = np.arange(max_ch - 5, max_ch + 5)
269 >>> wf1_2, wf2_2, s_2 = bb.plot.wf_comp(path_to_ephys_file, ts1_2, ts2_2, ch)
270 '''
272 # Ensure `ch` is ndarray
273 ch = np.asarray(ch)
274 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch
276 # Extract the waveforms for these timestamps and compute similarity score.
277 wf1 = extract_waveforms(ephys_file, ts1, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype,
278 car=car)
279 wf2 = extract_waveforms(ephys_file, ts2, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype,
280 car=car)
281 s = single_units.wf_similarity(wf1, wf2)
283 # Plot these waveforms against each other.
284 n_ch = ch.size
285 if ax is None:
286 fig, ax = plt.subplots(nrows=n_ch, ncols=2) # left col is all waveforms, right col is mean
287 for cur_ax, cur_ch in enumerate(ch):
288 ax[cur_ax][0].plot(wf1[:, :, cur_ax].T, c=col[0])
289 ax[cur_ax][0].plot(wf2[:, :, cur_ax].T, c=col[1])
290 ax[cur_ax][1].plot(np.mean(wf1[:, :, cur_ax], axis=0), c=col[0])
291 ax[cur_ax][1].plot(np.mean(wf2[:, :, cur_ax], axis=0), c=col[1])
292 ax[cur_ax][0].set_ylabel('Ch {0}'.format(cur_ch))
293 ax[0][0].set_title('All Waveforms. S = {:.2f}'.format(s))
294 ax[0][1].set_title('Mean Waveforms')
295 plt.legend(['1st spike set', '2nd spike set'])
297 return wf1, wf2, s
300def amp_heatmap(ephys_file, ts, ch, sr=30000, n_ch_probe=385, dtype='int16', cmap_name='RdBu',
301 car=True, ax=None):
302 '''
303 Plots a heatmap of the normalized voltage values over time and space for given timestamps and
304 channels, after (optionally) common-average-referencing.
306 Parameters
307 ----------
308 ephys_file : string
309 The file path to the binary ephys data.
310 ts: array_like
311 A set of timestamps for which to get the voltage values.
312 ch : array-like
313 The channels to use for extracting the voltage values.
314 sr : int (optional)
315 The sampling rate (in hz) that the ephys data was acquired at.
316 n_ch_probe : int (optional)
317 The number of channels of the recording.
318 dtype: str (optional)
319 The datatype represented by the bytes in `ephys_file`.
320 cmap_name : string (optional)
321 The name of the colormap associated with the plot.
322 car: bool (optional)
323 A flag for whether or not to perform common-average-referencing before extracting waveforms
324 ax : axessubplot (optional)
325 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
327 Returns
328 -------
329 v_vals : ndarray
330 The voltage values.
332 Examples
333 --------
334 1) Plot a heatmap of the spike amplitudes across 20 channels around the channel of max
335 amplitude for all spikes in unit 1.
336 >>> ts = units_b['times']['1']
337 >>> max_ch = clstrs_b['channels'][1]
338 >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`.
339 >>> ch = np.arange(max_ch, max_ch + 20)
340 >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`.
341 >>> ch = np.arange(max_ch - 20, max_ch)
342 >>> else: # take `n_c_ch` around `max_ch`.
343 >>> ch = np.arange(max_ch - 10, max_ch + 10)
344 >>> bb.plot.amp_heatmap(path_to_ephys_file, ts, ch)
345 '''
346 # Ensure `ch` is ndarray
347 ch = np.asarray(ch)
348 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch
350 # Get memmapped array of `ephys_file`
351 s_reader = spikeglx.Reader(ephys_file, open=True)
352 file_m = s_reader.data
354 # Get voltage values for each peak amplitude sample for `ch`.
355 max_amp_samples = (ts * sr).astype(int)
356 # Currently this is an annoying way to calculate `v_vals` b/c indexing with multiple values
357 # is currently unsupported.
358 v_vals = np.zeros((max_amp_samples.size, ch.size))
359 for sample in range(max_amp_samples.size):
360 v_vals[sample] = file_m[max_amp_samples[sample]:max_amp_samples[sample] + 1, ch]
361 if car: # compute spatial noise in chunks, and subtract from `v_vals`.
362 # Get subset of time (from first to last max amp sample)
363 n_chunk_samples = 5e6 # number of samples per chunk
364 n_chunks = np.ceil((max_amp_samples[-1] - max_amp_samples[0]) /
365 n_chunk_samples).astype('int')
366 # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the
367 # samples that make up the first chunk.
368 chunk_sample = np.arange(max_amp_samples[0], max_amp_samples[-1], n_chunk_samples,
369 dtype=int)
370 chunk_sample = np.append(chunk_sample, max_amp_samples[-1])
371 noise_s_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16) # spatial noise array
372 # Give time estimate for computing `noise_s_chunks`.
373 t0 = time.perf_counter()
374 np.median(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0)
375 dt = time.perf_counter() - t0
376 print('Performing spatial CAR before waveform extraction. Estimated time is {:.2f} mins.'
377 ' ({})'.format(dt * n_chunks / 60, time.ctime()))
378 # Compute noise for each chunk, then take the median noise of all chunks.
379 for chunk in range(n_chunks):
380 noise_s_chunks[chunk, :] = np.median(
381 file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0)
382 noise_s = np.median(noise_s_chunks, axis=0)
383 v_vals -= noise_s[None, :]
384 print('Done. ({})'.format(time.ctime()))
385 s_reader.close()
387 # Plot heatmap.
388 if ax is None:
389 fig, ax = plt.subplots()
390 v_vals_norm = (v_vals / np.max(abs(v_vals))).T
391 cbar_map = ax.imshow(v_vals_norm, cmap=cmap_name, aspect='auto',
392 extent=[ts[0], ts[-1], ch[0], ch[-1]], origin='lower')
393 ax.set_yticks(np.arange(ch[0], ch[-1], 5))
394 ax.set_ylabel('Channel Numbers')
395 ax.set_xlabel('Time (s)')
396 ax.set_title('Voltage Heatmap')
397 fig = ax.figure
398 cbar = fig.colorbar(cbar_map, ax=ax)
399 cbar.set_label('V', rotation=-90)
401 return v_vals
404def firing_rate(ts, hist_win=0.01, fr_win=0.5, n_bins=10, show_fr_cv=True, ax=None):
405 '''
406 Plots the instantaneous firing rate of for given spike timestamps over time, and optionally
407 overlays the value of the coefficient of variation of the firing rate for a specified number
408 of bins.
410 Parameters
411 ----------
412 ts : ndarray
413 The spike timestamps from which to compute the firing rate.
414 hist_win : float (optional)
415 The time window (in s) to use for computing spike counts.
416 fr_win : float (optional)
417 The time window (in s) to use as a moving slider to compute the instantaneous firing rate.
418 n_bins : int (optional)
419 The number of bins in which to compute coefficients of variation of the firing rate.
420 show_fr_cv : bool (optional)
421 A flag for whether or not to compute and show the coefficients of variation of the firing
422 rate for `n_bins`.
423 ax : axessubplot (optional)
424 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
426 Returns
427 -------
428 fr: ndarray
429 The instantaneous firing rate over time (in hz).
430 cv: float
431 The mean coefficient of variation of the firing rate of the `n_bins` number of coefficients
432 computed. Can only be returned if `show_fr_cv` is True.
433 cvs: ndarray
434 The coefficients of variation of the firing for each bin of `n_bins`. Can only be returned
435 if `show_fr_cv` is True.
437 See Also
438 --------
439 single_units.firing_rate_cv
440 singecell.firing_rate
442 Examples
443 --------
444 1) Plot the firing rate for unit 1 from the time of its first to last spike, showing the cv
445 of the firing rate for 10 evenly spaced bins.
446 >>> ts = units_b['times']['1']
447 >>> fr, cv, cvs = bb.plot.firing_rate(ts)
448 '''
450 if ax is None:
451 fig, ax = plt.subplots()
452 if not (show_fr_cv): # compute just the firing rate
453 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win)
454 else: # compute firing rate and coefficients of variation
455 cv, cvs, fr = single_units.firing_rate_coeff_var(ts, hist_win=hist_win, fr_win=fr_win,
456 n_bins=n_bins)
457 x = np.arange(fr.size) * hist_win
458 ax.plot(x, fr)
459 ax.set_title('Firing Rate')
460 ax.set_xlabel('Time (s)')
461 ax.set_ylabel('Rate (s$^-1$)')
463 if not (show_fr_cv):
464 return fr
465 else: # show coefficients of variation
466 y_max = np.max(fr) * 1.05
467 x_l = x[int(x.size / n_bins)]
468 # Plot vertical lines separating plots into `n_bins`.
469 [ax.vlines((x_l * i), 0, y_max, linestyles='dashed', linewidth=2)
470 for i in range(1, n_bins)]
471 # Plot text with cv of firing rate for each bin.
472 [ax.text(x_l * (i + 1), y_max, 'cv={0:.2f}'.format(cvs[i]), fontsize=9, ha='right')
473 for i in range(n_bins)]
474 return fr, cv, cvs
477def peri_event_time_histogram(
478 spike_times, spike_clusters, events, cluster_id, # Everything you need for a basic plot
479 t_before=0.2, t_after=0.5, bin_size=0.025, smoothing=0.025, as_rate=True,
480 include_raster=False, n_rasters=None, error_bars='std', ax=None,
481 pethline_kwargs={'color': 'blue', 'lw': 2},
482 errbar_kwargs={'color': 'blue', 'alpha': 0.5},
483 eventline_kwargs={'color': 'black', 'alpha': 0.5},
484 raster_kwargs={'color': 'black', 'lw': 0.5}, **kwargs):
485 """
486 Plot peri-event time histograms, with the meaning firing rate of units centered on a given
487 series of events. Can optionally add a raster underneath the PETH plot of individual spike
488 trains about the events.
490 Parameters
491 ----------
492 spike_times : array_like
493 Spike times (in seconds)
494 spike_clusters : array-like
495 Cluster identities for each element of spikes
496 events : array-like
497 Times to align the histogram(s) to
498 cluster_id : int
499 Identity of the cluster for which to plot a PETH
501 t_before : float, optional
502 Time before event to plot (default: 0.2s)
503 t_after : float, optional
504 Time after event to plot (default: 0.5s)
505 bin_size :float, optional
506 Width of bin for histograms (default: 0.025s)
507 smoothing : float, optional
508 Sigma of gaussian smoothing to use in histograms. (default: 0.025s)
509 as_rate : bool, optional
510 Whether to use spike counts or rates in the plot (default: `True`, uses rates)
511 include_raster : bool, optional
512 Whether to put a raster below the PETH of individual spike trains (default: `False`)
513 n_rasters : int, optional
514 If include_raster is True, the number of rasters to include. If `None`
515 will default to plotting rasters around all provided events. (default: `None`)
516 error_bars : {'std', 'sem', 'none'}, optional
517 Defines which type of error bars to plot. Options are:
518 -- `'std'` for 1 standard deviation
519 -- `'sem'` for standard error of the mean
520 -- `'none'` for only plotting the mean value
521 (default: `'std'`)
522 ax : matplotlib axes, optional
523 If passed, the function will plot on the passed axes. Note: current
524 behavior causes whatever was on the axes to be cleared before plotting!
525 (default: `None`)
526 pethline_kwargs : dict, optional
527 Dict containing line properties to define PETH plot line. Default
528 is a blue line with weight of 2. Needs to have color. See matplotlib plot documentation
529 for more options.
530 (default: `{'color': 'blue', 'lw': 2}`)
531 errbar_kwargs : dict, optional
532 Dict containing fill-between properties to define PETH error bars.
533 Default is a blue fill with 50 percent opacity.. Needs to have color. See matplotlib
534 fill_between documentation for more options.
535 (default: `{'color': 'blue', 'alpha': 0.5}`)
536 eventline_kwargs : dict, optional
537 Dict containing fill-between properties to define line at event.
538 Default is a black line with 50 percent opacity.. Needs to have color. See matplotlib
539 vlines documentation for more options.
540 (default: `{'color': 'black', 'alpha': 0.5}`)
541 raster_kwargs : dict, optional
542 Dict containing properties defining lines in the raster plot.
543 Default is black lines with line width of 0.5. See matplotlib vlines for more options.
544 (default: `{'color': 'black', 'lw': 0.5}`)
546 Returns
547 -------
548 ax : matplotlib axes
549 Axes with all of the plots requested.
550 """
552 # Check to make sure if we fail, we fail in an informative way
553 if not len(spike_times) == len(spike_clusters):
554 raise ValueError('Spike times and clusters are not of the same shape')
555 if len(events) == 1:
556 raise ValueError('Cannot make a PETH with only one event.')
557 if error_bars not in ('std', 'sem', 'none'):
558 raise ValueError('Invalid error bar type was passed.')
559 if not all(np.isfinite(events)):
560 raise ValueError('There are NaN or inf values in the list of events passed. '
561 ' Please remove non-finite data points and try again.')
563 # Compute peths
564 peths, binned_spikes = singlecell.calculate_peths(spike_times, spike_clusters, [cluster_id],
565 events, t_before, t_after, bin_size,
566 smoothing, as_rate)
567 # Construct an axis object if none passed
568 if ax is None:
569 plt.figure()
570 ax = plt.gca()
571 # Plot the curve and add error bars
572 mean = peths.means[0, :]
573 ax.plot(peths.tscale, mean, **pethline_kwargs)
574 if error_bars == 'std':
575 bars = peths.stds[0, :]
576 elif error_bars == 'sem':
577 bars = peths.stds[0, :] / np.sqrt(len(events))
578 else:
579 bars = np.zeros_like(mean)
580 if error_bars != 'none':
581 ax.fill_between(peths.tscale, mean - bars, mean + bars, **errbar_kwargs)
583 # Plot the event marker line. Extends to 5% higher than max value of means plus any error bar.
584 plot_edge = (mean.max() + bars[mean.argmax()]) * 1.05
585 ax.vlines(0., 0., plot_edge, **eventline_kwargs)
586 # Set the limits on the axes to t_before and t_after. Either set the ylim to the 0 and max
587 # values of the PETH, or if we want to plot a spike raster below, create an equal amount of
588 # blank space below the zero where the raster will go.
589 ax.set_xlim([-t_before, t_after])
590 ax.set_ylim([-plot_edge if include_raster else 0., plot_edge])
591 # Put y ticks only at min, max, and zero
592 if mean.min() != 0:
593 ax.set_yticks([0, mean.min(), mean.max()])
594 else:
595 ax.set_yticks([0., mean.max()])
596 # Move the x axis line from the bottom of the plotting space to zero if including a raster,
597 # Then plot the raster
598 if include_raster:
599 if n_rasters is None:
600 n_rasters = len(events)
601 if n_rasters > 60:
602 warn("Number of raster traces is greater than 60. This might look bad on the plot.")
603 ax.axhline(0., color='black')
604 tickheight = plot_edge / len(events[:n_rasters]) # How much space per trace
605 tickedges = np.arange(0., -plot_edge - 1e-5, -tickheight)
606 clu_spks = spike_times[spike_clusters == cluster_id]
607 for i, t in enumerate(events[:n_rasters]):
608 idx = np.bitwise_and(clu_spks >= t - t_before, clu_spks <= t + t_after)
609 event_spks = clu_spks[idx]
610 ax.vlines(event_spks - t, tickedges[i + 1], tickedges[i], **raster_kwargs)
611 ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes', y=0.75)
612 else:
613 ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes')
614 ax.spines['top'].set_visible(False)
615 ax.spines['right'].set_visible(False)
616 ax.set_xlabel('Time (s) after event')
617 return ax
620def driftmap(ts, feat, ax=None, plot_style='bincount',
621 t_bin=0.01, d_bin=20, weights=None, vmax=None, **kwargs):
622 """
623 Plots the values of a spike feature array (y-axis) over time (x-axis).
624 Two arguments can be given for the plot_style of the drift map:
625 - 'scatter' : whereby each value is plotted as a marker (up to 100'000 data point)
626 - 'bincount' : whereby the values are binned (optimised to represent spike raster)
628 Parameters
629 ----------
630 feat : ndarray
631 The spikes' feature values.
632 ts : ndarray
633 The spike timestamps from which to compute the firing rate.
634 ax : axessubplot (optional)
635 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
636 t_bin: time bin used when plot_style='bincount'
637 d_bin: depth bin used when plot_style='bincount'
638 plot_style: 'scatter', 'bincount'
639 **kwargs: matplotlib.imshow arguments
641 Returns
642 -------
643 cd: float
644 The cumulative drift of `feat`.
645 md: float
646 The maximum drift of `feat`.
648 See Also
649 --------
650 metrics.cum_drift
651 metrics.max_drift
653 Examples
654 --------
655 1) Plot the amplitude driftmap for unit 1.
656 >>> ts = units_b['times']['1']
657 >>> amps = units_b['amps']['1']
658 >>> ax = bb.plot.driftmap(ts, amps)
659 2) Plot the depth driftmap for unit 1.
660 >>> ts = units_b['times']['1']
661 >>> depths = units_b['depths']['1']
662 >>> ax = bb.plot.driftmap(ts, depths)
663 """
664 iok = ~np.isnan(feat) 1b
665 if ax is None: 1b
666 fig, ax = plt.subplots()
668 if plot_style == 'scatter' and len(ts) < 100000: 1b
669 print('here todo')
670 if 'color' not in kwargs.keys():
671 kwargs['color'] = 'k'
672 ax.plot(ts, feat, **kwargs)
673 else:
674 # compute raster map as a function of site depth
675 R, times, depths = bincount2D( 1b
676 ts[iok], feat[iok], t_bin, d_bin, weights=weights[iok] if weights is not None else None)
677 # plot raster map
678 ax.imshow(R, aspect='auto', cmap='binary', vmin=0, vmax=vmax or np.std(R) * 4, 1b
679 extent=np.r_[times[[0, -1]], depths[[0, -1]]], origin='lower', **kwargs)
680 ax.set_xlabel('time (secs)') 1b
681 ax.set_ylabel('depth (um)') 1b
682 return ax 1b
685def pres_ratio(ts, hist_win=10, ax=None):
686 '''
687 Plots the presence ratio of spike counts: the number of bins where there is at least one
688 spike, over the total number of bins, given a specified bin width.
690 Parameters
691 ----------
692 ts : ndarray
693 The spike timestamps from which to compute the presence ratio.
694 hist_win : float
695 The time window (in s) to use for computing the presence ratio.
696 ax : axessubplot (optional)
697 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
699 Returns
700 -------
701 pr : float
702 The presence ratio.
703 spks_bins : ndarray
704 The number of spks in each bin.
706 See Also
707 --------
708 metrics.pres_ratio
710 Examples
711 --------
712 1) Plot the presence ratio for unit 1, given a window of 10 s.
713 >>> ts = units_b['times']['1']
714 >>> pr, pr_bins = bb.plot.pres_ratio(ts)
715 '''
717 pr, spks_bins = single_units.pres_ratio(ts, hist_win)
718 pr_bins = np.where(spks_bins > 0, 1, 0)
720 if ax is None:
721 fig, ax = plt.subplots()
723 ax.plot(pr_bins)
724 ax.set_xlabel('Bin Number (width={:.1f}s)'.format(hist_win))
725 ax.set_ylabel('Presence')
726 ax.set_title('Presence Ratio')
728 return pr, spks_bins
731def driftmap_color(
732 clusters_depths, spikes_times,
733 spikes_amps, spikes_depths, spikes_clusters,
734 ax=None, axesoff=False, return_lims=False):
736 '''
737 Plots the driftmap of a session or a trial
739 The plot shows the spike times vs spike depths.
740 Each dot is a spike, whose color indicates the cluster
741 and opacity indicates the spike amplitude.
743 Parameters
744 -------------
745 clusters_depths: ndarray
746 depths of all clusters
747 spikes_times: ndarray
748 spike times of all clusters
749 spikes_amps: ndarray
750 amplitude of each spike
751 spikes_depths: ndarray
752 depth of each spike
753 spikes_clusters: ndarray
754 cluster idx of each spike
755 ax: matplotlib.axes.Axes object (optional)
756 The axis object to plot the driftmap on
757 (if `None`, a new figure and axis is created)
759 Return
760 ---
761 ax: matplotlib.axes.Axes object
762 The axis object with driftmap plotted
763 x_lim: list of two elements
764 range of x axis
765 y_lim: list of two elements
766 range of y axis
767 '''
769 color_bins = sns.color_palette("hls", 500)
770 new_color_bins = np.vstack(
771 np.transpose(np.reshape(color_bins, [5, 100, 3]), [1, 0, 2]))
773 # get the sorted idx of each depth, and create colors based on the idx
775 sorted_idx = np.argsort(np.argsort(clusters_depths))
777 colors = np.vstack(
778 [np.repeat(
779 new_color_bins[np.mod(idx, 500), :][np.newaxis, ...],
780 n_spikes, axis=0)
781 for (idx, n_spikes) in
782 zip(sorted_idx, np.unique(spikes_clusters,
783 return_counts=True)[1])])
785 max_amp = np.percentile(spikes_amps, 90)
786 min_amp = np.percentile(spikes_amps, 10)
787 opacity = np.divide(spikes_amps - min_amp, max_amp - min_amp)
788 opacity[opacity > 1] = 1
789 opacity[opacity < 0] = 0
791 colorvec = np.zeros([len(opacity), 4], dtype='float16')
792 colorvec[:, 3] = opacity.astype('float16')
793 colorvec[:, 0:3] = colors.astype('float16')
795 x = spikes_times.astype('float32')
796 y = spikes_depths.astype('float32')
798 args = dict(color=colorvec, edgecolors='none')
800 if ax is None:
801 fig = plt.Figure(dpi=200, frameon=False, figsize=[10, 10])
802 ax = plt.Axes(fig, [0.1, 0.1, 0.9, 0.9])
803 ax.set_xlabel('Time (sec)')
804 ax.set_ylabel('Distance from the probe tip (um)')
805 savefig = True
806 args.update(s=0.1)
808 ax.scatter(x, y, **args)
809 x_edge = (max(x) - min(x)) * 0.05
810 x_lim = [min(x) - x_edge, max(x) + x_edge]
811 y_lim = [min(y) - 50, max(y) + 100]
812 ax.set_xlim(x_lim[0], x_lim[1])
813 ax.set_ylim(y_lim[0], y_lim[1])
815 if axesoff:
816 ax.axis('off')
818 if savefig:
819 fig.add_axes(ax)
820 fig.savefig('driftmap.png')
822 if return_lims:
823 return ax, x_lim, y_lim
824 else:
825 return ax