Coverage for brainbox/processing.py: 44%

105 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1''' 

2Processes data from one form into another, e.g. taking spike times and binning them into 

3non-overlapping bins and convolving spike times with a gaussian kernel. 

4''' 

5 

6import numpy as np 

7import pandas as pd 

8from scipy import interpolate, sparse 

9from brainbox import core 

10from iblutil.numerical import bincount2D as _bincount2D 

11from iblutil.util import Bunch 

12import logging 

13import warnings 

14import traceback 

15 

16_logger = logging.getLogger(__name__) 

17 

18 

19def sync(dt, times=None, values=None, timeseries=None, offsets=None, interp='zero', 

20 fillval=np.nan): 

21 """ 

22 Function for resampling a single or multiple time series to a single, evenly-spaced, delta t 

23 between observations. Uses interpolation to find values. 

24 

25 Can be used on raw numpy arrays of timestamps and values using the 'times' and 'values' kwargs 

26 and/or on brainbox.core.TimeSeries objects passed to the 'timeseries' kwarg. If passing both 

27 TimeSeries objects and numpy arrays, the offsets passed should be for the TS objects first and 

28 then the numpy arrays. 

29 

30 Uses scipy's interpolation library to perform interpolation. 

31 See scipy.interp1d for more information regarding interp and fillval parameters. 

32 

33 :param dt: Separation of points which the output timeseries will be sampled at 

34 :type dt: float 

35 :param timeseries: A group of time series to perform alignment or a single time series. 

36 Must have time stamps. 

37 :type timeseries: tuple of TimeSeries objects, or a single TimeSeries object. 

38 :param times: time stamps for the observations in 'values'] 

39 :type times: np.ndarray or list of np.ndarrays 

40 :param values: observations corresponding to the timestamps in 'times' 

41 :type values: np.ndarray or list of np.ndarrays 

42 :param offsets: tuple of offsets for time stamps of each time series. Offsets for passed 

43 TimeSeries objects first, then offsets for passed numpy arrays. defaults to None 

44 :type offsets: tuple of floats, optional 

45 :param interp: Type of interpolation to use. Refer to scipy.interpolate.interp1d for possible 

46 values, defaults to np.nan 

47 :type interp: str 

48 :param fillval: Fill values to use when interpolating outside of range of data. See interp1d 

49 for possible values, defaults to np.nan 

50 :return: TimeSeries object with each row representing synchronized values of all 

51 input TimeSeries. Will carry column names from input time series if all of them have column 

52 names. 

53 """ 

54 ######################################### 

55 # Checks on inputs and input processing # 

56 ######################################### 

57 

58 # Initialize a list to contain times/values pairs if no TS objs are passed 

59 if timeseries is None: 1a

60 timeseries = [] 1a

61 # If a single time series is passed for resampling, wrap it in an iterable 

62 elif isinstance(timeseries, core.TimeSeries): 1a

63 timeseries = [timeseries] 1a

64 # Yell at the user if they try to pass stuff to timeseries that isn't a TimeSeries object 

65 elif not all([isinstance(ts, core.TimeSeries) for ts in timeseries]): 1a

66 raise TypeError('All elements of \'timeseries\' argument must be brainbox.core.TimeSeries ' 

67 'objects. Please uses \'times\' and \'values\' for np.ndarray args.') 

68 # Check that if something is passed to times or values, there is a corresponding equal-length 

69 # argument for the other element. 

70 if (times is not None) or (values is not None): 1a

71 if len(times) != len(values): 1a

72 raise ValueError('\'times\' and \'values\' must have the same number of elements.') 

73 if type(times[0]) is np.ndarray: 1a

74 if not all([t.shape == v.shape for t, v in zip(times, values)]): 1a

75 raise ValueError('All arrays in \'times\' must match the shape of the' 

76 ' corresponding entry in \'values\'.') 

77 # If all checks are passed, convert all times and values args into TimeSeries objects 

78 timeseries.extend([core.TimeSeries(t, v) for t, v in zip(times, values)]) 1a

79 else: 

80 # If times and values are only numpy arrays and lists of arrays, pair them and add 

81 timeseries.append(core.TimeSeries(times, values)) 1a

82 

83 # Adjust each timeseries by the associated offset if necessary then load into a list 

84 if offsets is not None: 1a

85 tstamps = [ts.times + os for ts, os in zip(timeseries, offsets)] 

