Coverage for brainbox/processing.py: 45%

96 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1"""Process data from one form into another. 

2 

3For example, taking spike times and binning them into non-overlapping bins and convolving spike 

4times with a gaussian kernel. 

5""" 

6 

7import numpy as np 

8import pandas as pd 

9from scipy import interpolate, sparse 

10from brainbox import core 

11from iblutil.numerical import bincount2D 

12from iblutil.util import Bunch 

13import logging 

14 

15_logger = logging.getLogger(__name__) 

16 

17 

18def sync(dt, times=None, values=None, timeseries=None, offsets=None, interp='zero', 

19 fillval=np.nan): 

20 """ 

21 Function for resampling a single or multiple time series to a single, evenly-spaced, delta t 

22 between observations. Uses interpolation to find values. 

23 

24 Can be used on raw numpy arrays of timestamps and values using the 'times' and 'values' kwargs 

25 and/or on brainbox.core.TimeSeries objects passed to the 'timeseries' kwarg. If passing both 

26 TimeSeries objects and numpy arrays, the offsets passed should be for the TS objects first and 

27 then the numpy arrays. 

28 

29 Uses scipy's interpolation library to perform interpolation. 

30 See scipy.interp1d for more information regarding interp and fillval parameters. 

31 

32 :param dt: Separation of points which the output timeseries will be sampled at 

33 :type dt: float 

34 :param timeseries: A group of time series to perform alignment or a single time series. 

35 Must have time stamps. 

36 :type timeseries: tuple of TimeSeries objects, or a single TimeSeries object. 

37 :param times: time stamps for the observations in 'values'] 

38 :type times: np.ndarray or list of np.ndarrays 

39 :param values: observations corresponding to the timestamps in 'times' 

40 :type values: np.ndarray or list of np.ndarrays 

41 :param offsets: tuple of offsets for time stamps of each time series. Offsets for passed 

42 TimeSeries objects first, then offsets for passed numpy arrays. defaults to None 

43 :type offsets: tuple of floats, optional 

44 :param interp: Type of interpolation to use. Refer to scipy.interpolate.interp1d for possible 

45 values, defaults to np.nan 

46 :type interp: str 

47 :param fillval: Fill values to use when interpolating outside of range of data. See interp1d 

48 for possible values, defaults to np.nan 

49 :return: TimeSeries object with each row representing synchronized values of all 

50 input TimeSeries. Will carry column names from input time series if all of them have column 

51 names. 

52 """ 

53 ######################################### 

54 # Checks on inputs and input processing # 

55 ######################################### 

56 

57 # Initialize a list to contain times/values pairs if no TS objs are passed 

58 if timeseries is None: 1a

59 timeseries = [] 1a

60 # If a single time series is passed for resampling, wrap it in an iterable 

61 elif isinstance(timeseries, core.TimeSeries): 1a

62 timeseries = [timeseries] 1a

63 # Yell at the user if they try to pass stuff to timeseries that isn't a TimeSeries object 

64 elif not all([isinstance(ts, core.TimeSeries) for ts in timeseries]): 1a

65 raise TypeError('All elements of \'timeseries\' argument must be brainbox.core.TimeSeries ' 

66 'objects. Please uses \'times\' and \'values\' for np.ndarray args.') 

67 # Check that if something is passed to times or values, there is a corresponding equal-length 

68 # argument for the other element. 

69 if (times is not None) or (values is not None): 1a

70 if len(times) != len(values): 1a

71 raise ValueError('\'times\' and \'values\' must have the same number of elements.') 

72 if type(times[0]) is np.ndarray: 1a

73 if not all([t.shape == v.shape for t, v in zip(times, values)]): 1a

74 raise ValueError('All arrays in \'times\' must match the shape of the' 

75 ' corresponding entry in \'values\'.') 

76 # If all checks are passed, convert all times and values args into TimeSeries objects 

77 timeseries.extend([core.TimeSeries(t, v) for t, v in zip(times, values)]) 1a

78 else: 

