Coverage for brainbox/behavior/wheel.py: 84%
151 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""
2Set of functions to handle wheel data.
3"""
4import numpy as np
5from numpy import pi
6from iblutil.numerical import between_sorted
7import scipy.interpolate as interpolate
8import scipy.signal
9from scipy.linalg import hankel
10import matplotlib.pyplot as plt
11from matplotlib.collections import LineCollection
12# from ibllib.io.extractors.ephys_fpga import WHEEL_TICKS # FIXME Circular dependencies
14__all__ = ['cm_to_deg',
15 'cm_to_rad',
16 'interpolate_position',
17 'get_movement_onset',
18 'movements',
19 'samples_to_cm',
20 'traces_by_trial',
21 'velocity_filtered']
23# Define some constants
24ENC_RES = 1024 * 4 # Rotary encoder resolution, assumes X4 encoding
25WHEEL_DIAMETER = 3.1 * 2 # Wheel diameter in cm
28def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None):
29 """
30 Return linearly interpolated wheel position.
32 Parameters
33 ----------
34 re_ts : array_like
35 Array of timestamps
36 re_pos: array_like
37 Array of unwrapped wheel positions
38 freq : float
39 frequency in Hz of the interpolation
40 kind : {'linear', 'cubic'}
41 Type of interpolation. Defaults to linear interpolation.
42 fill_gaps : float
43 Minimum gap length to fill. For gaps over this time (seconds),
44 forward fill values before interpolation
45 Returns
46 -------
47 yinterp : array
48 Interpolated position
49 t : array
50 Timestamps of interpolated positions
51 """
52 t = np.arange(re_ts[0], re_ts[-1], 1 / freq) # Evenly resample at frequency 1avBwizjyklmnboeCDfghEFGcdpqrstu
53 if t[-1] > re_ts[-1]: 1avBwizjyklmnboeCDfghEFGcdpqrstu
54 t = t[:-1] # Occasionally due to precision errors the last sample may be outside of range. 1oe
55 yinterp = interpolate.interp1d(re_ts, re_pos, kind=kind)(t) 1avBwizjyklmnboeCDfghEFGcdpqrstu
57 if fill_gaps: 1avBwizjyklmnboeCDfghEFGcdpqrstu
58 # Find large gaps and forward fill @fixme This is inefficient
59 gaps, = np.where(np.diff(re_ts) >= fill_gaps)
61 for i in gaps:
62 yinterp[(t >= re_ts[i]) & (t < re_ts[i + 1])] = re_pos[i]
64 return yinterp, t 1avBwizjyklmnboeCDfghEFGcdpqrstu
67def velocity_filtered(pos, fs, corner_frequency=20, order=8):
68 """
69 Compute wheel velocity from uniformly sampled wheel data.
71 pos: array_like
72 Vector of uniformly sampled wheel positions.
73 fs : float
74 Frequency in Hz of the sampling frequency.
75 corner_frequency : float
76 Corner frequency of low-pass filter.
77 order : int
78 Order of Butterworth filter.
80 Returns
81 -------
82 vel : np.ndarray
83 Array of velocity values.
84 acc : np.ndarray
85 Array of acceleration values.
86 """
87 sos = scipy.signal.butter(**{'N': order, 'Wn': corner_frequency / fs * 2, 'btype': 'lowpass'}, output='sos') 1ABybOeCDfghEFG
88 vel = np.insert(np.diff(scipy.signal.sosfiltfilt(sos, pos)), 0, 0) * fs 1ABybOeCDfghEFG
89 acc = np.insert(np.diff(vel), 0, 0) * fs 1ABybOeCDfghEFG
90 return vel, acc 1ABybOeCDfghEFG
93def get_movement_onset(intervals, event_times):
94 """
95 Find the time at which movement started, given an event timestamp that occurred during the
96 movement.
98 Parameters
99 ----------
100 intervals : numpy.array
101 The wheel movement intervals.
102 event_times : numpy.array
103 Sorted event timestamps anywhere during movement of interest, e.g. peak velocity, feedback
104 time.
106 Returns
107 -------
108 numpy.array
109 An array the length of event_time of intervals.
111 Examples
112 --------
113 Find the last movement onset before each trial response time
115 >>> trials = one.load_object(eid, 'trials')
116 >>> wheelMoves = one.load_object(eid, 'wheelMoves')
117 >>> onsets = last_movement_onset(wheelMoves.intervals, trials.response_times)
118 """
119 if not np.all(np.diff(event_times) > 0): 1K
120 raise ValueError('event_times must be in ascending order.') 1K
121 onsets = np.full(event_times.size, np.nan) 1K
122 for i in np.arange(intervals.shape[0]): 1K
123 onset = between_sorted(event_times, intervals[i, :]) 1K
124 if np.any(onset): 1K
125 onsets[onset] = intervals[i, 0] 1K
126 return onsets 1K
129def movements(t, pos, freq=1000, pos_thresh=8, t_thresh=.2, min_gap=.1, pos_thresh_onset=1.5,
130 min_dur=.05, make_plots=False):
131 """
132 Detect wheel movements.
134 Parameters
135 ----------
136 t : array_like
137 An array of evenly sampled wheel timestamps in absolute seconds
138 pos : array_like
139 An array of evenly sampled wheel positions
140 freq : int
141 The sampling rate of the wheel data
142 pos_thresh : float
143 The minimum required movement during the t_thresh window to be considered part of a
144 movement
145 t_thresh : float
146 The time window over which to check whether the pos_thresh has been crossed
147 min_gap : float
148 The minimum time between one movement's offset and another movement's onset in order to be
149 considered separate. Movements with a gap smaller than this are 'stictched together'
150 pos_thresh_onset : float
151 A lower threshold for finding precise onset times. The first position of each movement
152 transition that is this much bigger than the starting position is considered the onset
153 min_dur : float
154 The minimum duration of a valid movement. Detected movements shorter than this are ignored
155 make_plots : boolean
156 Plot trace of position and velocity, showing detected onsets and offsets
158 Returns
159 -------
160 onsets : np.ndarray
161 Timestamps of detected movement onsets
162 offsets : np.ndarray
163 Timestamps of detected movement offsets
164 peak_amps : np.ndarray
165 The absolute maximum amplitude of each detected movement, relative to onset position
166 peak_vel_times : np.ndarray
167 Timestamps of peak velocity for each detected movement
168 """
169 # Wheel position must be evenly sampled.
170 dt = np.diff(t) 1axvwizjyklmnboefghcdpqrstu
171 assert np.all(np.abs(dt - dt.mean()) < 1e-10), 'Values not evenly sampled' 1axvwizjyklmnboefghcdpqrstu
173 # Convert the time threshold into number of samples given the sampling frequency
174 t_thresh_samps = int(np.round(t_thresh * freq)) 1axvwizjyklmnboefghcdpqrstu
175 max_disp = np.empty(t.size, dtype=float) # initialize array of total wheel displacement 1axvwizjyklmnboefghcdpqrstu
177 # Calculate a Hankel matrix of size t_thresh_samps in batches. This is effectively a
178 # sliding window within which we look for changes in position greater than pos_thresh
179 BATCH_SIZE = 10000 # do this in batches in order to keep memory usage reasonable 1axvwizjyklmnboefghcdpqrstu
180 c = 0 # index of 'window' position 1axvwizjyklmnboefghcdpqrstu
181 while True: 1axvwizjyklmnboefghcdpqrstu
182 i2proc = np.arange(BATCH_SIZE) + c 1axvwizjyklmnboefghcdpqrstu
183 i2proc = i2proc[i2proc < t.size] 1axvwizjyklmnboefghcdpqrstu
184 w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) 1axvwizjyklmnboefghcdpqrstu
185 # Below is the total change in position for each window
186 max_disp[i2proc] = np.nanmax(w2e, axis=1) - np.nanmin(w2e, axis=1) 1axvwizjyklmnboefghcdpqrstu
187 c += BATCH_SIZE - t_thresh_samps 1axvwizjyklmnboefghcdpqrstu
188 if i2proc[-1] == t.size - 1: 1axvwizjyklmnboefghcdpqrstu
189 break 1axvwizjyklmnboefghcdpqrstu
191 moving = max_disp > pos_thresh # for each window is the change in position greater than our threshold? 1axvwizjyklmnboefghcdpqrstu
192 moving = np.insert(moving, 0, False) # First sample should always be not moving to ensure we have an onset 1axvwizjyklmnboefghcdpqrstu
193 moving[-1] = False # Likewise, ensure we always end on an offset 1axvwizjyklmnboefghcdpqrstu
195 onset_samps = np.where(~moving[:-1] & moving[1:])[0] 1axvwizjyklmnboefghcdpqrstu
196 offset_samps = np.where(moving[:-1] & ~moving[1:])[0] 1axvwizjyklmnboefghcdpqrstu
197 too_short = np.where((onset_samps[1:] - offset_samps[:-1]) / freq < min_gap)[0] 1axvwizjyklmnboefghcdpqrstu
198 for p in too_short: 1axvwizjyklmnboefghcdpqrstu
199 moving[offset_samps[p]:onset_samps[p + 1] + 1] = True 1axvwijklmnbfghcdpqrstu
201 onset_samps = np.where(~moving[:-1] & moving[1:])[0] 1axvwizjyklmnboefghcdpqrstu
202 onsets_disp_arr = np.empty((onset_samps.size, t_thresh_samps)) 1axvwizjyklmnboefghcdpqrstu
203 c = 0 1axvwizjyklmnboefghcdpqrstu
204 cwt = 0 1axvwizjyklmnboefghcdpqrstu
205 while onset_samps.size != 0: 1axvwizjyklmnboefghcdpqrstu
206 i2proc = np.arange(BATCH_SIZE) + c 1axvwijklmnboefghcdpqrstu
207 icomm = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, assume_unique=True) 1axvwijklmnboefghcdpqrstu
208 itpltz = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, 1axvwijklmnboefghcdpqrstu
209 return_indices=True, assume_unique=True)[1]
210 i2proc = i2proc[i2proc < t.size] 1axvwijklmnboefghcdpqrstu
211 if icomm.size > 0: 1axvwijklmnboefghcdpqrstu
212 w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) 1axvwijklmnboefghcdpqrstu
213 w2e = np.abs((w2e.T - w2e[:, 0]).T) 1axvwijklmnboefghcdpqrstu
214 onsets_disp_arr[cwt + np.arange(icomm.size), :] = w2e[itpltz, :] 1axvwijklmnboefghcdpqrstu
215 cwt += icomm.size 1axvwijklmnboefghcdpqrstu
216 c += BATCH_SIZE - t_thresh_samps 1axvwijklmnboefghcdpqrstu
217 if i2proc[-1] >= onset_samps[-1]: 1axvwijklmnboefghcdpqrstu
218 break 1axvwijklmnboefghcdpqrstu
220 has_onset = onsets_disp_arr > pos_thresh_onset 1axvwizjyklmnboefghcdpqrstu
221 A = np.argmin(np.fliplr(has_onset).T, axis=0) 1axvwizjyklmnboefghcdpqrstu
222 onset_lags = t_thresh_samps - A 1axvwizjyklmnboefghcdpqrstu
223 onset_samps = onset_samps + onset_lags - 1 1axvwizjyklmnboefghcdpqrstu
224 onsets = t[onset_samps] 1axvwizjyklmnboefghcdpqrstu
225 offset_samps = np.where(moving[:-1] & ~moving[1:])[0] 1axvwizjyklmnboefghcdpqrstu
226 offsets = t[offset_samps] 1axvwizjyklmnboefghcdpqrstu
228 durations = offsets - onsets 1axvwizjyklmnboefghcdpqrstu
229 too_short = durations < min_dur 1axvwizjyklmnboefghcdpqrstu
230 onset_samps = onset_samps[~too_short] 1axvwizjyklmnboefghcdpqrstu
231 onsets = onsets[~too_short] 1axvwizjyklmnboefghcdpqrstu
232 offset_samps = offset_samps[~too_short] 1axvwizjyklmnboefghcdpqrstu
233 offsets = offsets[~too_short] 1axvwizjyklmnboefghcdpqrstu
235 moveGaps = onsets[1:] - offsets[:-1] 1axvwizjyklmnboefghcdpqrstu
236 gap_too_small = moveGaps < min_gap 1axvwizjyklmnboefghcdpqrstu
237 if onsets.size > 0: 1axvwizjyklmnboefghcdpqrstu
238 onsets = onsets[np.insert(~gap_too_small, 0, True)] # always keep first onset 1axvwijklmnboefghcdpqrstu
239 onset_samps = onset_samps[np.insert(~gap_too_small, 0, True)] 1axvwijklmnboefghcdpqrstu
240 offsets = offsets[np.append(~gap_too_small, True)] # always keep last offset 1axvwijklmnboefghcdpqrstu
241 offset_samps = offset_samps[np.append(~gap_too_small, True)] 1axvwijklmnboefghcdpqrstu
243 # Calculate the peak amplitudes -
244 # the maximum absolute value of the difference from the onset position
245 peaks = (pos[m + np.abs(pos[m:n] - pos[m]).argmax()] - pos[m] 1axvwizjyklmnboefghcdpqrstu
246 for m, n in zip(onset_samps, offset_samps))
247 peak_amps = np.fromiter(peaks, dtype=float, count=onsets.size) 1axvwizjyklmnboefghcdpqrstu
248 N = 10 # Number of points in the Gaussian 1axvwizjyklmnboefghcdpqrstu
249 STDEV = 1.8 # Equivalent to a width factor (alpha value) of 2.5 1axvwizjyklmnboefghcdpqrstu
250 gauss = scipy.signal.windows.gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d. 1axvwizjyklmnboefghcdpqrstu
251 vel = scipy.signal.convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same') 1axvwizjyklmnboefghcdpqrstu
252 # For each movement period, find the timestamp where the absolute velocity was greatest
253 peaks = (t[m + np.abs(vel[m:n]).argmax()] for m, n in zip(onset_samps, offset_samps)) 1axvwizjyklmnboefghcdpqrstu
254 peak_vel_times = np.fromiter(peaks, dtype=float, count=onsets.size) 1axvwizjyklmnboefghcdpqrstu
256 if make_plots: 1axvwizjyklmnboefghcdpqrstu
257 fig, axes = plt.subplots(nrows=2, sharex='all')
258 indices = np.sort(np.hstack((onset_samps, offset_samps))) # Points to split trace
259 vel, acc = velocity_filtered(pos, freq)
261 # Plot the wheel position and velocity
262 for ax, y in zip(axes, (pos, vel)):
263 ax.plot(onsets, y[onset_samps], 'go')
264 ax.plot(offsets, y[offset_samps], 'bo')
266 t_split = np.split(np.vstack((t, y)).T, indices, axis=0)
267 ax.add_collection(LineCollection(t_split[1::2], colors='r')) # Moving
268 ax.add_collection(LineCollection(t_split[0::2], colors='k')) # Not moving
270 axes[1].autoscale() # rescale after adding line collections
271 axes[0].autoscale()
272 axes[0].set_ylabel('position')
273 axes[1].set_ylabel('velocity')
274 axes[1].set_xlabel('time')
275 axes[0].legend(['onsets', 'offsets', 'in movement'])
276 plt.show()
278 return onsets, offsets, peak_amps, peak_vel_times 1axvwizjyklmnboefghcdpqrstu
281def cm_to_deg(positions, wheel_diameter=WHEEL_DIAMETER):
282 """
283 Convert wheel position to degrees turned. This may be useful for e.g. calculating velocity
284 in revolutions per second
285 :param positions: array of wheel positions in cm
286 :param wheel_diameter: the diameter of the wheel in cm
287 :return: array of wheel positions in degrees turned
289 # Example: Convert linear cm to degrees
290 >>> cm_to_deg(3.142 * WHEEL_DIAMETER)
291 360.04667846020925
293 # Example: Get positions in deg from cm for 5cm diameter wheel
294 >>> import numpy as np
295 >>> cm_to_deg(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5)
296 array([0.61999992, 0.93000011, 1.24000007, 1.55000003])
297 """
298 return positions / (wheel_diameter * pi) * 360
301def cm_to_rad(positions, wheel_diameter=WHEEL_DIAMETER):
302 """
303 Convert wheel position to radians. This may be useful for e.g. calculating angular velocity.
304 :param positions: array of wheel positions in cm
305 :param wheel_diameter: the diameter of the wheel in cm
306 :return: array of wheel angle in radians
308 # Example: Convert linear cm to radians
309 >>> cm_to_rad(1)
310 0.3225806451612903
312 # Example: Get positions in rad from cm for 5cm diameter wheel
313 >>> import numpy as np
314 >>> cm_to_rad(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5)
315 array([0.01082104, 0.01623156, 0.02164208, 0.0270526 ])
316 """
317 return positions * (2 / wheel_diameter) 1aizjHklmnboecdpqrIJstu
320def samples_to_cm(positions, wheel_diameter=WHEEL_DIAMETER, resolution=ENC_RES):
321 """
322 Convert wheel position samples to cm linear displacement. This may be useful for
323 inter-converting threshold units
324 :param positions: array of wheel positions in sample counts
325 :param wheel_diameter: the diameter of the wheel in cm
326 :param resolution: resolution of the rotary encoder
327 :return: array of wheel angle in radians
329 # Example: Get resolution in linear cm
330 >>> samples_to_cm(1)
331 0.004755340442445488
333 # Example: Get positions in linear cm for 4X, 360 ppr encoder
334 >>> import numpy as np
335 >>> samples_to_cm(np.array([2, 3, 4, 5, 6, 7, 6, 5, 4]), resolution=360*4)
336 array([0.0270526 , 0.04057891, 0.05410521, 0.06763151, 0.08115781,
337 0.09468411, 0.08115781, 0.06763151, 0.05410521])
338 """
339 return positions / resolution * pi * wheel_diameter 1avwizjklmnboecdpqrstu
342def direction_changes(t, vel, intervals):
343 """
344 Find the direction changes for the given movement intervals.
346 Parameters
347 ----------
348 t : array_like
349 An array of evenly sampled wheel timestamps in absolute seconds
350 vel : array_like
351 An array of evenly sampled wheel positions
352 intervals : array_like
353 An n-by-2 array of wheel movement intervals
355 Returns
356 ----------
357 times : iterable
358 A list of numpy arrays of direction change timestamps, one array per interval
359 indices : iterable
360 A list of numpy arrays containing indices of direction changes; the size of times
361 """
362 indices = [] 1A
363 times = [] 1A
364 chg = np.insert(np.diff(np.sign(vel)) != 0, 0, 0) 1A
366 for on, off in intervals.reshape(-1, 2): 1A
367 mask = np.logical_and(t > on, t < off) 1A
368 ind, = np.where(np.logical_and(mask, chg)) 1A
369 times.append(t[ind]) 1A
370 indices.append(ind) 1A
372 return times, indices 1A
375def traces_by_trial(t, *args, start=None, end=None, separate=True):
376 """
377 Returns list of tuples of positions and velocity for samples between stimulus onset and
378 feedback.
379 :param t: numpy array of timestamps
380 :param args: optional numpy arrays of the same length as timestamps, such as positions,
381 velocities or accelerations
382 :param start: start timestamp or array thereof
383 :param end: end timestamp or array thereof
384 :param separate: when True, the output is returned as tuples list of the form [(t, args[0],
385 args[1]), ...], when False, the output is a list of n-by-m ndarrays where n = number of
386 positional args and m = len(t)
387 :return: list of sliced arrays where length == len(start)
388 """
389 if start is None: 1aLMNHbcdIJ
390 start = t[0]
391 if end is None: 1aLMNHbcdIJ
392 end = t[-1]
393 traces = np.stack((t, *args)) 1aLMNHbcdIJ
394 assert len(start) == len(end), 'number of start timestamps must equal end timestamps' 1aLMNHbcdIJ
396 def to_mask(a, b): 1aLMNHbcdIJ
397 return np.logical_and(t > a, t < b) 1aLMNHbcdIJ
399 cuts = [traces[:, to_mask(s, e)] for s, e in zip(start, end)] 1aLMNHbcdIJ
400 return [(cuts[n][0, :], cuts[n][1, :]) for n in range(len(cuts))] if separate else cuts 1aLMNHbcdIJ
403if __name__ == '__main__':
404 import doctest
405 doctest.testmod()