Coverage for brainbox/metrics/single_units.py: 53%

238 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1""" 

2Computes metrics for assessing quality of single units. 

3 

4Run the following to set-up the workspace to run the docstring examples: 

5>>> import brainbox as bb 

6>>> import one.alf.io as aio 

7>>> import numpy as np 

8>>> import matplotlib.pyplot as plt 

9>>> import ibllib.ephys.spikes as e_spks 

10# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

11>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

12# Load the alf spikes bunch and clusters bunch, and get a units bunch. 

13>>> spks_b = aio.load_object(path_to_alf_out, 'spikes') 

14>>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters') 

15>>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute 

16""" 

17 

18import time 

19import logging 

20 

21import numpy as np 

22from scipy.ndimage import gaussian_filter1d 

23import scipy.stats as stats 

24import pandas as pd 

25 

26import spikeglx 

27from phylib.stats import correlograms 

28from iblutil.util import Bunch 

29from iblutil.numerical import ismember, between_sorted, bincount2D 

30from slidingRP import metrics 

31 

32from brainbox import singlecell 

33from brainbox.io.spikeglx import extract_waveforms 

34from brainbox.metrics import electrode_drift 

35 

36 

37_logger = logging.getLogger('ibllib') 

38 

39# Parameters to be used in `quick_unit_metrics` 

40METRICS_PARAMS = { 

41 'noise_cutoff': dict(quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10), 

42 'missed_spikes_est': dict(spks_per_bin=10, sigma=4, min_num_bins=50), 

43 'acceptable_contamination': 0.1, 

44 'bin_size': 0.25, 

45 'med_amp_thresh_uv': 50, 

46 'min_isi': 0.0001, 

47 'presence_window': 10, 

48 'refractory_period': 0.0015, 

49 'RPslide_thresh': 0.1, 

50} 

51 

52 

53def unit_stability(units_b, units=None, feat_names=['amps'], dist='norm', test='ks'): 

54 """ 

55 Computes the probability that the empirical spike feature distribution(s), for specified 

56 feature(s), for all units, comes from a specific theoretical distribution, based on a specified 

57 statistical test. Also computes the coefficients of variation of the spike feature(s) for all 

58 units. 

59 

60 Parameters 

61 ---------- 

62 units_b : bunch 

63 A units bunch containing fields with spike information (e.g. cluster IDs, times, features, 

64 etc.) for all units. 

65 units : array-like (optional) 

66 A subset of all units for which to create the bar plot. (If `None`, all units are used) 

67 feat_names : list of strings (optional) 

68 A list of names of spike features that can be found in `spks` to specify which features to 

69 use for calculating unit stability. 

70 dist : string (optional) 

71 The type of hypothetical null distribution for which the empirical spike feature 

72 distributions are presumed to belong to. 

73 test : string (optional) 

74 The statistical test used to compute the probability that the empirical spike feature 

75 distributions come from `dist`. 

76 

77 Returns 

78 ------- 

79 p_vals_b : bunch 

80 A bunch with `feat_names` as keys, containing a ndarray with p-values (the probabilities 

81 that the empirical spike feature distribution for each unit comes from `dist` based on 

82 `test`) for each unit for all `feat_names`. 

83 cv_b : bunch 

84 A bunch with `feat_names` as keys, containing a ndarray with the coefficients of variation 

85 of each unit's empirical spike feature distribution for all features. 

86 

87 See Also 

88 -------- 

89 plot.feat_vars 

90 

91 Examples 

92 -------- 

93 1) Compute 1) the p-values obtained from running a one-sample ks test on the spike amplitudes 

94 for each unit, and 2) the variances of the empirical spike amplitudes distribution for each 

95 unit. Create a histogram of the variances of the spike amplitudes for each unit, color-coded by 

96 depth of channel of max amplitudes. Get cluster IDs of those units which have variances greater 

97 than 50. 

98 >>> p_vals_b, variances_b = bb.metrics.unit_stability(units_b) 

99 # Plot histograms of variances color-coded by depth of channel of max amplitudes 

100 >>> fig = bb.plot.feat_vars(units_b, feat_name='amps') 

101 # Get all unit IDs which have amps variance > 50 

102 >>> var_vals = np.array(tuple(variances_b['amps'].values())) 

103 >>> bad_units = np.where(var_vals > 50) 

104 """ 

105 

106 # Get units. 

107 if not (units is None): # we're using a subset of all units 

108 unit_list = list(units_b[feat_names[0]].keys()) 

109 # for each `feat` and unit in `unit_list`, remove unit from `units_b` if not in `units` 

110 for feat in feat_names: 

111 [units_b[feat].pop(unit) for unit in unit_list if not (int(unit) in units)] 

112 unit_list = list(units_b[feat_names[0]].keys()) # get new `unit_list` after removing units 

113 

114 # Initialize `p_vals` and `variances`. 

115 p_vals_b = Bunch() 

116 cv_b = Bunch() 

117 

118 # Set the test as a lambda function (in future, more tests can be added to this dict) 

119 tests = \ 

120 { 

121 'ks': lambda x, y: stats.kstest(x, y) 

122 } 

123 test_fun = tests[test] 

124 

125 # Compute the statistical tests and variances. For each feature, iteratively get each unit's 

126 # p-values and variances, and add them as keys to the respective bunches `p_vals_feat` and 

127 # `variances_feat`. After iterating through all units, add these bunches as keys to their 

128 # respective parent bunches, `p_vals` and `variances`. 

129 for feat in feat_names: 

130 p_vals_feat = Bunch((unit, 0) for unit in unit_list) 

131 cv_feat = Bunch((unit, 0) for unit in unit_list) 

132 for unit in unit_list: 

