Coverage for brainbox/behavior/wheel.py: 84%

151 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1""" 

2Set of functions to handle wheel data. 

3""" 

4import numpy as np 

5from numpy import pi 

6from iblutil.numerical import between_sorted 

7import scipy.interpolate as interpolate 

8import scipy.signal 

9from scipy.linalg import hankel 

10import matplotlib.pyplot as plt 

11from matplotlib.collections import LineCollection 

12# from ibllib.io.extractors.ephys_fpga import WHEEL_TICKS # FIXME Circular dependencies 

13 

14__all__ = ['cm_to_deg', 

15 'cm_to_rad', 

16 'interpolate_position', 

17 'get_movement_onset', 

18 'movements', 

19 'samples_to_cm', 

20 'traces_by_trial', 

21 'velocity_filtered'] 

22 

23# Define some constants 

24ENC_RES = 1024 * 4 # Rotary encoder resolution, assumes X4 encoding 

25WHEEL_DIAMETER = 3.1 * 2 # Wheel diameter in cm 

26 

27 

28def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None): 

29 """ 

30 Return linearly interpolated wheel position. 

31 

32 Parameters 

33 ---------- 

34 re_ts : array_like 

35 Array of timestamps 

36 re_pos: array_like 

37 Array of unwrapped wheel positions 

38 freq : float 

39 frequency in Hz of the interpolation 

40 kind : {'linear', 'cubic'} 

41 Type of interpolation. Defaults to linear interpolation. 

42 fill_gaps : float 

43 Minimum gap length to fill. For gaps over this time (seconds), 

44 forward fill values before interpolation 

45 Returns 

46 ------- 

47 yinterp : array 

48 Interpolated position 

49 t : array 

50 Timestamps of interpolated positions 

51 """ 

52 t = np.arange(re_ts[0], re_ts[-1], 1 / freq) # Evenly resample at frequency 1avBwizjyklmnboeCDfghEFGcdpqrstu

53 if t[-1] > re_ts[-1]: 1avBwizjyklmnboeCDfghEFGcdpqrstu

54 t = t[:-1] # Occasionally due to precision errors the last sample may be outside of range. 1oe

55 yinterp = interpolate.interp1d(re_ts, re_pos, kind=kind)(t) 1avBwizjyklmnboeCDfghEFGcdpqrstu

56 

57 if fill_gaps: 1avBwizjyklmnboeCDfghEFGcdpqrstu

58 # Find large gaps and forward fill @fixme This is inefficient 

59 gaps, = np.where(np.diff(re_ts) >= fill_gaps) 

60 

61 for i in gaps: 

62 yinterp[(t >= re_ts[i]) & (t < re_ts[i + 1])] = re_pos[i] 

63 

64 return yinterp, t 1avBwizjyklmnboeCDfghEFGcdpqrstu

65 

66 

67def velocity_filtered(pos, fs, corner_frequency=20, order=8): 

68 """ 

69 Compute wheel velocity from uniformly sampled wheel data. 

70 

71 pos: array_like 

72 Vector of uniformly sampled wheel positions. 

73 fs : float 

74 Frequency in Hz of the sampling frequency. 

75 corner_frequency : float 

76 Corner frequency of low-pass filter. 

77 order : int 

78 Order of Butterworth filter. 

79 

80 Returns 

81 ------- 

82 vel : np.ndarray 

83 Array of velocity values. 

84 acc : np.ndarray 

85 Array of acceleration values. 

86 """ 

87 sos = scipy.signal.butter(**{'N': order, 'Wn': corner_frequency / fs * 2, 'btype': 'lowpass'}, output='sos') 1ABybOeCDfghEFG

88 vel = np.insert(np.diff(scipy.signal.sosfiltfilt(sos, pos)), 0, 0) * fs 1ABybOeCDfghEFG

89 acc = np.insert(np.diff(vel), 0, 0) * fs 1ABybOeCDfghEFG

90 return vel, acc 1ABybOeCDfghEFG

91 

92 

93def get_movement_onset(intervals, event_times): 

