Coverage for brainbox/metrics/single_units.py: 53%
238 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1"""
2Computes metrics for assessing quality of single units.
4Run the following to set-up the workspace to run the docstring examples:
5>>> import brainbox as bb
6>>> import one.alf.io as aio
7>>> import numpy as np
8>>> import matplotlib.pyplot as plt
9>>> import ibllib.ephys.spikes as e_spks
10# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
11>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
12# Load the alf spikes bunch and clusters bunch, and get a units bunch.
13>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
14>>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters')
15>>> units_b = bb.processing.get_units_bunch(spks_b) # may take a few mins to compute
16"""
18import time
19import logging
21import numpy as np
22from scipy.ndimage import gaussian_filter1d
23import scipy.stats as stats
24import pandas as pd
26import spikeglx
27from phylib.stats import correlograms
28from iblutil.util import Bunch
29from iblutil.numerical import ismember, between_sorted, bincount2D
30from slidingRP import metrics
32from brainbox import singlecell
33from brainbox.io.spikeglx import extract_waveforms
34from brainbox.metrics import electrode_drift
37_logger = logging.getLogger('ibllib')
39# Parameters to be used in `quick_unit_metrics`
40METRICS_PARAMS = {
41 'noise_cutoff': dict(quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10),
42 'missed_spikes_est': dict(spks_per_bin=10, sigma=4, min_num_bins=50),
43 'acceptable_contamination': 0.1,
44 'bin_size': 0.25,
45 'med_amp_thresh_uv': 50,
46 'min_isi': 0.0001,
47 'presence_window': 10,
48 'refractory_period': 0.0015,
49 'RPslide_thresh': 0.1,
50}
53def unit_stability(units_b, units=None, feat_names=['amps'], dist='norm', test='ks'):
54 """
55 Computes the probability that the empirical spike feature distribution(s), for specified
56 feature(s), for all units, comes from a specific theoretical distribution, based on a specified
57 statistical test. Also computes the coefficients of variation of the spike feature(s) for all
58 units.
60 Parameters
61 ----------
62 units_b : bunch
63 A units bunch containing fields with spike information (e.g. cluster IDs, times, features,
64 etc.) for all units.
65 units : array-like (optional)
66 A subset of all units for which to create the bar plot. (If `None`, all units are used)
67 feat_names : list of strings (optional)
68 A list of names of spike features that can be found in `spks` to specify which features to
69 use for calculating unit stability.
70 dist : string (optional)
71 The type of hypothetical null distribution for which the empirical spike feature
72 distributions are presumed to belong to.
73 test : string (optional)
74 The statistical test used to compute the probability that the empirical spike feature
75 distributions come from `dist`.
77 Returns
78 -------
79 p_vals_b : bunch
80 A bunch with `feat_names` as keys, containing a ndarray with p-values (the probabilities
81 that the empirical spike feature distribution for each unit comes from `dist` based on
82 `test`) for each unit for all `feat_names`.
83 cv_b : bunch
84 A bunch with `feat_names` as keys, containing a ndarray with the coefficients of variation
85 of each unit's empirical spike feature distribution for all features.
87 See Also
88 --------
89 plot.feat_vars
91 Examples
92 --------
93 1) Compute 1) the p-values obtained from running a one-sample ks test on the spike amplitudes
94 for each unit, and 2) the variances of the empirical spike amplitudes distribution for each
95 unit. Create a histogram of the variances of the spike amplitudes for each unit, color-coded by
96 depth of channel of max amplitudes. Get cluster IDs of those units which have variances greater
97 than 50.
98 >>> p_vals_b, variances_b = bb.metrics.unit_stability(units_b)
99 # Plot histograms of variances color-coded by depth of channel of max amplitudes
100 >>> fig = bb.plot.feat_vars(units_b, feat_name='amps')
101 # Get all unit IDs which have amps variance > 50
102 >>> var_vals = np.array(tuple(variances_b['amps'].values()))
103 >>> bad_units = np.where(var_vals > 50)
104 """
106 # Get units.
107 if not (units is None): # we're using a subset of all units
108 unit_list = list(units_b[feat_names[0]].keys())
109 # for each `feat` and unit in `unit_list`, remove unit from `units_b` if not in `units`
110 for feat in feat_names:
111 [units_b[feat].pop(unit) for unit in unit_list if not (int(unit) in units)]
112 unit_list = list(units_b[feat_names[0]].keys()) # get new `unit_list` after removing units
114 # Initialize `p_vals` and `variances`.
115 p_vals_b = Bunch()
116 cv_b = Bunch()
118 # Set the test as a lambda function (in future, more tests can be added to this dict)
119 tests = \
120 {
121 'ks': lambda x, y: stats.kstest(x, y)
122 }
123 test_fun = tests[test]
125 # Compute the statistical tests and variances. For each feature, iteratively get each unit's
126 # p-values and variances, and add them as keys to the respective bunches `p_vals_feat` and
127 # `variances_feat`. After iterating through all units, add these bunches as keys to their
128 # respective parent bunches, `p_vals` and `variances`.
129 for feat in feat_names:
130 p_vals_feat = Bunch((unit, 0) for unit in unit_list)
131 cv_feat = Bunch((unit, 0) for unit in unit_list)
132 for unit in unit_list:
133 # If we're missing units/features, create a NaN placeholder and skip them:
134 if len(units_b['times'][str(unit)]) == 0:
135 p_val = np.nan
136 cv = np.nan
137 else:
138 # compute p_val and var for current feature
139 _, p_val = test_fun(units_b[feat][unit], dist)
140 cv = np.var(units_b[feat][unit]) / np.mean(units_b[feat][unit])
141 # Append current unit's values to list of units' values for current feature:
142 p_vals_feat[str(unit)] = p_val
143 cv_feat[str(unit)] = cv
144 p_vals_b[feat] = p_vals_feat
145 cv_b[feat] = cv_feat
147 return p_vals_b, cv_b
150def missed_spikes_est(feat, spks_per_bin=20, sigma=5, min_num_bins=50):
151 """
152 Computes the approximate fraction of spikes missing from a spike feature distribution for a
153 given unit, assuming the distribution is symmetric.
154 Inspired by metric described in Hill et al. (2011) J Neurosci 31: 8699-8705.
156 Parameters
157 ----------
158 feat : ndarray
159 The spikes' feature values (e.g. amplitudes)
160 spks_per_bin : int (optional)
161 The number of spikes per bin from which to compute the spike feature histogram.
162 sigma : int (optional)
163 The standard deviation for the gaussian kernel used to compute the pdf from the spike
164 feature histogram.
165 min_num_bins : int (optional)
166 The minimum number of bins used to compute the spike feature histogram.
168 Returns
169 -------
170 fraction_missing : float
171 The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an
172 accurate estimate isn't possible.
173 pdf : ndarray
174 The computed pdf of the spike feature histogram.
175 cutoff_idx : int
176 The index for `pdf` at which point `pdf` is no longer symmetrical around the peak. (This
177 is returned for plotting purposes).
179 See Also
180 --------
181 plot.feat_cutoff
182 Examples
183 --------
184 1) Determine the fraction of spikes missing from unit 1 based on the recorded unit's spike
185 amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric.
186 # Get unit 1 amplitudes from a unit bunch, and compute fraction spikes missing.
187 >>> feat = units_b['amps']['1']
188 >>> fraction_missing = bb.plot.feat_cutoff(feat)
189 """
191 # Ensure minimum number of spikes requirement is met, return Nan otherwise
192 if feat.size <= (spks_per_bin * min_num_bins): 1a
193 return np.nan, None, None 1a
195 # compute the spike feature histogram and pdf:
196 num_bins = int(feat.size / spks_per_bin) 1a
197 hist, bins = np.histogram(feat, num_bins, density=True) 1a
198 pdf = gaussian_filter1d(hist, sigma) 1a
200 # Find where the distribution stops being symmetric around the peak:
201 peak_idx = np.argmax(pdf) 1a
202 max_idx_sym_around_peak = np.argmin(np.abs(pdf[peak_idx:] - pdf[0])) 1a
203 cutoff_idx = peak_idx + max_idx_sym_around_peak 1a
205 # compute fraction missing from the tail of the pdf (the area where pdf stops being
206 # symmetric around peak).
207 fraction_missing = np.sum(pdf[cutoff_idx:]) / np.sum(pdf) 1a
208 fraction_missing = 0.5 if (fraction_missing > 0.5) else fraction_missing 1a
210 return fraction_missing, pdf, cutoff_idx 1a
213def wf_similarity(wf1, wf2):
214 """
215 Computes a unit normalized spatiotemporal similarity score between two sets of waveforms.
216 This score is based on how waveform shape correlates for each pair of spikes between the
217 two sets of waveforms across space and time. The shapes of the arrays of the two sets of
218 waveforms must be equal.
220 Parameters
221 ----------
222 wf1 : ndarray
223 An array of shape (#spikes, #samples, #channels).
224 wf2 : ndarray
225 An array of shape (#spikes, #samples, #channels).
227 Returns
228 -------
229 s: float
230 The unit normalized spatiotemporal similarity score.
232 See Also
233 --------
234 io.extract_waveforms
235 plot.single_unit_wf_comp
237 Examples
238 --------
239 1) Compute the similarity between the first and last 100 waveforms for unit1, across the 20
240 channels around the channel of max amplitude.
241 # Get the channels around the max amp channel for the unit, two sets of timestamps for the
242 # unit, and the two corresponding sets of waveforms for those two sets of timestamps.
243 # Then compute `s`.
244 >>> max_ch = clstrs_b['channels'][1]
245 >>> if max_ch < 10: # take only channels greater than `max_ch`.
246 >>> ch = np.arange(max_ch, max_ch + 20)
247 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`.
248 >>> ch = np.arange(max_ch - 20, max_ch)
249 >>> else: # take `n_c_ch` around `max_ch`.
250 >>> ch = np.arange(max_ch - 10, max_ch + 10)
251 >>> ts1 = units_b['times']['1'][:100]
252 >>> ts2 = units_b['times']['1'][-100:]
253 >>> wf1 = bb.io.extract_waveforms(path_to_ephys_file, ts1, ch)
254 >>> wf2 = bb.io.extract_waveforms(path_to_ephys_file, ts2, ch)
255 >>> s = bb.metrics.wf_similarity(wf1, wf2)
257 TODO check `s` calculation:
258 take median of waveforms
259 xcorr all waveforms with median, and divide by autocorr of all waveforms
260 profile
261 for two sets of units: xcorr(cl1, cl2) / (sqrt autocorr(cl1) * autocorr(cl2))
262 """
264 # Remove warning for dividing by 0 when calculating `s` (this is resolved by using
265 # `np.nan_to_num`)
266 import warnings
267 warnings.filterwarnings('ignore', r'invalid value encountered in true_divide')
268 assert wf1.shape == wf2.shape, ('The shapes of the sets of waveforms are inconsistent ({})'
269 '({})'.format(wf1.shape, wf2.shape))
271 # Get number of spikes, samples, and channels of waveforms.
272 n_spks = wf1.shape[0]
273 n_samples = wf1.shape[1]
274 n_ch = wf1.shape[2]
276 # Create a matrix that will hold the similarity values of each spike in `wf1` to `wf2`.
277 # Iterate over both sets of spikes, computing `s` for each pair.
278 similarity_matrix = np.zeros((n_spks, n_spks))
279 for spk1 in range(n_spks):
280 for spk2 in range(n_spks):
281 s_spk = \
282 np.sum(np.nan_to_num(
283 wf1[spk1, :, :] * wf2[spk2, :, :] /
284 np.sqrt(wf1[spk1, :, :] ** 2 * wf2[spk2, :, :] ** 2))) / (n_samples * n_ch)
285 similarity_matrix[spk1, spk2] = s_spk
287 # Return mean of similarity matrix
288 s = np.mean(similarity_matrix)
289 return s
292def firing_rate_coeff_var(ts, hist_win=0.01, fr_win=0.5, n_bins=10):
293 '''
294 Computes the coefficient of variation of the firing rate: the ratio of the standard
295 deviation to the mean.
297 Parameters
298 ----------
299 ts : ndarray
300 The spike timestamps from which to compute the firing rate.
301 hist_win : float (optional)
302 The time window (in s) to use for computing spike counts.
303 fr_win : float (optional)
304 The time window (in s) to use as a moving slider to compute the instantaneous firing rate.
305 n_bins : int (optional)
306 The number of bins in which to compute a coefficient of variation of the firing rate.
308 Returns
309 -------
310 cv : float
311 The mean coefficient of variation of the firing rate of the `n_bins` number of coefficients
312 computed.
313 cvs : ndarray
314 The coefficients of variation of the firing for each bin of `n_bins`.
315 fr : ndarray
316 The instantaneous firing rate over time (in hz).
318 See Also
319 --------
320 singlecell.firing_rate
321 plot.firing_rate
323 Examples
324 --------
325 1) Compute the coefficient of variation of the firing rate for unit 1 from the time of its
326 first to last spike, and compute the coefficient of variation of the firing rate for unit 2
327 from the first to second minute.
328 >>> ts_1 = units_b['times']['1']
329 >>> ts_2 = units_b['times']['2']
330 >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0])
331 >>> cv, cvs, fr = bb.metrics.firing_rate_coeff_var(ts_1)
332 >>> cv_2, cvs_2, fr_2 = bb.metrics.firing_rate_coeff_var(ts_2)
333 '''
335 # Compute overall instantaneous firing rate and firing rate for each bin.
336 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win)
337 bin_sz = int(fr.size / n_bins)
338 fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)])
340 # Compute coefficient of variations of firing rate for each bin, and the mean c.v.
341 cvs = np.std(fr_binned, axis=1) / np.mean(fr_binned, axis=1)
342 # NaNs from zero spikes are turned into 0's
343 # cvs[np.isnan(cvs)] = 0 nan's can happen if neuron doesn't spike in a bin
344 cv = np.mean(cvs)
346 return cv, cvs, fr
349def firing_rate_fano_factor(ts, hist_win=0.01, fr_win=0.5, n_bins=10):
350 '''
351 Computes the fano factor of the firing rate: the ratio of the variance to the mean.
352 (Almost identical to coeff. of variation)
354 Parameters
355 ----------
356 ts : ndarray
357 The spike timestamps from which to compute the firing rate.
358 hist_win : float
359 The time window (in s) to use for computing spike counts.
360 fr_win : float
361 The time window (in s) to use as a moving slider to compute the instantaneous firing rate.
362 n_bins : int (optional)
363 The number of bins in which to compute a fano factor of the firing rate.
365 Returns
366 -------
367 ff : float
368 The mean fano factor of the firing rate of the `n_bins` number of factors
369 computed.
370 ffs : ndarray
371 The fano factors of the firing for each bin of `n_bins`.
372 fr : ndarray
373 The instantaneous firing rate over time (in hz).
375 See Also
376 --------
377 singlecell.firing_rate
378 plot.firing_rate
380 Examples
381 --------
382 1) Compute the fano factor of the firing rate for unit 1 from the time of its
383 first to last spike, and compute the fano factor of the firing rate for unit 2
384 from the first to second minute.
385 >>> ts_1 = units_b['times']['1']
386 >>> ts_2 = units_b['times']['2']
387 >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0])
388 >>> ff, ffs, fr = bb.metrics.firing_rate_fano_factor(ts_1)
389 >>> ff_2, ffs_2, fr_2 = bb.metrics.firing_rate_fano_factor(ts_2)
390 '''
392 # Compute overall instantaneous firing rate and firing rate for each bin.
393 fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win)
394 # this procedure can cut off data at the end, up to n_bins last timesteps
395 bin_sz = int(fr.size / n_bins)
396 fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)])
398 # Compute fano factor of firing rate for each bin, and the mean fano factor
399 ffs = np.var(fr_binned, axis=1) / np.mean(fr_binned, axis=1)
400 # ffs[np.isnan(ffs)] = 0 nan's can happen if neuron doesn't spike in a bin
401 ff = np.mean(ffs)
403 return ff, ffs, fr
406def average_drift(feat, times):
407 """
408 Computes the cumulative drift (normalized by the total number of spikes) of a spike feature
409 array.
411 Parameters
412 ----------
413 feat : ndarray
414 The spike feature values from which to compute the maximum drift.
415 Usually amplitudes
417 Returns
418 -------
419 cd : float
420 The cumulative drift of the unit.
422 See Also
423 --------
424 max_drift
426 Examples
427 --------
428 1) Get the cumulative depth drift for unit 1.
429 >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0]
430 >>> depths = spks_b['depths'][unit_idxs]
431 >>> amps = spks_b['amps'][unit_idxs]
432 >>> depth_cd = bb.metrics.cum_drift(depths)
433 >>> amp_cd = bb.metrics.cum_drift(amps)
434 """
436 cd = np.sum(np.abs(np.diff(feat) / np.diff(times))) / len(feat)
437 return cd
440def pres_ratio(ts, hist_win=10):
441 """
442 Computes the presence ratio of spike counts: the number of bins where there is at least one
443 spike, over the total number of bins, given a specified bin width.
445 Parameters
446 ----------
447 ts : ndarray
448 The spike timestamps from which to compute the presence ratio.
449 hist_win : float (optional)
450 The time window (in s) to use for computing the presence ratio.
452 Returns
453 -------
454 pr : float
455 The presence ratio.
456 spks_bins : ndarray
457 The number of spks in each bin.
459 See Also
460 --------
461 plot.pres_ratio
463 Examples
464 --------
465 1) Compute the presence ratio for unit 1, given a window of 10 s.
466 >>> ts = units_b['times']['1']
467 >>> pr, pr_bins = bb.metrics.pres_ratio(ts)
468 """
470 bins = np.arange(0, ts[-1] + hist_win, hist_win)
471 spks_bins, _ = np.histogram(ts, bins)
472 pr = len(np.where(spks_bins)[0]) / len(spks_bins)
473 return pr, spks_bins
476def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, car=True):
477 """
478 For specified channels, for specified timestamps, computes the mean (peak-to-peak amplitudes /
479 the MADs of the background noise).
481 Parameters
482 ----------
483 ephys_file : string
484 The file path to the binary ephys data.
485 ts : ndarray_like
486 The timestamps (in s) of the spikes.
487 ch : ndarray_like
488 The channels on which to extract the waveforms.
489 t : numeric (optional)
490 The time (in ms) of the waveforms to extract to compute the ptp.
491 sr : int (optional)
492 The sampling rate (in hz) that the ephys data was acquired at.
493 n_ch_probe : int (optional)
494 The number of channels of the recording.
495 car: bool (optional)
496 A flag to perform common-average-referencing before extracting waveforms.
498 Returns
499 -------
500 ptp_sigma : ndarray
501 An array containing the mean ptp_over_noise values for the specified `ts` and `ch`.
503 Examples
504 --------
505 1) Compute ptp_over_noise for all spikes on 20 channels around the channel of max amplitude
506 for unit 1.
507 >>> ts = units_b['times']['1']
508 >>> max_ch = max_ch = clstrs_b['channels'][1]
509 >>> if max_ch < 10: # take only channels greater than `max_ch`.
510 >>> ch = np.arange(max_ch, max_ch + 20)
511 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`.
512 >>> ch = np.arange(max_ch - 20, max_ch)
513 >>> else: # take `n_c_ch` around `max_ch`.
514 >>> ch = np.arange(max_ch - 10, max_ch + 10)
515 >>> p = bb.metrics.ptp_over_noise(ephys_file, ts, ch)
516 """
518 # Ensure `ch` is ndarray
519 ch = np.asarray(ch)
520 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch
522 # Get waveforms.
523 wf = extract_waveforms(ephys_file, ts, ch, t=t, sr=sr, n_ch_probe=n_ch_probe, car=car)
525 # Initialize `mean_ptp` based on `ch`, and compute mean ptp of all spikes for each ch.
526 mean_ptp = np.zeros((ch.size,))
527 for cur_ch in range(ch.size, ):
528 mean_ptp[cur_ch] = np.mean(np.max(wf[:, :, cur_ch], axis=1) -
529 np.min(wf[:, :, cur_ch], axis=1))
531 # Compute MAD for `ch` in chunks.
532 with spikeglx.Reader(ephys_file) as s_reader:
533 file_m = s_reader.data # the memmapped array
534 n_chunk_samples = 5e6 # number of samples per chunk
535 n_chunks = np.ceil(file_m.shape[0] / n_chunk_samples).astype('int')
536 # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the
537 # samples that make up the first chunk.
538 chunk_sample = np.arange(0, file_m.shape[0], n_chunk_samples, dtype=int)
539 chunk_sample = np.append(chunk_sample, file_m.shape[0])
540 # Give time estimate for computing MAD.
541 t0 = time.perf_counter()
542 stats.median_absolute_deviation(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0)
543 dt = time.perf_counter() - t0
544 print('Performing MAD computation. Estimated time is {:.2f} mins.'
545 ' ({})'.format(dt * n_chunks / 60, time.ctime()))
546 # Compute MAD for each chunk, then take the median MAD of all chunks.
547 mad_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16)
548 for chunk in range(n_chunks):
549 mad_chunks[chunk, :] = stats.median_absolute_deviation(
550 file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0, scale=1)
551 print('Done. ({})'.format(time.ctime()))
553 # Return `mean_ptp` over `mad`
554 mad = np.median(mad_chunks, axis=0)
555 ptp_sigma = mean_ptp / mad
556 return ptp_sigma
559def contamination_alt(ts, rp=0.002):
560 """
561 An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on
562 the number of spikes, number of isi violations, and time between the first and last spike.
563 (see Hill et al. (2011) J Neurosci 31: 8699-8705).
565 Parameters
566 ----------
567 ts : ndarray_like
568 The timestamps (in s) of the spikes.
569 rp : float (optional)
570 The refractory period (in s).
572 Returns
573 -------
574 ce : float
575 An estimate of the fraction of contamination.
577 See Also
578 --------
579 contamination_alt
581 Examples
582 --------
583 1) Compute contamination estimate for unit 1.
584 >>> ts = units_b['times']['1']
585 >>> ce = bb.metrics.contamination(ts)
586 """
588 # Get number of spikes, number of isi violations, and time from first to final spike.
589 n_spks = ts.size 1a
590 n_isi_viol = np.size(np.where(np.diff(ts) < rp)[0]) 1a
591 t = ts[-1] - ts[0] 1a
593 # `ce` is min of roots of solved quadratic equation.
594 c = (t * n_isi_viol) / (2 * rp * n_spks ** 2) # 3rd term in quadratic 1a
595 ce = np.min(np.abs(np.roots([-1, 1, c]))) # solve quadratic 1a
596 return ce 1a
599def contamination(ts, min_time, max_time, rp=0.002, min_isi=0.0001):
600 """
601 An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on
602 the number of spikes, number of isi violations, and time between the first and last spike.
603 (see Hill et al. (2011) J Neurosci 31: 8699-8705).
605 Modified by Dan Denman from cortex-lab/sortingQuality GitHub by Nick Steinmetz.
607 Parameters
608 ----------
609 ts : ndarray_like
610 The timestamps (in s) of the spikes.
611 min_time : float
612 The minimum time (in s) that a potential spike occurred.
613 max_time : float
614 The maximum time (in s) that a potential spike occurred.
615 rp : float (optional)
616 The refractory period (in s).
617 min_isi : float (optional)
618 The minimum interspike-interval (in s) for counting duplicate spikes.
620 Returns
621 -------
622 ce : float
623 An estimate of the contamination.
624 A perfect unit has a ce = 0
625 A unit with some contamination has a ce < 0.5
626 A unit with lots of contamination has a ce > 1.0
627 num_violations : int
628 The total number of isi violations.
630 See Also
631 --------
632 contamination
634 Examples
635 --------
636 1) Compute contamination estimate for unit 1, with a minimum isi for counting duplicate
637 spikes of 0.1 ms.
638 >>> ts = units_b['times']['1']
639 >>> ce = bb.metrics.contamination_alt(ts, min_isi=0.0001)
640 """
642 duplicate_spikes = np.where(np.diff(ts) <= min_isi)[0] 1a
644 ts = np.delete(ts, duplicate_spikes + 1) 1a
645 isis = np.diff(ts) 1a
647 num_spikes = ts.size 1a
648 num_violations = np.sum(isis < rp) 1a
649 violation_time = 2 * num_spikes * (rp - min_isi) 1a
650 total_rate = ts.size / (max_time - min_time) 1a
651 violation_rate = num_violations / violation_time 1a
652 ce = violation_rate / total_rate 1a
654 return ce, num_violations 1a
657def _max_acceptable_cont(FR, RP, rec_duration, acceptableCont, thresh):
658 """
659 Function to compute the maximum acceptable refractory period contamination
660 called during slidingRP_viol
661 """
663 time_for_viol = RP * 2 * FR * rec_duration
664 expected_count_for_acceptable_limit = acceptableCont * time_for_viol
665 max_acceptable = stats.poisson.ppf(thresh, expected_count_for_acceptable_limit)
666 if max_acceptable == 0 and stats.poisson.pmf(0, expected_count_for_acceptable_limit) > 0:
667 max_acceptable = -1
668 return max_acceptable
671def slidingRP_viol(ts, bin_size=0.25, thresh=0.1, acceptThresh=0.1):
672 """
673 A binary metric which determines whether there is an acceptable level of
674 refractory period violations by using a sliding refractory period:
676 This takes into account the firing rate of the neuron and computes a
677 maximum acceptable level of contamination at different possible values of
678 the refractory period. If the unit has less than the maximum contamination
679 at any of the possible values of the refractory period, the unit passes.
681 A neuron will always fail this metric for very low firing rates, and thus
682 this metric takes into account both firing rate and refractory period
683 violations.
686 Parameters
687 ----------
688 ts : ndarray_like
689 The timestamps (in s) of the spikes.
690 bin_size : float
691 The size of binning for the autocorrelogram.
692 thresh : float
693 Spike rate used to generate poisson distribution (to compute maximum
694 acceptable contamination, see _max_acceptable_cont)
695 acceptThresh : float
696 The fraction of contamination we are willing to accept (default value
697 set to 0.1, or 10% contamination)
699 Returns
700 -------
701 didpass : int
702 0 if unit didn't pass
703 1 if unit did pass
705 See Also
706 --------
707 contamination
709 Examples
710 --------
711 1) Compute whether a unit has too much refractory period contamination at
712 any possible value of a refractory period, for a 0.25 ms bin, with a
713 threshold of 10% acceptable contamination
714 >>> ts = units_b['times']['1']
715 >>> didpass = bb.metrics.slidingRP_viol(ts, bin_size=0.25, thresh=0.1,
716 acceptThresh=0.1)
717 """
719 b = np.arange(0, 10.25, bin_size) / 1000 + 1e-6 # bins in seconds
720 bTestIdx = [5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 32, 36, 40]
721 bTest = [b[i] for i in bTestIdx]
723 if len(ts) > 0 and ts[-1] > ts[0]: # only do this for units with samples
724 recDur = (ts[-1] - ts[0])
725 # compute acg
726 c0 = correlograms(ts, np.zeros(len(ts), dtype='int8'), cluster_ids=[0],
727 bin_size=bin_size / 1000, sample_rate=20000,
728 window_size=2,
729 symmetrize=False)
730 # cumulative sum of acg, i.e. number of total spikes occuring from 0
731 # to end of that bin
732 cumsumc0 = np.cumsum(c0[0, 0, :])
733 # cumulative sum at each of the testing bins
734 res = cumsumc0[bTestIdx]
735 total_spike_count = len(ts)
737 # divide each bin's count by the total spike count and the bin size
738 bin_count_normalized = c0[0, 0] / total_spike_count / bin_size * 1000
739 num_bins_2s = len(c0[0, 0]) # number of total bins that equal 2 secs
740 num_bins_1s = int(num_bins_2s / 2) # number of bins that equal 1 sec
741 # compute fr based on the mean of bin_count_normalized from 1 to 2 s
742 # instead of as before (len(ts)/recDur) for a better estimate
743 fr = np.sum(bin_count_normalized[num_bins_1s:num_bins_2s]) / num_bins_1s
744 mfunc = np.vectorize(_max_acceptable_cont)
745 # compute the maximum allowed number of spikes per testing bin
746 m = mfunc(fr, bTest, recDur, fr * acceptThresh, thresh)
747 # did the unit pass (resulting number of spikes less than maximum
748 # allowed spikes) at any of the testing bins?
749 didpass = int(np.any(np.less_equal(res, m)))
750 else:
751 didpass = 0
753 return didpass
756def noise_cutoff(amps, quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10):
757 """
758 A new metric to determine whether a unit's amplitude distribution is cut off
759 (at floor), without assuming a Gaussian distribution.
760 This metric takes the amplitude distribution, computes the mean and std
761 of an upper quartile of the distribution, and determines how many standard
762 deviations away from that mean a lower quartile lies.
763 Parameters
764 ----------
765 amps : ndarray_like
766 The amplitudes (in uV) of the spikes.
767 quantile_length : float
768 The size of the upper quartile of the amplitude distribution.
769 n_bins : int
770 The number of bins used to compute a histogram of the amplitude
771 distribution.
772 n_low_bins : int
773 The number of bins used in the lower part of the distribution (where
774 cutoff is determined).
775 nc_threshold: float
776 the noise cutoff result has to be lower than this for a neuron to fail
777 percent_threshold: float
778 the first bin has to be greater than percent_threshold for neuron the to fail
779 Returns
780 -------
781 cutoff : float
782 Number of standard deviations that the lower mean is outside of the
783 mean of the upper quartile.
784 See Also
785 --------
786 missed_spikes_est
787 Examples
788 --------
789 1) Compute whether a unit's amplitude distribution is cut off
790 >>> amps = spks_b['amps'][unit_idxs]
791 >>> cutoff = bb.metrics.noise_cutoff(amps, quantile_length=.25, n_bins=100)
792 """
793 cutoff = np.float64(np.nan) 1a
794 first_low_quantile = np.float64(np.nan) 1a
795 fail_criteria = np.ones(1).astype(bool)[0] 1a
797 if amps.size > 1: # ensure there are amplitudes available to analyze 1a
798 bins_list = np.linspace(0, np.max(amps), n_bins) # list of bins to compute the amplitude histogram 1a
799 n, bins = np.histogram(amps, bins=bins_list) # construct amplitude histogram 1a
800 idx_peak = np.argmax(n) # peak of amplitude distribution 1a
801 # don't count zeros #len(n) - idx_peak, compute the length of the top half of the distribution -- ignoring zero bins
802 length_top_half = len(np.where(n[idx_peak:-1] > 0)[0]) 1a
803 # the remaining part of the distribution, which we will compare the low quantile to
804 high_quantile = 2 * quantile_length 1a
805 # the first bin (index) of the high quantile part of the distribution
806 high_quantile_start_ind = int(np.ceil(high_quantile * length_top_half + idx_peak)) 1a
807 # bins to consider in the high quantile (of all non-zero bins)
808 indices_bins_high_quantile = np.arange(high_quantile_start_ind, len(n)) 1a
809 idx_use = np.where(n[indices_bins_high_quantile] >= 1)[0] 1a
811 if len(n[indices_bins_high_quantile]) > 0: # ensure there are amplitudes in these bins 1a
812 # mean of all amp values in high quantile bins
813 mean_high_quantile = np.mean(n[indices_bins_high_quantile][idx_use]) 1a
814 std_high_quantile = np.std(n[indices_bins_high_quantile][idx_use]) 1a
815 if std_high_quantile > 0: 1a
816 first_low_quantile = n[(n != 0)][1] # take the second bin 1a
817 cutoff = (first_low_quantile - mean_high_quantile) / std_high_quantile 1a
818 peak_bin_height = np.max(n) 1a
819 percent_of_peak = percent_threshold * peak_bin_height 1a
821 fail_criteria = (cutoff > nc_threshold) & (first_low_quantile > percent_of_peak) 1a
823 nc_pass = ~fail_criteria 1a
824 return nc_pass, cutoff, first_low_quantile 1a
827def spike_sorting_metrics(times, clusters, amps, depths, cluster_ids=None, params=METRICS_PARAMS):
828 """
829 Computes:
830 - cell level metrics (cf quick_unit_metrics)
831 - label the metrics according to quality thresholds
832 - estimates drift as a function of time
833 :param times: vector of spike times
834 :param clusters:
835 :param amplitudes:
836 :param depths:
837 :param cluster_ids (optional): set of clusters (if None the output datgrame will match
838 the unique set of clusters represented in spike clusters)
839 :param params: dict (optional) parameters for qc computation (
840 see constant at the top of the module for default values and keys)
841 :return: data_frame of metrics (cluster records, columns are qc attributes)|
842 :return: dictionary of recording qc (keys 'time_scale' and 'drift_um')
843 """
844 # compute metrics and convert to `DataFrame`
845 df_units = quick_unit_metrics( 1a
846 clusters, times, amps, depths, cluster_ids=cluster_ids, params=params)
847 df_units = pd.DataFrame(df_units) 1a
848 # compute drift as a function of time and put in a dictionary
849 drift, ts = electrode_drift.estimate_drift(times, amps, depths) 1a
850 rec_qc = {'time_scale': ts, 'drift_um': drift} 1a
851 return df_units, rec_qc 1a
854def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
855 params=METRICS_PARAMS, cluster_ids=None, tbounds=None):
856 """
857 Computes single unit metrics from only the spike times, amplitudes, and
858 depths for a set of units.
860 Metrics computed:
861 'amp_max',
862 'amp_min',
863 'amp_median',
864 'amp_std_dB',
865 'contamination',
866 'contamination_alt',
867 'drift',
868 'missed_spikes_est',
869 'noise_cutoff',
870 'presence_ratio',
871 'presence_ratio_std',
872 'slidingRP_viol',
873 'spike_count'
875 Parameters (see the METRICS_PARAMS constant)
876 ----------
877 spike_clusters : ndarray_like
878 A vector of the unit ids for a set of spikes.
879 spike_times : ndarray_like
880 A vector of the timestamps for a set of spikes.
881 spike_amps : ndarray_like
882 A vector of the amplitudes for a set of spikes.
883 spike_depths : ndarray_like
884 A vector of the depths for a set of spikes.
885 clusters_id: (optional) lists of cluster ids. If not all clusters are represented in the
886 spikes_clusters (ie. cluster has no spike), this will ensure the output size is consistent
887 with the input arrays.
888 tbounds: (optional) list or 2 elements array containing a time-selection to perform the
889 metrics computation on.
890 params : dict (optional)
891 Parameters used for computing some of the metrics in the function:
892 'presence_window': float
893 The time window (in s) used to look for spikes when computing the presence ratio.
894 'refractory_period': float
895 The refractory period used when computing isi violations and the contamination
896 estimate.
897 'min_isi': float
898 The minimum interspike-interval (in s) for counting duplicate spikes when computing
899 the contamination estimate.
900 'spks_per_bin_for_missed_spks_est': int
901 The number of spikes per bin used to compute the spike amplitude pdf for a unit,
902 when computing the missed spikes estimate.
903 'std_smoothing_kernel_for_missed_spks_est': float
904 The standard deviation for the gaussian kernel used to compute the spike amplitude
905 pdf for a unit, when computing the missed spikes estimate.
906 'min_num_bins_for_missed_spks_est': int
907 The minimum number of bins used to compute the spike amplitude pdf for a unit,
908 when computing the missed spikes estimate.
910 Returns
911 -------
912 r : bunch
913 A bunch whose keys are the computed spike metrics.
915 Notes
916 -----
917 This function is called by `ephysqc.unit_metrics_ks2` which is called by `spikes.ks2_to_alf`
918 during alf extraction of an ephys dataset in the ibl ephys extraction pipeline.
920 Examples
921 --------
922 1) Compute quick metrics from a ks2 output directory:
923 >>> from ibllib.ephys.ephysqc import phy_model_from_ks2_path
924 >>> m = phy_model_from_ks2_path(path_to_ks2_out)
925 >>> cluster_ids = m.spike_clusters
926 >>> ts = m.spike_times
927 >>> amps = m.amplitudes
928 >>> depths = m.depths
929 >>> r = bb.metrics.quick_unit_metrics(cluster_ids, ts, amps, depths)
930 """
931 metrics_list = [ 1a
932 'cluster_id',
933 'amp_max',
934 'amp_min',
935 'amp_median',
936 'amp_std_dB',
937 'contamination',
938 'contamination_alt',
939 'drift',
940 'missed_spikes_est',
941 'noise_cutoff',
942 'presence_ratio',
943 'presence_ratio_std',
944 'slidingRP_viol',
945 'spike_count'
946 ]
947 if tbounds: 1a
948 ispi = between_sorted(spike_times, tbounds)
949 spike_times = spike_times[ispi]
950 spike_clusters = spike_clusters[ispi]
951 spike_amps = spike_amps[ispi]
952 spike_depths = spike_depths[ispi]
954 if cluster_ids is None: 1a
955 cluster_ids = np.unique(spike_clusters)
956 nclust = cluster_ids.size 1a
958 r = Bunch({k: np.full((nclust,), np.nan) for k in metrics_list}) 1a
959 r['cluster_id'] = cluster_ids 1a
961 # vectorized computation of basic metrics such as presence ratio and firing rate
962 tmin = spike_times[0] 1a
963 tmax = spike_times[-1] 1a
964 presence_ratio = bincount2D(spike_times, spike_clusters, 1a
965 xbin=params['presence_window'],
966 ybin=cluster_ids, xlim=[tmin, tmax])[0]
967 r.presence_ratio = np.sum(presence_ratio > 0, axis=1) / presence_ratio.shape[1] 1a
968 r.presence_ratio_std = np.std(presence_ratio, axis=1) 1a
969 r.spike_count = np.sum(presence_ratio, axis=1) 1a
970 r.firing_rate = r.spike_count / (tmax - tmin) 1a
972 # computing amplitude statistical indicators by aggregating over cluster id
973 camp = pd.DataFrame(np.c_[spike_amps, 20 * np.log10(spike_amps), spike_clusters], 1a
974 columns=['amps', 'log_amps', 'clusters'])
975 camp = camp.groupby('clusters') 1a
976 ir, ib = ismember(r.cluster_id, camp.clusters.unique()) 1a
977 r.amp_min[ir] = np.array(camp['amps'].min()) 1a
978 r.amp_max[ir] = np.array(camp['amps'].max()) 1a
979 # this is the geometric median
980 r.amp_median[ir] = np.array(10 ** (camp['log_amps'].median() / 20)) 1a
981 r.amp_std_dB[ir] = np.array(camp['log_amps'].std()) 1a
982 srp = metrics.slidingRP_all(spikeTimes=spike_times, spikeClusters=spike_clusters, 1a
983 **{'sampleRate': 30000, 'binSizeCorr': 1 / 30000})
984 r.slidingRP_viol[srp['cidx']] = srp['value'] 1a
986 # loop over each cluster to compute the rest of the metrics
987 for ic in np.arange(nclust): 1a
988 # slice the spike_times array
989 ispikes = spike_clusters == cluster_ids[ic] 1a
990 if np.all(~ispikes): # if this cluster has no spikes, continue 1a
991 continue
992 ts = spike_times[ispikes] 1a
993 amps = spike_amps[ispikes] 1a
994 depths = spike_depths[ispikes] 1a
995 # compute metrics
996 r.contamination_alt[ic] = contamination_alt(ts, rp=params['refractory_period']) 1a
997 r.contamination[ic], _ = contamination( 1a
998 ts, tmin, tmax, rp=params['refractory_period'], min_isi=params['min_isi'])
999 _, r.noise_cutoff[ic], _ = noise_cutoff(amps, **params['noise_cutoff']) 1a
1000 r.missed_spikes_est[ic], _, _ = missed_spikes_est(amps, **params['missed_spikes_est']) 1a
1001 # wonder if there is a need to low-cut this
1002 r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600 1a
1004 r.label = compute_labels(r) 1a
1005 return r 1a
1008def compute_labels(r, params=METRICS_PARAMS, return_details=False):
1009 """
1010 From a dataframe or a dictionary of unit metrics, compute a lablel
1011 :param r: dictionary or pandas dataframe containing unit qcs
1012 :param return_details: False (returns a full dictionary of metrics)
1013 :return: vector of proportion of qcs passed between 0 and 1, where 1 denotes an all pass
1014 """
1015 # right now the score is a value between 0 and 1 denoting the proportion of passing qcs
1016 # we could eventually do a bitwise qc
1017 labels = np.c_[ 1a
1018 r.slidingRP_viol,
1019 r.noise_cutoff < params['noise_cutoff']['nc_threshold'],
1020 r.amp_median > params['med_amp_thresh_uv'] / 1e6,
1021 ]
1022 if not return_details: 1a
1023 return np.mean(labels, axis=1) 1a
1024 column_names = ['slidingRP_viol', 'noise_cutoff', 'amp_median']
1025 qcdict = {}
1026 for c in np.arange(labels.shape[1]):
1027 qcdict[column_names[c]] = labels[:, c]
1028 return np.mean(labels, axis=1), qcdict