Coverage for brainbox/metrics/single_units.py: 13%
241 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 09:55 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 09:55 +0000
1"""
2Computes metrics for assessing quality of single units.
4Run the following to set-up the workspace to run the docstring examples:
5>>> import brainbox as bb
6>>> import one.alf.io as aio
7>>> import numpy as np
8>>> import matplotlib.pyplot as plt
9>>> import ibllib.ephys.spikes as e_spks
10# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
11>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
12# Load the alf spikes bunch and clusters bunch, and get a units bunch.
13>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
14>>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters')
15>>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute
16"""
18import time
19import logging
21import numpy as np
22from scipy.ndimage import gaussian_filter1d
23import scipy.stats as stats
24import pandas as pd
26import spikeglx
27from phylib.stats import correlograms
28from iblutil.util import Bunch
29from iblutil.numerical import ismember, between_sorted, bincount2D
30from slidingRP import metrics
32from brainbox import singlecell
33from brainbox.io.spikeglx import extract_waveforms
34from brainbox.metrics import electrode_drift
37_logger = logging.getLogger('ibllib')
39# Parameters to be used in `quick_unit_metrics`
40METRICS_PARAMS = {
41 'noise_cutoff': dict(quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10),
42 'missed_spikes_est': dict(spks_per_bin=10, sigma=4, min_num_bins=50),
43 'acceptable_contamination': 0.1,
44 'bin_size': 0.25,
45 'med_amp_thresh_uv': 50, # units below this threshold are considered noise
46 'min_isi': 0.0001,
47 'presence_window': 10,
48 'refractory_period': 0.0015,
49 'RPslide_thresh': 0.1,
50 'RPmax_confidence': 90, # a unit needs to pass with at least this confidence percentage (0 - 100)
51}
54def unit_stability(units_b, units=None, feat_names=['amps'], dist='norm', test='ks'):
55 """
56 Computes the probability that the empirical spike feature distribution(s), for specified
57 feature(s), for all units, comes from a specific theoretical distribution, based on a specified
58 statistical test. Also computes the coefficients of variation of the spike feature(s) for all
59 units.
61 Parameters
62 ----------
63 units_b : bunch
64 A units bunch containing fields with spike information (e.g. cluster IDs, times, features,
65 etc.) for all units.
66 units : array-like (optional)
67 A subset of all units for which to create the bar plot. (If `None`, all units are used)
68 feat_names : list of strings (optional)
69 A list of names of spike features that can be found in `spks` to specify which features to
70 use for calculating unit stability.
71 dist : string (optional)
72 The type of hypothetical null distribution for which the empirical spike feature
73 distributions are presumed to belong to.
74 test : string (optional)
75 The statistical test used to compute the probability that the empirical spike feature
76 distributions come from `dist`.
78 Returns
79 -------
80 p_vals_b : bunch
81 A bunch with `feat_names` as keys, containing a ndarray with p-values (the probabilities
82 that the empirical spike feature distribution for each unit comes from `dist` based on
83 `test`) for each unit for all `feat_names`.
84 cv_b : bunch
85 A bunch with `feat_names` as keys, containing a ndarray with the coefficients of variation
86 of each unit's empirical spike feature distribution for all features.
88 See Also
89 --------
90 plot.feat_vars
92 Examples
93 --------
94 1) Compute 1) the p-values obtained from running a one-sample ks test on the spike amplitudes
95 for each unit, and 2) the variances of the empirical spike amplitudes distribution for each
96 unit. Create a histogram of the variances of the spike amplitudes for each unit, color-coded by
97 depth of channel of max amplitudes. Get cluster IDs of those units which have variances greater
98 than 50.
99 >>> p_vals_b, variances_b = bb.metrics.unit_stability(units_b)
100 # Plot histograms of variances color-coded by depth of channel of max amplitudes
101 >>> fig = bb.plot.feat_vars(units_b, feat_name='amps')
102 # Get all unit IDs which have amps variance > 50
103 >>> var_vals = np.array(tuple(variances_b['amps'].values()))
104 >>> bad_units = np.where(var_vals > 50)
105 """
107 # Get units.
108 if not (units is None): # we're using a subset of all units
109 unit_list = list(units_b[feat_names[0]].keys())
110 # for each `feat` and unit in `unit_list`, remove unit from `units_b` if not in `units`
111 for feat in feat_names:
112 [units_b[feat].pop(unit) for unit in unit_list if not (int(unit) in units)]
113 unit_list = list(units_b[feat_names[0]].keys()) # get new `unit_list` after removing units
115 # Initialize `p_vals` and `variances`.
116 p_vals_b = Bunch()
117 cv_b = Bunch()
119 # Set the test as a lambda function (in future, more tests can be added to this dict)
120 tests = \
121 {
122 'ks': lambda x, y: stats.kstest(x, y)
123 }
124 test_fun = tests[test]
126 # Compute the statistical tests and variances. For each feature, iteratively get each unit's
127 # p-values and variances, and add them as keys to the respective bunches `p_vals_feat` and
128 # `variances_feat`. After iterating through all units, add these bunches as keys to their
129 # respective parent bunches, `p_vals` and `variances`.
130 for feat in feat_names:
131 p_vals_feat = Bunch((unit, 0) for unit in unit_list)
132 cv_feat = Bunch((unit, 0) for unit in unit_list)
133 for unit in unit_list:
134 # If we're missing units/features, create a NaN placeholder and skip them:
135 if len(units_b['times'][str(unit)]) == 0:
136 p_val = np.nan
137 cv = np.nan
138 else:
139 # compute p_val and var for current feature
140 _, p_val = test_fun(units_b[feat][unit], dist)
141 cv = np.var(units_b[feat][unit]) / np.mean(units_b[feat][unit])
142 # Append current unit's values to list of units' values for current feature:
143 p_vals_feat[str(unit)] = p_val
144 cv_feat[str(unit)] = cv
145 p_vals_b[feat] = p_vals_feat
146 cv_b[feat] = cv_feat
148 return p_vals_b, cv_b
151def missed_spikes_est(feat, spks_per_bin=20, sigma=5, min_num_bins=50):
152 """
153 Computes the approximate fraction of spikes missing from a spike feature distribution for a
154 given unit, assuming the distribution is symmetric.
155 Inspired by metric described in Hill et al. (2011) J Neurosci 31: 8699-8705.
157 Parameters
158 ----------
159 feat : ndarray
160 The spikes' feature values (e.g. amplitudes)
161 spks_per_bin : int (optional)
162 The number of spikes per bin from which to compute the spike feature histogram.
163 sigma : int (optional)
164 The standard deviation for the gaussian kernel used to compute the pdf from the spike
165 feature histogram.
166 min_num_bins : int (optional)
167 The minimum number of bins used to compute the spike feature histogram.
169 Returns
170 -------
171 fraction_missing : float
172 The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an
173 accurate estimate isn't possible.
174 pdf : ndarray
175 The computed pdf of the spike feature histogram.
176 cutoff_idx : int
177 The index for `pdf` at which point `pdf` is no longer symmetrical around the peak. (This
178 is returned for plotting purposes).
180 See Also
181 --------
182 plot.feat_cutoff
183 Examples
184 --------
185 1) Determine the fraction of spikes missing from unit 1 based on the recorded unit's spike
186 amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric.
187 # Get unit 1 amplitudes from a unit bunch, and compute fraction spikes missing.
188 >>> feat = units_b['amps']['1']
189 >>> fraction_missing = bb.plot.feat_cutoff(feat)
190 """
192 # Ensure minimum number of spikes requirement is met, return Nan otherwise
193 if feat.size <= (spks_per_bin * min_num_bins):
194 return np.nan, None, None
196 # compute the spike feature histogram and pdf:
197 num_bins = int(feat.size / spks_per_bin)
198 hist, bins = np.histogram(feat, num_bins, density=True)
199 pdf = gaussian_filter1d(hist, sigma)
201 # Find where the distribution stops being symmetric around the peak:
202 peak_idx = np.argmax(pdf)
203 max_idx_sym_around_peak = np.argmin(np.abs(pdf[peak_idx:] - pdf[0]))
204 cutoff_idx = peak_idx + max_idx_sym_around_peak
206 # compute fraction missing from the tail of the pdf (the area where pdf stops being
207 # symmetric around peak).
208 fraction_missing = np.sum(pdf[cutoff_idx:]) / np.sum(pdf)
209 fraction_missing = 0.5 if (fraction_missing > 0.5) else fraction_missing
211 return fraction_missing, pdf, cutoff_idx
214def wf_similarity(wf1, wf2):
215 """
216 Computes a unit normalized spatiotemporal similarity score between two sets of waveforms.
217 This score is based on how waveform shape correlates for each pair of spikes between the
218 two sets of waveforms across space and time. The shapes of the arrays of the two sets of
219 waveforms must be equal.
221 Parameters
222 ----------
223 wf1 : ndarray
224 An array of shape (#spikes, #samples, #channels).
225 wf2 : ndarray
226 An array of shape (#spikes, #samples, #channels).
228 Returns
229 -------
230 s: float
231 The unit normalized spatiotemporal similarity score.
233 See Also
234 --------
235 io.extract_waveforms
236 plot.single_unit_wf_comp
238 Examples
239 --------
240 1) Compute the similarity between the first and last 100 waveforms for unit1, across the 20
241 channels around the channel of max amplitude.
242 # Get the channels around the max amp channel for the unit, two sets of timestamps for the
243 # unit, and the two corresponding sets of waveforms for those two sets of timestamps.
244 # Then compute `s`.
245 >>> max_ch = clstrs_b['channels'][1]
246 >>> if max_ch < 10: # take only channels greater than `max_ch`.
247 >>> ch = np.arange(max_ch, max_ch + 20)
248 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`.
249 >>> ch = np.arange(max_ch - 20, max_ch)
250 >>> else: # take `n_c_ch` around `max_ch`.
251 >>> ch = np.arange(max_ch - 10, max_ch + 10)
252 >>> ts1 = units_b['times']['1'][:100]
253 >>> ts2 = units_b['times']['1'][-100:]
254 >>> wf1 = bb.io.extract_waveforms(path_to_ephys_file, ts1, ch)
255 >>> wf2 = bb.io.extract_waveforms(path_to_ephys_file, ts2, ch)
256 >>> s = bb.metrics.wf_similarity(wf1, wf2)
258 TODO check `s` calculation:
259 take median of waveforms
260 xcorr all waveforms with median, and divide by autocorr of all waveforms
261 profile
262 for two sets of units: xcorr(cl1, cl2) / (sqrt autocorr(cl1) * autocorr(cl2))
263 """
265 # Remove warning for dividing by 0 when calculating `s` (this is resolved by using
266 # `np.nan_to_num`)
267 import warnings
268 warnings.filterwarnings('ignore', r'invalid value encountered in true_divide')
269 assert wf1.shape == wf2.shape, ('The shapes of the sets of waveforms are inconsistent ({})'
270 '({})'.format(wf1.shape, wf2.shape))
272 # Get number of spikes, samples, and channels of waveforms.
273 n_spks = wf1.shape[0]
274 n_samples = wf1.shape[1]
275 n_ch = wf1.shape[2]
277 # Create a matrix that will hold the similarity values of each spike in `wf1` to `wf2`.
278 # Iterate over both sets of spikes, computing `s` for each pair.
279 similarity_matrix = np.zeros((n_spks, n_spks))
280 for spk1 in range(n_spks):
281 for spk2 in range(n_spks):
282 s_spk = \
283 np.sum(np.nan_to_num(
284 wf1[spk1, :, :] * wf2[spk2, :, :] /
285 np.sqrt(wf1[spk1, :, :] ** 2 * wf2[spk2, :, :] ** 2))) / (n_samples * n_ch)
286 similarity_matrix[spk1, spk2] = s_spk
288 # Return mean of similarity matrix
289 s = np.mean(similarity_matrix)
290 return s
293def firing_rate_coeff_var(ts, hist_win=0.01, fr_win=0.5, n_bins=10):
294 '''
295 Computes the coefficient of variation of the firing rate: the ratio of the standard
296 deviation to the mean.
298 Parameters
299 ----------
300 ts : ndarray
301 The spike timestamps from which to compute the firing rate.
302 hist_win : float (optional)
303 The time window (in s) to use for computing spike counts.
304 fr_win : float (optional)
305 The time window (in s) to use as a moving slider to compute the instantaneous firing rate.
306 n_bins : int (optional)
307 The number of bins in which to compute a coefficient of variation of the firing rate.
309 Returns
310 -------
311 cv : float
312 The mean coefficient of variation of the firing rate of the `n_bins` number of coefficients
313 computed.
314 cvs : ndarray
315 The coefficients of variation of the firing for each bin of `n_bins`.
316 fr : ndarray
317 The instantaneous firing rate over time (in hz).
319 See Also
320 --------
321 singlecell.firing_rate
322 plot.firing_rate
324 Examples
325 --------
326 1) Compute the coefficient of variation of the firing rate for unit 1 from the time of its
327 first to last spike, and compute the coefficient of variation of the firing rate for unit 2
328 from the first to second minute.
329 >>> ts_1 = units_b['times']['1']
330 >>> ts_2 = units_b['times']['2']
331 >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0])
332 >>> cv, cvs, fr = bb.metrics.firing_rate_coeff_var(ts_1)
333 >>> cv_2, cvs_2, fr_2 = bb.metrics.firing_rate_coeff_var(ts_2)
334 '''
336 # Compute overall instantaneous firing rate and firing rate for each bin.
337 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win)
338 bin_sz = int(fr.size / n_bins)
339 fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)])
341 # Compute coefficient of variations of firing rate for each bin, and the mean c.v.
342 cvs = np.std(fr_binned, axis=1) / np.mean(fr_binned, axis=1)
343 # NaNs from zero spikes are turned into 0's
344 # cvs[np.isnan(cvs)] = 0 nan's can happen if neuron doesn't spike in a bin
345 cv = np.mean(cvs)
347 return cv, cvs, fr
350def firing_rate_fano_factor(ts, hist_win=0.01, fr_win=0.5, n_bins=10):
351 '''
352 Computes the fano factor of the firing rate: the ratio of the variance to the mean.
353 (Almost identical to coeff. of variation)
355 Parameters
356 ----------
357 ts : ndarray
358 The spike timestamps from which to compute the firing rate.
359 hist_win : float
360 The time window (in s) to use for computing spike counts.
361 fr_win : float
362 The time window (in s) to use as a moving slider to compute the instantaneous firing rate.
363 n_bins : int (optional)
364 The number of bins in which to compute a fano factor of the firing rate.
366 Returns
367 -------
368 ff : float
369 The mean fano factor of the firing rate of the `n_bins` number of factors
370 computed.
371 ffs : ndarray
372 The fano factors of the firing for each bin of `n_bins`.
373 fr : ndarray
374 The instantaneous firing rate over time (in hz).
376 See Also
377 --------
378 singlecell.firing_rate
379 plot.firing_rate
381 Examples
382 --------
383 1) Compute the fano factor of the firing rate for unit 1 from the time of its
384 first to last spike, and compute the fano factor of the firing rate for unit 2
385 from the first to second minute.
386 >>> ts_1 = units_b['times']['1']
387 >>> ts_2 = units_b['times']['2']
388 >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0])
389 >>> ff, ffs, fr = bb.metrics.firing_rate_fano_factor(ts_1)
390 >>> ff_2, ffs_2, fr_2 = bb.metrics.firing_rate_fano_factor(ts_2)
391 '''
393 # Compute overall instantaneous firing rate and firing rate for each bin.
394 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win)
395 # this procedure can cut off data at the end, up to n_bins last timesteps
396 bin_sz = int(fr.size / n_bins)
397 fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)])
399 # Compute fano factor of firing rate for each bin, and the mean fano factor
400 ffs = np.var(fr_binned, axis=1) / np.mean(fr_binned, axis=1)
401 # ffs[np.isnan(ffs)] = 0 nan's can happen if neuron doesn't spike in a bin
402 ff = np.mean(ffs)
404 return ff, ffs, fr
407def average_drift(feat, times):
408 """
409 Computes the cumulative drift (normalized by the total number of spikes) of a spike feature
410 array.
412 Parameters
413 ----------
414 feat : ndarray
415 The spike feature values from which to compute the maximum drift.
416 Usually amplitudes
418 Returns
419 -------
420 cd : float
421 The cumulative drift of the unit.
423 See Also
424 --------
425 max_drift
427 Examples
428 --------
429 1) Get the cumulative depth drift for unit 1.
430 >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0]
431 >>> depths = spks_b['depths'][unit_idxs]
432 >>> amps = spks_b['amps'][unit_idxs]
433 >>> depth_cd = bb.metrics.cum_drift(depths)
434 >>> amp_cd = bb.metrics.cum_drift(amps)
435 """
437 cd = np.sum(np.abs(np.diff(feat) / np.diff(times))) / len(feat)
438 return cd
441def pres_ratio(ts, hist_win=10):
442 """
443 Computes the presence ratio of spike counts: the number of bins where there is at least one
444 spike, over the total number of bins, given a specified bin width.
446 Parameters
447 ----------
448 ts : ndarray
449 The spike timestamps from which to compute the presence ratio.
450 hist_win : float (optional)
451 The time window (in s) to use for computing the presence ratio.
453 Returns
454 -------
455 pr : float
456 The presence ratio.
457 spks_bins : ndarray
458 The number of spks in each bin.
460 See Also
461 --------
462 plot.pres_ratio
464 Examples
465 --------
466 1) Compute the presence ratio for unit 1, given a window of 10 s.
467 >>> ts = units_b['times']['1']
468 >>> pr, pr_bins = bb.metrics.pres_ratio(ts)
469 """
471 bins = np.arange(0, ts[-1] + hist_win, hist_win)
472 spks_bins, _ = np.histogram(ts, bins)
473 pr = len(np.where(spks_bins)[0]) / len(spks_bins)
474 return pr, spks_bins
477def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, car=True):
478 """
479 For specified channels, for specified timestamps, computes the mean (peak-to-peak amplitudes /
480 the MADs of the background noise).
482 Parameters
483 ----------
484 ephys_file : string
485 The file path to the binary ephys data.
486 ts : ndarray_like
487 The timestamps (in s) of the spikes.
488 ch : ndarray_like
489 The channels on which to extract the waveforms.
490 t : numeric (optional)
491 The time (in ms) of the waveforms to extract to compute the ptp.
492 sr : int (optional)
493 The sampling rate (in hz) that the ephys data was acquired at.
494 n_ch_probe : int (optional)
495 The number of channels of the recording.
496 car: bool (optional)
497 A flag to perform common-average-referencing before extracting waveforms.
499 Returns
500 -------
501 ptp_sigma : ndarray
502 An array containing the mean ptp_over_noise values for the specified `ts` and `ch`.
504 Examples
505 --------
506 1) Compute ptp_over_noise for all spikes on 20 channels around the channel of max amplitude
507 for unit 1.
508 >>> ts = units_b['times']['1']
509 >>> max_ch = max_ch = clstrs_b['channels'][1]
510 >>> if max_ch < 10: # take only channels greater than `max_ch`.
511 >>> ch = np.arange(max_ch, max_ch + 20)
512 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`.
513 >>> ch = np.arange(max_ch - 20, max_ch)
514 >>> else: # take `n_c_ch` around `max_ch`.
515 >>> ch = np.arange(max_ch - 10, max_ch + 10)
516 >>> p = bb.metrics.ptp_over_noise(ephys_file, ts, ch)
517 """
519 # Ensure `ch` is ndarray
520 ch = np.asarray(ch)
521 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch
523 # Get waveforms.
524 wf = extract_waveforms(ephys_file, ts, ch, t=t, sr=sr, n_ch_probe=n_ch_probe, car=car)
526 # Initialize `mean_ptp` based on `ch`, and compute mean ptp of all spikes for each ch.
527 mean_ptp = np.zeros((ch.size,))
528 for cur_ch in range(ch.size, ):
529 mean_ptp[cur_ch] = np.mean(np.max(wf[:, :, cur_ch], axis=1) -
530 np.min(wf[:, :, cur_ch], axis=1))
532 # Compute MAD for `ch` in chunks.
533 with spikeglx.Reader(ephys_file) as s_reader:
534 file_m = s_reader.data # the memmapped array
535 n_chunk_samples = 5e6 # number of samples per chunk
536 n_chunks = np.ceil(file_m.shape[0] / n_chunk_samples).astype('int')
537 # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the
538 # samples that make up the first chunk.
539 chunk_sample = np.arange(0, file_m.shape[0], n_chunk_samples, dtype=int)
540 chunk_sample = np.append(chunk_sample, file_m.shape[0])
541 # Give time estimate for computing MAD.
542 t0 = time.perf_counter()
543 stats.median_absolute_deviation(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0)
544 dt = time.perf_counter() - t0
545 print('Performing MAD computation. Estimated time is {:.2f} mins.'
546 ' ({})'.format(dt * n_chunks / 60, time.ctime()))
547 # Compute MAD for each chunk, then take the median MAD of all chunks.
548 mad_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16)
549 for chunk in range(n_chunks):
550 mad_chunks[chunk, :] = stats.median_absolute_deviation(
551 file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0, scale=1)
552 print('Done. ({})'.format(time.ctime()))
554 # Return `mean_ptp` over `mad`
555 mad = np.median(mad_chunks, axis=0)
556 ptp_sigma = mean_ptp / mad
557 return ptp_sigma
560def contamination_alt(ts, rp=0.002):
561 """
562 An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on
563 the number of spikes, number of isi violations, and time between the first and last spike.
564 (see Hill et al. (2011) J Neurosci 31: 8699-8705).
566 Parameters
567 ----------
568 ts : ndarray_like
569 The timestamps (in s) of the spikes.
570 rp : float (optional)
571 The refractory period (in s).
573 Returns
574 -------
575 ce : float
576 An estimate of the fraction of contamination.
578 See Also
579 --------
580 contamination_alt
582 Examples
583 --------
584 1) Compute contamination estimate for unit 1.
585 >>> ts = units_b['times']['1']
586 >>> ce = bb.metrics.contamination(ts)
587 """
589 # Get number of spikes, number of isi violations, and time from first to final spike.
590 n_spks = ts.size
591 n_isi_viol = np.size(np.where(np.diff(ts) < rp)[0])
592 t = ts[-1] - ts[0]
594 # `ce` is min of roots of solved quadratic equation.
595 c = (t * n_isi_viol) / (2 * rp * n_spks ** 2) # 3rd term in quadratic
596 ce = np.min(np.abs(np.roots([-1, 1, c]))) # solve quadratic
597 return ce
600def contamination(ts, min_time, max_time, rp=0.002, min_isi=0.0001):
601 """
602 An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on
603 the number of spikes, number of isi violations, and time between the first and last spike.
604 (see Hill et al. (2011) J Neurosci 31: 8699-8705).
606 Modified by Dan Denman from cortex-lab/sortingQuality GitHub by Nick Steinmetz.
608 Parameters
609 ----------
610 ts : ndarray_like
611 The timestamps (in s) of the spikes.
612 min_time : float
613 The minimum time (in s) that a potential spike occurred.
614 max_time : float
615 The maximum time (in s) that a potential spike occurred.
616 rp : float (optional)
617 The refractory period (in s).
618 min_isi : float (optional)
619 The minimum interspike-interval (in s) for counting duplicate spikes.
621 Returns
622 -------
623 ce : float
624 An estimate of the contamination.
625 A perfect unit has a ce = 0
626 A unit with some contamination has a ce < 0.5
627 A unit with lots of contamination has a ce > 1.0
628 num_violations : int
629 The total number of isi violations.
631 See Also
632 --------
633 contamination
635 Examples
636 --------
637 1) Compute contamination estimate for unit 1, with a minimum isi for counting duplicate
638 spikes of 0.1 ms.
639 >>> ts = units_b['times']['1']
640 >>> ce = bb.metrics.contamination_alt(ts, min_isi=0.0001)
641 """
643 duplicate_spikes = np.where(np.diff(ts) <= min_isi)[0]
645 ts = np.delete(ts, duplicate_spikes + 1)
646 isis = np.diff(ts)
648 num_spikes = ts.size
649 num_violations = np.sum(isis < rp)
650 violation_time = 2 * num_spikes * (rp - min_isi)
651 total_rate = ts.size / (max_time - min_time)
652 violation_rate = num_violations / violation_time
653 ce = violation_rate / total_rate
655 return ce, num_violations
658def _max_acceptable_cont(FR, RP, rec_duration, acceptableCont, thresh):
659 """
660 Function to compute the maximum acceptable refractory period contamination
661 called during slidingRP_viol
662 """
664 time_for_viol = RP * 2 * FR * rec_duration
665 expected_count_for_acceptable_limit = acceptableCont * time_for_viol
666 max_acceptable = stats.poisson.ppf(thresh, expected_count_for_acceptable_limit)
667 if max_acceptable == 0 and stats.poisson.pmf(0, expected_count_for_acceptable_limit) > 0:
668 max_acceptable = -1
669 return max_acceptable
672def slidingRP_viol(ts, bin_size=0.25, thresh=0.1, acceptThresh=0.1):
673 """
674 A binary metric which determines whether there is an acceptable level of
675 refractory period violations by using a sliding refractory period:
677 This takes into account the firing rate of the neuron and computes a
678 maximum acceptable level of contamination at different possible values of
679 the refractory period. If the unit has less than the maximum contamination
680 at any of the possible values of the refractory period, the unit passes.
682 A neuron will always fail this metric for very low firing rates, and thus
683 this metric takes into account both firing rate and refractory period
684 violations.
687 Parameters
688 ----------
689 ts : ndarray_like
690 The timestamps (in s) of the spikes.
691 bin_size : float
692 The size of binning for the autocorrelogram.
693 thresh : float
694 Spike rate used to generate poisson distribution (to compute maximum
695 acceptable contamination, see _max_acceptable_cont)
696 acceptThresh : float
697 The fraction of contamination we are willing to accept (default value
698 set to 0.1, or 10% contamination)
700 Returns
701 -------
702 didpass : int
703 0 if unit didn't pass
704 1 if unit did pass
706 See Also
707 --------
708 contamination
710 Examples
711 --------
712 1) Compute whether a unit has too much refractory period contamination at
713 any possible value of a refractory period, for a 0.25 ms bin, with a
714 threshold of 10% acceptable contamination
715 >>> ts = units_b['times']['1']
716 >>> didpass = bb.metrics.slidingRP_viol(ts, bin_size=0.25, thresh=0.1,
717 acceptThresh=0.1)
718 """
720 b = np.arange(0, 10.25, bin_size) / 1000 + 1e-6 # bins in seconds
721 bTestIdx = [5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 32, 36, 40]
722 bTest = [b[i] for i in bTestIdx]
724 if len(ts) > 0 and ts[-1] > ts[0]: # only do this for units with samples
725 recDur = (ts[-1] - ts[0])
726 # compute acg
727 c0 = correlograms(ts, np.zeros(len(ts), dtype='int8'), cluster_ids=[0],
728 bin_size=bin_size / 1000, sample_rate=20000,
729 window_size=2,
730 symmetrize=False)
731 # cumulative sum of acg, i.e. number of total spikes occuring from 0
732 # to end of that bin
733 cumsumc0 = np.cumsum(c0[0, 0, :])
734 # cumulative sum at each of the testing bins
735 res = cumsumc0[bTestIdx]
736 total_spike_count = len(ts)
738 # divide each bin's count by the total spike count and the bin size
739 bin_count_normalized = c0[0, 0] / total_spike_count / bin_size * 1000
740 num_bins_2s = len(c0[0, 0]) # number of total bins that equal 2 secs
741 num_bins_1s = int(num_bins_2s / 2) # number of bins that equal 1 sec
742 # compute fr based on the mean of bin_count_normalized from 1 to 2 s
743 # instead of as before (len(ts)/recDur) for a better estimate
744 fr = np.sum(bin_count_normalized[num_bins_1s:num_bins_2s]) / num_bins_1s
745 mfunc = np.vectorize(_max_acceptable_cont)
746 # compute the maximum allowed number of spikes per testing bin
747 m = mfunc(fr, bTest, recDur, fr * acceptThresh, thresh)
748 # did the unit pass (resulting number of spikes less than maximum
749 # allowed spikes) at any of the testing bins?
750 didpass = int(np.any(np.less_equal(res, m)))
751 else:
752 didpass = 0
754 return didpass
757def noise_cutoff(amps, quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10):
758 """
759 A new metric to determine whether a unit's amplitude distribution is cut off
760 (at floor), without assuming a Gaussian distribution.
761 This metric takes the amplitude distribution, computes the mean and std
762 of an upper quartile of the distribution, and determines how many standard
763 deviations away from that mean a lower quartile lies.
764 Parameters
765 ----------
766 amps : ndarray_like
767 The amplitudes (in uV) of the spikes.
768 quantile_length : float
769 The size of the upper quartile of the amplitude distribution.
770 n_bins : int
771 The number of bins used to compute a histogram of the amplitude
772 distribution.
773 n_low_bins : int
774 The number of bins used in the lower part of the distribution (where
775 cutoff is determined).
776 nc_threshold: float
777 the noise cutoff result has to be lower than this for a neuron to fail
778 percent_threshold: float
779 the first bin has to be greater than percent_threshold for neuron the to fail
780 Returns
781 -------
782 cutoff : float
783 Number of standard deviations that the lower mean is outside of the
784 mean of the upper quartile.
785 See Also
786 --------
787 missed_spikes_est
788 Examples
789 --------
790 1) Compute whether a unit's amplitude distribution is cut off
791 >>> amps = spks_b['amps'][unit_idxs]
792 >>> cutoff = bb.metrics.noise_cutoff(amps, quantile_length=.25, n_bins=100)
793 """
794 cutoff = np.float64(np.nan)
795 first_low_quantile = np.float64(np.nan)
796 fail_criteria = np.ones(1).astype(bool)[0]
798 if amps.size > 1: # ensure there are amplitudes available to analyze
799 bins_list = np.linspace(0, np.max(amps), n_bins) # list of bins to compute the amplitude histogram
800 n, bins = np.histogram(amps, bins=bins_list) # construct amplitude histogram
801 idx_peak = np.argmax(n) # peak of amplitude distribution
802 # don't count zeros #len(n) - idx_peak, compute the length of the top half of the distribution -- ignoring zero bins
803 length_top_half = len(np.where(n[idx_peak:-1] > 0)[0])
804 # the remaining part of the distribution, which we will compare the low quantile to
805 high_quantile = 2 * quantile_length
806 # the first bin (index) of the high quantile part of the distribution
807 high_quantile_start_ind = int(np.ceil(high_quantile * length_top_half + idx_peak))
808 # bins to consider in the high quantile (of all non-zero bins)
809 indices_bins_high_quantile = np.arange(high_quantile_start_ind, len(n))
810 idx_use = np.where(n[indices_bins_high_quantile] >= 1)[0]
812 if len(n[indices_bins_high_quantile]) > 0: # ensure there are amplitudes in these bins
813 # mean of all amp values in high quantile bins
814 mean_high_quantile = np.mean(n[indices_bins_high_quantile][idx_use])
815 std_high_quantile = np.std(n[indices_bins_high_quantile][idx_use])
816 if std_high_quantile > 0:
817 first_low_quantile = n[(n != 0)][1] # take the second bin
818 cutoff = (first_low_quantile - mean_high_quantile) / std_high_quantile
819 peak_bin_height = np.max(n)
820 percent_of_peak = percent_threshold * peak_bin_height
822 fail_criteria = (cutoff > nc_threshold) & (first_low_quantile > percent_of_peak)
824 nc_pass = ~fail_criteria
825 return nc_pass, cutoff, first_low_quantile
828def spike_sorting_metrics(times, clusters, amps, depths, cluster_ids=None, params=METRICS_PARAMS):
829 """
830 Computes:
831 - cell level metrics (cf quick_unit_metrics)
832 - label the metrics according to quality thresholds
833 - estimates drift as a function of time
834 :param times: vector of spike times
835 :param clusters:
836 :param amplitudes:
837 :param depths:
838 :param cluster_ids (optional): set of clusters (if None the output datgrame will match
839 the unique set of clusters represented in spike clusters)
840 :param params: dict (optional) parameters for qc computation (
841 see constant at the top of the module for default values and keys)
842 :return: data_frame of metrics (cluster records, columns are qc attributes)|
843 :return: dictionary of recording qc (keys 'time_scale' and 'drift_um')
844 """
845 # compute metrics and convert to `DataFrame`
846 df_units = quick_unit_metrics(
847 clusters, times, amps, depths, cluster_ids=cluster_ids, params=params)
848 df_units = pd.DataFrame(df_units)
849 # compute drift as a function of time and put in a dictionary
850 drift, ts = electrode_drift.estimate_drift(times, amps, depths)
851 rec_qc = {'time_scale': ts, 'drift_um': drift}
852 return df_units, rec_qc
855def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
856 params=METRICS_PARAMS, cluster_ids=None, tbounds=None):
857 """
858 Computes single unit metrics from only the spike times, amplitudes, and
859 depths for a set of units.
861 Metrics computed:
862 'amp_max',
863 'amp_min',
864 'amp_median',
865 'amp_std_dB',
866 'contamination',
867 'contamination_alt',
868 'drift',
869 'missed_spikes_est',
870 'noise_cutoff',
871 'presence_ratio',
872 'presence_ratio_std',
873 'slidingRP_viol',
874 'spike_count'
876 Parameters (see the METRICS_PARAMS constant)
877 ----------
878 spike_clusters : ndarray_like
879 A vector of the unit ids for a set of spikes.
880 spike_times : ndarray_like
881 A vector of the timestamps for a set of spikes.
882 spike_amps : ndarray_like
883 A vector of the amplitudes for a set of spikes.
884 spike_depths : ndarray_like
885 A vector of the depths for a set of spikes.
886 clusters_id: (optional) lists of cluster ids. If not all clusters are represented in the
887 spikes_clusters (ie. cluster has no spike), this will ensure the output size is consistent
888 with the input arrays.
889 tbounds: (optional) list or 2 elements array containing a time-selection to perform the
890 metrics computation on.
891 params : dict (optional)
892 Parameters used for computing some of the metrics in the function:
893 'presence_window': float
894 The time window (in s) used to look for spikes when computing the presence ratio.
895 'refractory_period': float
896 The refractory period used when computing isi violations and the contamination
897 estimate.
898 'min_isi': float
899 The minimum interspike-interval (in s) for counting duplicate spikes when computing
900 the contamination estimate.
901 'spks_per_bin_for_missed_spks_est': int
902 The number of spikes per bin used to compute the spike amplitude pdf for a unit,
903 when computing the missed spikes estimate.
904 'std_smoothing_kernel_for_missed_spks_est': float
905 The standard deviation for the gaussian kernel used to compute the spike amplitude
906 pdf for a unit, when computing the missed spikes estimate.
907 'min_num_bins_for_missed_spks_est': int
908 The minimum number of bins used to compute the spike amplitude pdf for a unit,
909 when computing the missed spikes estimate.
911 Returns
912 -------
913 r : bunch
914 A bunch whose keys are the computed spike metrics.
916 Notes
917 -----
918 This function is called by `ephysqc.unit_metrics_ks2` which is called by `spikes.ks2_to_alf`
919 during alf extraction of an ephys dataset in the ibl ephys extraction pipeline.
921 Examples
922 --------
923 1) Compute quick metrics from a ks2 output directory:
924 >>> from ibllib.ephys.ephysqc import phy_model_from_ks2_path
925 >>> m = phy_model_from_ks2_path(path_to_ks2_out)
926 >>> cluster_ids = m.spike_clusters
927 >>> ts = m.spike_times
928 >>> amps = m.amplitudes
929 >>> depths = m.depths
930 >>> r = bb.metrics.quick_unit_metrics(cluster_ids, ts, amps, depths)
931 """
932 metrics_list = [
933 'cluster_id',
934 'amp_max',
935 'amp_min',
936 'amp_median',
937 'amp_std_dB',
938 'contamination',
939 'contamination_alt',
940 'drift',
941 'missed_spikes_est',
942 'noise_cutoff',
943 'presence_ratio',
944 'presence_ratio_std',
945 'slidingRP_viol',
946 'spike_count',
947 'slidingRP_viol_forced',
948 'max_confidence',
949 'min_contamination',
950 'n_spikes_below2'
951 ]
952 if tbounds:
953 ispi = between_sorted(spike_times, tbounds)
954 spike_times = spike_times[ispi]
955 spike_clusters = spike_clusters[ispi]
956 spike_amps = spike_amps[ispi]
957 spike_depths = spike_depths[ispi]
959 if cluster_ids is None:
960 cluster_ids = np.unique(spike_clusters)
961 nclust = cluster_ids.size
963 r = Bunch({k: np.full((nclust,), np.nan) for k in metrics_list})
964 r['cluster_id'] = cluster_ids
966 # vectorized computation of basic metrics such as presence ratio and firing rate
967 tmin = spike_times[0]
968 tmax = spike_times[-1]
969 presence_ratio = bincount2D(spike_times, spike_clusters,
970 xbin=params['presence_window'],
971 ybin=cluster_ids, xlim=[tmin, tmax])[0]
972 r.presence_ratio = np.sum(presence_ratio > 0, axis=1) / presence_ratio.shape[1]
973 r.presence_ratio_std = np.std(presence_ratio, axis=1)
974 r.spike_count = np.sum(presence_ratio, axis=1)
975 r.firing_rate = r.spike_count / (tmax - tmin)
977 # computing amplitude statistical indicators by aggregating over cluster id
978 camp = pd.DataFrame(np.c_[spike_amps, 20 * np.log10(spike_amps), spike_clusters],
979 columns=['amps', 'log_amps', 'clusters'])
980 camp = camp.groupby('clusters')
981 ir, ib = ismember(r.cluster_id, camp.clusters.unique())
982 r.amp_min[ir] = np.array(camp['amps'].min())
983 r.amp_max[ir] = np.array(camp['amps'].max())
984 # this is the geometric median
985 r.amp_median[ir] = np.array(10 ** (camp['log_amps'].median() / 20))
986 r.amp_std_dB[ir] = np.array(camp['log_amps'].std())
987 srp = metrics.slidingRP_all(spikeTimes=spike_times, spikeClusters=spike_clusters,
988 sampleRate=30000, binSizeCorr=1 / 30000)
989 r.slidingRP_viol[ir] = srp['value']
990 r.slidingRP_viol_forced[ir] = srp['value_forced']
991 r.max_confidence[ir] = srp['max_confidence']
992 r.min_contamination[ir] = srp['min_contamination']
993 r.n_spikes_below2[ir] = srp['n_spikes_below2']
995 # loop over each cluster to compute the rest of the metrics
996 for ic in np.arange(nclust):
997 # slice the spike_times array
998 ispikes = spike_clusters == cluster_ids[ic]
999 if np.all(~ispikes): # if this cluster has no spikes, continue
1000 continue
1001 ts = spike_times[ispikes]
1002 amps = spike_amps[ispikes]
1003 depths = spike_depths[ispikes]
1004 # compute metrics
1005 r.contamination_alt[ic] = contamination_alt(ts, rp=params['refractory_period'])
1006 r.contamination[ic], _ = contamination(
1007 ts, tmin, tmax, rp=params['refractory_period'], min_isi=params['min_isi'])
1008 _, r.noise_cutoff[ic], _ = noise_cutoff(amps, **params['noise_cutoff'])
1009 r.missed_spikes_est[ic], _, _ = missed_spikes_est(amps, **params['missed_spikes_est'])
1010 # wonder if there is a need to low-cut this
1011 r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600
1012 r.label, r.bitwise_fail = compute_labels(r, return_bitwise=True)
1013 return r
1016def compute_labels(r, params=METRICS_PARAMS, return_bitwise=False):
1017 """
1018 From a dataframe or a dictionary of unit metrics, compute a label
1019 :param r: dictionary or pandas dataframe containing unit qcs
1020 :param return_bitwise: True (returns a full dictionary of metrics)
1021 :return: vector of proportion of qcs passed between 0 and 1, where 1 denotes an all pass
1022 """
1023 # right now the score is a value between 0 and 1 denoting the proportion of passing qcs,
1024 # where 1 means passing and 0 means failing
1025 labels = np.c_[
1026 r['max_confidence'] >= params['RPmax_confidence'], # this is the least significant bit
1027 r.noise_cutoff < params['noise_cutoff']['nc_threshold'],
1028 r.amp_median > params['med_amp_thresh_uv'] / 1e6,
1029 # add a new metric here on higher significant bits
1030 ]
1031 # The first column takes binary values 001 or 000 to represent fail or pass,
1032 # the second, 010 or 000, the third, 100 or 000 etc.
1033 # The bitwise or "sum" produces 111 if all metrics fail, or 000 if all metrics pass
1034 # All other permutations are also captured, i.e. 110 == 000 || 010 || 100 means
1035 # the second and third metrics failed and the first metric was a pass
1036 score = np.mean(labels, axis=1)
1037 if return_bitwise:
1038 # note the cast to uint8 casts nan to 0
1039 # a nan implies no metrics was computed which we mark as a failure here
1040 n_criteria = labels.shape[1]
1041 bitwise = np.bitwise_or.reduce(2 ** np.arange(n_criteria) * (~ labels.astype(bool)).astype(np.uint8), axis=1)
1042 return score, bitwise.astype(np.uint8)
1043 else:
1044 return score