94 """ 

95 Find the time at which movement started, given an event timestamp that occurred during the 

96 movement. 

97 

98 Parameters 

99 ---------- 

100 intervals : numpy.array 

101 The wheel movement intervals. 

102 event_times : numpy.array 

103 Sorted event timestamps anywhere during movement of interest, e.g. peak velocity, feedback 

104 time. 

105 

106 Returns 

107 ------- 

108 numpy.array 

109 An array the length of event_time of intervals. 

110 

111 Examples 

112 -------- 

113 Find the last movement onset before each trial response time 

114 

115 >>> trials = one.load_object(eid, 'trials') 

116 >>> wheelMoves = one.load_object(eid, 'wheelMoves') 

117 >>> onsets = last_movement_onset(wheelMoves.intervals, trials.response_times) 

118 """ 

119 if not np.all(np.diff(event_times) > 0): 1K

120 raise ValueError('event_times must be in ascending order.') 1K

121 onsets = np.full(event_times.size, np.nan) 1K

122 for i in np.arange(intervals.shape[0]): 1K

123 onset = between_sorted(event_times, intervals[i, :]) 1K

124 if np.any(onset): 1K

125 onsets[onset] = intervals[i, 0] 1K

126 return onsets 1K

127 

128 

129def movements(t, pos, freq=1000, pos_thresh=8, t_thresh=.2, min_gap=.1, pos_thresh_onset=1.5, 

130 min_dur=.05, make_plots=False): 

131 """ 

132 Detect wheel movements. 

133 

134 Parameters 

135 ---------- 

136 t : array_like 

137 An array of evenly sampled wheel timestamps in absolute seconds 

138 pos : array_like 

139 An array of evenly sampled wheel positions 

140 freq : int 

141 The sampling rate of the wheel data 

142 pos_thresh : float 

143 The minimum required movement during the t_thresh window to be considered part of a 

144 movement 

145 t_thresh : float 

146 The time window over which to check whether the pos_thresh has been crossed 

147 min_gap : float 

148 The minimum time between one movement's offset and another movement's onset in order to be 

149 considered separate. Movements with a gap smaller than this are 'stictched together' 

150 pos_thresh_onset : float 

151 A lower threshold for finding precise onset times. The first position of each movement 

152 transition that is this much bigger than the starting position is considered the onset 

153 min_dur : float 

154 The minimum duration of a valid movement. Detected movements shorter than this are ignored 

155 make_plots : boolean 

156 Plot trace of position and velocity, showing detected onsets and offsets 

157 

158 Returns 

159 ------- 

160 onsets : np.ndarray 

161 Timestamps of detected movement onsets 

162 offsets : np.ndarray 

163 Timestamps of detected movement offsets 

164 peak_amps : np.ndarray 

165 The absolute maximum amplitude of each detected movement, relative to onset position 

166 peak_vel_times : np.ndarray 

167 Timestamps of peak velocity for each detected movement 

168 """ 

169 # Wheel position must be evenly sampled. 

170 dt = np.diff(t) 1axvwizjyklmnboefghcdpqrstu

171 assert np.all(np.abs(dt - dt.mean()) < 1e-10), 'Values not evenly sampled' 1axvwizjyklmnboefghcdpqrstu

172 

173 # Convert the time threshold into number of samples given the sampling frequency 

174 t_thresh_samps = int(np.round(t_thresh * freq)) 1axvwizjyklmnboefghcdpqrstu

175 max_disp = np.empty(t.size, dtype=float) # initialize array of total wheel displacement 1axvwizjyklmnboefghcdpqrstu

176 

177 # Calculate a Hankel matrix of size t_thresh_samps in batches. This is effectively a 

178 # sliding window within which we look for changes in position greater than pos_thresh 

179 BATCH_SIZE = 10000 # do this in batches in order to keep memory usage reasonable 1axvwizjyklmnboefghcdpqrstu

180 c = 0 # index of 'window' position 1axvwizjyklmnboefghcdpqrstu

181 while True: 1axvwizjyklmnboefghcdpqrstu

182 i2proc = np.arange(BATCH_SIZE) + c 1axvwizjyklmnboefghcdpqrstu

183 i2proc = i2proc[i2proc < t.size] 1axvwizjyklmnboefghcdpqrstu

184 w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) 1axvwizjyklmnboefghcdpqrstu

185 # Below is the total change in position for each window 

186 max_disp[i2proc] = np.nanmax(w2e, axis=1) - np.nanmin(w2e, axis=1) 1axvwizjyklmnboefghcdpqrstu

187 c += BATCH_SIZE - t_thresh_samps 1axvwizjyklmnboefghcdpqrstu

188 if i2proc[-1] == t.size - 1: 1axvwizjyklmnboefghcdpqrstu

