Coverage for brainbox/plot.py: 11%

244 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1""" 

2Plots metrics that assess quality of single units. Some functions here generate plots for the 

3output of functions in the brainbox `single_units.py` module. 

4 

5Run the following to set-up the workspace to run the docstring examples: 

6>>> from brainbox import processing 

7>>> import one.alf.io as alfio 

8>>> import numpy as np 

9>>> import matplotlib.pyplot as plt 

10>>> import ibllib.ephys.spikes as e_spks 

11# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

12>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

13# Load the alf spikes bunch and clusters bunch, and get a units bunch. 

14>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes') 

15>>> clstrs_b = alfio.load_object(path_to_alf_out, 'clusters') 

16>>> units_b = processing.get_units_bunch(spks_b) # may take a few mins to compute 

17""" 

18 

19import time 

20from warnings import warn 

21 

22import matplotlib.pyplot as plt 

23import seaborn as sns 

24import numpy as np 

25 

26# from matplotlib.ticker import StrMethodFormatter 

27from brainbox import singlecell 

28from brainbox.metrics import single_units 

29from brainbox.io.spikeglx import extract_waveforms 

30from iblutil.numerical import bincount2D 

31import spikeglx 

32 

33 

34def feat_vars(units_b, units=None, feat_name='amps', dist='norm', test='ks', cmap_name='coolwarm', 

35 ax=None): 

36 ''' 

37 Plots the coefficients of variation of a particular spike feature for all units as a bar plot, 

38 where each bar is color-coded corresponding to the depth of the max amplitude channel of the 

39 respective unit. 

40 

41 Parameters 

42 ---------- 

43 units_b : bunch 

44 A units bunch containing fields with spike information (e.g. cluster IDs, times, features, 

45 etc.) for all units. 

46 units : array-like (optional) 

47 A subset of all units for which to create the bar plot. (If `None`, all units are used) 

48 feat_name : string (optional) 

49 The spike feature to plot. 

50 dist : string (optional) 

51 The type of hypothetical null distribution from which the empirical spike feature 

52 distributions are presumed to belong to. 

53 test : string (optional) 

54 The statistical test used to calculate the probability that the empirical spike feature 

55 distributions come from `dist`. 

56 cmap_name : string (optional) 

57 The name of the colormap associated with the plot. 

58 ax : axessubplot (optional) 

59 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) 

60 

61 Returns 

62 ------- 

63 cv_vals : ndarray 

64 The coefficients of variation of `feat_name` for each unit. 

65 p_vals : ndarray 

66 The probabilites that the distribution for `feat_name` for each unit comes from a 

67 `dist` distribution based on the `test` statistical test. 

68 

69 See Also 

70 -------- 

71 metrics.unit_stability 

72 

73 Examples 

74 -------- 

75 1) Create a bar plot of the coefficients of variation of the spike amplitudes for all units. 

76 >>> fig, var_vals, p_vals = bb.plot.feat_vars(units_b) 

77 ''' 

78 

79 # Get units. 

80 if not (units is None): # we're using a subset of all units 

81 unit_list = list(units_b['depths'].keys()) 

82 # For each unit in `unit_list`, remove unit from `units_b` if not in `units`. 

83 [units_b['depths'].pop(unit) for unit in unit_list if not (int(unit) in units)] 

84 unit_list = list(units_b['depths'].keys()) # get new `unit_list` after removing unit 

85 

86 # Calculate coefficients of variation for all units 

87 p_vals_b, cv_b = single_units.unit_stability( 

88 units_b, units=units, feat_names=[feat_name], dist=dist, test=test) 

89 cv_vals = np.array(tuple(cv_b[feat_name].values())) 

90 cv_vals = cv_vals * 1e6 if feat_name == 'amps' else cv_vals # convert to uV if amps 

91 p_vals = np.array(tuple(p_vals_b[feat_name].values())) 

92 

93 # Remove any empty units. This must be done AFTER the above calculations for ALL units so that 

94 # we can keep direct indexing. 

95 empty_unit_idxs = np.where([len(units_b['times'][unit]) == 0 for unit in unit_list])[0] 

96 good_units = [unit for unit in unit_list if unit not in empty_unit_idxs.astype(str)] 

97 

98 # Get mean depths of spikes for good units 

99 depths = np.asarray([np.mean(units_b['depths'][str(unit)]) for unit in good_units]) 

100 

101 # Create unit normalized colormap based on `depths`, sorted by depth. 

102 cmap = plt.cm.get_cmap(cmap_name) 

103 depths_norm = depths / np.max(depths) 

104 rgba = np.asarray([cmap(depth) for depth in np.sort(np.flip(depths_norm))]) 

105 

