Coverage for brainbox/behavior/wheel.py: 67%

198 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1""" 

2Set of functions to handle wheel data. 

3""" 

4import logging 

5import warnings 

6import traceback 

7 

8import numpy as np 

9from numpy import pi 

10from iblutil.numerical import between_sorted 

11import scipy.interpolate as interpolate 

12import scipy.signal 

13from scipy.linalg import hankel 

14import matplotlib.pyplot as plt 

15from matplotlib.collections import LineCollection 

16# from ibllib.io.extractors.ephys_fpga import WHEEL_TICKS # FIXME Circular dependencies 

17 

18__all__ = ['cm_to_deg', 

19 'cm_to_rad', 

20 'interpolate_position', 

21 'get_movement_onset', 

22 'movements', 

23 'samples_to_cm', 

24 'traces_by_trial', 

25 'velocity_filtered'] 

26 

27# Define some constants 

28ENC_RES = 1024 * 4 # Rotary encoder resolution, assumes X4 encoding 

29WHEEL_DIAMETER = 3.1 * 2 # Wheel diameter in cm 

30 

31 

32def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None): 

33 """ 

34 Return linearly interpolated wheel position. 

35 

36 Parameters 

37 ---------- 

38 re_ts : array_like 

39 Array of timestamps 

40 re_pos: array_like 

41 Array of unwrapped wheel positions 

42 freq : float 

43 frequency in Hz of the interpolation 

44 kind : {'linear', 'cubic'} 

45 Type of interpolation. Defaults to linear interpolation. 

46 fill_gaps : float 

47 Minimum gap length to fill. For gaps over this time (seconds), 

48 forward fill values before interpolation 

49 Returns 

50 ------- 

51 yinterp : array 

52 Interpolated position 

53 t : array 

54 Timestamps of interpolated positions 

55 """ 

56 t = np.arange(re_ts[0], re_ts[-1], 1 / freq) # Evenly resample at frequency 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD

57 if t[-1] > re_ts[-1]: 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD

58 t = t[:-1] # Occasionally due to precision errors the last sample may be outside of range. 1tf

59 yinterp = interpolate.interp1d(re_ts, re_pos, kind=kind)(t) 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD

60 

61 if fill_gaps: 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD

62 # Find large gaps and forward fill @fixme This is inefficient 

63 gaps, = np.where(np.diff(re_ts) >= fill_gaps) 

64 

65 for i in gaps: 

66 yinterp[(t >= re_ts[i]) & (t < re_ts[i + 1])] = re_pos[i] 

67 

68 return yinterp, t 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD

69 

70 

71def velocity(re_ts, re_pos): 

72 """ 

73 (DEPRECATED) Compute wheel velocity from non-uniformly sampled wheel data. Returns the velocity 

74 at the same samples locations as the position through interpolation. 

75 

76 Parameters 

77 ---------- 

78 re_ts : array_like 

79 Array of timestamps 

80 re_pos: array_like 

81 Array of unwrapped wheel positions 

82 

83 Returns 

84 ------- 

85 np.ndarray 

86 numpy array of velocities 

87 """ 

88 for line in traceback.format_stack(): 

89 print(line.strip()) 

90 

91 msg = 'brainbox.behavior.wheel.velocity will soon be removed. Use velocity_filtered instead.' 

92 warnings.warn(msg, FutureWarning) 

93 logging.getLogger(__name__).warning(msg) 

94 

95 dp = np.diff(re_pos) 

96 dt = np.diff(re_ts) 

97 # Compute raw velocity 

98 vel = dp / dt 

99 # Compute velocity time scale 

100 tv = re_ts[:-1] + dt / 2 

101 # interpolate over original time scale 

102 if tv.size > 1: 

103 ifcn = interpolate.interp1d(tv, vel, fill_value="extrapolate") 

104 return ifcn(re_ts) 

105 

106 

107def velocity_filtered(pos, fs, corner_frequency=20, order=8): 

108 """ 

109 Compute wheel velocity from uniformly sampled wheel data. 

110 

111 pos: array_like 

112 Vector of uniformly sampled wheel positions. 

113 fs : float 

114 Frequency in Hz of the sampling frequency. 

115 corner_frequency : float 

116 Corner frequency of low-pass filter. 

117 order : int 

118 Order of Butterworth filter. 

119 

120 Returns 

121 ------- 

122 vel : np.ndarray 

123 Array of velocity values. 

124 acc : np.ndarray 

125 Array of acceleration values. 

126 """ 

127 sos = scipy.signal.butter(**{'N': order, 'Wn': corner_frequency / fs * 2, 'btype': 'lowpass'}, output='sos') 1JKHbZfLMNghiOPQc