189 break 1axvwizjyklmnboefghcdpqrstu

190 

191 moving = max_disp > pos_thresh # for each window is the change in position greater than our threshold? 1axvwizjyklmnboefghcdpqrstu

192 moving = np.insert(moving, 0, False) # First sample should always be not moving to ensure we have an onset 1axvwizjyklmnboefghcdpqrstu

193 moving[-1] = False # Likewise, ensure we always end on an offset 1axvwizjyklmnboefghcdpqrstu

194 

195 onset_samps = np.where(~moving[:-1] & moving[1:])[0] 1axvwizjyklmnboefghcdpqrstu

196 offset_samps = np.where(moving[:-1] & ~moving[1:])[0] 1axvwizjyklmnboefghcdpqrstu

197 too_short = np.where((onset_samps[1:] - offset_samps[:-1]) / freq < min_gap)[0] 1axvwizjyklmnboefghcdpqrstu

198 for p in too_short: 1axvwizjyklmnboefghcdpqrstu

199 moving[offset_samps[p]:onset_samps[p + 1] + 1] = True 1axvwijklmnbfghcdpqrstu

200 

201 onset_samps = np.where(~moving[:-1] & moving[1:])[0] 1axvwizjyklmnboefghcdpqrstu

202 onsets_disp_arr = np.empty((onset_samps.size, t_thresh_samps)) 1axvwizjyklmnboefghcdpqrstu

203 c = 0 1axvwizjyklmnboefghcdpqrstu

204 cwt = 0 1axvwizjyklmnboefghcdpqrstu

205 while onset_samps.size != 0: 1axvwizjyklmnboefghcdpqrstu

206 i2proc = np.arange(BATCH_SIZE) + c 1axvwijklmnboefghcdpqrstu

207 icomm = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, assume_unique=True) 1axvwijklmnboefghcdpqrstu

208 itpltz = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, 1axvwijklmnboefghcdpqrstu

209 return_indices=True, assume_unique=True)[1] 

210 i2proc = i2proc[i2proc < t.size] 1axvwijklmnboefghcdpqrstu

211 if icomm.size > 0: 1axvwijklmnboefghcdpqrstu

212 w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) 1axvwijklmnboefghcdpqrstu

213 w2e = np.abs((w2e.T - w2e[:, 0]).T) 1axvwijklmnboefghcdpqrstu

214 onsets_disp_arr[cwt + np.arange(icomm.size), :] = w2e[itpltz, :] 1axvwijklmnboefghcdpqrstu

215 cwt += icomm.size 1axvwijklmnboefghcdpqrstu

216 c += BATCH_SIZE - t_thresh_samps 1axvwijklmnboefghcdpqrstu

217 if i2proc[-1] >= onset_samps[-1]: 1axvwijklmnboefghcdpqrstu

218 break 1axvwijklmnboefghcdpqrstu

219 

220 has_onset = onsets_disp_arr > pos_thresh_onset 1axvwizjyklmnboefghcdpqrstu

221 A = np.argmin(np.fliplr(has_onset).T, axis=0) 1axvwizjyklmnboefghcdpqrstu

222 onset_lags = t_thresh_samps - A 1axvwizjyklmnboefghcdpqrstu

223 onset_samps = onset_samps + onset_lags - 1 1axvwizjyklmnboefghcdpqrstu

224 onsets = t[onset_samps] 1axvwizjyklmnboefghcdpqrstu

225 offset_samps = np.where(moving[:-1] & ~moving[1:])[0] 1axvwizjyklmnboefghcdpqrstu

226 offsets = t[offset_samps] 1axvwizjyklmnboefghcdpqrstu

227 

228 durations = offsets - onsets 1axvwizjyklmnboefghcdpqrstu

229 too_short = durations < min_dur 1axvwizjyklmnboefghcdpqrstu

230 onset_samps = onset_samps[~too_short] 1axvwizjyklmnboefghcdpqrstu

231 onsets = onsets[~too_short] 1axvwizjyklmnboefghcdpqrstu

232 offset_samps = offset_samps[~too_short] 1axvwizjyklmnboefghcdpqrstu

233 offsets = offsets[~too_short] 1axvwizjyklmnboefghcdpqrstu

234 

235 moveGaps = onsets[1:] - offsets[:-1] 1axvwizjyklmnboefghcdpqrstu

236 gap_too_small = moveGaps < min_gap 1axvwizjyklmnboefghcdpqrstu