106 # Plot depth-color-coded h bar plot of CVs for `feature` for each unit, where units are 

107 # sorted descendingly by depth along y-axis. 

108 if ax is None: 

109 fig, ax = plt.subplots() 

110 ax.barh(y=[int(unit) for unit in good_units], width=cv_vals[np.argsort(depths)], color=rgba) 

111 fig = ax.figure 

112 cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap), ax=ax) 

113 max_d = np.max(depths) 

114 tick_labels = [int(max_d * tick) for tick in (0, 0.2, 0.4, 0.6, 0.8, 1.0)] 

115 cbar.set_ticks(cbar.get_ticks()) # must call `set_ticks` to call `set_ticklabels` 

116 cbar.set_ticklabels(tick_labels) 

117 ax.set_title('CV of {feat}'.format(feat=feat_name)) 

118 ax.set_ylabel('Unit Number (sorted by depth)') 

119 ax.set_xlabel('CV') 

120 cbar.set_label('Depth', rotation=-90) 

121 

122 return cv_vals, p_vals 

123 

124 

125def missed_spikes_est(feat, feat_name, spks_per_bin=20, sigma=5, min_num_bins=50, ax=None): 

126 ''' 

127 Plots the pdf of an estimated symmetric spike feature distribution, with a vertical cutoff line 

128 that indicates the approximate fraction of spikes missing from the distribution, assuming the 

129 true distribution is symmetric. 

130 

131 Parameters 

132 ---------- 

133 feat : ndarray 

134 The spikes' feature values. 

135 feat_name : string 

136 The spike feature to plot. 

137 spks_per_bin : int (optional) 

138 The number of spikes per bin from which to compute the spike feature histogram. 

139 sigma : int (optional) 

140 The standard deviation for the gaussian kernel used to compute the pdf from the spike 

141 feature histogram. 

142 min_num_bins : int (optional) 

143 The minimum number of bins used to compute the spike feature histogram. 

144 ax : axessubplot (optional) 

145 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) 

146 

147 Returns 

148 ------- 

149 fraction_missing : float 

150 The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an 

151 accurate estimate isn't possible. 

152 

153 See Also 

154 -------- 

155 single_units.feature_cutoff 

156 

157 Examples 

158 -------- 

159 1) Plot cutoff line indicating the fraction of spikes missing from a unit based on the recorded 

160 unit's spike amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric. 

161 >>> feat = units_b['amps']['1'] 

162 >>> fraction_missing = bb.plot.missed_spikes_est(feat, feat_name='amps', unit=1) 

163 ''' 

164 

165 # Calculate the feature distribution histogram and fraction of spikes missing. 

166 fraction_missing, pdf, cutoff_idx = \ 

167 single_units.missed_spikes_est(feat, spks_per_bin, sigma, min_num_bins) 

168 

169 # Plot. 

170 if ax is None: # create two axes 

171 fig, ax = plt.subplots(nrows=1, ncols=2) 

172 if ax is None or len(ax) == 2: # plot histogram and pdf on two separate axes 

173 num_bins = int(feat.size / spks_per_bin) 

174 ax[0].hist(feat, bins=num_bins) 

175 ax[0].set_xlabel('{0}'.format(feat_name)) 

176 ax[0].set_ylabel('Count') 

177 ax[0].set_title('Histogram of {0}'.format(feat_name)) 

178 ax[1].plot(pdf) 

179 ax[1].vlines(cutoff_idx, 0, np.max(pdf), colors='r') 

180 ax[1].set_xlabel('Bin Number') 

181 ax[1].set_ylabel('Density') 