133 # If we're missing units/features, create a NaN placeholder and skip them: 

134 if len(units_b['times'][str(unit)]) == 0: 

135 p_val = np.nan 

136 cv = np.nan 

137 else: 

138 # compute p_val and var for current feature 

139 _, p_val = test_fun(units_b[feat][unit], dist) 

140 cv = np.var(units_b[feat][unit]) / np.mean(units_b[feat][unit]) 

141 # Append current unit's values to list of units' values for current feature: 

142 p_vals_feat[str(unit)] = p_val 

143 cv_feat[str(unit)] = cv 

144 p_vals_b[feat] = p_vals_feat 

145 cv_b[feat] = cv_feat 

146 

147 return p_vals_b, cv_b 

148 

149 

150def missed_spikes_est(feat, spks_per_bin=20, sigma=5, min_num_bins=50): 

151 """ 

152 Computes the approximate fraction of spikes missing from a spike feature distribution for a 

153 given unit, assuming the distribution is symmetric. 

154 Inspired by metric described in Hill et al. (2011) J Neurosci 31: 8699-8705. 

155 

156 Parameters 

157 ---------- 

158 feat : ndarray 

159 The spikes' feature values (e.g. amplitudes) 

160 spks_per_bin : int (optional) 

161 The number of spikes per bin from which to compute the spike feature histogram. 

162 sigma : int (optional) 

163 The standard deviation for the gaussian kernel used to compute the pdf from the spike 

164 feature histogram. 

165 min_num_bins : int (optional) 

166 The minimum number of bins used to compute the spike feature histogram. 

167 

168 Returns 

169 ------- 

170 fraction_missing : float 

171 The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an 

172 accurate estimate isn't possible. 

173 pdf : ndarray 

174 The computed pdf of the spike feature histogram. 

175 cutoff_idx : int 

176 The index for `pdf` at which point `pdf` is no longer symmetrical around the peak. (This 

177 is returned for plotting purposes). 

178 

179 See Also 

180 -------- 

181 plot.feat_cutoff 

182 Examples 

183 -------- 

184 1) Determine the fraction of spikes missing from unit 1 based on the recorded unit's spike 

185 amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric. 

186 # Get unit 1 amplitudes from a unit bunch, and compute fraction spikes missing. 

187 >>> feat = units_b['amps']['1'] 

188 >>> fraction_missing = bb.plot.feat_cutoff(feat) 

189 """ 

190 

191 # Ensure minimum number of spikes requirement is met, return Nan otherwise 

192 if feat.size <= (spks_per_bin * min_num_bins): 1a

193 return np.nan, None, None 1a

194 

195 # compute the spike feature histogram and pdf: 

196 num_bins = int(feat.size / spks_per_bin) 1a

197 hist, bins = np.histogram(feat, num_bins, density=True) 1a

198 pdf = gaussian_filter1d(hist, sigma) 1a

199 

200 # Find where the distribution stops being symmetric around the peak: 

201 peak_idx = np.argmax(pdf) 1a

202 max_idx_sym_around_peak = np.argmin(np.abs(pdf[peak_idx:] - pdf[0])) 1a

203 cutoff_idx = peak_idx + max_idx_sym_around_peak 1a

204 

205 # compute fraction missing from the tail of the pdf (the area where pdf stops being 

206 # symmetric around peak). 

207 fraction_missing = np.sum(pdf[cutoff_idx:]) / np.sum(pdf) 1a

208 fraction_missing = 0.5 if (fraction_missing > 0.5) else fraction_missing 1a

209 

210 return fraction_missing, pdf, cutoff_idx 1a

211 

212 

213def wf_similarity(wf1, wf2): 

214 """ 

215 Computes a unit normalized spatiotemporal similarity score between two sets of waveforms. 

216 This score is based on how waveform shape correlates for each pair of spikes between the 

217 two sets of waveforms across space and time. The shapes of the arrays of the two sets of 

218 waveforms must be equal. 

219 

220 Parameters 

221 ---------- 

222 wf1 : ndarray 

223 An array of shape (#spikes, #samples, #channels). 

224 wf2 : ndarray 

225 An array of shape (#spikes, #samples, #channels). 

226 

227 Returns 

228 ------- 

229 s: float 

230 The unit normalized spatiotemporal similarity score. 

231 

232 See Also 

233 -------- 

234 io.extract_waveforms 

235 plot.single_unit_wf_comp 

236 

237 Examples 

238 -------- 

239 1) Compute the similarity between the first and last 100 waveforms for unit1, across the 20 

240 channels around the channel of max amplitude. 

241 # Get the channels around the max amp channel for the unit, two sets of timestamps for the 

242 # unit, and the two corresponding sets of waveforms for those two sets of timestamps. 

243 # Then compute `s`. 

244 >>> max_ch = clstrs_b['channels'][1] 

245 >>> if max_ch < 10: # take only channels greater than `max_ch`. 

246 >>> ch = np.arange(max_ch, max_ch + 20) 

247 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`. 

248 >>> ch = np.arange(max_ch - 20, max_ch) 

249 >>> else: # take `n_c_ch` around `max_ch`. 

250 >>> ch = np.arange(max_ch - 10, max_ch + 10) 

251 >>> ts1 = units_b['times']['1'][:100] 

252 >>> ts2 = units_b['times']['1'][-100:] 

253 >>> wf1 = bb.io.extract_waveforms(path_to_ephys_file, ts1, ch) 

254 >>> wf2 = bb.io.extract_waveforms(path_to_ephys_file, ts2, ch) 

255 >>> s = bb.metrics.wf_similarity(wf1, wf2) 

256 

257 TODO check `s` calculation: 

258 take median of waveforms 

259 xcorr all waveforms with median, and divide by autocorr of all waveforms 

260 profile 

261 for two sets of units: xcorr(cl1, cl2) / (sqrt autocorr(cl1) * autocorr(cl2)) 

262 """ 

