Coverage for brainbox/processing.py: 44%
105 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1'''
2Processes data from one form into another, e.g. taking spike times and binning them into
3non-overlapping bins and convolving spike times with a gaussian kernel.
4'''
6import numpy as np
7import pandas as pd
8from scipy import interpolate, sparse
9from brainbox import core
10from iblutil.numerical import bincount2D as _bincount2D
11from iblutil.util import Bunch
12import logging
13import warnings
14import traceback
16_logger = logging.getLogger(__name__)
19def sync(dt, times=None, values=None, timeseries=None, offsets=None, interp='zero',
20 fillval=np.nan):
21 """
22 Function for resampling a single or multiple time series to a single, evenly-spaced, delta t
23 between observations. Uses interpolation to find values.
25 Can be used on raw numpy arrays of timestamps and values using the 'times' and 'values' kwargs
26 and/or on brainbox.core.TimeSeries objects passed to the 'timeseries' kwarg. If passing both
27 TimeSeries objects and numpy arrays, the offsets passed should be for the TS objects first and
28 then the numpy arrays.
30 Uses scipy's interpolation library to perform interpolation.
31 See scipy.interp1d for more information regarding interp and fillval parameters.
33 :param dt: Separation of points which the output timeseries will be sampled at
34 :type dt: float
35 :param timeseries: A group of time series to perform alignment or a single time series.
36 Must have time stamps.
37 :type timeseries: tuple of TimeSeries objects, or a single TimeSeries object.
38 :param times: time stamps for the observations in 'values']
39 :type times: np.ndarray or list of np.ndarrays
40 :param values: observations corresponding to the timestamps in 'times'
41 :type values: np.ndarray or list of np.ndarrays
42 :param offsets: tuple of offsets for time stamps of each time series. Offsets for passed
43 TimeSeries objects first, then offsets for passed numpy arrays. defaults to None
44 :type offsets: tuple of floats, optional
45 :param interp: Type of interpolation to use. Refer to scipy.interpolate.interp1d for possible
46 values, defaults to np.nan
47 :type interp: str
48 :param fillval: Fill values to use when interpolating outside of range of data. See interp1d
49 for possible values, defaults to np.nan
50 :return: TimeSeries object with each row representing synchronized values of all
51 input TimeSeries. Will carry column names from input time series if all of them have column
52 names.
53 """
54 #########################################
55 # Checks on inputs and input processing #
56 #########################################
58 # Initialize a list to contain times/values pairs if no TS objs are passed
59 if timeseries is None: 1a
60 timeseries = [] 1a
61 # If a single time series is passed for resampling, wrap it in an iterable
62 elif isinstance(timeseries, core.TimeSeries): 1a
63 timeseries = [timeseries] 1a
64 # Yell at the user if they try to pass stuff to timeseries that isn't a TimeSeries object
65 elif not all([isinstance(ts, core.TimeSeries) for ts in timeseries]): 1a
66 raise TypeError('All elements of \'timeseries\' argument must be brainbox.core.TimeSeries '
67 'objects. Please uses \'times\' and \'values\' for np.ndarray args.')
68 # Check that if something is passed to times or values, there is a corresponding equal-length
69 # argument for the other element.
70 if (times is not None) or (values is not None): 1a
71 if len(times) != len(values): 1a
72 raise ValueError('\'times\' and \'values\' must have the same number of elements.')
73 if type(times[0]) is np.ndarray: 1a
74 if not all([t.shape == v.shape for t, v in zip(times, values)]): 1a
75 raise ValueError('All arrays in \'times\' must match the shape of the'
76 ' corresponding entry in \'values\'.')
77 # If all checks are passed, convert all times and values args into TimeSeries objects
78 timeseries.extend([core.TimeSeries(t, v) for t, v in zip(times, values)]) 1a
79 else:
80 # If times and values are only numpy arrays and lists of arrays, pair them and add
81 timeseries.append(core.TimeSeries(times, values)) 1a
83 # Adjust each timeseries by the associated offset if necessary then load into a list
84 if offsets is not None: 1a
85 tstamps = [ts.times + os for ts, os in zip(timeseries, offsets)]
86 else:
87 tstamps = [ts.times for ts in timeseries] 1a
88 # If all input timeseries have column names, put them together for the output TS
89 if all([ts.columns is not None for ts in timeseries]): 1a
90 colnames = [] 1a
91 for ts in timeseries: 1a
92 colnames.extend(ts.columns) 1a
93 else:
94 colnames = None 1a
96 #################
97 # Main function #
98 #################
100 # Get the min and max values for all timeseries combined after offsetting
101 tbounds = np.array([(np.amin(ts), np.amax(ts)) for ts in tstamps]) 1a
102 if not np.all(np.isfinite(tbounds)): 1a
103 # If there is a np.inf or np.nan in the time stamps for any of the timeseries this will
104 # break any further code so we check for all finite values and throw an informative error.
105 raise ValueError('NaN or inf encountered in passed timeseries.\
106 Please either drop or fill these values.')
107 tmin, tmax = np.amin(tbounds[:, 0]), np.amax(tbounds[:, 1]) 1a
108 if fillval == 'extrapolate': 1a
109 # If extrapolation is enabled we can ensure we have a full coverage of the data by
110 # extending the t max to be an whole integer multiple of dt above tmin.
111 # The 0.01% fudge factor is to account for floating point arithmetic errors.
112 newt = np.arange(tmin, tmax + 1.0001 * (dt - (tmax - tmin) % dt), dt) 1a
113 else:
114 newt = np.arange(tmin, tmax, dt)
115 tsinterps = [interpolate.interp1d(ts.times, ts.values, kind=interp, fill_value=fillval, axis=0) 1a
116 for ts in timeseries]
117 syncd = core.TimeSeries(newt, np.hstack([tsi(newt) for tsi in tsinterps]), columns=colnames) 1a
118 return syncd 1a
121def bincount2D(x, y, xbin=0, ybin=0, xlim=None, ylim=None, weights=None):
122 """
123 Computes a 2D histogram by aggregating values in a 2D array.
125 :param x: values to bin along the 2nd dimension (c-contiguous)
126 :param y: values to bin along the 1st dimension
127 :param xbin:
128 scalar: bin size along 2nd dimension
129 0: aggregate according to unique values
130 array: aggregate according to exact values (count reduce operation)
131 :param ybin:
132 scalar: bin size along 1st dimension
133 0: aggregate according to unique values
134 array: aggregate according to exact values (count reduce operation)
135 :param xlim: (optional) 2 values (array or list) that restrict range along 2nd dimension
136 :param ylim: (optional) 2 values (array or list) that restrict range along 1st dimension
137 :param weights: (optional) defaults to None, weights to apply to each value for aggregation
138 :return: 3 numpy arrays MAP [ny,nx] image, xscale [nx], yscale [ny]
139 """
140 for line in traceback.format_stack():
141 print(line.strip())
142 warning_text = """Deprecation warning: bincount2D() is now a part of iblutil.
143 brainbox.processing.bincount2D is deprecated and will be removed in
144 future versions. Please replace imports with iblutil.numerical.bincount2D."""
145 _logger.warning(warning_text)
146 warnings.warn(warning_text, DeprecationWarning)
147 return _bincount2D(x, y, xbin, ybin, xlim, ylim, weights)
150def compute_cluster_average(spike_clusters, spike_var):
151 """
152 Quickish way to compute the average of some quantity across spikes in each cluster given
153 quantity for each spike
155 :param spike_clusters: cluster idx of each spike
156 :param spike_var: variable of each spike (e.g spike amps or spike depths)
157 :return: cluster id, average of quantity for each cluster, no. of spikes per cluster
158 """
159 clust, inverse, counts = np.unique(spike_clusters, return_inverse=True, return_counts=True) 1c
160 _spike_var = sparse.csr_matrix((spike_var, (inverse, np.zeros(inverse.size, dtype=int)))) 1c
161 spike_var_avg = np.ravel(_spike_var.toarray()) / counts 1c
163 return clust, spike_var_avg, counts 1c
166def bin_spikes(spikes, binsize, interval_indices=False):
167 """
168 Wrapper for bincount2D which is intended to take in a TimeSeries object of spike times
169 and cluster identities and spit out spike counts in bins of a specified width binsize, also in
170 another TimeSeries object. Can either return a TS object with each row labeled with the
171 corresponding interval or the value of the left edge of the bin.
173 :param spikes: Spike times and cluster identities of sorted spikes
174 :type spikes: TimeSeries object with \'clusters\' column and timestamps
175 :param binsize: Width of the non-overlapping bins in which to bin spikes
176 :type binsize: float
177 :param interval_indices: Whether to use intervals as the time stamps for binned spikes, rather
178 than the left edge value of the bins, defaults to False
179 :type interval_indices: bool, optional
180 :return: Object with 2D array of shape T x N, for T timesteps and N clusters, and the
181 associated time stamps.
182 :rtype: TimeSeries object
183 """
184 if type(spikes) is not core.TimeSeries:
185 raise TypeError('Input spikes need to be in TimeSeries object format')
187 if not hasattr(spikes, 'clusters'):
188 raise AttributeError('Input spikes need to have a clusters attribute. Make sure you set '
189 'columns=(\'clusters\',)) when constructing spikes.')
191 rates, tbins, clusters = bincount2D(spikes.times, spikes.clusters, binsize)
192 if interval_indices:
193 intervals = pd.interval_range(tbins[0], tbins[-1], freq=binsize, closed='left')
194 return core.TimeSeries(times=intervals, values=rates.T[:-1], columns=clusters)
195 else:
196 return core.TimeSeries(times=tbins, values=rates.T, columns=clusters)
199def get_units_bunch(spks_b, *args):
200 '''
201 Returns a bunch, where the bunch keys are keys from `spks` with labels of spike information
202 (e.g. unit IDs, times, features, etc.), and the values for each key are arrays with values for
203 each unit: these arrays are ordered and can be indexed by unit id.
205 Parameters
206 ----------
207 spks_b : bunch
208 A spikes bunch containing fields with spike information (e.g. unit IDs, times, features,
209 etc.) for all spikes.
210 features : list of strings (optional positional arg)
211 A list of names of labels of spike information (which must be keys in `spks`) that specify
212 which labels to return as keys in `units`. If not provided, all keys in `spks` are returned
213 as keys in `units`.
215 Returns
216 -------
217 units_b : bunch
218 A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.)
219 whose values are arrays that hold values for each unit. The arrays for each key are ordered
220 by unit ID.
222 Examples
223 --------
224 1) Create a units bunch given a spikes bunch, and get the amps for unit #4 from the units
225 bunch.
226 >>> import brainbox as bb
227 >>> import alf.io as aio
228 >>> import ibllib.ephys.spikes as e_spks
229 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
230 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
231 >>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
232 >>> units_b = bb.processing.get_units_bunch(spks_b)
233 # Get amplitudes for unit 4.
234 >>> amps = units_b['amps']['4']
236 TODO add computation time estimate?
237 '''
239 # Initialize `units`
240 units_b = Bunch()
241 # Get the keys to return for `units`:
242 if not args:
243 feat_keys = list(spks_b.keys())
244 else:
245 feat_keys = args[0]
246 # Get unit id for each spike and number of units. *Note: `n_units` might not equal `len(units)`
247 # because some clusters may be empty (due to a "wontfix" bug in ks2).
248 spks_unit_id = spks_b['clusters']
249 n_units = np.max(spks_unit_id)
250 units = np.unique(spks_b['clusters'])
251 # For each key in `units`, iteratively get each unit's values and add as a key to a bunch,
252 # `feat_bunch`. After iterating through all units, add `feat_bunch` as a key to `units`:
253 for feat in feat_keys:
254 # Initialize `feat_bunch` with a key for each unit.
255 feat_bunch = Bunch((str(unit), np.array([])) for unit in np.arange(n_units))
256 for unit in units:
257 unit_idxs = np.where(spks_unit_id == unit)[0]
258 feat_bunch[str(unit)] = spks_b[feat][unit_idxs]
259 units_b[feat] = feat_bunch
260 return units_b
263def filter_units(units_b, t, **kwargs):
264 '''
265 Filters units according to some parameters. **kwargs are the keyword parameters used to filter
266 the units.
268 Parameters
269 ----------
270 units_b : bunch
271 A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.)
272 whose values are arrays that hold values for each unit. The arrays for each key are ordered
273 by unit ID.
274 t : float
275 Duration of time over which to calculate the firing rate and false positive rate.
277 Keyword Parameters
278 ------------------
279 min_amp : float
280 The minimum mean amplitude (in V) of the spikes in the unit. Default value is 50e-6.
281 min_fr : float
282 The minimum firing rate (in Hz) of the unit. Default value is 0.5.
283 max_fpr : float
284 The maximum false positive rate of the unit (using the fp formula in Hill et al. (2011)
285 J Neurosci 31: 8699-8705). Default value is 0.2.
286 rp : float
287 The refractory period (in s) of the unit. Used to calculate `max_fp`. Default value is
288 0.002.
290 Returns
291 -------
292 filt_units : ndarray
293 The ids of the filtered units.
295 See Also
296 --------
297 get_units_bunch
299 Examples
300 --------
301 1) Filter units according to the default parameters.
302 >>> import brainbox as bb
303 >>> import alf.io as aio
304 >>> import ibllib.ephys.spikes as e_spks
305 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
306 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
307 # Get a spikes bunch, units bunch, and filter the units.
308 >>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
309 >>> units_b = bb.processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters'])
310 >>> T = spks_b['times'][-1] - spks_b['times'][0]
311 >>> filtered_units = bb.processing.filter_units(units_b, T)
313 2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false
314 positive rate of 0.2, given a refractory period of 2 ms.
315 >>> filtered_units = bb.processing.filter_units(units_b, T, min_amp=0, min_fr=1)
317 TODO: `units_b` input arg could eventually be replaced by `clstrs_b` if the required metrics
318 are in `clstrs_b['metrics']`
319 '''
321 # Set params
322 params = {'min_amp': 50e-6, 'min_fr': 0.5, 'max_fpr': 0.2, 'rp': 0.002} # defaults
323 params.update(kwargs) # update from **kwargs
325 # Iteratively filter the units for each filter param #
326 # -------------------------------------------------- #
327 units = np.asarray(list(units_b.amps.keys()))
328 # Remove empty clusters
329 empty_cl = np.where([len(units_b.amps[unit]) == 0 for unit in units])[0]
330 filt_units = np.delete(units, empty_cl)
331 for param in params.keys():
332 if param == 'min_amp': # return units above with amp > `'min_amp'`
333 mean_amps = np.asarray([np.mean(units_b.amps[unit]) for unit in filt_units])
334 filt_idxs = np.where(mean_amps > params['min_amp'])[0]
335 filt_units = filt_units[filt_idxs]
336 elif param == 'min_fr': # return units with fr > `'min_fr'`
337 fr = np.asarray([len(units_b.amps[unit]) /
338 (units_b.times[unit][-1] - units_b.times[unit][0])
339 for unit in filt_units])
340 filt_idxs = np.where(fr > params['min_fr'])[0]
341 filt_units = filt_units[filt_idxs]
342 elif param == 'max_fpr': # return units with fpr < `'max_fpr'`
343 fpr = np.zeros_like(filt_units, dtype='float')
344 for i, unit in enumerate(filt_units):
345 n_spks = len(units_b.amps[unit])
346 n_isi_viol = len(np.where(np.diff(units_b.times[unit]) < params['rp'])[0])
347 # fpr is min of roots of solved quadratic equation (Hill, et al. 2011).
348 c = (t * n_isi_viol) / (2 * params['rp'] * n_spks**2) # 3rd term in quadratic
349 fpr[i] = np.min(np.abs(np.roots([-1, 1, c]))) # solve quadratic
350 filt_idxs = np.where(fpr < params['max_fpr'])[0]
351 filt_units = filt_units[filt_idxs]
352 return filt_units.astype(int)