182 ax[1].set_title('PDF Symmetry Cutoff\n' 

183 '(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100)) 

184 else: # just plot pdf 

185 ax = ax[0] 

186 ax.plot(pdf) 

187 ax.vlines(cutoff_idx, 0, np.max(pdf), colors='r') 

188 ax.set_xlabel('Bin Number') 

189 ax.set_ylabel('Density') 

190 ax.set_title('PDF Symmetry Cutoff\n' 

191 '(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100)) 

192 

193 return fraction_missing 

194 

195 

196def wf_comp(ephys_file, ts1, ts2, ch, sr=30000, n_ch_probe=385, dtype='int16', car=True, 

197 col=['b', 'r'], ax=None): 

198 ''' 

199 Plots two different sets of waveforms across specified channels after (optionally) 

200 common-average-referencing. In this way, waveforms can be compared to see if there is, 

201 e.g. drift during the recording, or if two units should be merged, or one unit should be split. 

202 

203 Parameters 

204 ---------- 

205 ephys_file : string 

206 The file path to the binary ephys data. 

207 ts1 : array_like 

208 A set of timestamps for which to compare waveforms with `ts2`. 

209 ts2: array_like 

210 A set of timestamps for which to compare waveforms with `ts1`. 

211 ch : array-like 

212 The channels to use for extracting and plotting the waveforms. 

213 sr : int (optional) 

214 The sampling rate (in hz) that the ephys data was acquired at. 

215 n_ch_probe : int (optional) 

216 The number of channels of the recording. 

217 dtype: str (optional) 

218 The datatype represented by the bytes in `ephys_file`. 

219 car: bool (optional) 

220 A flag for whether or not to perform common-average-referencing before extracting waveforms 

221 col: list of strings or float arrays (optional) 

222 Two elements in the list, where each specifies the color the `ts1` and `ts2` waveforms 

223 will be plotted in, respectively. 

224 ax : axessubplot (optional) 

225 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) 

226 

227 Returns 

228 ------- 

229 wf1 : ndarray 

230 The waveforms for the spikes in `ts1`: an array of shape (#spikes, #samples, #channels). 

231 wf2 : ndarray 

232 The waveforms for the spikes in `ts2`: an array of shape (#spikes, #samples, #channels). 

233 s : float 

234 The similarity score between the two sets of waveforms, calculated by 

235 `single_units.wf_similarity` 

236 

237 See Also 

238 -------- 

239 io.extract_waveforms 

240 single_units.wf_similarity 

241 

242 Examples 

243 -------- 

244 1) Compare first and last 100 spike waveforms for unit1, across 20 channels around the channel 

245 of max amplitude, and compare the waveforms in the first minute to the waveforms in the fourth 

246 minutes for unit2, across 10 channels around the mean. 

247 # Get first and last 100 spikes, and 20 channels around channel of max amp for unit 1: 

248 >>> ts1 = units_b['times']['1'][:100] 

249 >>> ts2 = units_b['times']['1'][-100:] 

250 >>> max_ch = clstrs_b['channels'][1] 

251 >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`. 

252 >>> ch = np.arange(max_ch, max_ch + 20) 

253 >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. 

254 >>> ch = np.arange(max_ch - 20, max_ch) 

255 >>> else: # take `n_c_ch` around `max_ch`. 

256 >>> ch = np.arange(max_ch - 10, max_ch + 10) 

257 >>> wf1, wf2, s = bb.plot.wf_comp(path_to_ephys_file, ts1, ts2, ch) 

258 # Plot waveforms for unit2 from the first and fourth minutes across 10 channels. 

259 >>> ts = units_b['times']['2'] 

260 >>> ts1_2 = ts[np.where(ts<60)[0]] 

261 >>> ts2_2 = ts[np.where(ts>180)[0][:len(ts1)]] 

262 >>> max_ch = clstrs_b['channels'][2] 

263 >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`. 

264 >>> ch = np.arange(max_ch, max_ch + 10) 

265 >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. 

266 >>> ch = np.arange(max_ch - 10, max_ch) 

267 >>> else: # take `n_c_ch` around `max_ch`. 

268 >>> ch = np.arange(max_ch - 5, max_ch + 5) 

269 >>> wf1_2, wf2_2, s_2 = bb.plot.wf_comp(path_to_ephys_file, ts1_2, ts2_2, ch) 

270 ''' 

271 

272 # Ensure `ch` is ndarray 

273 ch = np.asarray(ch) 

274 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch 

275 

276 # Extract the waveforms for these timestamps and compute similarity score. 

277 wf1 = extract_waveforms(ephys_file, ts1, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype, 

278 car=car) 

279 wf2 = extract_waveforms(ephys_file, ts2, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype, 

280 car=car) 

281 s = single_units.wf_similarity(wf1, wf2) 

282 

283 # Plot these waveforms against each other. 

284 n_ch = ch.size 

285 if ax is None: 

286 fig, ax = plt.subplots(nrows=n_ch, ncols=2) # left col is all waveforms, right col is mean 

287 for cur_ax, cur_ch in enumerate(ch): 

288 ax[cur_ax][0].plot(wf1[:, :, cur_ax].T, c=col[0]) 

289 ax[cur_ax][0].plot(wf2[:, :, cur_ax].T, c=col[1]) 

290 ax[cur_ax][1].plot(np.mean(wf1[:, :, cur_ax], axis=0), c=col[0]) 

291 ax[cur_ax][1].plot(np.mean(wf2[:, :, cur_ax], axis=0), c=col[1]) 

292 ax[cur_ax][0].set_ylabel('Ch {0}'.format(cur_ch)) 

293 ax[0][0].set_title('All Waveforms. S = {:.2f}'.format(s)) 

294 ax[0][1].set_title('Mean Waveforms') 

295 plt.legend(['1st spike set', '2nd spike set']) 

296 

297 return wf1, wf2, s 

298 

299 

300def amp_heatmap(ephys_file, ts, ch, sr=30000, n_ch_probe=385, dtype='int16', cmap_name='RdBu', 

301 car=True, ax=None): 

302 ''' 

303 Plots a heatmap of the normalized voltage values over time and space for given timestamps and 

304 channels, after (optionally) common-average-referencing. 

305 

306 Parameters 

307 ---------- 

308 ephys_file : string 

309 The file path to the binary ephys data. 

310 ts: array_like 

311 A set of timestamps for which to get the voltage values. 

312 ch : array-like 

313 The channels to use for extracting the voltage values. 

314 sr : int (optional) 

315 The sampling rate (in hz) that the ephys data was acquired at. 

316 n_ch_probe : int (optional) 

317 The number of channels of the recording. 

318 dtype: str (optional) 

319 The datatype represented by the bytes in `ephys_file`. 

320 cmap_name : string (optional) 

321 The name of the colormap associated with the plot. 

322 car: bool (optional) 

323 A flag for whether or not to perform common-average-referencing before extracting waveforms 

324 ax : axessubplot (optional) 

325 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) 

326 

327 Returns 

328 ------- 

329 v_vals : ndarray 

330 The voltage values. 

331 

332 Examples 

333 -------- 

334 1) Plot a heatmap of the spike amplitudes across 20 channels around the channel of max 

335 amplitude for all spikes in unit 1. 

336 >>> ts = units_b['times']['1'] 

337 >>> max_ch = clstrs_b['channels'][1] 

338 >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`. 

339 >>> ch = np.arange(max_ch, max_ch + 20) 

340 >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. 

341 >>> ch = np.arange(max_ch - 20, max_ch) 

342 >>> else: # take `n_c_ch` around `max_ch`. 

343 >>> ch = np.arange(max_ch - 10, max_ch + 10) 

344 >>> bb.plot.amp_heatmap(path_to_ephys_file, ts, ch) 

345 ''' 

346 # Ensure `ch` is ndarray 

347 ch = np.asarray(ch) 

348 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch 

349 

350 # Get memmapped array of `ephys_file` 

351 s_reader = spikeglx.Reader(ephys_file, open=True) 

352 file_m = s_reader.data 

353 

354 # Get voltage values for each peak amplitude sample for `ch`. 

355 max_amp_samples = (ts * sr).astype(int) 

356 # Currently this is an annoying way to calculate `v_vals` b/c indexing with multiple values 

357 # is currently unsupported. 

358 v_vals = np.zeros((max_amp_samples.size, ch.size)) 

359 for sample in range(max_amp_samples.size): 

360 v_vals[sample] = file_m[max_amp_samples[sample]:max_amp_samples[sample] + 1, ch] 

361 if car: # compute spatial noise in chunks, and subtract from `v_vals`. 

362 # Get subset of time (from first to last max amp sample) 

363 n_chunk_samples = 5e6 # number of samples per chunk 

364 n_chunks = np.ceil((max_amp_samples[-1] - max_amp_samples[0]) / 

365 n_chunk_samples).astype('int') 

366 # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the 

367 # samples that make up the first chunk. 

368 chunk_sample = np.arange(max_amp_samples[0], max_amp_samples[-1], n_chunk_samples, 

369 dtype=int) 

370 chunk_sample = np.append(chunk_sample, max_amp_samples[-1]) 

371 noise_s_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16) # spatial noise array 

372 # Give time estimate for computing `noise_s_chunks`. 

373 t0 = time.perf_counter() 

374 np.median(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0) 

375 dt = time.perf_counter() - t0 

376 print('Performing spatial CAR before waveform extraction. Estimated time is {:.2f} mins.' 

377 ' ({})'.format(dt * n_chunks / 60, time.ctime())) 

378 # Compute noise for each chunk, then take the median noise of all chunks. 

379 for chunk in range(n_chunks): 

380 noise_s_chunks[chunk, :] = np.median( 

381 file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0) 

382 noise_s = np.median(noise_s_chunks, axis=0) 

383 v_vals -= noise_s[None, :] 

384 print('Done. ({})'.format(time.ctime())) 

385 s_reader.close() 

386 

387 # Plot heatmap. 

388 if ax is None: 

389 fig, ax = plt.subplots() 

390 v_vals_norm = (v_vals / np.max(abs(v_vals))).T 

391 cbar_map = ax.imshow(v_vals_norm, cmap=cmap_name, aspect='auto', 

392 extent=[ts[0], ts[-1], ch[0], ch[-1]], origin='lower') 

393 ax.set_yticks(np.arange(ch[0], ch[-1], 5)) 

394 ax.set_ylabel('Channel Numbers') 

395 ax.set_xlabel('Time (s)') 

396 ax.set_title('Voltage Heatmap') 

397 fig = ax.figure 

398 cbar = fig.colorbar(cbar_map, ax=ax) 

399 cbar.set_label('V', rotation=-90) 

400 

401 return v_vals 

402 

403 

404def firing_rate(ts, hist_win=0.01, fr_win=0.5, n_bins=10, show_fr_cv=True, ax=None): 

405 ''' 

406 Plots the instantaneous firing rate of for given spike timestamps over time, and optionally 

407 overlays the value of the coefficient of variation of the firing rate for a specified number 

408 of bins. 

409 

410 Parameters 

411 ---------- 

412 ts : ndarray 

413 The spike timestamps from which to compute the firing rate. 

414 hist_win : float (optional) 

415 The time window (in s) to use for computing spike counts. 

416 fr_win : float (optional) 

417 The time window (in s) to use as a moving slider to compute the instantaneous firing rate. 

418 n_bins : int (optional) 

419 The number of bins in which to compute coefficients of variation of the firing rate. 

420 show_fr_cv : bool (optional) 

421 A flag for whether or not to compute and show the coefficients of variation of the firing 

422 rate for `n_bins`. 

423 ax : axessubplot (optional) 

424 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) 

425 

426 Returns 

427 ------- 

428 fr: ndarray 

429 The instantaneous firing rate over time (in hz). 

430 cv: float 

431 The mean coefficient of variation of the firing rate of the `n_bins` number of coefficients 

432 computed. Can only be returned if `show_fr_cv` is True. 

433 cvs: ndarray 

434 The coefficients of variation of the firing for each bin of `n_bins`. Can only be returned 

435 if `show_fr_cv` is True. 

436 

437 See Also 

438 -------- 

439 single_units.firing_rate_cv 

440 singecell.firing_rate 

441 

442 Examples 

443 -------- 

444 1) Plot the firing rate for unit 1 from the time of its first to last spike, showing the cv 

445 of the firing rate for 10 evenly spaced bins. 

446 >>> ts = units_b['times']['1'] 

447 >>> fr, cv, cvs = bb.plot.firing_rate(ts) 

448 ''' 

449 

450 if ax is None: 

451 fig, ax = plt.subplots() 

452 if not (show_fr_cv): # compute just the firing rate 

453 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) 

