Coverage for brainbox/metrics/single_units.py: 13%

241 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1""" 

2Computes metrics for assessing quality of single units. 

3 

4Run the following to set-up the workspace to run the docstring examples: 

5>>> import brainbox as bb 

6>>> import one.alf.io as aio 

7>>> import numpy as np 

8>>> import matplotlib.pyplot as plt 

9>>> import ibllib.ephys.spikes as e_spks 

10# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

11>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

12# Load the alf spikes bunch and clusters bunch, and get a units bunch. 

13>>> spks_b = aio.load_object(path_to_alf_out, 'spikes') 

14>>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters') 

15>>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute 

16""" 

17 

18import time 

19import logging 

20 

21import numpy as np 

22from scipy.ndimage import gaussian_filter1d 

23import scipy.stats as stats 

24import pandas as pd 

25 

26import spikeglx 

27from phylib.stats import correlograms 

28from iblutil.util import Bunch 

29from iblutil.numerical import ismember, between_sorted, bincount2D 

30from slidingRP import metrics 

31 

32from brainbox import singlecell 

33from brainbox.io.spikeglx import extract_waveforms 

34from brainbox.metrics import electrode_drift 

35 

36 

37_logger = logging.getLogger('ibllib') 

38 

39# Parameters to be used in `quick_unit_metrics` 

40METRICS_PARAMS = { 

41 'noise_cutoff': dict(quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10), 

42 'missed_spikes_est': dict(spks_per_bin=10, sigma=4, min_num_bins=50), 

43 'acceptable_contamination': 0.1, 

44 'bin_size': 0.25, 

45 'med_amp_thresh_uv': 50, # units below this threshold are considered noise 

46 'min_isi': 0.0001, 

47 'presence_window': 10, 

48 'refractory_period': 0.0015, 

49 'RPslide_thresh': 0.1, 

50 'RPmax_confidence': 90, # a unit needs to pass with at least this confidence percentage (0 - 100) 

51} 

52 

53 

54def unit_stability(units_b, units=None, feat_names=['amps'], dist='norm', test='ks'): 

55 """ 

56 Computes the probability that the empirical spike feature distribution(s), for specified 

57 feature(s), for all units, comes from a specific theoretical distribution, based on a specified 

58 statistical test. Also computes the coefficients of variation of the spike feature(s) for all 

59 units. 

60 

61 Parameters 

62 ---------- 

63 units_b : bunch 

64 A units bunch containing fields with spike information (e.g. cluster IDs, times, features, 

65 etc.) for all units. 

66 units : array-like (optional) 

67 A subset of all units for which to create the bar plot. (If `None`, all units are used) 

68 feat_names : list of strings (optional) 

69 A list of names of spike features that can be found in `spks` to specify which features to 

70 use for calculating unit stability. 

71 dist : string (optional) 

72 The type of hypothetical null distribution for which the empirical spike feature 

73 distributions are presumed to belong to. 

74 test : string (optional) 

75 The statistical test used to compute the probability that the empirical spike feature 

76 distributions come from `dist`. 

77 

78 Returns 

79 ------- 

80 p_vals_b : bunch 

81 A bunch with `feat_names` as keys, containing a ndarray with p-values (the probabilities 

82 that the empirical spike feature distribution for each unit comes from `dist` based on 

83 `test`) for each unit for all `feat_names`. 

84 cv_b : bunch 

85 A bunch with `feat_names` as keys, containing a ndarray with the coefficients of variation 

86 of each unit's empirical spike feature distribution for all features. 

87 

88 See Also 

89 -------- 

90 plot.feat_vars 

91 

92 Examples 

93 -------- 

94 1) Compute 1) the p-values obtained from running a one-sample ks test on the spike amplitudes 

95 for each unit, and 2) the variances of the empirical spike amplitudes distribution for each 

96 unit. Create a histogram of the variances of the spike amplitudes for each unit, color-coded by 

97 depth of channel of max amplitudes. Get cluster IDs of those units which have variances greater 

98 than 50. 

99 >>> p_vals_b, variances_b = bb.metrics.unit_stability(units_b) 

100 # Plot histograms of variances color-coded by depth of channel of max amplitudes 

101 >>> fig = bb.plot.feat_vars(units_b, feat_name='amps') 

102 # Get all unit IDs which have amps variance > 50 

103 >>> var_vals = np.array(tuple(variances_b['amps'].values())) 

104 >>> bad_units = np.where(var_vals > 50) 

105 """ 

106 

107 # Get units. 

108 if not (units is None): # we're using a subset of all units 

109 unit_list = list(units_b[feat_names[0]].keys()) 

110 # for each `feat` and unit in `unit_list`, remove unit from `units_b` if not in `units` 

111 for feat in feat_names: 

112 [units_b[feat].pop(unit) for unit in unit_list if not (int(unit) in units)] 

113 unit_list = list(units_b[feat_names[0]].keys()) # get new `unit_list` after removing units 

114 

115 # Initialize `p_vals` and `variances`. 

116 p_vals_b = Bunch() 

117 cv_b = Bunch() 

118 

119 # Set the test as a lambda function (in future, more tests can be added to this dict) 

120 tests = \ 

121 { 

122 'ks': lambda x, y: stats.kstest(x, y) 

123 } 

124 test_fun = tests[test] 

125 

126 # Compute the statistical tests and variances. For each feature, iteratively get each unit's 

127 # p-values and variances, and add them as keys to the respective bunches `p_vals_feat` and 

128 # `variances_feat`. After iterating through all units, add these bunches as keys to their 

129 # respective parent bunches, `p_vals` and `variances`. 

130 for feat in feat_names: 

131 p_vals_feat = Bunch((unit, 0) for unit in unit_list) 

132 cv_feat = Bunch((unit, 0) for unit in unit_list) 