237 if onsets.size > 0: 1axvwizjyklmnboefghcdpqrstu

238 onsets = onsets[np.insert(~gap_too_small, 0, True)] # always keep first onset 1axvwijklmnboefghcdpqrstu

239 onset_samps = onset_samps[np.insert(~gap_too_small, 0, True)] 1axvwijklmnboefghcdpqrstu

240 offsets = offsets[np.append(~gap_too_small, True)] # always keep last offset 1axvwijklmnboefghcdpqrstu

241 offset_samps = offset_samps[np.append(~gap_too_small, True)] 1axvwijklmnboefghcdpqrstu

242 

243 # Calculate the peak amplitudes - 

244 # the maximum absolute value of the difference from the onset position 

245 peaks = (pos[m + np.abs(pos[m:n] - pos[m]).argmax()] - pos[m] 1axvwizjyklmnboefghcdpqrstu

246 for m, n in zip(onset_samps, offset_samps)) 

247 peak_amps = np.fromiter(peaks, dtype=float, count=onsets.size) 1axvwizjyklmnboefghcdpqrstu

248 N = 10 # Number of points in the Gaussian 1axvwizjyklmnboefghcdpqrstu

249 STDEV = 1.8 # Equivalent to a width factor (alpha value) of 2.5 1axvwizjyklmnboefghcdpqrstu

250 gauss = scipy.signal.windows.gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d. 1axvwizjyklmnboefghcdpqrstu

251 vel = scipy.signal.convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same') 1axvwizjyklmnboefghcdpqrstu

252 # For each movement period, find the timestamp where the absolute velocity was greatest 

253 peaks = (t[m + np.abs(vel[m:n]).argmax()] for m, n in zip(onset_samps, offset_samps)) 1axvwizjyklmnboefghcdpqrstu

254 peak_vel_times = np.fromiter(peaks, dtype=float, count=onsets.size) 1axvwizjyklmnboefghcdpqrstu

255 

256 if make_plots: 1axvwizjyklmnboefghcdpqrstu

257 fig, axes = plt.subplots(nrows=2, sharex='all') 

258 indices = np.sort(np.hstack((onset_samps, offset_samps))) # Points to split trace 

259 vel, acc = velocity_filtered(pos, freq) 

260 

261 # Plot the wheel position and velocity 

262 for ax, y in zip(axes, (pos, vel)): 

263 ax.plot(onsets, y[onset_samps], 'go') 

264 ax.plot(offsets, y[offset_samps], 'bo') 

265 

266 t_split = np.split(np.vstack((t, y)).T, indices, axis=0) 

267 ax.add_collection(LineCollection(t_split[1::2], colors='r')) # Moving 

268 ax.add_collection(LineCollection(t_split[0::2], colors='k')) # Not moving 

269 

270 axes[1].autoscale() # rescale after adding line collections 

271 axes[0].autoscale() 

272 axes[0].set_ylabel('position') 

273 axes[1].set_ylabel('velocity') 

274 axes[1].set_xlabel('time') 

275 axes[0].legend(['onsets', 'offsets', 'in movement']) 

276 plt.show() 

277 

278 return onsets, offsets, peak_amps, peak_vel_times 1axvwizjyklmnboefghcdpqrstu

279 

280 

281def cm_to_deg(positions, wheel_diameter=WHEEL_DIAMETER): 

282 """ 

283 Convert wheel position to degrees turned. This may be useful for e.g. calculating velocity 

284 in revolutions per second 

285 :param positions: array of wheel positions in cm 

286 :param wheel_diameter: the diameter of the wheel in cm 

287 :return: array of wheel positions in degrees turned 

288 

289 # Example: Convert linear cm to degrees 

290 >>> cm_to_deg(3.142 * WHEEL_DIAMETER) 

291 360.04667846020925 

292 

293 # Example: Get positions in deg from cm for 5cm diameter wheel 

294 >>> import numpy as np 

295 >>> cm_to_deg(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5) 

296 array([0.61999992, 0.93000011, 1.24000007, 1.55000003]) 

297 """ 

298 return positions / (wheel_diameter * pi) * 360 

299 

300 

301def cm_to_rad(positions, wheel_diameter=WHEEL_DIAMETER): 