454 else: # compute firing rate and coefficients of variation 

455 cv, cvs, fr = single_units.firing_rate_coeff_var(ts, hist_win=hist_win, fr_win=fr_win, 

456 n_bins=n_bins) 

457 x = np.arange(fr.size) * hist_win 

458 ax.plot(x, fr) 

459 ax.set_title('Firing Rate') 

460 ax.set_xlabel('Time (s)') 

461 ax.set_ylabel('Rate (s$^-1$)') 

462 

463 if not (show_fr_cv): 

464 return fr 

465 else: # show coefficients of variation 

466 y_max = np.max(fr) * 1.05 

467 x_l = x[int(x.size / n_bins)] 

468 # Plot vertical lines separating plots into `n_bins`. 

469 [ax.vlines((x_l * i), 0, y_max, linestyles='dashed', linewidth=2) 

470 for i in range(1, n_bins)] 

471 # Plot text with cv of firing rate for each bin. 

472 [ax.text(x_l * (i + 1), y_max, 'cv={0:.2f}'.format(cvs[i]), fontsize=9, ha='right') 

473 for i in range(n_bins)] 

474 return fr, cv, cvs 

475 

476 

477def peri_event_time_histogram( 

478 spike_times, spike_clusters, events, cluster_id, # Everything you need for a basic plot 

479 t_before=0.2, t_after=0.5, bin_size=0.025, smoothing=0.025, as_rate=True, 

480 include_raster=False, n_rasters=None, error_bars='std', ax=None, 

481 pethline_kwargs={'color': 'blue', 'lw': 2}, 

482 errbar_kwargs={'color': 'blue', 'alpha': 0.5}, 

483 eventline_kwargs={'color': 'black', 'alpha': 0.5}, 

484 raster_kwargs={'color': 'black', 'lw': 0.5}, **kwargs): 