133 for unit in unit_list: 

134 # If we're missing units/features, create a NaN placeholder and skip them: 

135 if len(units_b['times'][str(unit)]) == 0: 

136 p_val = np.nan 

137 cv = np.nan 

138 else: 

139 # compute p_val and var for current feature 

140 _, p_val = test_fun(units_b[feat][unit], dist) 

141 cv = np.var(units_b[feat][unit]) / np.mean(units_b[feat][unit]) 

142 # Append current unit's values to list of units' values for current feature: 

143 p_vals_feat[str(unit)] = p_val 

144 cv_feat[str(unit)] = cv 

145 p_vals_b[feat] = p_vals_feat 

146 cv_b[feat] = cv_feat 

147 

148 return p_vals_b, cv_b 

149 

150 

151def missed_spikes_est(feat, spks_per_bin=20, sigma=5, min_num_bins=50): 

152 """ 

153 Computes the approximate fraction of spikes missing from a spike feature distribution for a 

154 given unit, assuming the distribution is symmetric. 

155 Inspired by metric described in Hill et al. (2011) J Neurosci 31: 8699-8705. 

156 

157 Parameters 

158 ---------- 

159 feat : ndarray 

160 The spikes' feature values (e.g. amplitudes) 

161 spks_per_bin : int (optional) 

162 The number of spikes per bin from which to compute the spike feature histogram. 

163 sigma : int (optional) 

164 The standard deviation for the gaussian kernel used to compute the pdf from the spike 

165 feature histogram. 

166 min_num_bins : int (optional) 

167 The minimum number of bins used to compute the spike feature histogram. 

168 

169 Returns 

170 ------- 

171 fraction_missing : float 

172 The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an 

173 accurate estimate isn't possible. 

174 pdf : ndarray 

175 The computed pdf of the spike feature histogram. 

176 cutoff_idx : int 

177 The index for `pdf` at which point `pdf` is no longer symmetrical around the peak. (This 

178 is returned for plotting purposes). 

179 

180 See Also 

181 -------- 

182 plot.feat_cutoff 

183 Examples 

184 -------- 

185 1) Determine the fraction of spikes missing from unit 1 based on the recorded unit's spike 

186 amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric. 

187 # Get unit 1 amplitudes from a unit bunch, and compute fraction spikes missing. 

188 >>> feat = units_b['amps']['1'] 

189 >>> fraction_missing = bb.plot.feat_cutoff(feat) 

190 """ 

191 

192 # Ensure minimum number of spikes requirement is met, return Nan otherwise 

193 if feat.size <= (spks_per_bin * min_num_bins): 

194 return np.nan, None, None 

195 

196 # compute the spike feature histogram and pdf: 

197 num_bins = int(feat.size / spks_per_bin) 

198 hist, bins = np.histogram(feat, num_bins, density=True) 

199 pdf = gaussian_filter1d(hist, sigma) 

200 

201 # Find where the distribution stops being symmetric around the peak: 

202 peak_idx = np.argmax(pdf) 

203 max_idx_sym_around_peak = np.argmin(np.abs(pdf[peak_idx:] - pdf[0])) 

204 cutoff_idx = peak_idx + max_idx_sym_around_peak 

205 

206 # compute fraction missing from the tail of the pdf (the area where pdf stops being 

207 # symmetric around peak). 

208 fraction_missing = np.sum(pdf[cutoff_idx:]) / np.sum(pdf) 

209 fraction_missing = 0.5 if (fraction_missing > 0.5) else fraction_missing 

210 

211 return fraction_missing, pdf, cutoff_idx 

212 

213 

214def wf_similarity(wf1, wf2): 

215 """ 

216 Computes a unit normalized spatiotemporal similarity score between two sets of waveforms. 

217 This score is based on how waveform shape correlates for each pair of spikes between the 

218 two sets of waveforms across space and time. The shapes of the arrays of the two sets of 

219 waveforms must be equal. 

220 

221 Parameters 

222 ---------- 

223 wf1 : ndarray 

224 An array of shape (#spikes, #samples, #channels). 

225 wf2 : ndarray 

226 An array of shape (#spikes, #samples, #channels). 

227 

228 Returns 

229 ------- 

230 s: float 

231 The unit normalized spatiotemporal similarity score. 

232 

233 See Also 

234 -------- 

235 io.extract_waveforms 

236 plot.single_unit_wf_comp 

237 

238 Examples 

239 -------- 

240 1) Compute the similarity between the first and last 100 waveforms for unit1, across the 20 

241 channels around the channel of max amplitude. 

242 # Get the channels around the max amp channel for the unit, two sets of timestamps for the 

243 # unit, and the two corresponding sets of waveforms for those two sets of timestamps. 

244 # Then compute `s`. 

245 >>> max_ch = clstrs_b['channels'][1] 

246 >>> if max_ch < 10: # take only channels greater than `max_ch`. 

247 >>> ch = np.arange(max_ch, max_ch + 20) 

248 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`. 

249 >>> ch = np.arange(max_ch - 20, max_ch) 

250 >>> else: # take `n_c_ch` around `max_ch`. 

251 >>> ch = np.arange(max_ch - 10, max_ch + 10) 

252 >>> ts1 = units_b['times']['1'][:100] 

253 >>> ts2 = units_b['times']['1'][-100:] 

254 >>> wf1 = bb.io.extract_waveforms(path_to_ephys_file, ts1, ch) 

255 >>> wf2 = bb.io.extract_waveforms(path_to_ephys_file, ts2, ch) 

256 >>> s = bb.metrics.wf_similarity(wf1, wf2) 

257 

258 TODO check `s` calculation: 

259 take median of waveforms 

260 xcorr all waveforms with median, and divide by autocorr of all waveforms 

261 profile 

262 for two sets of units: xcorr(cl1, cl2) / (sqrt autocorr(cl1) * autocorr(cl2)) 

263 """ 