263 

264 # Remove warning for dividing by 0 when calculating `s` (this is resolved by using 

265 # `np.nan_to_num`) 

266 import warnings 

267 warnings.filterwarnings('ignore', r'invalid value encountered in true_divide') 

268 assert wf1.shape == wf2.shape, ('The shapes of the sets of waveforms are inconsistent ({})' 

269 '({})'.format(wf1.shape, wf2.shape)) 

270 

271 # Get number of spikes, samples, and channels of waveforms. 

272 n_spks = wf1.shape[0] 

273 n_samples = wf1.shape[1] 

274 n_ch = wf1.shape[2] 

275 

276 # Create a matrix that will hold the similarity values of each spike in `wf1` to `wf2`. 

277 # Iterate over both sets of spikes, computing `s` for each pair. 

278 similarity_matrix = np.zeros((n_spks, n_spks)) 

279 for spk1 in range(n_spks): 

280 for spk2 in range(n_spks): 

281 s_spk = \ 

282 np.sum(np.nan_to_num( 

283 wf1[spk1, :, :] * wf2[spk2, :, :] / 

284 np.sqrt(wf1[spk1, :, :] ** 2 * wf2[spk2, :, :] ** 2))) / (n_samples * n_ch) 

285 similarity_matrix[spk1, spk2] = s_spk 

286 

287 # Return mean of similarity matrix 

288 s = np.mean(similarity_matrix) 

289 return s 

290 

291 

292def firing_rate_coeff_var(ts, hist_win=0.01, fr_win=0.5, n_bins=10): 

293 ''' 

294 Computes the coefficient of variation of the firing rate: the ratio of the standard 

295 deviation to the mean. 

296 

297 Parameters 

298 ---------- 

299 ts : ndarray 

300 The spike timestamps from which to compute the firing rate. 

301 hist_win : float (optional) 

302 The time window (in s) to use for computing spike counts. 

303 fr_win : float (optional) 

304 The time window (in s) to use as a moving slider to compute the instantaneous firing rate. 

305 n_bins : int (optional) 

306 The number of bins in which to compute a coefficient of variation of the firing rate. 

307 

308 Returns 

309 ------- 

310 cv : float 

311 The mean coefficient of variation of the firing rate of the `n_bins` number of coefficients 

312 computed. 

313 cvs : ndarray 

314 The coefficients of variation of the firing for each bin of `n_bins`. 

315 fr : ndarray 

316 The instantaneous firing rate over time (in hz). 

317 

318 See Also 

319 -------- 

320 singlecell.firing_rate 

321 plot.firing_rate 

322 

323 Examples 

324 -------- 

325 1) Compute the coefficient of variation of the firing rate for unit 1 from the time of its 

326 first to last spike, and compute the coefficient of variation of the firing rate for unit 2 

327 from the first to second minute. 

328 >>> ts_1 = units_b['times']['1'] 

329 >>> ts_2 = units_b['times']['2'] 

330 >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0]) 

331 >>> cv, cvs, fr = bb.metrics.firing_rate_coeff_var(ts_1) 

332 >>> cv_2, cvs_2, fr_2 = bb.metrics.firing_rate_coeff_var(ts_2) 

333 ''' 

334 

335 # Compute overall instantaneous firing rate and firing rate for each bin. 

336 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) 

337 bin_sz = int(fr.size / n_bins) 

338 fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)]) 

339 

340 # Compute coefficient of variations of firing rate for each bin, and the mean c.v. 

341 cvs = np.std(fr_binned, axis=1) / np.mean(fr_binned, axis=1) 

342 # NaNs from zero spikes are turned into 0's 

343 # cvs[np.isnan(cvs)] = 0 nan's can happen if neuron doesn't spike in a bin 

344 cv = np.mean(cvs) 

345 

346 return cv, cvs, fr 

347 

348 

349def firing_rate_fano_factor(ts, hist_win=0.01, fr_win=0.5, n_bins=10): 

350 ''' 

351 Computes the fano factor of the firing rate: the ratio of the variance to the mean. 

352 (Almost identical to coeff. of variation) 

353 

354 Parameters 

355 ---------- 

356 ts : ndarray 

357 The spike timestamps from which to compute the firing rate. 

358 hist_win : float 

359 The time window (in s) to use for computing spike counts. 

360 fr_win : float 

361 The time window (in s) to use as a moving slider to compute the instantaneous firing rate. 

362 n_bins : int (optional) 

363 The number of bins in which to compute a fano factor of the firing rate. 

364 

365 Returns 

366 ------- 

367 ff : float 

368 The mean fano factor of the firing rate of the `n_bins` number of factors 

369 computed. 

370 ffs : ndarray 

371 The fano factors of the firing for each bin of `n_bins`. 

372 fr : ndarray 

373 The instantaneous firing rate over time (in hz). 

374 

375 See Also 

376 -------- 

377 singlecell.firing_rate 

378 plot.firing_rate 

379 

380 Examples 

381 -------- 

382 1) Compute the fano factor of the firing rate for unit 1 from the time of its 

383 first to last spike, and compute the fano factor of the firing rate for unit 2 

384 from the first to second minute. 

385 >>> ts_1 = units_b['times']['1'] 

386 >>> ts_2 = units_b['times']['2'] 

387 >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0]) 

388 >>> ff, ffs, fr = bb.metrics.firing_rate_fano_factor(ts_1) 

389 >>> ff_2, ffs_2, fr_2 = bb.metrics.firing_rate_fano_factor(ts_2) 

390 ''' 