485 """ 

486 Plot peri-event time histograms, with the meaning firing rate of units centered on a given 

487 series of events. Can optionally add a raster underneath the PETH plot of individual spike 

488 trains about the events. 

489 

490 Parameters 

491 ---------- 

492 spike_times : array_like 

493 Spike times (in seconds) 

494 spike_clusters : array-like 

495 Cluster identities for each element of spikes 

496 events : array-like 

497 Times to align the histogram(s) to 

498 cluster_id : int 

499 Identity of the cluster for which to plot a PETH 

500 

501 t_before : float, optional 

502 Time before event to plot (default: 0.2s) 

503 t_after : float, optional 

504 Time after event to plot (default: 0.5s) 

505 bin_size :float, optional 

506 Width of bin for histograms (default: 0.025s) 

507 smoothing : float, optional 

508 Sigma of gaussian smoothing to use in histograms. (default: 0.025s) 

509 as_rate : bool, optional 

510 Whether to use spike counts or rates in the plot (default: `True`, uses rates) 

511 include_raster : bool, optional 

512 Whether to put a raster below the PETH of individual spike trains (default: `False`) 

513 n_rasters : int, optional 

514 If include_raster is True, the number of rasters to include. If `None` 

515 will default to plotting rasters around all provided events. (default: `None`) 

516 error_bars : {'std', 'sem', 'none'}, optional 

517 Defines which type of error bars to plot. Options are: 

518 -- `'std'` for 1 standard deviation 

519 -- `'sem'` for standard error of the mean 

520 -- `'none'` for only plotting the mean value 

521 (default: `'std'`) 

522 ax : matplotlib axes, optional 

523 If passed, the function will plot on the passed axes. Note: current 

524 behavior causes whatever was on the axes to be cleared before plotting! 

525 (default: `None`) 

526 pethline_kwargs : dict, optional 

527 Dict containing line properties to define PETH plot line. Default 

528 is a blue line with weight of 2. Needs to have color. See matplotlib plot documentation 

529 for more options. 

530 (default: `{'color': 'blue', 'lw': 2}`) 

531 errbar_kwargs : dict, optional 

532 Dict containing fill-between properties to define PETH error bars. 

533 Default is a blue fill with 50 percent opacity.. Needs to have color. See matplotlib 

534 fill_between documentation for more options. 

535 (default: `{'color': 'blue', 'alpha': 0.5}`) 

536 eventline_kwargs : dict, optional 

537 Dict containing fill-between properties to define line at event. 

538 Default is a black line with 50 percent opacity.. Needs to have color. See matplotlib 

539 vlines documentation for more options. 

540 (default: `{'color': 'black', 'alpha': 0.5}`) 

541 raster_kwargs : dict, optional 

542 Dict containing properties defining lines in the raster plot. 

543 Default is black lines with line width of 0.5. See matplotlib vlines for more options. 

544 (default: `{'color': 'black', 'lw': 0.5}`) 

545 

546 Returns 

547 ------- 

548 ax : matplotlib axes 

549 Axes with all of the plots requested. 

550 """ 