79 # If times and values are only numpy arrays and lists of arrays, pair them and add 

80 timeseries.append(core.TimeSeries(times, values)) 1a

81 

82 # Adjust each timeseries by the associated offset if necessary then load into a list 

83 if offsets is not None: 1a

84 tstamps = [ts.times + os for ts, os in zip(timeseries, offsets)] 

85 else: 

86 tstamps = [ts.times for ts in timeseries] 1a

87 # If all input timeseries have column names, put them together for the output TS 

88 if all([ts.columns is not None for ts in timeseries]): 1a

89 colnames = [] 1a

90 for ts in timeseries: 1a

91 colnames.extend(ts.columns) 1a

92 else: 

93 colnames = None 1a

94 

95 ################# 

96 # Main function # 

97 ################# 

98 

99 # Get the min and max values for all timeseries combined after offsetting 

100 tbounds = np.array([(np.amin(ts), np.amax(ts)) for ts in tstamps]) 1a

101 if not np.all(np.isfinite(tbounds)): 1a

102 # If there is a np.inf or np.nan in the time stamps for any of the timeseries this will 

103 # break any further code so we check for all finite values and throw an informative error. 

104 raise ValueError('NaN or inf encountered in passed timeseries.\ 

105 Please either drop or fill these values.') 

106 tmin, tmax = np.amin(tbounds[:, 0]), np.amax(tbounds[:, 1]) 1a

107 if fillval == 'extrapolate': 1a

108 # If extrapolation is enabled we can ensure we have a full coverage of the data by 

109 # extending the t max to be an whole integer multiple of dt above tmin. 

110 # The 0.01% fudge factor is to account for floating point arithmetic errors. 

111 newt = np.arange(tmin, tmax + 1.0001 * (dt - (tmax - tmin) % dt), dt) 1a

112 else: 

113 newt = np.arange(tmin, tmax, dt) 

114 tsinterps = [interpolate.interp1d(ts.times, ts.values, kind=interp, fill_value=fillval, axis=0) 1a

115 for ts in timeseries] 

116 syncd = core.TimeSeries(newt, np.hstack([tsi(newt) for tsi in tsinterps]), columns=colnames) 1a

117 return syncd 1a

118 

119 

120def compute_cluster_average(spike_clusters, spike_var): 

121 """ 

122 Quickish way to compute the average of some quantity across spikes in each cluster given 

123 quantity for each spike 

124 

125 :param spike_clusters: cluster idx of each spike 

126 :param spike_var: variable of each spike (e.g spike amps or spike depths) 

127 :return: cluster id, average of quantity for each cluster, no. of spikes per cluster 

128 """ 

129 clust, inverse, counts = np.unique(spike_clusters, return_inverse=True, return_counts=True) 1c

130 _spike_var = sparse.csr_matrix((spike_var, (inverse, np.zeros(inverse.size, dtype=int)))) 1c

131 spike_var_avg = np.ravel(_spike_var.toarray()) / counts 1c

132 

133 return clust, spike_var_avg, counts 1c

134 

135 

136def bin_spikes(spikes, binsize, interval_indices=False): 

137 """ 

138 Wrapper for bincount2D which is intended to take in a TimeSeries object of spike times 

139 and cluster identities and spit out spike counts in bins of a specified width binsize, also in 

140 another TimeSeries object. Can either return a TS object with each row labeled with the 

141 corresponding interval or the value of the left edge of the bin. 

142 

143 :param spikes: Spike times and cluster identities of sorted spikes 

144 :type spikes: TimeSeries object with \'clusters\' column and timestamps 

145 :param binsize: Width of the non-overlapping bins in which to bin spikes 

146 :type binsize: float 

147 :param interval_indices: Whether to use intervals as the time stamps for binned spikes, rather 

148 than the left edge value of the bins, defaults to False 

149 :type interval_indices: bool, optional 

150 :return: Object with 2D array of shape T x N, for T timesteps and N clusters, and the 

151 associated time stamps. 

152 :rtype: TimeSeries object 

153 """ 

154 if type(spikes) is not core.TimeSeries: 

155 raise TypeError('Input spikes need to be in TimeSeries object format') 

156 

157 if not hasattr(spikes, 'clusters'): 

158 raise AttributeError('Input spikes need to have a clusters attribute. Make sure you set ' 

159 'columns=(\'clusters\',)) when constructing spikes.') 

160 

161 rates, tbins, clusters = bincount2D(spikes.times, spikes.clusters, binsize) 

162 if interval_indices: 

163 intervals = pd.interval_range(tbins[0], tbins[-1], freq=binsize, closed='left') 

164 return core.TimeSeries(times=intervals, values=rates.T[:-1], columns=clusters) 

165 else: 

166 return core.TimeSeries(times=tbins, values=rates.T, columns=clusters) 

167 

168 

169def get_units_bunch(spks_b, *args): 

170 """ 

171 Returns a bunch, where the bunch keys are keys from `spks` with labels of spike information 

172 (e.g. unit IDs, times, features, etc.), and the values for each key are arrays with values for 

173 each unit: these arrays are ordered and can be indexed by unit id. 

174 

175 Parameters 

176 ---------- 

177 spks_b : bunch 

178 A spikes bunch containing fields with spike information (e.g. unit IDs, times, features, 

179 etc.) for all spikes. 

180 features : list of strings (optional positional arg) 

181 A list of names of labels of spike information (which must be keys in `spks`) that specify 

182 which labels to return as keys in `units`. If not provided, all keys in `spks` are returned 

183 as keys in `units`. 

184 

185 Returns 

186 ------- 

187 units_b : bunch 

188 A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.) 

189 whose values are arrays that hold values for each unit. The arrays for each key are ordered 

190 by unit ID. 

191 

192 Examples 

193 -------- 

194 1) Create a units bunch given a spikes bunch, and get the amps for unit #4 from the units 

195 bunch. 

196 >>> from brainbox import processing 

197 >>> import one.alf.io as alfio 

198 >>> import ibllib.ephys.spikes as e_spks 

199 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

200 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

201 >>> spks_b = alfio.load_object(path_to_alf_out, 'spikes') 

202 >>> units_b = processing.get_units_bunch(spks_b) 

203 # Get amplitudes for unit 4. 

204 >>> amps = units_b['amps']['4'] 

205 

206 TODO add computation time estimate? 

207 """ 

208 

209 # Initialize `units` 

210 units_b = Bunch() 

211 # Get the keys to return for `units`: 

212 if not args: 

213 feat_keys = list(spks_b.keys()) 

214 else: 

215 feat_keys = args[0] 

216 # Get unit id for each spike and number of units. *Note: `n_units` might not equal `len(units)` 

217 # because some clusters may be empty (due to a "wontfix" bug in ks2). 

218 spks_unit_id = spks_b['clusters'] 

219 n_units = np.max(spks_unit_id) 

220 units = np.unique(spks_b['clusters']) 

221 # For each key in `units`, iteratively get each unit's values and add as a key to a bunch, 

222 # `feat_bunch`. After iterating through all units, add `feat_bunch` as a key to `units`: 

223 for feat in feat_keys: 

224 # Initialize `feat_bunch` with a key for each unit. 

225 feat_bunch = Bunch((str(unit), np.array([])) for unit in np.arange(n_units)) 

226 for unit in units: 

227 unit_idxs = np.where(spks_unit_id == unit)[0] 

228 feat_bunch[str(unit)] = spks_b[feat][unit_idxs] 

229 units_b[feat] = feat_bunch 

230 return units_b 

231 

232 

233def filter_units(units_b, t, **kwargs): 

234 """ 

235 Filters units according to some parameters. **kwargs are the keyword parameters used to filter 

236 the units. 

237 

238 Parameters 

239 ---------- 

240 units_b : bunch 

241 A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.) 

242 whose values are arrays that hold values for each unit. The arrays for each key are ordered 

243 by unit ID. 

244 t : float 

245 Duration of time over which to calculate the firing rate and false positive rate. 

246 

247 Keyword Parameters 

248 ------------------ 

249 min_amp : float 

250 The minimum mean amplitude (in V) of the spikes in the unit. Default value is 50e-6. 

251 min_fr : float 

252 The minimum firing rate (in Hz) of the unit. Default value is 0.5. 

253 max_fpr : float 

254 The maximum false positive rate of the unit (using the fp formula in Hill et al. (2011) 

255 J Neurosci 31: 8699-8705). Default value is 0.2. 

256 rp : float 

257 The refractory period (in s) of the unit. Used to calculate `max_fp`. Default value is 

258 0.002. 

259 

260 Returns 

261 ------- 

262 filt_units : ndarray 

263 The ids of the filtered units. 

264 

265 See Also 

266 -------- 

267 get_units_bunch 

268 

269 Examples 

270 -------- 

271 1) Filter units according to the default parameters. 

272 >>> from brainbox import processing 

273 >>> import one.alf.io as alfio 

274 >>> import ibllib.ephys.spikes as e_spks 

275 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

276 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

277 # Get a spikes bunch, units bunch, and filter the units. 

278 >>> spks_b = alfio.load_object(path_to_alf_out, 'spikes') 

279 >>> units_b = processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters']) 

280 >>> T = spks_b['times'][-1] - spks_b['times'][0] 

281 >>> filtered_units = processing.filter_units(units_b, T) 

282 

283 2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false 

284 positive rate of 0.2, given a refractory period of 2 ms. 

285 >>> filtered_units = processing.filter_units(units_b, T, min_amp=0, min_fr=1) 

286 

287 TODO: `units_b` input arg could eventually be replaced by `clstrs_b` if the required metrics 

288 are in `clstrs_b['metrics']` 

289 """ 

290 

291 # Set params 

292 params = {'min_amp': 50e-6, 'min_fr': 0.5, 'max_fpr': 0.2, 'rp': 0.002} # defaults 

293 params.update(kwargs) # update from **kwargs 

294 

295 # Iteratively filter the units for each filter param # 

296 # -------------------------------------------------- # 

297 units = np.asarray(list(units_b.amps.keys())) 

298 # Remove empty clusters 

299 empty_cl = np.where([len(units_b.amps[unit]) == 0 for unit in units])[0] 

300 filt_units = np.delete(units, empty_cl) 

301 for param in params.keys(): 

302 if param == 'min_amp': # return units above with amp > `'min_amp'` 

303 mean_amps = np.asarray([np.mean(units_b.amps[unit]) for unit in filt_units]) 

304 filt_idxs = np.where(mean_amps > params['min_amp'])[0] 

305 filt_units = filt_units[filt_idxs] 

306 elif param == 'min_fr': # return units with fr > `'min_fr'` 

307 fr = np.asarray([len(units_b.amps[unit]) / 

308 (units_b.times[unit][-1] - units_b.times[unit][0]) 

309 for unit in filt_units]) 

310 filt_idxs = np.where(fr > params['min_fr'])[0] 

311 filt_units = filt_units[filt_idxs] 

312 elif param == 'max_fpr': # return units with fpr < `'max_fpr'` 

313 fpr = np.zeros_like(filt_units, dtype='float') 

314 for i, unit in enumerate(filt_units): 

315 n_spks = len(units_b.amps[unit]) 

316 n_isi_viol = len(np.where(np.diff(units_b.times[unit]) < params['rp'])[0]) 

317 # fpr is min of roots of solved quadratic equation (Hill, et al. 2011). 

318 c = (t * n_isi_viol) / (2 * params['rp'] * n_spks**2) # 3rd term in quadratic 

319 fpr[i] = np.min(np.abs(np.roots([-1, 1, c]))) # solve quadratic 

320 filt_idxs = np.where(fpr < params['max_fpr'])[0] 

321 filt_units = filt_units[filt_idxs] 

322 return filt_units.astype(int)