264 

265 # Remove warning for dividing by 0 when calculating `s` (this is resolved by using 

266 # `np.nan_to_num`) 

267 import warnings 

268 warnings.filterwarnings('ignore', r'invalid value encountered in true_divide') 

269 assert wf1.shape == wf2.shape, ('The shapes of the sets of waveforms are inconsistent ({})' 

270 '({})'.format(wf1.shape, wf2.shape)) 

271 

272 # Get number of spikes, samples, and channels of waveforms. 

273 n_spks = wf1.shape[0] 

274 n_samples = wf1.shape[1] 

275 n_ch = wf1.shape[2] 

276 

277 # Create a matrix that will hold the similarity values of each spike in `wf1` to `wf2`. 

278 # Iterate over both sets of spikes, computing `s` for each pair. 

279 similarity_matrix = np.zeros((n_spks, n_spks)) 

280 for spk1 in range(n_spks): 

281 for spk2 in range(n_spks): 

282 s_spk = \ 

283 np.sum(np.nan_to_num( 

284 wf1[spk1, :, :] * wf2[spk2, :, :] / 

285 np.sqrt(wf1[spk1, :, :] ** 2 * wf2[spk2, :, :] ** 2))) / (n_samples * n_ch) 

286 similarity_matrix[spk1, spk2] = s_spk 

287 

288 # Return mean of similarity matrix 

289 s = np.mean(similarity_matrix) 

290 return s 

291 

292 

293def firing_rate_coeff_var(ts, hist_win=0.01, fr_win=0.5, n_bins=10): 

294 ''' 

295 Computes the coefficient of variation of the firing rate: the ratio of the standard 

296 deviation to the mean. 

297 

298 Parameters 

299 ---------- 

300 ts : ndarray 

301 The spike timestamps from which to compute the firing rate. 

302 hist_win : float (optional) 

303 The time window (in s) to use for computing spike counts. 

304 fr_win : float (optional) 

305 The time window (in s) to use as a moving slider to compute the instantaneous firing rate. 

306 n_bins : int (optional) 

307 The number of bins in which to compute a coefficient of variation of the firing rate. 

308 

309 Returns 

310 ------- 

311 cv : float 

312 The mean coefficient of variation of the firing rate of the `n_bins` number of coefficients 

313 computed. 

314 cvs : ndarray 

315 The coefficients of variation of the firing for each bin of `n_bins`. 

316 fr : ndarray 

317 The instantaneous firing rate over time (in hz). 

318 

319 See Also 

320 -------- 

321 singlecell.firing_rate 

322 plot.firing_rate 

323 

324 Examples 

325 -------- 

326 1) Compute the coefficient of variation of the firing rate for unit 1 from the time of its 

327 first to last spike, and compute the coefficient of variation of the firing rate for unit 2 

328 from the first to second minute. 

329 >>> ts_1 = units_b['times']['1'] 

330 >>> ts_2 = units_b['times']['2'] 

331 >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0]) 

332 >>> cv, cvs, fr = bb.metrics.firing_rate_coeff_var(ts_1) 

333 >>> cv_2, cvs_2, fr_2 = bb.metrics.firing_rate_coeff_var(ts_2) 

334 ''' 

335 

336 # Compute overall instantaneous firing rate and firing rate for each bin. 

337 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) 

338 bin_sz = int(fr.size / n_bins) 

339 fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)]) 

340 

341 # Compute coefficient of variations of firing rate for each bin, and the mean c.v. 

342 cvs = np.std(fr_binned, axis=1) / np.mean(fr_binned, axis=1) 

343 # NaNs from zero spikes are turned into 0's 

344 # cvs[np.isnan(cvs)] = 0 nan's can happen if neuron doesn't spike in a bin 

345 cv = np.mean(cvs) 

346 

347 return cv, cvs, fr 

348 

349 

350def firing_rate_fano_factor(ts, hist_win=0.01, fr_win=0.5, n_bins=10): 

351 ''' 

352 Computes the fano factor of the firing rate: the ratio of the variance to the mean. 

353 (Almost identical to coeff. of variation) 

354 

355 Parameters 

356 ---------- 

357 ts : ndarray 

358 The spike timestamps from which to compute the firing rate. 

359 hist_win : float 

360 The time window (in s) to use for computing spike counts. 

361 fr_win : float 

362 The time window (in s) to use as a moving slider to compute the instantaneous firing rate. 

363 n_bins : int (optional) 

364 The number of bins in which to compute a fano factor of the firing rate. 

365 

366 Returns 

367 ------- 

368 ff : float 

369 The mean fano factor of the firing rate of the `n_bins` number of factors 

370 computed. 

371 ffs : ndarray 

372 The fano factors of the firing for each bin of `n_bins`. 

373 fr : ndarray 

374 The instantaneous firing rate over time (in hz). 

375 

376 See Also 

377 -------- 

378 singlecell.firing_rate 

379 plot.firing_rate 

380 

381 Examples 

382 -------- 

383 1) Compute the fano factor of the firing rate for unit 1 from the time of its 

384 first to last spike, and compute the fano factor of the firing rate for unit 2 

385 from the first to second minute. 

386 >>> ts_1 = units_b['times']['1'] 

387 >>> ts_2 = units_b['times']['2'] 

388 >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0]) 

389 >>> ff, ffs, fr = bb.metrics.firing_rate_fano_factor(ts_1) 

390 >>> ff_2, ffs_2, fr_2 = bb.metrics.firing_rate_fano_factor(ts_2) 

391 ''' 

392 

393 # Compute overall instantaneous firing rate and firing rate for each bin. 

394 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) 

395 # this procedure can cut off data at the end, up to n_bins last timesteps 