551 

552 # Check to make sure if we fail, we fail in an informative way 

553 if not len(spike_times) == len(spike_clusters): 

554 raise ValueError('Spike times and clusters are not of the same shape') 

555 if len(events) == 1: 

556 raise ValueError('Cannot make a PETH with only one event.') 

557 if error_bars not in ('std', 'sem', 'none'): 

558 raise ValueError('Invalid error bar type was passed.') 

559 if not all(np.isfinite(events)): 

560 raise ValueError('There are NaN or inf values in the list of events passed. ' 

561 ' Please remove non-finite data points and try again.') 

562 

563 # Compute peths 

564 peths, binned_spikes = singlecell.calculate_peths(spike_times, spike_clusters, [cluster_id], 

565 events, t_before, t_after, bin_size, 

566 smoothing, as_rate) 

567 # Construct an axis object if none passed 

568 if ax is None: 

569 plt.figure() 

570 ax = plt.gca() 

571 # Plot the curve and add error bars 

572 mean = peths.means[0, :] 

573 ax.plot(peths.tscale, mean, **pethline_kwargs) 

574 if error_bars == 'std': 

575 bars = peths.stds[0, :] 

576 elif error_bars == 'sem': 

577 bars = peths.stds[0, :] / np.sqrt(len(events)) 

578 else: 

579 bars = np.zeros_like(mean) 

580 if error_bars != 'none': 

581 ax.fill_between(peths.tscale, mean - bars, mean + bars, **errbar_kwargs) 

582 

583 # Plot the event marker line. Extends to 5% higher than max value of means plus any error bar. 

584 plot_edge = (mean.max() + bars[mean.argmax()]) * 1.05 

585 ax.vlines(0., 0., plot_edge, **eventline_kwargs) 

586 # Set the limits on the axes to t_before and t_after. Either set the ylim to the 0 and max 

587 # values of the PETH, or if we want to plot a spike raster below, create an equal amount of 

588 # blank space below the zero where the raster will go. 

589 ax.set_xlim([-t_before, t_after]) 

590 ax.set_ylim([-plot_edge if include_raster else 0., plot_edge]) 

