Coverage for brainbox/processing.py: 45%
96 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""Process data from one form into another.
3For example, taking spike times and binning them into non-overlapping bins and convolving spike
4times with a gaussian kernel.
5"""
7import numpy as np
8import pandas as pd
9from scipy import interpolate, sparse
10from brainbox import core
11from iblutil.numerical import bincount2D
12from iblutil.util import Bunch
13import logging
15_logger = logging.getLogger(__name__)
18def sync(dt, times=None, values=None, timeseries=None, offsets=None, interp='zero',
19 fillval=np.nan):
20 """
21 Function for resampling a single or multiple time series to a single, evenly-spaced, delta t
22 between observations. Uses interpolation to find values.
24 Can be used on raw numpy arrays of timestamps and values using the 'times' and 'values' kwargs
25 and/or on brainbox.core.TimeSeries objects passed to the 'timeseries' kwarg. If passing both
26 TimeSeries objects and numpy arrays, the offsets passed should be for the TS objects first and
27 then the numpy arrays.
29 Uses scipy's interpolation library to perform interpolation.
30 See scipy.interp1d for more information regarding interp and fillval parameters.
32 :param dt: Separation of points which the output timeseries will be sampled at
33 :type dt: float
34 :param timeseries: A group of time series to perform alignment or a single time series.
35 Must have time stamps.
36 :type timeseries: tuple of TimeSeries objects, or a single TimeSeries object.
37 :param times: time stamps for the observations in 'values']
38 :type times: np.ndarray or list of np.ndarrays
39 :param values: observations corresponding to the timestamps in 'times'
40 :type values: np.ndarray or list of np.ndarrays
41 :param offsets: tuple of offsets for time stamps of each time series. Offsets for passed
42 TimeSeries objects first, then offsets for passed numpy arrays. defaults to None
43 :type offsets: tuple of floats, optional
44 :param interp: Type of interpolation to use. Refer to scipy.interpolate.interp1d for possible
45 values, defaults to np.nan
46 :type interp: str
47 :param fillval: Fill values to use when interpolating outside of range of data. See interp1d
48 for possible values, defaults to np.nan
49 :return: TimeSeries object with each row representing synchronized values of all
50 input TimeSeries. Will carry column names from input time series if all of them have column
51 names.
52 """
53 #########################################
54 # Checks on inputs and input processing #
55 #########################################
57 # Initialize a list to contain times/values pairs if no TS objs are passed
58 if timeseries is None: 1a
59 timeseries = [] 1a
60 # If a single time series is passed for resampling, wrap it in an iterable
61 elif isinstance(timeseries, core.TimeSeries): 1a
62 timeseries = [timeseries] 1a
63 # Yell at the user if they try to pass stuff to timeseries that isn't a TimeSeries object
64 elif not all([isinstance(ts, core.TimeSeries) for ts in timeseries]): 1a
65 raise TypeError('All elements of \'timeseries\' argument must be brainbox.core.TimeSeries '
66 'objects. Please uses \'times\' and \'values\' for np.ndarray args.')
67 # Check that if something is passed to times or values, there is a corresponding equal-length
68 # argument for the other element.
69 if (times is not None) or (values is not None): 1a
70 if len(times) != len(values): 1a
71 raise ValueError('\'times\' and \'values\' must have the same number of elements.')
72 if type(times[0]) is np.ndarray: 1a
73 if not all([t.shape == v.shape for t, v in zip(times, values)]): 1a
74 raise ValueError('All arrays in \'times\' must match the shape of the'
75 ' corresponding entry in \'values\'.')
76 # If all checks are passed, convert all times and values args into TimeSeries objects
77 timeseries.extend([core.TimeSeries(t, v) for t, v in zip(times, values)]) 1a
78 else:
79 # If times and values are only numpy arrays and lists of arrays, pair them and add
80 timeseries.append(core.TimeSeries(times, values)) 1a
82 # Adjust each timeseries by the associated offset if necessary then load into a list
83 if offsets is not None: 1a
84 tstamps = [ts.times + os for ts, os in zip(timeseries, offsets)]
85 else:
86 tstamps = [ts.times for ts in timeseries] 1a
87 # If all input timeseries have column names, put them together for the output TS
88 if all([ts.columns is not None for ts in timeseries]): 1a
89 colnames = [] 1a
90 for ts in timeseries: 1a
91 colnames.extend(ts.columns) 1a
92 else:
93 colnames = None 1a
95 #################
96 # Main function #
97 #################
99 # Get the min and max values for all timeseries combined after offsetting
100 tbounds = np.array([(np.amin(ts), np.amax(ts)) for ts in tstamps]) 1a
101 if not np.all(np.isfinite(tbounds)): 1a
102 # If there is a np.inf or np.nan in the time stamps for any of the timeseries this will
103 # break any further code so we check for all finite values and throw an informative error.
104 raise ValueError('NaN or inf encountered in passed timeseries.\
105 Please either drop or fill these values.')
106 tmin, tmax = np.amin(tbounds[:, 0]), np.amax(tbounds[:, 1]) 1a
107 if fillval == 'extrapolate': 1a
108 # If extrapolation is enabled we can ensure we have a full coverage of the data by
109 # extending the t max to be an whole integer multiple of dt above tmin.
110 # The 0.01% fudge factor is to account for floating point arithmetic errors.
111 newt = np.arange(tmin, tmax + 1.0001 * (dt - (tmax - tmin) % dt), dt) 1a
112 else:
113 newt = np.arange(tmin, tmax, dt)
114 tsinterps = [interpolate.interp1d(ts.times, ts.values, kind=interp, fill_value=fillval, axis=0) 1a
115 for ts in timeseries]
116 syncd = core.TimeSeries(newt, np.hstack([tsi(newt) for tsi in tsinterps]), columns=colnames) 1a
117 return syncd 1a
120def compute_cluster_average(spike_clusters, spike_var):
121 """
122 Quickish way to compute the average of some quantity across spikes in each cluster given
123 quantity for each spike
125 :param spike_clusters: cluster idx of each spike
126 :param spike_var: variable of each spike (e.g spike amps or spike depths)
127 :return: cluster id, average of quantity for each cluster, no. of spikes per cluster
128 """
129 clust, inverse, counts = np.unique(spike_clusters, return_inverse=True, return_counts=True) 1c
130 _spike_var = sparse.csr_matrix((spike_var, (inverse, np.zeros(inverse.size, dtype=int)))) 1c
131 spike_var_avg = np.ravel(_spike_var.toarray()) / counts 1c
133 return clust, spike_var_avg, counts 1c
136def bin_spikes(spikes, binsize, interval_indices=False):
137 """
138 Wrapper for bincount2D which is intended to take in a TimeSeries object of spike times
139 and cluster identities and spit out spike counts in bins of a specified width binsize, also in
140 another TimeSeries object. Can either return a TS object with each row labeled with the
141 corresponding interval or the value of the left edge of the bin.
143 :param spikes: Spike times and cluster identities of sorted spikes
144 :type spikes: TimeSeries object with \'clusters\' column and timestamps
145 :param binsize: Width of the non-overlapping bins in which to bin spikes
146 :type binsize: float
147 :param interval_indices: Whether to use intervals as the time stamps for binned spikes, rather
148 than the left edge value of the bins, defaults to False
149 :type interval_indices: bool, optional
150 :return: Object with 2D array of shape T x N, for T timesteps and N clusters, and the
151 associated time stamps.
152 :rtype: TimeSeries object
153 """
154 if type(spikes) is not core.TimeSeries:
155 raise TypeError('Input spikes need to be in TimeSeries object format')
157 if not hasattr(spikes, 'clusters'):
158 raise AttributeError('Input spikes need to have a clusters attribute. Make sure you set '
159 'columns=(\'clusters\',)) when constructing spikes.')
161 rates, tbins, clusters = bincount2D(spikes.times, spikes.clusters, binsize)
162 if interval_indices:
163 intervals = pd.interval_range(tbins[0], tbins[-1], freq=binsize, closed='left')
164 return core.TimeSeries(times=intervals, values=rates.T[:-1], columns=clusters)
165 else:
166 return core.TimeSeries(times=tbins, values=rates.T, columns=clusters)
169def get_units_bunch(spks_b, *args):
170 """
171 Returns a bunch, where the bunch keys are keys from `spks` with labels of spike information
172 (e.g. unit IDs, times, features, etc.), and the values for each key are arrays with values for
173 each unit: these arrays are ordered and can be indexed by unit id.
175 Parameters
176 ----------
177 spks_b : bunch
178 A spikes bunch containing fields with spike information (e.g. unit IDs, times, features,
179 etc.) for all spikes.
180 features : list of strings (optional positional arg)
181 A list of names of labels of spike information (which must be keys in `spks`) that specify
182 which labels to return as keys in `units`. If not provided, all keys in `spks` are returned
183 as keys in `units`.
185 Returns
186 -------
187 units_b : bunch
188 A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.)
189 whose values are arrays that hold values for each unit. The arrays for each key are ordered
190 by unit ID.
192 Examples
193 --------
194 1) Create a units bunch given a spikes bunch, and get the amps for unit #4 from the units
195 bunch.
196 >>> from brainbox import processing
197 >>> import one.alf.io as alfio
198 >>> import ibllib.ephys.spikes as e_spks
199 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
200 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
201 >>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
202 >>> units_b = processing.get_units_bunch(spks_b)
203 # Get amplitudes for unit 4.
204 >>> amps = units_b['amps']['4']
206 TODO add computation time estimate?
207 """
209 # Initialize `units`
210 units_b = Bunch()
211 # Get the keys to return for `units`:
212 if not args:
213 feat_keys = list(spks_b.keys())
214 else:
215 feat_keys = args[0]
216 # Get unit id for each spike and number of units. *Note: `n_units` might not equal `len(units)`
217 # because some clusters may be empty (due to a "wontfix" bug in ks2).
218 spks_unit_id = spks_b['clusters']
219 n_units = np.max(spks_unit_id)
220 units = np.unique(spks_b['clusters'])
221 # For each key in `units`, iteratively get each unit's values and add as a key to a bunch,
222 # `feat_bunch`. After iterating through all units, add `feat_bunch` as a key to `units`:
223 for feat in feat_keys:
224 # Initialize `feat_bunch` with a key for each unit.
225 feat_bunch = Bunch((str(unit), np.array([])) for unit in np.arange(n_units))
226 for unit in units:
227 unit_idxs = np.where(spks_unit_id == unit)[0]
228 feat_bunch[str(unit)] = spks_b[feat][unit_idxs]
229 units_b[feat] = feat_bunch
230 return units_b
233def filter_units(units_b, t, **kwargs):
234 """
235 Filters units according to some parameters. **kwargs are the keyword parameters used to filter
236 the units.
238 Parameters
239 ----------
240 units_b : bunch
241 A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.)
242 whose values are arrays that hold values for each unit. The arrays for each key are ordered
243 by unit ID.
244 t : float
245 Duration of time over which to calculate the firing rate and false positive rate.
247 Keyword Parameters
248 ------------------
249 min_amp : float
250 The minimum mean amplitude (in V) of the spikes in the unit. Default value is 50e-6.
251 min_fr : float
252 The minimum firing rate (in Hz) of the unit. Default value is 0.5.
253 max_fpr : float
254 The maximum false positive rate of the unit (using the fp formula in Hill et al. (2011)
255 J Neurosci 31: 8699-8705). Default value is 0.2.
256 rp : float
257 The refractory period (in s) of the unit. Used to calculate `max_fp`. Default value is
258 0.002.
260 Returns
261 -------
262 filt_units : ndarray
263 The ids of the filtered units.
265 See Also
266 --------
267 get_units_bunch
269 Examples
270 --------
271 1) Filter units according to the default parameters.
272 >>> from brainbox import processing
273 >>> import one.alf.io as alfio
274 >>> import ibllib.ephys.spikes as e_spks
275 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
276 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
277 # Get a spikes bunch, units bunch, and filter the units.
278 >>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
279 >>> units_b = processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters'])
280 >>> T = spks_b['times'][-1] - spks_b['times'][0]
281 >>> filtered_units = processing.filter_units(units_b, T)
283 2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false
284 positive rate of 0.2, given a refractory period of 2 ms.
285 >>> filtered_units = processing.filter_units(units_b, T, min_amp=0, min_fr=1)
287 TODO: `units_b` input arg could eventually be replaced by `clstrs_b` if the required metrics
288 are in `clstrs_b['metrics']`
289 """
291 # Set params
292 params = {'min_amp': 50e-6, 'min_fr': 0.5, 'max_fpr': 0.2, 'rp': 0.002} # defaults
293 params.update(kwargs) # update from **kwargs
295 # Iteratively filter the units for each filter param #
296 # -------------------------------------------------- #
297 units = np.asarray(list(units_b.amps.keys()))
298 # Remove empty clusters
299 empty_cl = np.where([len(units_b.amps[unit]) == 0 for unit in units])[0]
300 filt_units = np.delete(units, empty_cl)
301 for param in params.keys():
302 if param == 'min_amp': # return units above with amp > `'min_amp'`
303 mean_amps = np.asarray([np.mean(units_b.amps[unit]) for unit in filt_units])
304 filt_idxs = np.where(mean_amps > params['min_amp'])[0]
305 filt_units = filt_units[filt_idxs]
306 elif param == 'min_fr': # return units with fr > `'min_fr'`
307 fr = np.asarray([len(units_b.amps[unit]) /
308 (units_b.times[unit][-1] - units_b.times[unit][0])
309 for unit in filt_units])
310 filt_idxs = np.where(fr > params['min_fr'])[0]
311 filt_units = filt_units[filt_idxs]
312 elif param == 'max_fpr': # return units with fpr < `'max_fpr'`
313 fpr = np.zeros_like(filt_units, dtype='float')
314 for i, unit in enumerate(filt_units):
315 n_spks = len(units_b.amps[unit])
316 n_isi_viol = len(np.where(np.diff(units_b.times[unit]) < params['rp'])[0])
317 # fpr is min of roots of solved quadratic equation (Hill, et al. 2011).
318 c = (t * n_isi_viol) / (2 * params['rp'] * n_spks**2) # 3rd term in quadratic
319 fpr[i] = np.min(np.abs(np.roots([-1, 1, c]))) # solve quadratic
320 filt_idxs = np.where(fpr < params['max_fpr'])[0]
321 filt_units = filt_units[filt_idxs]
322 return filt_units.astype(int)