391 

392 # Compute overall instantaneous firing rate and firing rate for each bin. 

393 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) 

394 # this procedure can cut off data at the end, up to n_bins last timesteps 

395 bin_sz = int(fr.size / n_bins) 

396 fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)]) 

397 

398 # Compute fano factor of firing rate for each bin, and the mean fano factor 

399 ffs = np.var(fr_binned, axis=1) / np.mean(fr_binned, axis=1) 

400 # ffs[np.isnan(ffs)] = 0 nan's can happen if neuron doesn't spike in a bin 

401 ff = np.mean(ffs) 

402 

403 return ff, ffs, fr 

404 

405 

406def average_drift(feat, times): 

407 """ 

408 Computes the cumulative drift (normalized by the total number of spikes) of a spike feature 

409 array. 

410 

411 Parameters 

412 ---------- 

413 feat : ndarray 

414 The spike feature values from which to compute the maximum drift. 

415 Usually amplitudes 

416 

417 Returns 

418 ------- 

419 cd : float 

420 The cumulative drift of the unit. 

421 

422 See Also 

423 -------- 

424 max_drift 

425 

426 Examples 

427 -------- 

428 1) Get the cumulative depth drift for unit 1. 

429 >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0] 

430 >>> depths = spks_b['depths'][unit_idxs] 

431 >>> amps = spks_b['amps'][unit_idxs] 

432 >>> depth_cd = bb.metrics.cum_drift(depths) 

433 >>> amp_cd = bb.metrics.cum_drift(amps) 

434 """ 

435 

436 cd = np.sum(np.abs(np.diff(feat) / np.diff(times))) / len(feat) 

437 return cd 

438 

439 

440def pres_ratio(ts, hist_win=10): 

441 """ 

442 Computes the presence ratio of spike counts: the number of bins where there is at least one 

443 spike, over the total number of bins, given a specified bin width. 

444 

445 Parameters 

446 ---------- 

447 ts : ndarray 

448 The spike timestamps from which to compute the presence ratio. 

449 hist_win : float (optional) 

450 The time window (in s) to use for computing the presence ratio. 

451 

452 Returns 

453 ------- 

454 pr : float 

455 The presence ratio. 

456 spks_bins : ndarray 

457 The number of spks in each bin. 

458 

459 See Also 

460 -------- 

461 plot.pres_ratio 

462 

463 Examples 

464 -------- 

465 1) Compute the presence ratio for unit 1, given a window of 10 s. 

466 >>> ts = units_b['times']['1'] 

467 >>> pr, pr_bins = bb.metrics.pres_ratio(ts) 

468 """ 

469 

470 bins = np.arange(0, ts[-1] + hist_win, hist_win) 

471 spks_bins, _ = np.histogram(ts, bins) 

472 pr = len(np.where(spks_bins)[0]) / len(spks_bins) 

473 return pr, spks_bins 

474 

475 

476def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, car=True): 

477 """ 

478 For specified channels, for specified timestamps, computes the mean (peak-to-peak amplitudes / 

479 the MADs of the background noise). 

480 

481 Parameters 

482 ---------- 

483 ephys_file : string 

484 The file path to the binary ephys data. 

485 ts : ndarray_like 

486 The timestamps (in s) of the spikes. 

487 ch : ndarray_like 

488 The channels on which to extract the waveforms. 

489 t : numeric (optional) 

490 The time (in ms) of the waveforms to extract to compute the ptp. 

491 sr : int (optional) 

492 The sampling rate (in hz) that the ephys data was acquired at. 

493 n_ch_probe : int (optional) 

494 The number of channels of the recording. 

495 car: bool (optional) 

496 A flag to perform common-average-referencing before extracting waveforms. 

497 

498 Returns 

499 ------- 

500 ptp_sigma : ndarray 

501 An array containing the mean ptp_over_noise values for the specified `ts` and `ch`. 

502 

503 Examples 

504 -------- 

505 1) Compute ptp_over_noise for all spikes on 20 channels around the channel of max amplitude 

506 for unit 1. 

507 >>> ts = units_b['times']['1'] 

508 >>> max_ch = max_ch = clstrs_b['channels'][1] 

509 >>> if max_ch < 10: # take only channels greater than `max_ch`. 

510 >>> ch = np.arange(max_ch, max_ch + 20) 

511 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`. 

512 >>> ch = np.arange(max_ch - 20, max_ch) 

513 >>> else: # take `n_c_ch` around `max_ch`. 

514 >>> ch = np.arange(max_ch - 10, max_ch + 10) 

515 >>> p = bb.metrics.ptp_over_noise(ephys_file, ts, ch) 

516 """ 

517 

518 # Ensure `ch` is ndarray 

519 ch = np.asarray(ch) 

520 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch 

521 

522 # Get waveforms. 

523 wf = extract_waveforms(ephys_file, ts, ch, t=t, sr=sr, n_ch_probe=n_ch_probe, car=car) 

524 

525 # Initialize `mean_ptp` based on `ch`, and compute mean ptp of all spikes for each ch. 

526 mean_ptp = np.zeros((ch.size,)) 

527 for cur_ch in range(ch.size, ): 

528 mean_ptp[cur_ch] = np.mean(np.max(wf[:, :, cur_ch], axis=1) - 

529 np.min(wf[:, :, cur_ch], axis=1)) 

530 

531 # Compute MAD for `ch` in chunks. 

532 with spikeglx.Reader(ephys_file) as s_reader: 

533 file_m = s_reader.data # the memmapped array 

534 n_chunk_samples = 5e6 # number of samples per chunk 

535 n_chunks = np.ceil(file_m.shape[0] / n_chunk_samples).astype('int') 