128 vel = np.insert(np.diff(scipy.signal.sosfiltfilt(sos, pos)), 0, 0) * fs 1JKHbZfLMNghiOPQc

129 acc = np.insert(np.diff(vel), 0, 0) * fs 1JKHbZfLMNghiOPQc

130 return vel, acc 1JKHbZfLMNghiOPQc

131 

132 

133def velocity_smoothed(pos, freq, smooth_size=0.03): 

134 """ 

135 (DEPRECATED) Compute wheel velocity from uniformly sampled wheel data. 

136 

137 Parameters 

138 ---------- 

139 pos : array_like 

140 Array of wheel positions 

141 smooth_size : float 

142 Size of Gaussian smoothing window in seconds 

143 freq : float 

144 Sampling frequency of the data 

145 

146 Returns 

147 ------- 

148 vel : np.ndarray 

149 Array of velocity values 

150 acc : np.ndarray 

151 Array of acceleration values 

152 """ 

153 for line in traceback.format_stack(): 

154 print(line.strip()) 

155 

156 msg = 'brainbox.behavior.wheel.velocity_smoothed will be removed. Use velocity_filtered instead.' 

157 warnings.warn(msg, FutureWarning) 

158 logging.getLogger(__name__).warning(msg) 

159 

160 # Define our smoothing window with an area of 1 so the units won't be changed 

161 std_samps = np.round(smooth_size * freq) # Standard deviation relative to sampling frequency 

162 N = std_samps * 6 # Number of points in the Gaussian covering +/-3 standard deviations 

163 gauss_std = (N - 1) / 6 

164 win = scipy.signal.windows.gaussian(N, gauss_std) 

165 win = win / win.sum() # Normalize amplitude 

166 

167 # Convolve and multiply by sampling frequency to restore original units 

168 vel = np.insert(scipy.signal.convolve(np.diff(pos), win, mode='same'), 0, 0) * freq 

169 acc = np.insert(scipy.signal.convolve(np.diff(vel), win, mode='same'), 0, 0) * freq 

170 

171 return vel, acc 

172 

173 

174def last_movement_onset(t, vel, event_time): 

175 """ 

176 (DEPRECATED) Find the time at which movement started, given an event timestamp that occurred during the 

177 movement. 

178 

179 Movement start is defined as the first sample after the velocity has been zero for at least 50ms. 

180 Wheel inputs should be evenly sampled. 

181 

182 :param t: numpy array of wheel timestamps in seconds 

183 :param vel: numpy array of wheel velocities 

184 :param event_time: timestamp anywhere during movement of interest, e.g. peak velocity 

185 :return: timestamp of movement onset 

186 """ 

187 for line in traceback.format_stack(): 

188 print(line.strip()) 

189 

190 msg = 'brainbox.behavior.wheel.last_movement_onset has been deprecated. Use get_movement_onset instead.' 

191 warnings.warn(msg, FutureWarning) 

192 logging.getLogger(__name__).warning(msg) 

193 

194 # Look back from timestamp 

195 threshold = 50e-3 

196 mask = t < event_time 

197 times = t[mask] 

198 vel = vel[mask] 

199 t = None # Initialize 

200 for i, t in enumerate(times[::-1]): 

201 i = times.size - i 

202 idx = np.min(np.where((t - times) < threshold)) 

203 if np.max(np.abs(vel[idx:i])) < 0.5: 

204 break 

205 

206 # Return timestamp 

207 return t 

208 

209 

210def get_movement_onset(intervals, event_times): 

211 """ 

212 Find the time at which movement started, given an event timestamp that occurred during the 

213 movement. 

214 

215 Parameters 

216 ---------- 

217 intervals : numpy.array 

218 The wheel movement intervals. 

219 event_times : numpy.array 

220 Sorted event timestamps anywhere during movement of interest, e.g. peak velocity, feedback 

221 time. 

222 

223 Returns 

224 ------- 

225 numpy.array 

226 An array the length of event_time of intervals. 

227 

228 Examples 

229 -------- 

230 Find the last movement onset before each trial response time 

231 

232 >>> trials = one.load_object(eid, 'trials') 

233 >>> wheelMoves = one.load_object(eid, 'wheelMoves') 

234 >>> onsets = last_movement_onset(wheelMoves.intervals, trials.response_times) 

235 """ 

236 if not np.all(np.diff(event_times) > 0): 1V

237 raise ValueError('event_times must be in ascending order.') 1V

238 onsets = np.full(event_times.size, np.nan) 1V