396 bin_sz = int(fr.size / n_bins) 

397 fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)]) 

398 

399 # Compute fano factor of firing rate for each bin, and the mean fano factor 

400 ffs = np.var(fr_binned, axis=1) / np.mean(fr_binned, axis=1) 

401 # ffs[np.isnan(ffs)] = 0 nan's can happen if neuron doesn't spike in a bin 

402 ff = np.mean(ffs) 

403 

404 return ff, ffs, fr 

405 

406 

407def average_drift(feat, times): 

408 """ 

409 Computes the cumulative drift (normalized by the total number of spikes) of a spike feature 

410 array. 

411 

412 Parameters 

413 ---------- 

414 feat : ndarray 

415 The spike feature values from which to compute the maximum drift. 

416 Usually amplitudes 

417 

418 Returns 

419 ------- 

420 cd : float 

421 The cumulative drift of the unit. 

422 

423 See Also 

424 -------- 

425 max_drift 

426 

427 Examples 

428 -------- 

429 1) Get the cumulative depth drift for unit 1. 

430 >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0] 

431 >>> depths = spks_b['depths'][unit_idxs] 

432 >>> amps = spks_b['amps'][unit_idxs] 

433 >>> depth_cd = bb.metrics.cum_drift(depths) 

434 >>> amp_cd = bb.metrics.cum_drift(amps) 

435 """ 

436 

437 cd = np.sum(np.abs(np.diff(feat) / np.diff(times))) / len(feat) 

438 return cd 

439 

440 

441def pres_ratio(ts, hist_win=10): 

442 """ 

443 Computes the presence ratio of spike counts: the number of bins where there is at least one 

444 spike, over the total number of bins, given a specified bin width. 

445 

446 Parameters 

447 ---------- 

448 ts : ndarray 

449 The spike timestamps from which to compute the presence ratio. 

450 hist_win : float (optional) 

451 The time window (in s) to use for computing the presence ratio. 

452 

453 Returns 

454 ------- 

455 pr : float 

456 The presence ratio. 

457 spks_bins : ndarray 

458 The number of spks in each bin. 

459 

460 See Also 

461 -------- 

462 plot.pres_ratio 

463 

464 Examples 

465 -------- 

466 1) Compute the presence ratio for unit 1, given a window of 10 s. 

467 >>> ts = units_b['times']['1'] 

468 >>> pr, pr_bins = bb.metrics.pres_ratio(ts) 

469 """ 

470 

471 bins = np.arange(0, ts[-1] + hist_win, hist_win) 

472 spks_bins, _ = np.histogram(ts, bins) 

473 pr = len(np.where(spks_bins)[0]) / len(spks_bins) 

474 return pr, spks_bins 

475 

476 

477def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, car=True): 

478 """ 

479 For specified channels, for specified timestamps, computes the mean (peak-to-peak amplitudes / 

480 the MADs of the background noise). 

481 

482 Parameters 

483 ---------- 

484 ephys_file : string 

485 The file path to the binary ephys data. 

486 ts : ndarray_like 

487 The timestamps (in s) of the spikes. 

488 ch : ndarray_like 

489 The channels on which to extract the waveforms. 

490 t : numeric (optional) 

491 The time (in ms) of the waveforms to extract to compute the ptp. 

492 sr : int (optional) 

493 The sampling rate (in hz) that the ephys data was acquired at. 

494 n_ch_probe : int (optional) 

495 The number of channels of the recording. 

496 car: bool (optional) 

497 A flag to perform common-average-referencing before extracting waveforms. 

498 

499 Returns 

500 ------- 

501 ptp_sigma : ndarray 

502 An array containing the mean ptp_over_noise values for the specified `ts` and `ch`. 

503 

504 Examples 

505 -------- 

506 1) Compute ptp_over_noise for all spikes on 20 channels around the channel of max amplitude 

507 for unit 1. 

508 >>> ts = units_b['times']['1'] 

509 >>> max_ch = max_ch = clstrs_b['channels'][1] 

510 >>> if max_ch < 10: # take only channels greater than `max_ch`. 

511 >>> ch = np.arange(max_ch, max_ch + 20) 

512 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`. 

513 >>> ch = np.arange(max_ch - 20, max_ch) 

514 >>> else: # take `n_c_ch` around `max_ch`. 

515 >>> ch = np.arange(max_ch - 10, max_ch + 10) 

516 >>> p = bb.metrics.ptp_over_noise(ephys_file, ts, ch) 

517 """ 

518 

519 # Ensure `ch` is ndarray 

520 ch = np.asarray(ch) 

521 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch 

522 

523 # Get waveforms. 

524 wf = extract_waveforms(ephys_file, ts, ch, t=t, sr=sr, n_ch_probe=n_ch_probe, car=car) 

525 

526 # Initialize `mean_ptp` based on `ch`, and compute mean ptp of all spikes for each ch. 

527 mean_ptp = np.zeros((ch.size,)) 

528 for cur_ch in range(ch.size, ): 

529 mean_ptp[cur_ch] = np.mean(np.max(wf[:, :, cur_ch], axis=1) - 

530 np.min(wf[:, :, cur_ch], axis=1)) 

531 

532 # Compute MAD for `ch` in chunks. 

533 with spikeglx.Reader(ephys_file) as s_reader: 

534 file_m = s_reader.data # the memmapped array 

535 n_chunk_samples = 5e6 # number of samples per chunk 

536 n_chunks = np.ceil(file_m.shape[0] / n_chunk_samples).astype('int') 

537 # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the 

538 # samples that make up the first chunk. 

539 chunk_sample = np.arange(0, file_m.shape[0], n_chunk_samples, dtype=int) 

540 chunk_sample = np.append(chunk_sample, file_m.shape[0]) 

541 # Give time estimate for computing MAD. 