536 # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the 

537 # samples that make up the first chunk. 

538 chunk_sample = np.arange(0, file_m.shape[0], n_chunk_samples, dtype=int) 

539 chunk_sample = np.append(chunk_sample, file_m.shape[0]) 

540 # Give time estimate for computing MAD. 

541 t0 = time.perf_counter() 

542 stats.median_absolute_deviation(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0) 

543 dt = time.perf_counter() - t0 

544 print('Performing MAD computation. Estimated time is {:.2f} mins.' 

545 ' ({})'.format(dt * n_chunks / 60, time.ctime())) 

546 # Compute MAD for each chunk, then take the median MAD of all chunks. 

547 mad_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16) 

548 for chunk in range(n_chunks): 

549 mad_chunks[chunk, :] = stats.median_absolute_deviation( 

550 file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0, scale=1) 

551 print('Done. ({})'.format(time.ctime())) 

552 

553 # Return `mean_ptp` over `mad` 

554 mad = np.median(mad_chunks, axis=0) 

555 ptp_sigma = mean_ptp / mad 

556 return ptp_sigma 

557 

558 

559def contamination_alt(ts, rp=0.002): 

560 """ 

561 An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on 

562 the number of spikes, number of isi violations, and time between the first and last spike. 

563 (see Hill et al. (2011) J Neurosci 31: 8699-8705). 

564 

565 Parameters 

566 ---------- 

567 ts : ndarray_like 

568 The timestamps (in s) of the spikes. 

569 rp : float (optional) 

570 The refractory period (in s). 

571 

572 Returns 

573 ------- 

574 ce : float 

575 An estimate of the fraction of contamination. 

576 

577 See Also 

578 -------- 

579 contamination_alt 

580 

581 Examples 

582 -------- 

583 1) Compute contamination estimate for unit 1. 

584 >>> ts = units_b['times']['1'] 

585 >>> ce = bb.metrics.contamination(ts) 

586 """ 

587 

588 # Get number of spikes, number of isi violations, and time from first to final spike. 

589 n_spks = ts.size 1a

590 n_isi_viol = np.size(np.where(np.diff(ts) < rp)[0]) 1a

591 t = ts[-1] - ts[0] 1a

592 

593 # `ce` is min of roots of solved quadratic equation. 

594 c = (t * n_isi_viol) / (2 * rp * n_spks ** 2) # 3rd term in quadratic 1a

595 ce = np.min(np.abs(np.roots([-1, 1, c]))) # solve quadratic 1a

596 return ce 1a

597 

598 

599def contamination(ts, min_time, max_time, rp=0.002, min_isi=0.0001): 

600 """ 

601 An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on 

602 the number of spikes, number of isi violations, and time between the first and last spike. 

603 (see Hill et al. (2011) J Neurosci 31: 8699-8705). 

604 

605 Modified by Dan Denman from cortex-lab/sortingQuality GitHub by Nick Steinmetz. 

606 

607 Parameters 

608 ---------- 

609 ts : ndarray_like 

610 The timestamps (in s) of the spikes. 

611 min_time : float 

612 The minimum time (in s) that a potential spike occurred. 

613 max_time : float 

614 The maximum time (in s) that a potential spike occurred. 

615 rp : float (optional) 

616 The refractory period (in s). 

617 min_isi : float (optional) 

618 The minimum interspike-interval (in s) for counting duplicate spikes. 

619 

620 Returns 

621 ------- 

622 ce : float 

623 An estimate of the contamination. 

624 A perfect unit has a ce = 0 

625 A unit with some contamination has a ce < 0.5 

626 A unit with lots of contamination has a ce > 1.0 

627 num_violations : int 

628 The total number of isi violations. 

629 

630 See Also 

631 -------- 

632 contamination 

633 

634 Examples 

635 -------- 

636 1) Compute contamination estimate for unit 1, with a minimum isi for counting duplicate 

637 spikes of 0.1 ms. 

638 >>> ts = units_b['times']['1'] 

639 >>> ce = bb.metrics.contamination_alt(ts, min_isi=0.0001) 

640 """ 

641 

642 duplicate_spikes = np.where(np.diff(ts) <= min_isi)[0] 1a

643 

644 ts = np.delete(ts, duplicate_spikes + 1) 1a

645 isis = np.diff(ts) 1a

646 

647 num_spikes = ts.size 1a

648 num_violations = np.sum(isis < rp) 1a

649 violation_time = 2 * num_spikes * (rp - min_isi) 1a

650 total_rate = ts.size / (max_time - min_time) 1a

651 violation_rate = num_violations / violation_time 1a

652 ce = violation_rate / total_rate 1a

653 

654 return ce, num_violations 1a

655 

656 

657def _max_acceptable_cont(FR, RP, rec_duration, acceptableCont, thresh): 

658 """ 

659 Function to compute the maximum acceptable refractory period contamination 

660 called during slidingRP_viol 

661 """ 

662 

663 time_for_viol = RP * 2 * FR * rec_duration 

664 expected_count_for_acceptable_limit = acceptableCont * time_for_viol 

665 max_acceptable = stats.poisson.ppf(thresh, expected_count_for_acceptable_limit) 

666 if max_acceptable == 0 and stats.poisson.pmf(0, expected_count_for_acceptable_limit) > 0: 

667 max_acceptable = -1 

668 return max_acceptable 

669 

670 

671def slidingRP_viol(ts, bin_size=0.25, thresh=0.1, acceptThresh=0.1): 