239 for i in np.arange(intervals.shape[0]): 1V

240 onset = between_sorted(event_times, intervals[i, :]) 1V

241 if np.any(onset): 1V

242 onsets[onset] = intervals[i, 0] 1V

243 return onsets 1V

244 

245 

246def movements(t, pos, freq=1000, pos_thresh=8, t_thresh=.2, min_gap=.1, pos_thresh_onset=1.5, 

247 min_dur=.05, make_plots=False): 

248 """ 

249 Detect wheel movements. 

250 

251 Parameters 

252 ---------- 

253 t : array_like 

254 An array of evenly sampled wheel timestamps in absolute seconds 

255 pos : array_like 

256 An array of evenly sampled wheel positions 

257 freq : int 

258 The sampling rate of the wheel data 

259 pos_thresh : float 

260 The minimum required movement during the t_thresh window to be considered part of a 

261 movement 

262 t_thresh : float 

263 The time window over which to check whether the pos_thresh has been crossed 

264 min_gap : float 

265 The minimum time between one movement's offset and another movement's onset in order to be 

266 considered separate. Movements with a gap smaller than this are 'stictched together' 

267 pos_thresh_onset : float 

268 A lower threshold for finding precise onset times. The first position of each movement 

269 transition that is this much bigger than the starting position is considered the onset 

270 min_dur : float 

271 The minimum duration of a valid movement. Detected movements shorter than this are ignored 

272 make_plots : boolean 

273 Plot trace of position and velocity, showing detected onsets and offsets 

274 

275 Returns 

276 ------- 

277 onsets : np.ndarray 

278 Timestamps of detected movement onsets 

279 offsets : np.ndarray 

280 Timestamps of detected movement offsets 

281 peak_amps : np.ndarray 

282 The absolute maximum amplitude of each detected movement, relative to onset position 

283 peak_vel_times : np.ndarray 

284 Timestamps of peak velocity for each detected movement 

285 """ 

286 # Wheel position must be evenly sampled. 

287 dt = np.diff(t) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

288 assert np.all(np.abs(dt - dt.mean()) < 1e-10), 'Values not evenly sampled' 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

289 

290 # Convert the time threshold into number of samples given the sampling frequency 

291 t_thresh_samps = int(np.round(t_thresh * freq)) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

292 max_disp = np.empty(t.size, dtype=float) # initialize array of total wheel displacement 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

293 

294 # Calculate a Hankel matrix of size t_thresh_samps in batches. This is effectively a 

295 # sliding window within which we look for changes in position greater than pos_thresh 

296 BATCH_SIZE = 10000 # do this in batches in order to keep memory usage reasonable 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

297 c = 0 # index of 'window' position 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

298 while True: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

299 i2proc = np.arange(BATCH_SIZE) + c 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

300 i2proc = i2proc[i2proc < t.size] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

301 w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

302 # Below is the total change in position for each window 

303 max_disp[i2proc] = np.nanmax(w2e, axis=1) - np.nanmin(w2e, axis=1) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

304 c += BATCH_SIZE - t_thresh_samps 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

305 if i2proc[-1] == t.size - 1: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

306 break 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

307 

308 moving = max_disp > pos_thresh # for each window is the change in position greater than our threshold? 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

309 moving = np.insert(moving, 0, False) # First sample should always be not moving to ensure we have an onset 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

310 moving[-1] = False # Likewise, ensure we always end on an offset 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

311 