86 else: 

87 tstamps = [ts.times for ts in timeseries] 1a

88 # If all input timeseries have column names, put them together for the output TS 

89 if all([ts.columns is not None for ts in timeseries]): 1a

90 colnames = [] 1a

91 for ts in timeseries: 1a

92 colnames.extend(ts.columns) 1a

93 else: 

94 colnames = None 1a

95 

96 ################# 

97 # Main function # 

98 ################# 

99 

100 # Get the min and max values for all timeseries combined after offsetting 

101 tbounds = np.array([(np.amin(ts), np.amax(ts)) for ts in tstamps]) 1a

102 if not np.all(np.isfinite(tbounds)): 1a

103 # If there is a np.inf or np.nan in the time stamps for any of the timeseries this will 

104 # break any further code so we check for all finite values and throw an informative error. 

105 raise ValueError('NaN or inf encountered in passed timeseries.\ 

106 Please either drop or fill these values.') 

107 tmin, tmax = np.amin(tbounds[:, 0]), np.amax(tbounds[:, 1]) 1a

108 if fillval == 'extrapolate': 1a

109 # If extrapolation is enabled we can ensure we have a full coverage of the data by 

110 # extending the t max to be an whole integer multiple of dt above tmin. 

111 # The 0.01% fudge factor is to account for floating point arithmetic errors. 

112 newt = np.arange(tmin, tmax + 1.0001 * (dt - (tmax - tmin) % dt), dt) 1a

113 else: 

114 newt = np.arange(tmin, tmax, dt) 

115 tsinterps = [interpolate.interp1d(ts.times, ts.values, kind=interp, fill_value=fillval, axis=0) 1a

116 for ts in timeseries] 

117 syncd = core.TimeSeries(newt, np.hstack([tsi(newt) for tsi in tsinterps]), columns=colnames) 1a

118 return syncd 1a

119 

120 

121def bincount2D(x, y, xbin=0, ybin=0, xlim=None, ylim=None, weights=None): 

122 """ 

123 Computes a 2D histogram by aggregating values in a 2D array. 

124 

125 :param x: values to bin along the 2nd dimension (c-contiguous) 

126 :param y: values to bin along the 1st dimension 

127 :param xbin: 

128 scalar: bin size along 2nd dimension 

129 0: aggregate according to unique values 

130 array: aggregate according to exact values (count reduce operation) 

131 :param ybin: 

132 scalar: bin size along 1st dimension 

133 0: aggregate according to unique values 

134 array: aggregate according to exact values (count reduce operation) 

135 :param xlim: (optional) 2 values (array or list) that restrict range along 2nd dimension 

136 :param ylim: (optional) 2 values (array or list) that restrict range along 1st dimension 

137 :param weights: (optional) defaults to None, weights to apply to each value for aggregation 

138 :return: 3 numpy arrays MAP [ny,nx] image, xscale [nx], yscale [ny] 

139 """ 

140 for line in traceback.format_stack(): 

141 print(line.strip()) 

142 warning_text = """Deprecation warning: bincount2D() is now a part of iblutil. 

143 brainbox.processing.bincount2D is deprecated and will be removed in 

144 future versions. Please replace imports with iblutil.numerical.bincount2D.""" 

145 _logger.warning(warning_text) 

146 warnings.warn(warning_text, DeprecationWarning) 

147 return _bincount2D(x, y, xbin, ybin, xlim, ylim, weights) 

148 

149 

150def compute_cluster_average(spike_clusters, spike_var): 

151 """ 

152 Quickish way to compute the average of some quantity across spikes in each cluster given 

153 quantity for each spike 

154 

155 :param spike_clusters: cluster idx of each spike 

156 :param spike_var: variable of each spike (e.g spike amps or spike depths) 

157 :return: cluster id, average of quantity for each cluster, no. of spikes per cluster 

158 """ 

159 clust, inverse, counts = np.unique(spike_clusters, return_inverse=True, return_counts=True) 1c

160 _spike_var = sparse.csr_matrix((spike_var, (inverse, np.zeros(inverse.size, dtype=int)))) 1c

161 spike_var_avg = np.ravel(_spike_var.toarray()) / counts 1c

162 

163 return clust, spike_var_avg, counts 1c

164 

165 

166def bin_spikes(spikes, binsize, interval_indices=False): 

167 """ 

168 Wrapper for bincount2D which is intended to take in a TimeSeries object of spike times 

169 and cluster identities and spit out spike counts in bins of a specified width binsize, also in 

170 another TimeSeries object. Can either return a TS object with each row labeled with the 

171 corresponding interval or the value of the left edge of the bin. 

172 

173 :param spikes: Spike times and cluster identities of sorted spikes 

174 :type spikes: TimeSeries object with \'clusters\' column and timestamps 

175 :param binsize: Width of the non-overlapping bins in which to bin spikes 

176 :type binsize: float 

177 :param interval_indices: Whether to use intervals as the time stamps for binned spikes, rather 

178 than the left edge value of the bins, defaults to False 

179 :type interval_indices: bool, optional 

180 :return: Object with 2D array of shape T x N, for T timesteps and N clusters, and the 

181 associated time stamps. 

182 :rtype: TimeSeries object 

183 """ 

184 if type(spikes) is not core.TimeSeries: 

185 raise TypeError('Input spikes need to be in TimeSeries object format') 

186 

187 if not hasattr(spikes, 'clusters'): 

188 raise AttributeError('Input spikes need to have a clusters attribute. Make sure you set ' 

189 'columns=(\'clusters\',)) when constructing spikes.') 

190 

191 rates, tbins, clusters = bincount2D(spikes.times, spikes.clusters, binsize) 

192 if interval_indices: 

193 intervals = pd.interval_range(tbins[0], tbins[-1], freq=binsize, closed='left') 

194 return core.TimeSeries(times=intervals, values=rates.T[:-1], columns=clusters) 

195 else: 

196 return core.TimeSeries(times=tbins, values=rates.T, columns=clusters) 

197 

198 

199def get_units_bunch(spks_b, *args): 

200 ''' 

201 Returns a bunch, where the bunch keys are keys from `spks` with labels of spike information 

202 (e.g. unit IDs, times, features, etc.), and the values for each key are arrays with values for 

203 each unit: these arrays are ordered and can be indexed by unit id. 

204 

205 Parameters 

206 ---------- 

207 spks_b : bunch 

208 A spikes bunch containing fields with spike information (e.g. unit IDs, times, features, 

209 etc.) for all spikes. 

210 features : list of strings (optional positional arg) 

211 A list of names of labels of spike information (which must be keys in `spks`) that specify 

212 which labels to return as keys in `units`. If not provided, all keys in `spks` are returned 

213 as keys in `units`. 

214 

215 Returns 

216 ------- 

217 units_b : bunch 

218 A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.) 

219 whose values are arrays that hold values for each unit. The arrays for each key are ordered 

220 by unit ID. 

221 

222 Examples 

223 -------- 

224 1) Create a units bunch given a spikes bunch, and get the amps for unit #4 from the units 

225 bunch. 

226 >>> import brainbox as bb 

227 >>> import alf.io as aio 

228 >>> import ibllib.ephys.spikes as e_spks 

229 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

230 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

231 >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') 

232 >>> units_b = bb.processing.get_units_bunch(spks_b) 

233 # Get amplitudes for unit 4. 

234 >>> amps = units_b['amps']['4'] 

235 

236 TODO add computation time estimate? 

237 ''' 

238 

239 # Initialize `units` 

240 units_b = Bunch() 

241 # Get the keys to return for `units`: 

242 if not args: 

243 feat_keys = list(spks_b.keys()) 

244 else: 

245 feat_keys = args[0] 

246 # Get unit id for each spike and number of units. *Note: `n_units` might not equal `len(units)` 

247 # because some clusters may be empty (due to a "wontfix" bug in ks2). 

248 spks_unit_id = spks_b['clusters'] 

249 n_units = np.max(spks_unit_id) 

250 units = np.unique(spks_b['clusters']) 

251 # For each key in `units`, iteratively get each unit's values and add as a key to a bunch, 

252 # `feat_bunch`. After iterating through all units, add `feat_bunch` as a key to `units`: 

253 for feat in feat_keys: 

254 # Initialize `feat_bunch` with a key for each unit. 

255 feat_bunch = Bunch((str(unit), np.array([])) for unit in np.arange(n_units)) 

256 for unit in units: 

257 unit_idxs = np.where(spks_unit_id == unit)[0] 

258 feat_bunch[str(unit)] = spks_b[feat][unit_idxs] 

259 units_b[feat] = feat_bunch 

260 return units_b 

261 

262 

263def filter_units(units_b, t, **kwargs): 

264 ''' 

265 Filters units according to some parameters. **kwargs are the keyword parameters used to filter 

266 the units. 

267 

268 Parameters 

269 ---------- 

270 units_b : bunch 

271 A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.) 

272 whose values are arrays that hold values for each unit. The arrays for each key are ordered 

273 by unit ID. 

274 t : float 

275 Duration of time over which to calculate the firing rate and false positive rate. 

276 

277 Keyword Parameters 

278 ------------------ 

279 min_amp : float 

280 The minimum mean amplitude (in V) of the spikes in the unit. Default value is 50e-6. 

281 min_fr : float 

282 The minimum firing rate (in Hz) of the unit. Default value is 0.5. 

283 max_fpr : float 

284 The maximum false positive rate of the unit (using the fp formula in Hill et al. (2011) 

285 J Neurosci 31: 8699-8705). Default value is 0.2. 

286 rp : float 

287 The refractory period (in s) of the unit. Used to calculate `max_fp`. Default value is 

288 0.002. 

289 

290 Returns 

291 ------- 

292 filt_units : ndarray 

293 The ids of the filtered units. 

294 

295 See Also 

296 -------- 

297 get_units_bunch 

298 

299 Examples 

300 -------- 

301 1) Filter units according to the default parameters. 

302 >>> import brainbox as bb 

303 >>> import alf.io as aio 

304 >>> import ibllib.ephys.spikes as e_spks 

305 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

306 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

307 # Get a spikes bunch, units bunch, and filter the units. 

308 >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') 

309 >>> units_b = bb.processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters']) 

310 >>> T = spks_b['times'][-1] - spks_b['times'][0] 

311 >>> filtered_units = bb.processing.filter_units(units_b, T) 

312 

313 2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false 

314 positive rate of 0.2, given a refractory period of 2 ms. 

315 >>> filtered_units = bb.processing.filter_units(units_b, T, min_amp=0, min_fr=1) 

316 

317 TODO: `units_b` input arg could eventually be replaced by `clstrs_b` if the required metrics 

318 are in `clstrs_b['metrics']` 

319 ''' 

320 

321 # Set params 

322 params = {'min_amp': 50e-6, 'min_fr': 0.5, 'max_fpr': 0.2, 'rp': 0.002} # defaults 

323 params.update(kwargs) # update from **kwargs 

324 

325 # Iteratively filter the units for each filter param # 

326 # -------------------------------------------------- # 

327 units = np.asarray(list(units_b.amps.keys())) 

328 # Remove empty clusters 

329 empty_cl = np.where([len(units_b.amps[unit]) == 0 for unit in units])[0] 

330 filt_units = np.delete(units, empty_cl) 

331 for param in params.keys(): 

332 if param == 'min_amp': # return units above with amp > `'min_amp'` 

333 mean_amps = np.asarray([np.mean(units_b.amps[unit]) for unit in filt_units]) 

334 filt_idxs = np.where(mean_amps > params['min_amp'])[0] 

335 filt_units = filt_units[filt_idxs] 

336 elif param == 'min_fr': # return units with fr > `'min_fr'` 

337 fr = np.asarray([len(units_b.amps[unit]) / 

338 (units_b.times[unit][-1] - units_b.times[unit][0]) 

339 for unit in filt_units]) 

340 filt_idxs = np.where(fr > params['min_fr'])[0] 

341 filt_units = filt_units[filt_idxs] 

342 elif param == 'max_fpr': # return units with fpr < `'max_fpr'` 

343 fpr = np.zeros_like(filt_units, dtype='float') 

344 for i, unit in enumerate(filt_units): 

345 n_spks = len(units_b.amps[unit]) 

346 n_isi_viol = len(np.where(np.diff(units_b.times[unit]) < params['rp'])[0]) 

347 # fpr is min of roots of solved quadratic equation (Hill, et al. 2011). 

348 c = (t * n_isi_viol) / (2 * params['rp'] * n_spks**2) # 3rd term in quadratic 

349 fpr[i] = np.min(np.abs(np.roots([-1, 1, c]))) # solve quadratic 

350 filt_idxs = np.where(fpr < params['max_fpr'])[0] 

351 filt_units = filt_units[filt_idxs] 

352 return filt_units.astype(int)