672 """ 

673 A binary metric which determines whether there is an acceptable level of 

674 refractory period violations by using a sliding refractory period: 

675 

676 This takes into account the firing rate of the neuron and computes a 

677 maximum acceptable level of contamination at different possible values of 

678 the refractory period. If the unit has less than the maximum contamination 

679 at any of the possible values of the refractory period, the unit passes. 

680 

681 A neuron will always fail this metric for very low firing rates, and thus 

682 this metric takes into account both firing rate and refractory period 

683 violations. 

684 

685 

686 Parameters 

687 ---------- 

688 ts : ndarray_like 

689 The timestamps (in s) of the spikes. 

690 bin_size : float 

691 The size of binning for the autocorrelogram. 

692 thresh : float 

693 Spike rate used to generate poisson distribution (to compute maximum 

694 acceptable contamination, see _max_acceptable_cont) 

695 acceptThresh : float 

696 The fraction of contamination we are willing to accept (default value 

697 set to 0.1, or 10% contamination) 

698 

699 Returns 

700 ------- 

701 didpass : int 

702 0 if unit didn't pass 

703 1 if unit did pass 

704 

705 See Also 

706 -------- 

707 contamination 

708 

709 Examples 

710 -------- 

711 1) Compute whether a unit has too much refractory period contamination at 

712 any possible value of a refractory period, for a 0.25 ms bin, with a 

713 threshold of 10% acceptable contamination 

714 >>> ts = units_b['times']['1'] 

715 >>> didpass = bb.metrics.slidingRP_viol(ts, bin_size=0.25, thresh=0.1, 

716 acceptThresh=0.1) 

717 """ 

718 

719 b = np.arange(0, 10.25, bin_size) / 1000 + 1e-6 # bins in seconds 

720 bTestIdx = [5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 32, 36, 40] 

721 bTest = [b[i] for i in bTestIdx] 

722 

723 if len(ts) > 0 and ts[-1] > ts[0]: # only do this for units with samples 

724 recDur = (ts[-1] - ts[0]) 

725 # compute acg 

726 c0 = correlograms(ts, np.zeros(len(ts), dtype='int8'), cluster_ids=[0], 

727 bin_size=bin_size / 1000, sample_rate=20000, 

728 window_size=2, 

729 symmetrize=False) 

730 # cumulative sum of acg, i.e. number of total spikes occuring from 0 

731 # to end of that bin 

732 cumsumc0 = np.cumsum(c0[0, 0, :]) 

733 # cumulative sum at each of the testing bins 

734 res = cumsumc0[bTestIdx] 

735 total_spike_count = len(ts) 

736 

737 # divide each bin's count by the total spike count and the bin size 

738 bin_count_normalized = c0[0, 0] / total_spike_count / bin_size * 1000 

739 num_bins_2s = len(c0[0, 0]) # number of total bins that equal 2 secs 

740 num_bins_1s = int(num_bins_2s / 2) # number of bins that equal 1 sec 

741 # compute fr based on the mean of bin_count_normalized from 1 to 2 s 

742 # instead of as before (len(ts)/recDur) for a better estimate 

743 fr = np.sum(bin_count_normalized[num_bins_1s:num_bins_2s]) / num_bins_1s 

744 mfunc = np.vectorize(_max_acceptable_cont) 

745 # compute the maximum allowed number of spikes per testing bin 

746 m = mfunc(fr, bTest, recDur, fr * acceptThresh, thresh) 

747 # did the unit pass (resulting number of spikes less than maximum 

748 # allowed spikes) at any of the testing bins? 

749 didpass = int(np.any(np.less_equal(res, m))) 

750 else: 

751 didpass = 0 

752 

753 return didpass 

754 

755 

756def noise_cutoff(amps, quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10): 

757 """ 

758 A new metric to determine whether a unit's amplitude distribution is cut off 

759 (at floor), without assuming a Gaussian distribution. 

760 This metric takes the amplitude distribution, computes the mean and std 

761 of an upper quartile of the distribution, and determines how many standard 

762 deviations away from that mean a lower quartile lies. 

763 Parameters 

764 ---------- 

765 amps : ndarray_like 

766 The amplitudes (in uV) of the spikes. 

767 quantile_length : float 

768 The size of the upper quartile of the amplitude distribution. 

769 n_bins : int 

770 The number of bins used to compute a histogram of the amplitude 

771 distribution. 

772 n_low_bins : int 

773 The number of bins used in the lower part of the distribution (where 

774 cutoff is determined). 

775 nc_threshold: float 

776 the noise cutoff result has to be lower than this for a neuron to fail 

777 percent_threshold: float 

778 the first bin has to be greater than percent_threshold for neuron the to fail 

779 Returns 

780 ------- 

781 cutoff : float 

782 Number of standard deviations that the lower mean is outside of the 

783 mean of the upper quartile. 

784 See Also 

785 -------- 

786 missed_spikes_est 

787 Examples 

788 -------- 

789 1) Compute whether a unit's amplitude distribution is cut off 

790 >>> amps = spks_b['amps'][unit_idxs] 

791 >>> cutoff = bb.metrics.noise_cutoff(amps, quantile_length=.25, n_bins=100) 

792 """ 

793 cutoff = np.float64(np.nan) 1a

794 first_low_quantile = np.float64(np.nan) 1a

795 fail_criteria = np.ones(1).astype(bool)[0] 1a

796 

797 if amps.size > 1: # ensure there are amplitudes available to analyze 1a

798 bins_list = np.linspace(0, np.max(amps), n_bins) # list of bins to compute the amplitude histogram 1a

799 n, bins = np.histogram(amps, bins=bins_list) # construct amplitude histogram 1a

800 idx_peak = np.argmax(n) # peak of amplitude distribution 1a

801 # don't count zeros #len(n) - idx_peak, compute the length of the top half of the distribution -- ignoring zero bins 

802 length_top_half = len(np.where(n[idx_peak:-1] > 0)[0]) 1a