312 onset_samps = np.where(~moving[:-1] & moving[1:])[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

313 offset_samps = np.where(moving[:-1] & ~moving[1:])[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

314 too_short = np.where((onset_samps[1:] - offset_samps[:-1]) / freq < min_gap)[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

315 for p in too_short: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

316 moving[offset_samps[p]:onset_samps[p + 1] + 1] = True 1aGEFjklmnopqrsbghideuvwxyzAcBCD

317 

318 onset_samps = np.where(~moving[:-1] & moving[1:])[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

319 onsets_disp_arr = np.empty((onset_samps.size, t_thresh_samps)) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

320 c = 0 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

321 cwt = 0 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

322 while onset_samps.size != 0: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

323 i2proc = np.arange(BATCH_SIZE) + c 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

324 icomm = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, assume_unique=True) 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

325 itpltz = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

326 return_indices=True, assume_unique=True)[1] 

327 i2proc = i2proc[i2proc < t.size] 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

328 if icomm.size > 0: 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

329 w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

330 w2e = np.abs((w2e.T - w2e[:, 0]).T) 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

331 onsets_disp_arr[cwt + np.arange(icomm.size), :] = w2e[itpltz, :] 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

332 cwt += icomm.size 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

333 c += BATCH_SIZE - t_thresh_samps 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

334 if i2proc[-1] >= onset_samps[-1]: 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

335 break 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

336 

337 has_onset = onsets_disp_arr > pos_thresh_onset 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

338 A = np.argmin(np.fliplr(has_onset).T, axis=0) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

339 onset_lags = t_thresh_samps - A 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

340 onset_samps = onset_samps + onset_lags - 1 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

341 onsets = t[onset_samps] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

342 offset_samps = np.where(moving[:-1] & ~moving[1:])[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

343 offsets = t[offset_samps] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

344 

345 durations = offsets - onsets 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

346 too_short = durations < min_dur 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

347 onset_samps = onset_samps[~too_short] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

348 onsets = onsets[~too_short] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

349 offset_samps = offset_samps[~too_short] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

350 offsets = offsets[~too_short] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

351 

352 moveGaps = onsets[1:] - offsets[:-1] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

353 gap_too_small = moveGaps < min_gap 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

354 if onsets.size > 0: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

355 onsets = onsets[np.insert(~gap_too_small, 0, True)] # always keep first onset 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

356 onset_samps = onset_samps[np.insert(~gap_too_small, 0, True)] 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

357 offsets = offsets[np.append(~gap_too_small, True)] # always keep last offset 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

358 offset_samps = offset_samps[np.append(~gap_too_small, True)] 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD

359 

360 # Calculate the peak amplitudes - 

361 # the maximum absolute value of the difference from the onset position 

362 peaks = (pos[m + np.abs(pos[m:n] - pos[m]).argmax()] - pos[m] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

363 for m, n in zip(onset_samps, offset_samps)) 

364 peak_amps = np.fromiter(peaks, dtype=float, count=onsets.size) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

365 N = 10 # Number of points in the Gaussian 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

366 STDEV = 1.8 # Equivalent to a width factor (alpha value) of 2.5 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

367 gauss = scipy.signal.windows.gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d. 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

368 vel = scipy.signal.convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same') 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

369 # For each movement period, find the timestamp where the absolute velocity was greatest 

370 peaks = (t[m + np.abs(vel[m:n]).argmax()] for m, n in zip(onset_samps, offset_samps)) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

371 peak_vel_times = np.fromiter(peaks, dtype=float, count=onsets.size) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

372 

373 if make_plots: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

374 fig, axes = plt.subplots(nrows=2, sharex='all') 

375 indices = np.sort(np.hstack((onset_samps, offset_samps))) # Points to split trace 

376 vel, acc = velocity_filtered(pos, freq) 

377 

378 # Plot the wheel position and velocity 

379 for ax, y in zip(axes, (pos, vel)): 

380 ax.plot(onsets, y[onset_samps], 'go') 

381 ax.plot(offsets, y[offset_samps], 'bo') 

382 

383 t_split = np.split(np.vstack((t, y)).T, indices, axis=0) 

384 ax.add_collection(LineCollection(t_split[1::2], colors='r')) # Moving 

385 ax.add_collection(LineCollection(t_split[0::2], colors='k')) # Not moving 

386 

387 axes[1].autoscale() # rescale after adding line collections 

388 axes[0].autoscale() 

389 axes[0].set_ylabel('position') 

390 axes[1].set_ylabel('velocity') 

391 axes[1].set_xlabel('time') 

392 axes[0].legend(['onsets', 'offsets', 'in movement']) 

393 plt.show() 

394 

395 return onsets, offsets, peak_amps, peak_vel_times 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD

396 

397 

398def cm_to_deg(positions, wheel_diameter=WHEEL_DIAMETER): 

399 """ 

400 Convert wheel position to degrees turned. This may be useful for e.g. calculating velocity 

401 in revolutions per second 

402 :param positions: array of wheel positions in cm 

403 :param wheel_diameter: the diameter of the wheel in cm 

404 :return: array of wheel positions in degrees turned 

405 

406 # Example: Convert linear cm to degrees 

407 >>> cm_to_deg(3.142 * WHEEL_DIAMETER) 

408 360.04667846020925 

409 

410 # Example: Get positions in deg from cm for 5cm diameter wheel 

411 >>> import numpy as np 

412 >>> cm_to_deg(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5) 

413 array([0.61999992, 0.93000011, 1.24000007, 1.55000003]) 

414 """ 

415 return positions / (wheel_diameter * pi) * 360 

416 

417 

418def cm_to_rad(positions, wheel_diameter=WHEEL_DIAMETER): 

419 """ 

420 Convert wheel position to radians. This may be useful for e.g. calculating angular velocity. 

421 :param positions: array of wheel positions in cm 

422 :param wheel_diameter: the diameter of the wheel in cm 

423 :return: array of wheel angle in radians 

424 

425 # Example: Convert linear cm to radians 

426 >>> cm_to_rad(1) 

427 0.3225806451612903 

428 

429 # Example: Get positions in rad from cm for 5cm diameter wheel 

430 >>> import numpy as np 

431 >>> cm_to_rad(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5) 

432 array([0.01082104, 0.01623156, 0.02164208, 0.0270526 ]) 

433 """ 

434 return positions * (2 / wheel_diameter) 1ajIklRmnopqrsbtfdeuvwxyzASTUcBCD

435 

436 

437def samples_to_cm(positions, wheel_diameter=WHEEL_DIAMETER, resolution=ENC_RES): 

438 """ 

439 Convert wheel position samples to cm linear displacement. This may be useful for 

440 inter-converting threshold units 

441 :param positions: array of wheel positions in sample counts 

442 :param wheel_diameter: the diameter of the wheel in cm 

443 :param resolution: resolution of the rotary encoder 

444 :return: array of wheel angle in radians 

445 

446 # Example: Get resolution in linear cm 

447 >>> samples_to_cm(1) 

448 0.004755340442445488 

449 

450 # Example: Get positions in linear cm for 4X, 360 ppr encoder 

451 >>> import numpy as np 

452 >>> samples_to_cm(np.array([2, 3, 4, 5, 6, 7, 6, 5, 4]), resolution=360*4) 

453 array([0.0270526 , 0.04057891, 0.05410521, 0.06763151, 0.08115781, 

454 0.09468411, 0.08115781, 0.06763151, 0.05410521]) 

455 """ 

456 return positions / resolution * pi * wheel_diameter 1aEFjIklmnopqrsbtfdeuvwxyzAcBCD

457 

458 

459def direction_changes(t, vel, intervals): 

460 """ 

461 Find the direction changes for the given movement intervals. 

462 

463 Parameters 

464 ---------- 

465 t : array_like 

466 An array of evenly sampled wheel timestamps in absolute seconds 

467 vel : array_like 

468 An array of evenly sampled wheel positions 

469 intervals : array_like 

470 An n-by-2 array of wheel movement intervals 

471 

472 Returns 

473 ---------- 

474 times : iterable 

475 A list of numpy arrays of direction change timestamps, one array per interval 

476 indices : iterable 

477 A list of numpy arrays containing indices of direction changes; the size of times 

478 """ 

479 indices = [] 1J

480 times = [] 1J

481 chg = np.insert(np.diff(np.sign(vel)) != 0, 0, 0) 1J

482 

483 for on, off in intervals.reshape(-1, 2): 1J

484 mask = np.logical_and(t > on, t < off) 1J

485 ind, = np.where(np.logical_and(mask, chg)) 1J

486 times.append(t[ind]) 1J

487 indices.append(ind) 1J

488 

489 return times, indices 1J

490 

491 

492def traces_by_trial(t, *args, start=None, end=None, separate=True): 

493 """ 

494 Returns list of tuples of positions and velocity for samples between stimulus onset and 

495 feedback. 

496 :param t: numpy array of timestamps 

497 :param args: optional numpy arrays of the same length as timestamps, such as positions, 

498 velocities or accelerations 

499 :param start: start timestamp or array thereof 

500 :param end: end timestamp or array thereof 

501 :param separate: when True, the output is returned as tuples list of the form [(t, args[0], 

502 args[1]), ...], when False, the output is a list of n-by-m ndarrays where n = number of 

503 positional args and m = len(t) 

504 :return: list of sliced arrays where length == len(start) 

505 """ 

506 if start is None: 1WXYRbdeSTUc

507 start = t[0] 

508 if end is None: 1WXYRbdeSTUc

509 end = t[-1] 

510 traces = np.stack((t, *args)) 1WXYRbdeSTUc

511 assert len(start) == len(end), 'number of start timestamps must equal end timestamps' 1WXYRbdeSTUc

512 

513 def to_mask(a, b): 1WXYRbdeSTUc

514 return np.logical_and(t > a, t < b) 1WXYRbdeSTUc

515 

516 cuts = [traces[:, to_mask(s, e)] for s, e in zip(start, end)] 1WXYRbdeSTUc

517 return [(cuts[n][0, :], cuts[n][1, :]) for n in range(len(cuts))] if separate else cuts 1WXYRbdeSTUc

518 

519 

520if __name__ == '__main__': 

521 import doctest 

522 doctest.testmod()