542 t0 = time.perf_counter() 

543 stats.median_absolute_deviation(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0) 

544 dt = time.perf_counter() - t0 

545 print('Performing MAD computation. Estimated time is {:.2f} mins.' 

546 ' ({})'.format(dt * n_chunks / 60, time.ctime())) 

547 # Compute MAD for each chunk, then take the median MAD of all chunks. 

548 mad_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16) 

549 for chunk in range(n_chunks): 

550 mad_chunks[chunk, :] = stats.median_absolute_deviation( 

551 file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0, scale=1) 

552 print('Done. ({})'.format(time.ctime())) 

553 

554 # Return `mean_ptp` over `mad` 

555 mad = np.median(mad_chunks, axis=0) 

556 ptp_sigma = mean_ptp / mad 

557 return ptp_sigma 

558 

559 

560def contamination_alt(ts, rp=0.002): 

561 """ 

562 An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on 

563 the number of spikes, number of isi violations, and time between the first and last spike. 

564 (see Hill et al. (2011) J Neurosci 31: 8699-8705). 

565 

566 Parameters 

567 ---------- 

568 ts : ndarray_like 

569 The timestamps (in s) of the spikes. 

570 rp : float (optional) 

571 The refractory period (in s). 

572 

573 Returns 

574 ------- 

575 ce : float 

576 An estimate of the fraction of contamination. 

577 

578 See Also 

579 -------- 

580 contamination_alt 

581 

582 Examples 

583 -------- 

584 1) Compute contamination estimate for unit 1. 

585 >>> ts = units_b['times']['1'] 

586 >>> ce = bb.metrics.contamination(ts) 

587 """ 

588 

589 # Get number of spikes, number of isi violations, and time from first to final spike. 

590 n_spks = ts.size 

591 n_isi_viol = np.size(np.where(np.diff(ts) < rp)[0]) 

592 t = ts[-1] - ts[0] 

593 

594 # `ce` is min of roots of solved quadratic equation. 

595 c = (t * n_isi_viol) / (2 * rp * n_spks ** 2) # 3rd term in quadratic 

596 ce = np.min(np.abs(np.roots([-1, 1, c]))) # solve quadratic 

597 return ce 

598 

599 

600def contamination(ts, min_time, max_time, rp=0.002, min_isi=0.0001): 

601 """ 

602 An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on 

603 the number of spikes, number of isi violations, and time between the first and last spike. 

604 (see Hill et al. (2011) J Neurosci 31: 8699-8705). 

605 

606 Modified by Dan Denman from cortex-lab/sortingQuality GitHub by Nick Steinmetz. 

607 

608 Parameters 

609 ---------- 

610 ts : ndarray_like 

611 The timestamps (in s) of the spikes. 

612 min_time : float 

613 The minimum time (in s) that a potential spike occurred. 

614 max_time : float 

615 The maximum time (in s) that a potential spike occurred. 

616 rp : float (optional) 

617 The refractory period (in s). 

618 min_isi : float (optional) 

619 The minimum interspike-interval (in s) for counting duplicate spikes. 

620 

621 Returns 

622 ------- 

623 ce : float 

624 An estimate of the contamination. 

625 A perfect unit has a ce = 0 

626 A unit with some contamination has a ce < 0.5 

627 A unit with lots of contamination has a ce > 1.0 

628 num_violations : int 

629 The total number of isi violations. 

630 

631 See Also 

632 -------- 

633 contamination 

634 

635 Examples 

636 -------- 

637 1) Compute contamination estimate for unit 1, with a minimum isi for counting duplicate 

638 spikes of 0.1 ms. 

639 >>> ts = units_b['times']['1'] 

640 >>> ce = bb.metrics.contamination_alt(ts, min_isi=0.0001) 

641 """ 

642 

643 duplicate_spikes = np.where(np.diff(ts) <= min_isi)[0] 

644 

645 ts = np.delete(ts, duplicate_spikes + 1) 

646 isis = np.diff(ts) 

647 

648 num_spikes = ts.size 

649 num_violations = np.sum(isis < rp) 

650 violation_time = 2 * num_spikes * (rp - min_isi) 

651 total_rate = ts.size / (max_time - min_time) 

652 violation_rate = num_violations / violation_time 

653 ce = violation_rate / total_rate 

654 

655 return ce, num_violations 

656 

657 

658def _max_acceptable_cont(FR, RP, rec_duration, acceptableCont, thresh): 

659 """ 

660 Function to compute the maximum acceptable refractory period contamination 

661 called during slidingRP_viol 

662 """ 

663 

664 time_for_viol = RP * 2 * FR * rec_duration 

665 expected_count_for_acceptable_limit = acceptableCont * time_for_viol 

666 max_acceptable = stats.poisson.ppf(thresh, expected_count_for_acceptable_limit) 

667 if max_acceptable == 0 and stats.poisson.pmf(0, expected_count_for_acceptable_limit) > 0: 

668 max_acceptable = -1 

669 return max_acceptable 

670 

671 

672def slidingRP_viol(ts, bin_size=0.25, thresh=0.1, acceptThresh=0.1): 