803 # the remaining part of the distribution, which we will compare the low quantile to 

804 high_quantile = 2 * quantile_length 1a

805 # the first bin (index) of the high quantile part of the distribution 

806 high_quantile_start_ind = int(np.ceil(high_quantile * length_top_half + idx_peak)) 1a

807 # bins to consider in the high quantile (of all non-zero bins) 

808 indices_bins_high_quantile = np.arange(high_quantile_start_ind, len(n)) 1a

809 idx_use = np.where(n[indices_bins_high_quantile] >= 1)[0] 1a

810 

811 if len(n[indices_bins_high_quantile]) > 0: # ensure there are amplitudes in these bins 1a

812 # mean of all amp values in high quantile bins 

813 mean_high_quantile = np.mean(n[indices_bins_high_quantile][idx_use]) 1a

814 std_high_quantile = np.std(n[indices_bins_high_quantile][idx_use]) 1a

815 if std_high_quantile > 0: 1a

816 first_low_quantile = n[(n != 0)][1] # take the second bin 1a

817 cutoff = (first_low_quantile - mean_high_quantile) / std_high_quantile 1a

818 peak_bin_height = np.max(n) 1a

819 percent_of_peak = percent_threshold * peak_bin_height 1a

820 

821 fail_criteria = (cutoff > nc_threshold) & (first_low_quantile > percent_of_peak) 1a

822 

823 nc_pass = ~fail_criteria 1a

824 return nc_pass, cutoff, first_low_quantile 1a

825 

826 

827def spike_sorting_metrics(times, clusters, amps, depths, cluster_ids=None, params=METRICS_PARAMS): 

828 """ 

829 Computes: 

830 - cell level metrics (cf quick_unit_metrics) 

831 - label the metrics according to quality thresholds 

832 - estimates drift as a function of time 

833 :param times: vector of spike times 

834 :param clusters: 

835 :param amplitudes: 

836 :param depths: 

837 :param cluster_ids (optional): set of clusters (if None the output datgrame will match 

838 the unique set of clusters represented in spike clusters) 

839 :param params: dict (optional) parameters for qc computation ( 

840 see constant at the top of the module for default values and keys) 

841 :return: data_frame of metrics (cluster records, columns are qc attributes)| 

842 :return: dictionary of recording qc (keys 'time_scale' and 'drift_um') 

843 """ 

844 # compute metrics and convert to `DataFrame` 

845 df_units = quick_unit_metrics( 1a

846 clusters, times, amps, depths, cluster_ids=cluster_ids, params=params) 

847 df_units = pd.DataFrame(df_units) 1a

848 # compute drift as a function of time and put in a dictionary 

849 drift, ts = electrode_drift.estimate_drift(times, amps, depths) 1a

850 rec_qc = {'time_scale': ts, 'drift_um': drift} 1a

851 return df_units, rec_qc 1a

852 

853 

854def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths, 

855 params=METRICS_PARAMS, cluster_ids=None, tbounds=None): 

856 """ 

857 Computes single unit metrics from only the spike times, amplitudes, and 

858 depths for a set of units. 

859 

860 Metrics computed: 

861 'amp_max', 

862 'amp_min', 

863 'amp_median', 

864 'amp_std_dB', 

865 'contamination', 

866 'contamination_alt', 

867 'drift', 

868 'missed_spikes_est', 

869 'noise_cutoff', 

870 'presence_ratio', 

871 'presence_ratio_std', 

872 'slidingRP_viol', 

873 'spike_count' 

874 

875 Parameters (see the METRICS_PARAMS constant) 

876 ---------- 

877 spike_clusters : ndarray_like 

878 A vector of the unit ids for a set of spikes. 

879 spike_times : ndarray_like 

880 A vector of the timestamps for a set of spikes. 

881 spike_amps : ndarray_like 

882 A vector of the amplitudes for a set of spikes. 

883 spike_depths : ndarray_like 

884 A vector of the depths for a set of spikes. 

885 clusters_id: (optional) lists of cluster ids. If not all clusters are represented in the 

886 spikes_clusters (ie. cluster has no spike), this will ensure the output size is consistent 

887 with the input arrays. 

888 tbounds: (optional) list or 2 elements array containing a time-selection to perform the 

889 metrics computation on. 

890 params : dict (optional) 

891 Parameters used for computing some of the metrics in the function: 

892 'presence_window': float 

893 The time window (in s) used to look for spikes when computing the presence ratio. 

894 'refractory_period': float 

895 The refractory period used when computing isi violations and the contamination 

896 estimate. 

897 'min_isi': float 

898 The minimum interspike-interval (in s) for counting duplicate spikes when computing 

899 the contamination estimate. 

900 'spks_per_bin_for_missed_spks_est': int 

901 The number of spikes per bin used to compute the spike amplitude pdf for a unit, 

902 when computing the missed spikes estimate. 

903 'std_smoothing_kernel_for_missed_spks_est': float 

904 The standard deviation for the gaussian kernel used to compute the spike amplitude 

905 pdf for a unit, when computing the missed spikes estimate. 

906 'min_num_bins_for_missed_spks_est': int 

907 The minimum number of bins used to compute the spike amplitude pdf for a unit, 

908 when computing the missed spikes estimate. 

909 

910 Returns 

911 ------- 

912 r : bunch 

913 A bunch whose keys are the computed spike metrics. 

914 

915 Notes 

916 ----- 

917 This function is called by `ephysqc.unit_metrics_ks2` which is called by `spikes.ks2_to_alf` 

918 during alf extraction of an ephys dataset in the ibl ephys extraction pipeline. 

919 

920 Examples 

921 -------- 

922 1) Compute quick metrics from a ks2 output directory: 

923 >>> from ibllib.ephys.ephysqc import phy_model_from_ks2_path 

924 >>> m = phy_model_from_ks2_path(path_to_ks2_out) 

925 >>> cluster_ids = m.spike_clusters 

926 >>> ts = m.spike_times 

927 >>> amps = m.amplitudes 

928 >>> depths = m.depths 

929 >>> r = bb.metrics.quick_unit_metrics(cluster_ids, ts, amps, depths) 

930 """ 