302 """ 

303 Convert wheel position to radians. This may be useful for e.g. calculating angular velocity. 

304 :param positions: array of wheel positions in cm 

305 :param wheel_diameter: the diameter of the wheel in cm 

306 :return: array of wheel angle in radians 

307 

308 # Example: Convert linear cm to radians 

309 >>> cm_to_rad(1) 

310 0.3225806451612903 

311 

312 # Example: Get positions in rad from cm for 5cm diameter wheel 

313 >>> import numpy as np 

314 >>> cm_to_rad(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5) 

315 array([0.01082104, 0.01623156, 0.02164208, 0.0270526 ]) 

316 """ 

317 return positions * (2 / wheel_diameter) 1aizjHklmnboecdpqrIJstu

318 

319 

320def samples_to_cm(positions, wheel_diameter=WHEEL_DIAMETER, resolution=ENC_RES): 

321 """ 

322 Convert wheel position samples to cm linear displacement. This may be useful for 

323 inter-converting threshold units 

324 :param positions: array of wheel positions in sample counts 

325 :param wheel_diameter: the diameter of the wheel in cm 

326 :param resolution: resolution of the rotary encoder 

327 :return: array of wheel angle in radians 

328 

329 # Example: Get resolution in linear cm 

330 >>> samples_to_cm(1) 

331 0.004755340442445488 

332 

333 # Example: Get positions in linear cm for 4X, 360 ppr encoder 

334 >>> import numpy as np 

335 >>> samples_to_cm(np.array([2, 3, 4, 5, 6, 7, 6, 5, 4]), resolution=360*4) 

336 array([0.0270526 , 0.04057891, 0.05410521, 0.06763151, 0.08115781, 

337 0.09468411, 0.08115781, 0.06763151, 0.05410521]) 

338 """ 

339 return positions / resolution * pi * wheel_diameter 1avwizjklmnboecdpqrstu

340 

341 

342def direction_changes(t, vel, intervals): 

343 """ 

344 Find the direction changes for the given movement intervals. 

345 

346 Parameters 

347 ---------- 

348 t : array_like 

349 An array of evenly sampled wheel timestamps in absolute seconds 

350 vel : array_like 

351 An array of evenly sampled wheel positions 

352 intervals : array_like 

353 An n-by-2 array of wheel movement intervals 

354 

355 Returns 

356 ---------- 

357 times : iterable 

358 A list of numpy arrays of direction change timestamps, one array per interval 

359 indices : iterable 

360 A list of numpy arrays containing indices of direction changes; the size of times 

361 """ 

362 indices = [] 1A

363 times = [] 1A

364 chg = np.insert(np.diff(np.sign(vel)) != 0, 0, 0) 1A

365 

366 for on, off in intervals.reshape(-1, 2): 1A

367 mask = np.logical_and(t > on, t < off) 1A

368 ind, = np.where(np.logical_and(mask, chg)) 1A

369 times.append(t[ind]) 1A

370 indices.append(ind) 1A

371 

372 return times, indices 1A

373 

374 

375def traces_by_trial(t, *args, start=None, end=None, separate=True): 

376 """ 

377 Returns list of tuples of positions and velocity for samples between stimulus onset and 

378 feedback. 

379 :param t: numpy array of timestamps 

380 :param args: optional numpy arrays of the same length as timestamps, such as positions, 

381 velocities or accelerations 

382 :param start: start timestamp or array thereof 

383 :param end: end timestamp or array thereof 

384 :param separate: when True, the output is returned as tuples list of the form [(t, args[0], 

385 args[1]), ...], when False, the output is a list of n-by-m ndarrays where n = number of 

386 positional args and m = len(t) 

387 :return: list of sliced arrays where length == len(start) 

388 """ 

389 if start is None: 1aLMNHbcdIJ

390 start = t[0] 

391 if end is None: 1aLMNHbcdIJ

392 end = t[-1] 

393 traces = np.stack((t, *args)) 1aLMNHbcdIJ

394 assert len(start) == len(end), 'number of start timestamps must equal end timestamps' 1aLMNHbcdIJ

395 

396 def to_mask(a, b): 1aLMNHbcdIJ

397 return np.logical_and(t > a, t < b) 1aLMNHbcdIJ

398 

399 cuts = [traces[:, to_mask(s, e)] for s, e in zip(start, end)] 1aLMNHbcdIJ

400 return [(cuts[n][0, :], cuts[n][1, :]) for n in range(len(cuts))] if separate else cuts 1aLMNHbcdIJ

401 

402 

403if __name__ == '__main__': 

404 import doctest 

405 doctest.testmod()