673 """ 

674 A binary metric which determines whether there is an acceptable level of 

675 refractory period violations by using a sliding refractory period: 

676 

677 This takes into account the firing rate of the neuron and computes a 

678 maximum acceptable level of contamination at different possible values of 

679 the refractory period. If the unit has less than the maximum contamination 

680 at any of the possible values of the refractory period, the unit passes. 

681 

682 A neuron will always fail this metric for very low firing rates, and thus 

683 this metric takes into account both firing rate and refractory period 

684 violations. 

685 

686 

687 Parameters 

688 ---------- 

689 ts : ndarray_like 

690 The timestamps (in s) of the spikes. 

691 bin_size : float 

692 The size of binning for the autocorrelogram. 

693 thresh : float 

694 Spike rate used to generate poisson distribution (to compute maximum 

695 acceptable contamination, see _max_acceptable_cont) 

696 acceptThresh : float 

697 The fraction of contamination we are willing to accept (default value 

698 set to 0.1, or 10% contamination) 

699 

700 Returns 

701 ------- 

702 didpass : int 

703 0 if unit didn't pass 

704 1 if unit did pass 

705 

706 See Also 

707 -------- 

708 contamination 

709 

710 Examples 

711 -------- 

712 1) Compute whether a unit has too much refractory period contamination at 

713 any possible value of a refractory period, for a 0.25 ms bin, with a 

714 threshold of 10% acceptable contamination 

715 >>> ts = units_b['times']['1'] 

716 >>> didpass = bb.metrics.slidingRP_viol(ts, bin_size=0.25, thresh=0.1, 

717 acceptThresh=0.1) 

718 """ 

719 

720 b = np.arange(0, 10.25, bin_size) / 1000 + 1e-6 # bins in seconds 

721 bTestIdx = [5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 32, 36, 40] 

722 bTest = [b[i] for i in bTestIdx] 

723 

724 if len(ts) > 0 and ts[-1] > ts[0]: # only do this for units with samples 

725 recDur = (ts[-1] - ts[0]) 

726 # compute acg 

727 c0 = correlograms(ts, np.zeros(len(ts), dtype='int8'), cluster_ids=[0], 

728 bin_size=bin_size / 1000, sample_rate=20000, 

729 window_size=2, 

730 symmetrize=False) 

731 # cumulative sum of acg, i.e. number of total spikes occuring from 0 

732 # to end of that bin 

733 cumsumc0 = np.cumsum(c0[0, 0, :]) 

734 # cumulative sum at each of the testing bins 

735 res = cumsumc0[bTestIdx] 

736 total_spike_count = len(ts) 

737 

738 # divide each bin's count by the total spike count and the bin size 

739 bin_count_normalized = c0[0, 0] / total_spike_count / bin_size * 1000 

740 num_bins_2s = len(c0[0, 0]) # number of total bins that equal 2 secs 

741 num_bins_1s = int(num_bins_2s / 2) # number of bins that equal 1 sec 

742 # compute fr based on the mean of bin_count_normalized from 1 to 2 s 

743 # instead of as before (len(ts)/recDur) for a better estimate 

744 fr = np.sum(bin_count_normalized[num_bins_1s:num_bins_2s]) / num_bins_1s 

745 mfunc = np.vectorize(_max_acceptable_cont) 

746 # compute the maximum allowed number of spikes per testing bin 

747 m = mfunc(fr, bTest, recDur, fr * acceptThresh, thresh) 

748 # did the unit pass (resulting number of spikes less than maximum 

749 # allowed spikes) at any of the testing bins? 

750 didpass = int(np.any(np.less_equal(res, m))) 

751 else: 

752 didpass = 0 

753 

754 return didpass 

755 

756 

757def noise_cutoff(amps, quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10): 

758 """ 

759 A new metric to determine whether a unit's amplitude distribution is cut off 

760 (at floor), without assuming a Gaussian distribution. 

761 This metric takes the amplitude distribution, computes the mean and std 

762 of an upper quartile of the distribution, and determines how many standard 

763 deviations away from that mean a lower quartile lies. 

764 Parameters 

765 ---------- 

766 amps : ndarray_like 

767 The amplitudes (in uV) of the spikes. 

768 quantile_length : float 

769 The size of the upper quartile of the amplitude distribution. 

770 n_bins : int 

771 The number of bins used to compute a histogram of the amplitude 

772 distribution. 

773 n_low_bins : int 

774 The number of bins used in the lower part of the distribution (where 

775 cutoff is determined). 

776 nc_threshold: float 

777 the noise cutoff result has to be lower than this for a neuron to fail 

778 percent_threshold: float 

779 the first bin has to be greater than percent_threshold for neuron the to fail 

780 Returns 

781 ------- 

782 cutoff : float 

783 Number of standard deviations that the lower mean is outside of the 

784 mean of the upper quartile. 

785 See Also 

786 -------- 

787 missed_spikes_est 

788 Examples 

789 -------- 

790 1) Compute whether a unit's amplitude distribution is cut off 

791 >>> amps = spks_b['amps'][unit_idxs] 

792 >>> cutoff = bb.metrics.noise_cutoff(amps, quantile_length=.25, n_bins=100) 

793 """ 

794 cutoff = np.float64(np.nan) 

795 first_low_quantile = np.float64(np.nan) 

796 fail_criteria = np.ones(1).astype(bool)[0] 

797 

798 if amps.size > 1: # ensure there are amplitudes available to analyze 

799 bins_list = np.linspace(0, np.max(amps), n_bins) # list of bins to compute the amplitude histogram 

800 n, bins = np.histogram(amps, bins=bins_list) # construct amplitude histogram 

801 idx_peak = np.argmax(n) # peak of amplitude distribution 

802 # don't count zeros #len(n) - idx_peak, compute the length of the top half of the distribution -- ignoring zero bins 

803 length_top_half = len(np.where(n[idx_peak:-1] > 0)[0]) 

804 # the remaining part of the distribution, which we will compare the low quantile to 

805 high_quantile = 2 * quantile_length 

806 # the first bin (index) of the high quantile part of the distribution 

807 high_quantile_start_ind = int(np.ceil(high_quantile * length_top_half + idx_peak)) 

808 # bins to consider in the high quantile (of all non-zero bins) 

809 indices_bins_high_quantile = np.arange(high_quantile_start_ind, len(n)) 

810 idx_use = np.where(n[indices_bins_high_quantile] >= 1)[0] 

811 