931 metrics_list = [ 1a

932 'cluster_id', 

933 'amp_max', 

934 'amp_min', 

935 'amp_median', 

936 'amp_std_dB', 

937 'contamination', 

938 'contamination_alt', 

939 'drift', 

940 'missed_spikes_est', 

941 'noise_cutoff', 

942 'presence_ratio', 

943 'presence_ratio_std', 

944 'slidingRP_viol', 

945 'spike_count' 

946 ] 

947 if tbounds: 1a

948 ispi = between_sorted(spike_times, tbounds) 

949 spike_times = spike_times[ispi] 

950 spike_clusters = spike_clusters[ispi] 

951 spike_amps = spike_amps[ispi] 

952 spike_depths = spike_depths[ispi] 

953 

954 if cluster_ids is None: 1a

955 cluster_ids = np.unique(spike_clusters) 

956 nclust = cluster_ids.size 1a

957 

958 r = Bunch({k: np.full((nclust,), np.nan) for k in metrics_list}) 1a

959 r['cluster_id'] = cluster_ids 1a

960 

961 # vectorized computation of basic metrics such as presence ratio and firing rate 

962 tmin = spike_times[0] 1a

963 tmax = spike_times[-1] 1a

964 presence_ratio = bincount2D(spike_times, spike_clusters, 1a

965 xbin=params['presence_window'], 

966 ybin=cluster_ids, xlim=[tmin, tmax])[0] 

967 r.presence_ratio = np.sum(presence_ratio > 0, axis=1) / presence_ratio.shape[1] 1a

968 r.presence_ratio_std = np.std(presence_ratio, axis=1) 1a

969 r.spike_count = np.sum(presence_ratio, axis=1) 1a

970 r.firing_rate = r.spike_count / (tmax - tmin) 1a

971 

972 # computing amplitude statistical indicators by aggregating over cluster id 

973 camp = pd.DataFrame(np.c_[spike_amps, 20 * np.log10(spike_amps), spike_clusters], 1a

974 columns=['amps', 'log_amps', 'clusters']) 

975 camp = camp.groupby('clusters') 1a

976 ir, ib = ismember(r.cluster_id, camp.clusters.unique()) 1a

977 r.amp_min[ir] = np.array(camp['amps'].min()) 1a

978 r.amp_max[ir] = np.array(camp['amps'].max()) 1a

979 # this is the geometric median 

980 r.amp_median[ir] = np.array(10 ** (camp['log_amps'].median() / 20)) 1a

981 r.amp_std_dB[ir] = np.array(camp['log_amps'].std()) 1a

982 srp = metrics.slidingRP_all(spikeTimes=spike_times, spikeClusters=spike_clusters, 1a

983 **{'sampleRate': 30000, 'binSizeCorr': 1 / 30000}) 

984 r.slidingRP_viol[srp['cidx']] = srp['value'] 1a

985 

986 # loop over each cluster to compute the rest of the metrics 

987 for ic in np.arange(nclust): 1a

988 # slice the spike_times array 

989 ispikes = spike_clusters == cluster_ids[ic] 1a

990 if np.all(~ispikes): # if this cluster has no spikes, continue 1a

991 continue 

992 ts = spike_times[ispikes] 1a

993 amps = spike_amps[ispikes] 1a

994 depths = spike_depths[ispikes] 1a

995 # compute metrics 

996 r.contamination_alt[ic] = contamination_alt(ts, rp=params['refractory_period']) 1a

997 r.contamination[ic], _ = contamination( 1a

998 ts, tmin, tmax, rp=params['refractory_period'], min_isi=params['min_isi']) 

999 _, r.noise_cutoff[ic], _ = noise_cutoff(amps, **params['noise_cutoff']) 1a

1000 r.missed_spikes_est[ic], _, _ = missed_spikes_est(amps, **params['missed_spikes_est']) 1a

1001 # wonder if there is a need to low-cut this 

1002 r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600 1a

1003 

1004 r.label = compute_labels(r) 1a

1005 return r 1a

1006 

1007 

1008def compute_labels(r, params=METRICS_PARAMS, return_details=False): 

1009 """ 

1010 From a dataframe or a dictionary of unit metrics, compute a lablel 

1011 :param r: dictionary or pandas dataframe containing unit qcs 

1012 :param return_details: False (returns a full dictionary of metrics) 

1013 :return: vector of proportion of qcs passed between 0 and 1, where 1 denotes an all pass 

1014 """ 

1015 # right now the score is a value between 0 and 1 denoting the proportion of passing qcs 

1016 # we could eventually do a bitwise qc 

1017 labels = np.c_[ 1a

1018 r.slidingRP_viol, 

1019 r.noise_cutoff < params['noise_cutoff']['nc_threshold'], 

1020 r.amp_median > params['med_amp_thresh_uv'] / 1e6, 

1021 ] 

1022 if not return_details: 1a

1023 return np.mean(labels, axis=1) 1a

1024 column_names = ['slidingRP_viol', 'noise_cutoff', 'amp_median'] 

1025 qcdict = {} 

1026 for c in np.arange(labels.shape[1]): 

1027 qcdict[column_names[c]] = labels[:, c] 

1028 return np.mean(labels, axis=1), qcdict