591 # Put y ticks only at min, max, and zero 

592 if mean.min() != 0: 

593 ax.set_yticks([0, mean.min(), mean.max()]) 

594 else: 

595 ax.set_yticks([0., mean.max()]) 

596 # Move the x axis line from the bottom of the plotting space to zero if including a raster, 

597 # Then plot the raster 

598 if include_raster: 

599 if n_rasters is None: 

600 n_rasters = len(events) 

601 if n_rasters > 60: 

602 warn("Number of raster traces is greater than 60. This might look bad on the plot.") 

603 ax.axhline(0., color='black') 

604 tickheight = plot_edge / len(events[:n_rasters]) # How much space per trace 

605 tickedges = np.arange(0., -plot_edge - 1e-5, -tickheight) 

606 clu_spks = spike_times[spike_clusters == cluster_id] 

607 for i, t in enumerate(events[:n_rasters]): 

608 idx = np.bitwise_and(clu_spks >= t - t_before, clu_spks <= t + t_after) 

609 event_spks = clu_spks[idx] 

610 ax.vlines(event_spks - t, tickedges[i + 1], tickedges[i], **raster_kwargs) 

611 ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes', y=0.75) 

612 else: 

613 ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes') 

614 ax.spines['top'].set_visible(False) 

615 ax.spines['right'].set_visible(False) 

616 ax.set_xlabel('Time (s) after event') 

617 return ax 

618 

619 

620def driftmap(ts, feat, ax=None, plot_style='bincount', 

621 t_bin=0.01, d_bin=20, weights=None, vmax=None, **kwargs): 

622 """ 

623 Plots the values of a spike feature array (y-axis) over time (x-axis). 

624 Two arguments can be given for the plot_style of the drift map: 

625 - 'scatter' : whereby each value is plotted as a marker (up to 100'000 data point) 

626 - 'bincount' : whereby the values are binned (optimised to represent spike raster) 

627 

628 Parameters 

629 ---------- 

630 feat : ndarray 

631 The spikes' feature values. 

632 ts : ndarray 

633 The spike timestamps from which to compute the firing rate. 

634 ax : axessubplot (optional) 

635 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) 

636 t_bin: time bin used when plot_style='bincount' 

637 d_bin: depth bin used when plot_style='bincount' 

638 plot_style: 'scatter', 'bincount' 

639 **kwargs: matplotlib.imshow arguments 

640 

641 Returns 

642 ------- 

643 cd: float 

644 The cumulative drift of `feat`. 

645 md: float 

646 The maximum drift of `feat`. 

647 

648 See Also 

649 -------- 

650 metrics.cum_drift 

651 metrics.max_drift 

652 

653 Examples 

654 -------- 

655 1) Plot the amplitude driftmap for unit 1. 

656 >>> ts = units_b['times']['1'] 

657 >>> amps = units_b['amps']['1'] 

658 >>> ax = bb.plot.driftmap(ts, amps) 

659 2) Plot the depth driftmap for unit 1. 

660 >>> ts = units_b['times']['1'] 

661 >>> depths = units_b['depths']['1'] 

662 >>> ax = bb.plot.driftmap(ts, depths) 

663 """ 

664 iok = ~np.isnan(feat) 1b

665 if ax is None: 1b

666 fig, ax = plt.subplots() 

667 

668 if plot_style == 'scatter' and len(ts) < 100000: 1b

669 print('here todo') 

670 if 'color' not in kwargs.keys(): 

671 kwargs['color'] = 'k' 

672 ax.plot(ts, feat, **kwargs) 

673 else: 

674 # compute raster map as a function of site depth 

675 R, times, depths = bincount2D( 1b

676 ts[iok], feat[iok], t_bin, d_bin, weights=weights[iok] if weights is not None else None) 

677 # plot raster map 

678 ax.imshow(R, aspect='auto', cmap='binary', vmin=0, vmax=vmax or np.std(R) * 4, 1b

679 extent=np.r_[times[[0, -1]], depths[[0, -1]]], origin='lower', **kwargs) 

680 ax.set_xlabel('time (secs)') 1b

681 ax.set_ylabel('depth (um)') 1b

682 return ax 1b

683 

684 

685def pres_ratio(ts, hist_win=10, ax=None): 

686 ''' 

687 Plots the presence ratio of spike counts: the number of bins where there is at least one 

688 spike, over the total number of bins, given a specified bin width. 

689 

690 Parameters 

691 ---------- 

692 ts : ndarray 

693 The spike timestamps from which to compute the presence ratio. 

694 hist_win : float 

695 The time window (in s) to use for computing the presence ratio. 

696 ax : axessubplot (optional) 

697 The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) 

698 

699 Returns 

700 ------- 

701 pr : float 

702 The presence ratio. 

703 spks_bins : ndarray 

704 The number of spks in each bin. 

705 

706 See Also 

707 -------- 

708 metrics.pres_ratio 

709 

710 Examples 

711 -------- 

712 1) Plot the presence ratio for unit 1, given a window of 10 s. 

713 >>> ts = units_b['times']['1'] 

714 >>> pr, pr_bins = bb.plot.pres_ratio(ts) 

715 ''' 