812 if len(n[indices_bins_high_quantile]) > 0: # ensure there are amplitudes in these bins 

813 # mean of all amp values in high quantile bins 

814 mean_high_quantile = np.mean(n[indices_bins_high_quantile][idx_use]) 

815 std_high_quantile = np.std(n[indices_bins_high_quantile][idx_use]) 

816 if std_high_quantile > 0: 

817 first_low_quantile = n[(n != 0)][1] # take the second bin 

818 cutoff = (first_low_quantile - mean_high_quantile) / std_high_quantile 

819 peak_bin_height = np.max(n) 

820 percent_of_peak = percent_threshold * peak_bin_height 

821 

822 fail_criteria = (cutoff > nc_threshold) & (first_low_quantile > percent_of_peak) 

823 

824 nc_pass = ~fail_criteria 

825 return nc_pass, cutoff, first_low_quantile 

826 

827 

828def spike_sorting_metrics(times, clusters, amps, depths, cluster_ids=None, params=METRICS_PARAMS): 

829 """ 

830 Computes: 

831 - cell level metrics (cf quick_unit_metrics) 

832 - label the metrics according to quality thresholds 

833 - estimates drift as a function of time 

834 :param times: vector of spike times 

835 :param clusters: 

836 :param amplitudes: 

837 :param depths: 

838 :param cluster_ids (optional): set of clusters (if None the output datgrame will match 

839 the unique set of clusters represented in spike clusters) 

840 :param params: dict (optional) parameters for qc computation ( 

841 see constant at the top of the module for default values and keys) 

842 :return: data_frame of metrics (cluster records, columns are qc attributes)| 

843 :return: dictionary of recording qc (keys 'time_scale' and 'drift_um') 

844 """ 

845 # compute metrics and convert to `DataFrame` 

846 df_units = quick_unit_metrics( 

847 clusters, times, amps, depths, cluster_ids=cluster_ids, params=params) 

848 df_units = pd.DataFrame(df_units) 

849 # compute drift as a function of time and put in a dictionary 

850 drift, ts = electrode_drift.estimate_drift(times, amps, depths) 

851 rec_qc = {'time_scale': ts, 'drift_um': drift} 

852 return df_units, rec_qc 

853 

854 

855def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths, 

856 params=METRICS_PARAMS, cluster_ids=None, tbounds=None): 

857 """ 

858 Computes single unit metrics from only the spike times, amplitudes, and 

859 depths for a set of units. 

860 

861 Metrics computed: 

862 'amp_max', 

863 'amp_min', 

864 'amp_median', 

865 'amp_std_dB', 

866 'contamination', 

867 'contamination_alt', 

868 'drift', 

869 'missed_spikes_est', 

870 'noise_cutoff', 

871 'presence_ratio', 

872 'presence_ratio_std', 

873 'slidingRP_viol', 

874 'spike_count' 

875 

876 Parameters (see the METRICS_PARAMS constant) 

877 ---------- 

878 spike_clusters : ndarray_like 

879 A vector of the unit ids for a set of spikes. 

880 spike_times : ndarray_like 

881 A vector of the timestamps for a set of spikes. 

882 spike_amps : ndarray_like 

883 A vector of the amplitudes for a set of spikes. 

884 spike_depths : ndarray_like 

885 A vector of the depths for a set of spikes. 

886 clusters_id: (optional) lists of cluster ids. If not all clusters are represented in the 

887 spikes_clusters (ie. cluster has no spike), this will ensure the output size is consistent 

888 with the input arrays. 

889 tbounds: (optional) list or 2 elements array containing a time-selection to perform the 

890 metrics computation on. 

891 params : dict (optional) 

892 Parameters used for computing some of the metrics in the function: 

893 'presence_window': float 

894 The time window (in s) used to look for spikes when computing the presence ratio. 

895 'refractory_period': float 

896 The refractory period used when computing isi violations and the contamination 

897 estimate. 

898 'min_isi': float 

899 The minimum interspike-interval (in s) for counting duplicate spikes when computing 

900 the contamination estimate. 

901 'spks_per_bin_for_missed_spks_est': int 

902 The number of spikes per bin used to compute the spike amplitude pdf for a unit, 

903 when computing the missed spikes estimate. 

904 'std_smoothing_kernel_for_missed_spks_est': float 

905 The standard deviation for the gaussian kernel used to compute the spike amplitude 

906 pdf for a unit, when computing the missed spikes estimate. 

907 'min_num_bins_for_missed_spks_est': int 

908 The minimum number of bins used to compute the spike amplitude pdf for a unit, 

909 when computing the missed spikes estimate. 

910 

911 Returns 

912 ------- 

913 r : bunch 

914 A bunch whose keys are the computed spike metrics. 

915 

916 Notes 

917 ----- 

918 This function is called by `ephysqc.unit_metrics_ks2` which is called by `spikes.ks2_to_alf` 

919 during alf extraction of an ephys dataset in the ibl ephys extraction pipeline. 

920 

921 Examples 

922 -------- 

923 1) Compute quick metrics from a ks2 output directory: 

924 >>> from ibllib.ephys.ephysqc import phy_model_from_ks2_path 

925 >>> m = phy_model_from_ks2_path(path_to_ks2_out) 

926 >>> cluster_ids = m.spike_clusters 

927 >>> ts = m.spike_times 

928 >>> amps = m.amplitudes 

929 >>> depths = m.depths 

930 >>> r = bb.metrics.quick_unit_metrics(cluster_ids, ts, amps, depths) 

931 """ 

932 metrics_list = [ 

933 'cluster_id', 

934 'amp_max', 

935 'amp_min', 

936 'amp_median', 

937 'amp_std_dB', 

938 'contamination', 

939 'contamination_alt', 

940 'drift', 

941 'missed_spikes_est', 

942 'noise_cutoff', 

943 'presence_ratio', 

944 'presence_ratio_std', 

945 'slidingRP_viol', 

946 'spike_count', 

947 'slidingRP_viol_forced', 

948 'max_confidence', 

949 'min_contamination', 

950 'n_spikes_below2' 

951 ] 

952 if tbounds: 

953 ispi = between_sorted(spike_times, tbounds) 

954 spike_times = spike_times[ispi] 

955 spike_clusters = spike_clusters[ispi] 

956 spike_amps = spike_amps[ispi] 

957 spike_depths = spike_depths[ispi] 

958 

959 if cluster_ids is None: 

960 cluster_ids = np.unique(spike_clusters) 

961 nclust = cluster_ids.size 

962 

963 r = Bunch({k: np.full((nclust,), np.nan) for k in metrics_list}) 

964 r['cluster_id'] = cluster_ids 

965 

966 # vectorized computation of basic metrics such as presence ratio and firing rate 

967 tmin = spike_times[0] 

968 tmax = spike_times[-1] 

969 presence_ratio = bincount2D(spike_times, spike_clusters, 

970 xbin=params['presence_window'], 

971 ybin=cluster_ids, xlim=[tmin, tmax])[0] 

972 r.presence_ratio = np.sum(presence_ratio > 0, axis=1) / presence_ratio.shape[1] 

973 r.presence_ratio_std = np.std(presence_ratio, axis=1) 

974 r.spike_count = np.sum(presence_ratio, axis=1) 

975 r.firing_rate = r.spike_count / (tmax - tmin) 

976 

977 # computing amplitude statistical indicators by aggregating over cluster id 

978 camp = pd.DataFrame(np.c_[spike_amps, 20 * np.log10(spike_amps), spike_clusters], 

979 columns=['amps', 'log_amps', 'clusters']) 

980 camp = camp.groupby('clusters') 

981 ir, ib = ismember(r.cluster_id, camp.clusters.unique()) 

982 r.amp_min[ir] = np.array(camp['amps'].min()) 

983 r.amp_max[ir] = np.array(camp['amps'].max()) 

984 # this is the geometric median 

985 r.amp_median[ir] = np.array(10 ** (camp['log_amps'].median() / 20)) 

986 r.amp_std_dB[ir] = np.array(camp['log_amps'].std()) 

987 srp = metrics.slidingRP_all(spikeTimes=spike_times, spikeClusters=spike_clusters, 

988 sampleRate=30000, binSizeCorr=1 / 30000) 

989 r.slidingRP_viol[ir] = srp['value'] 

990 r.slidingRP_viol_forced[ir] = srp['value_forced'] 

991 r.max_confidence[ir] = srp['max_confidence'] 

992 r.min_contamination[ir] = srp['min_contamination'] 

993 r.n_spikes_below2 = srp['n_spikes_below2'] 

994 

995 # loop over each cluster to compute the rest of the metrics 

996 for ic in np.arange(nclust): 

997 # slice the spike_times array 

998 ispikes = spike_clusters == cluster_ids[ic] 

999 if np.all(~ispikes): # if this cluster has no spikes, continue 

1000 continue 

1001 ts = spike_times[ispikes] 

1002 amps = spike_amps[ispikes] 

1003 depths = spike_depths[ispikes] 

1004 # compute metrics 

1005 r.contamination_alt[ic] = contamination_alt(ts, rp=params['refractory_period']) 

1006 r.contamination[ic], _ = contamination( 

1007 ts, tmin, tmax, rp=params['refractory_period'], min_isi=params['min_isi']) 

1008 _, r.noise_cutoff[ic], _ = noise_cutoff(amps, **params['noise_cutoff']) 

1009 r.missed_spikes_est[ic], _, _ = missed_spikes_est(amps, **params['missed_spikes_est']) 

1010 # wonder if there is a need to low-cut this 

1011 r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600 

1012 r.label, r.bitwise_fail = compute_labels(r, return_bitwise=True) 

1013 return r 

1014 

1015 

1016def compute_labels(r, params=METRICS_PARAMS, return_bitwise=False): 

1017 """ 

1018 From a dataframe or a dictionary of unit metrics, compute a label 

1019 :param r: dictionary or pandas dataframe containing unit qcs 

1020 :param return_bitwise: True (returns a full dictionary of metrics) 

1021 :return: vector of proportion of qcs passed between 0 and 1, where 1 denotes an all pass 

1022 """ 

1023 # right now the score is a value between 0 and 1 denoting the proportion of passing qcs, 

1024 # where 1 means passing and 0 means failing 

1025 labels = np.c_[ 

1026 r['max_confidence'] >= params['RPmax_confidence'], # this is the least significant bit 

1027 r.noise_cutoff < params['noise_cutoff']['nc_threshold'], 

1028 r.amp_median > params['med_amp_thresh_uv'] / 1e6, 

1029 # add a new metric here on higher significant bits 

1030 ] 

1031 # The first column takes binary values 001 or 000 to represent fail or pass, 

1032 # the second, 010 or 000, the third, 100 or 000 etc. 

1033 # The bitwise or "sum" produces 111 if all metrics fail, or 000 if all metrics pass 

1034 # All other permutations are also captured, i.e. 110 == 000 || 010 || 100 means 

1035 # the second and third metrics failed and the first metric was a pass 

1036 score = np.mean(labels, axis=1) 

1037 if return_bitwise: 

1038 # note the cast to uint8 casts nan to 0 

1039 # a nan implies no metrics was computed which we mark as a failure here 

1040 n_criteria = labels.shape[1] 

1041 bitwise = np.bitwise_or.reduce(2 ** np.arange(n_criteria) * (~ labels.astype(bool)).astype(np.uint8), axis=1) 

1042 return score, bitwise.astype(np.uint8) 

1043 else: 

1044 return score