716 

717 pr, spks_bins = single_units.pres_ratio(ts, hist_win) 

718 pr_bins = np.where(spks_bins > 0, 1, 0) 

719 

720 if ax is None: 

721 fig, ax = plt.subplots() 

722 

723 ax.plot(pr_bins) 

724 ax.set_xlabel('Bin Number (width={:.1f}s)'.format(hist_win)) 

725 ax.set_ylabel('Presence') 

726 ax.set_title('Presence Ratio') 

727 

728 return pr, spks_bins 

729 

730 

731def driftmap_color( 

732 clusters_depths, spikes_times, 

733 spikes_amps, spikes_depths, spikes_clusters, 

734 ax=None, axesoff=False, return_lims=False): 

735 

736 ''' 

737 Plots the driftmap of a session or a trial 

738 

739 The plot shows the spike times vs spike depths. 

740 Each dot is a spike, whose color indicates the cluster 

741 and opacity indicates the spike amplitude. 

742 

743 Parameters 

744 ------------- 

745 clusters_depths: ndarray 

746 depths of all clusters 

747 spikes_times: ndarray 

748 spike times of all clusters 

749 spikes_amps: ndarray 

750 amplitude of each spike 

751 spikes_depths: ndarray 

752 depth of each spike 

753 spikes_clusters: ndarray 

754 cluster idx of each spike 

755 ax: matplotlib.axes.Axes object (optional) 

756 The axis object to plot the driftmap on 

757 (if `None`, a new figure and axis is created) 

758 

759 Return 

760 --- 

761 ax: matplotlib.axes.Axes object 

762 The axis object with driftmap plotted 

763 x_lim: list of two elements 

764 range of x axis 

765 y_lim: list of two elements 

766 range of y axis 

767 ''' 

768 

769 color_bins = sns.color_palette("hls", 500) 

770 new_color_bins = np.vstack( 

771 np.transpose(np.reshape(color_bins, [5, 100, 3]), [1, 0, 2])) 

772 

773 # get the sorted idx of each depth, and create colors based on the idx 

774 

775 sorted_idx = np.argsort(np.argsort(clusters_depths)) 

776 

777 colors = np.vstack( 

778 [np.repeat( 

779 new_color_bins[np.mod(idx, 500), :][np.newaxis, ...], 

780 n_spikes, axis=0) 

781 for (idx, n_spikes) in 

782 zip(sorted_idx, np.unique(spikes_clusters, 

783 return_counts=True)[1])]) 

784 

785 max_amp = np.percentile(spikes_amps, 90) 

786 min_amp = np.percentile(spikes_amps, 10) 

787 opacity = np.divide(spikes_amps - min_amp, max_amp - min_amp) 

788 opacity[opacity > 1] = 1 

789 opacity[opacity < 0] = 0 

790 

791 colorvec = np.zeros([len(opacity), 4], dtype='float16') 

792 colorvec[:, 3] = opacity.astype('float16') 

793 colorvec[:, 0:3] = colors.astype('float16') 

794 

795 x = spikes_times.astype('float32') 

796 y = spikes_depths.astype('float32') 

797 

798 args = dict(color=colorvec, edgecolors='none') 

799 

800 if ax is None: 

801 fig = plt.Figure(dpi=200, frameon=False, figsize=[10, 10]) 

802 ax = plt.Axes(fig, [0.1, 0.1, 0.9, 0.9]) 

803 ax.set_xlabel('Time (sec)') 

804 ax.set_ylabel('Distance from the probe tip (um)') 

805 savefig = True 

806 args.update(s=0.1) 

807 

808 ax.scatter(x, y, **args) 

809 x_edge = (max(x) - min(x)) * 0.05 

810 x_lim = [min(x) - x_edge, max(x) + x_edge] 

811 y_lim = [min(y) - 50, max(y) + 100] 

812 ax.set_xlim(x_lim[0], x_lim[1]) 

813 ax.set_ylim(y_lim[0], y_lim[1]) 

814 

815 if axesoff: 

816 ax.axis('off') 

817 

818 if savefig: 

819 fig.add_axes(ax) 

820 fig.savefig('driftmap.png') 

821 

822 if return_lims: 

823 return ax, x_lim, y_lim 

824 else: 

825 return ax