Coverage for brainbox/behavior/wheel.py: 67%
198 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""
2Set of functions to handle wheel data.
3"""
4import logging
5import warnings
6import traceback
8import numpy as np
9from numpy import pi
10from iblutil.numerical import between_sorted
11import scipy.interpolate as interpolate
12import scipy.signal
13from scipy.linalg import hankel
14import matplotlib.pyplot as plt
15from matplotlib.collections import LineCollection
16# from ibllib.io.extractors.ephys_fpga import WHEEL_TICKS # FIXME Circular dependencies
18__all__ = ['cm_to_deg',
19 'cm_to_rad',
20 'interpolate_position',
21 'get_movement_onset',
22 'movements',
23 'samples_to_cm',
24 'traces_by_trial',
25 'velocity_filtered']
27# Define some constants
28ENC_RES = 1024 * 4 # Rotary encoder resolution, assumes X4 encoding
29WHEEL_DIAMETER = 3.1 * 2 # Wheel diameter in cm
32def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None):
33 """
34 Return linearly interpolated wheel position.
36 Parameters
37 ----------
38 re_ts : array_like
39 Array of timestamps
40 re_pos: array_like
41 Array of unwrapped wheel positions
42 freq : float
43 frequency in Hz of the interpolation
44 kind : {'linear', 'cubic'}
45 Type of interpolation. Defaults to linear interpolation.
46 fill_gaps : float
47 Minimum gap length to fill. For gaps over this time (seconds),
48 forward fill values before interpolation
49 Returns
50 -------
51 yinterp : array
52 Interpolated position
53 t : array
54 Timestamps of interpolated positions
55 """
56 t = np.arange(re_ts[0], re_ts[-1], 1 / freq) # Evenly resample at frequency 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD
57 if t[-1] > re_ts[-1]: 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD
58 t = t[:-1] # Occasionally due to precision errors the last sample may be outside of range. 1tf
59 yinterp = interpolate.interp1d(re_ts, re_pos, kind=kind)(t) 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD
61 if fill_gaps: 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD
62 # Find large gaps and forward fill @fixme This is inefficient
63 gaps, = np.where(np.diff(re_ts) >= fill_gaps)
65 for i in gaps:
66 yinterp[(t >= re_ts[i]) & (t < re_ts[i + 1])] = re_pos[i]
68 return yinterp, t 1aEKFjIklHmnopqrsbtfLMNghiOPQdeuvwxyzAcBCD
71def velocity(re_ts, re_pos):
72 """
73 (DEPRECATED) Compute wheel velocity from non-uniformly sampled wheel data. Returns the velocity
74 at the same samples locations as the position through interpolation.
76 Parameters
77 ----------
78 re_ts : array_like
79 Array of timestamps
80 re_pos: array_like
81 Array of unwrapped wheel positions
83 Returns
84 -------
85 np.ndarray
86 numpy array of velocities
87 """
88 for line in traceback.format_stack():
89 print(line.strip())
91 msg = 'brainbox.behavior.wheel.velocity will soon be removed. Use velocity_filtered instead.'
92 warnings.warn(msg, FutureWarning)
93 logging.getLogger(__name__).warning(msg)
95 dp = np.diff(re_pos)
96 dt = np.diff(re_ts)
97 # Compute raw velocity
98 vel = dp / dt
99 # Compute velocity time scale
100 tv = re_ts[:-1] + dt / 2
101 # interpolate over original time scale
102 if tv.size > 1:
103 ifcn = interpolate.interp1d(tv, vel, fill_value="extrapolate")
104 return ifcn(re_ts)
107def velocity_filtered(pos, fs, corner_frequency=20, order=8):
108 """
109 Compute wheel velocity from uniformly sampled wheel data.
111 pos: array_like
112 Vector of uniformly sampled wheel positions.
113 fs : float
114 Frequency in Hz of the sampling frequency.
115 corner_frequency : float
116 Corner frequency of low-pass filter.
117 order : int
118 Order of Butterworth filter.
120 Returns
121 -------
122 vel : np.ndarray
123 Array of velocity values.
124 acc : np.ndarray
125 Array of acceleration values.
126 """
127 sos = scipy.signal.butter(**{'N': order, 'Wn': corner_frequency / fs * 2, 'btype': 'lowpass'}, output='sos') 1JKHbZfLMNghiOPQc
128 vel = np.insert(np.diff(scipy.signal.sosfiltfilt(sos, pos)), 0, 0) * fs 1JKHbZfLMNghiOPQc
129 acc = np.insert(np.diff(vel), 0, 0) * fs 1JKHbZfLMNghiOPQc
130 return vel, acc 1JKHbZfLMNghiOPQc
133def velocity_smoothed(pos, freq, smooth_size=0.03):
134 """
135 (DEPRECATED) Compute wheel velocity from uniformly sampled wheel data.
137 Parameters
138 ----------
139 pos : array_like
140 Array of wheel positions
141 smooth_size : float
142 Size of Gaussian smoothing window in seconds
143 freq : float
144 Sampling frequency of the data
146 Returns
147 -------
148 vel : np.ndarray
149 Array of velocity values
150 acc : np.ndarray
151 Array of acceleration values
152 """
153 for line in traceback.format_stack():
154 print(line.strip())
156 msg = 'brainbox.behavior.wheel.velocity_smoothed will be removed. Use velocity_filtered instead.'
157 warnings.warn(msg, FutureWarning)
158 logging.getLogger(__name__).warning(msg)
160 # Define our smoothing window with an area of 1 so the units won't be changed
161 std_samps = np.round(smooth_size * freq) # Standard deviation relative to sampling frequency
162 N = std_samps * 6 # Number of points in the Gaussian covering +/-3 standard deviations
163 gauss_std = (N - 1) / 6
164 win = scipy.signal.windows.gaussian(N, gauss_std)
165 win = win / win.sum() # Normalize amplitude
167 # Convolve and multiply by sampling frequency to restore original units
168 vel = np.insert(scipy.signal.convolve(np.diff(pos), win, mode='same'), 0, 0) * freq
169 acc = np.insert(scipy.signal.convolve(np.diff(vel), win, mode='same'), 0, 0) * freq
171 return vel, acc
174def last_movement_onset(t, vel, event_time):
175 """
176 (DEPRECATED) Find the time at which movement started, given an event timestamp that occurred during the
177 movement.
179 Movement start is defined as the first sample after the velocity has been zero for at least 50ms.
180 Wheel inputs should be evenly sampled.
182 :param t: numpy array of wheel timestamps in seconds
183 :param vel: numpy array of wheel velocities
184 :param event_time: timestamp anywhere during movement of interest, e.g. peak velocity
185 :return: timestamp of movement onset
186 """
187 for line in traceback.format_stack():
188 print(line.strip())
190 msg = 'brainbox.behavior.wheel.last_movement_onset has been deprecated. Use get_movement_onset instead.'
191 warnings.warn(msg, FutureWarning)
192 logging.getLogger(__name__).warning(msg)
194 # Look back from timestamp
195 threshold = 50e-3
196 mask = t < event_time
197 times = t[mask]
198 vel = vel[mask]
199 t = None # Initialize
200 for i, t in enumerate(times[::-1]):
201 i = times.size - i
202 idx = np.min(np.where((t - times) < threshold))
203 if np.max(np.abs(vel[idx:i])) < 0.5:
204 break
206 # Return timestamp
207 return t
210def get_movement_onset(intervals, event_times):
211 """
212 Find the time at which movement started, given an event timestamp that occurred during the
213 movement.
215 Parameters
216 ----------
217 intervals : numpy.array
218 The wheel movement intervals.
219 event_times : numpy.array
220 Sorted event timestamps anywhere during movement of interest, e.g. peak velocity, feedback
221 time.
223 Returns
224 -------
225 numpy.array
226 An array the length of event_time of intervals.
228 Examples
229 --------
230 Find the last movement onset before each trial response time
232 >>> trials = one.load_object(eid, 'trials')
233 >>> wheelMoves = one.load_object(eid, 'wheelMoves')
234 >>> onsets = last_movement_onset(wheelMoves.intervals, trials.response_times)
235 """
236 if not np.all(np.diff(event_times) > 0): 1V
237 raise ValueError('event_times must be in ascending order.') 1V
238 onsets = np.full(event_times.size, np.nan) 1V
239 for i in np.arange(intervals.shape[0]): 1V
240 onset = between_sorted(event_times, intervals[i, :]) 1V
241 if np.any(onset): 1V
242 onsets[onset] = intervals[i, 0] 1V
243 return onsets 1V
246def movements(t, pos, freq=1000, pos_thresh=8, t_thresh=.2, min_gap=.1, pos_thresh_onset=1.5,
247 min_dur=.05, make_plots=False):
248 """
249 Detect wheel movements.
251 Parameters
252 ----------
253 t : array_like
254 An array of evenly sampled wheel timestamps in absolute seconds
255 pos : array_like
256 An array of evenly sampled wheel positions
257 freq : int
258 The sampling rate of the wheel data
259 pos_thresh : float
260 The minimum required movement during the t_thresh window to be considered part of a
261 movement
262 t_thresh : float
263 The time window over which to check whether the pos_thresh has been crossed
264 min_gap : float
265 The minimum time between one movement's offset and another movement's onset in order to be
266 considered separate. Movements with a gap smaller than this are 'stictched together'
267 pos_thresh_onset : float
268 A lower threshold for finding precise onset times. The first position of each movement
269 transition that is this much bigger than the starting position is considered the onset
270 min_dur : float
271 The minimum duration of a valid movement. Detected movements shorter than this are ignored
272 make_plots : boolean
273 Plot trace of position and velocity, showing detected onsets and offsets
275 Returns
276 -------
277 onsets : np.ndarray
278 Timestamps of detected movement onsets
279 offsets : np.ndarray
280 Timestamps of detected movement offsets
281 peak_amps : np.ndarray
282 The absolute maximum amplitude of each detected movement, relative to onset position
283 peak_vel_times : np.ndarray
284 Timestamps of peak velocity for each detected movement
285 """
286 # Wheel position must be evenly sampled.
287 dt = np.diff(t) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
288 assert np.all(np.abs(dt - dt.mean()) < 1e-10), 'Values not evenly sampled' 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
290 # Convert the time threshold into number of samples given the sampling frequency
291 t_thresh_samps = int(np.round(t_thresh * freq)) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
292 max_disp = np.empty(t.size, dtype=float) # initialize array of total wheel displacement 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
294 # Calculate a Hankel matrix of size t_thresh_samps in batches. This is effectively a
295 # sliding window within which we look for changes in position greater than pos_thresh
296 BATCH_SIZE = 10000 # do this in batches in order to keep memory usage reasonable 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
297 c = 0 # index of 'window' position 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
298 while True: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
299 i2proc = np.arange(BATCH_SIZE) + c 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
300 i2proc = i2proc[i2proc < t.size] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
301 w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
302 # Below is the total change in position for each window
303 max_disp[i2proc] = np.nanmax(w2e, axis=1) - np.nanmin(w2e, axis=1) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
304 c += BATCH_SIZE - t_thresh_samps 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
305 if i2proc[-1] == t.size - 1: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
306 break 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
308 moving = max_disp > pos_thresh # for each window is the change in position greater than our threshold? 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
309 moving = np.insert(moving, 0, False) # First sample should always be not moving to ensure we have an onset 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
310 moving[-1] = False # Likewise, ensure we always end on an offset 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
312 onset_samps = np.where(~moving[:-1] & moving[1:])[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
313 offset_samps = np.where(moving[:-1] & ~moving[1:])[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
314 too_short = np.where((onset_samps[1:] - offset_samps[:-1]) / freq < min_gap)[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
315 for p in too_short: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
316 moving[offset_samps[p]:onset_samps[p + 1] + 1] = True 1aGEFjklmnopqrsbghideuvwxyzAcBCD
318 onset_samps = np.where(~moving[:-1] & moving[1:])[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
319 onsets_disp_arr = np.empty((onset_samps.size, t_thresh_samps)) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
320 c = 0 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
321 cwt = 0 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
322 while onset_samps.size != 0: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
323 i2proc = np.arange(BATCH_SIZE) + c 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
324 icomm = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, assume_unique=True) 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
325 itpltz = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
326 return_indices=True, assume_unique=True)[1]
327 i2proc = i2proc[i2proc < t.size] 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
328 if icomm.size > 0: 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
329 w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
330 w2e = np.abs((w2e.T - w2e[:, 0]).T) 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
331 onsets_disp_arr[cwt + np.arange(icomm.size), :] = w2e[itpltz, :] 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
332 cwt += icomm.size 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
333 c += BATCH_SIZE - t_thresh_samps 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
334 if i2proc[-1] >= onset_samps[-1]: 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
335 break 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
337 has_onset = onsets_disp_arr > pos_thresh_onset 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
338 A = np.argmin(np.fliplr(has_onset).T, axis=0) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
339 onset_lags = t_thresh_samps - A 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
340 onset_samps = onset_samps + onset_lags - 1 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
341 onsets = t[onset_samps] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
342 offset_samps = np.where(moving[:-1] & ~moving[1:])[0] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
343 offsets = t[offset_samps] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
345 durations = offsets - onsets 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
346 too_short = durations < min_dur 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
347 onset_samps = onset_samps[~too_short] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
348 onsets = onsets[~too_short] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
349 offset_samps = offset_samps[~too_short] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
350 offsets = offsets[~too_short] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
352 moveGaps = onsets[1:] - offsets[:-1] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
353 gap_too_small = moveGaps < min_gap 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
354 if onsets.size > 0: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
355 onsets = onsets[np.insert(~gap_too_small, 0, True)] # always keep first onset 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
356 onset_samps = onset_samps[np.insert(~gap_too_small, 0, True)] 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
357 offsets = offsets[np.append(~gap_too_small, True)] # always keep last offset 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
358 offset_samps = offset_samps[np.append(~gap_too_small, True)] 1aGEFjklmnopqrsbtfghideuvwxyzAcBCD
360 # Calculate the peak amplitudes -
361 # the maximum absolute value of the difference from the onset position
362 peaks = (pos[m + np.abs(pos[m:n] - pos[m]).argmax()] - pos[m] 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
363 for m, n in zip(onset_samps, offset_samps))
364 peak_amps = np.fromiter(peaks, dtype=float, count=onsets.size) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
365 N = 10 # Number of points in the Gaussian 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
366 STDEV = 1.8 # Equivalent to a width factor (alpha value) of 2.5 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
367 gauss = scipy.signal.windows.gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d. 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
368 vel = scipy.signal.convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same') 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
369 # For each movement period, find the timestamp where the absolute velocity was greatest
370 peaks = (t[m + np.abs(vel[m:n]).argmax()] for m, n in zip(onset_samps, offset_samps)) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
371 peak_vel_times = np.fromiter(peaks, dtype=float, count=onsets.size) 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
373 if make_plots: 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
374 fig, axes = plt.subplots(nrows=2, sharex='all')
375 indices = np.sort(np.hstack((onset_samps, offset_samps))) # Points to split trace
376 vel, acc = velocity_filtered(pos, freq)
378 # Plot the wheel position and velocity
379 for ax, y in zip(axes, (pos, vel)):
380 ax.plot(onsets, y[onset_samps], 'go')
381 ax.plot(offsets, y[offset_samps], 'bo')
383 t_split = np.split(np.vstack((t, y)).T, indices, axis=0)
384 ax.add_collection(LineCollection(t_split[1::2], colors='r')) # Moving
385 ax.add_collection(LineCollection(t_split[0::2], colors='k')) # Not moving
387 axes[1].autoscale() # rescale after adding line collections
388 axes[0].autoscale()
389 axes[0].set_ylabel('position')
390 axes[1].set_ylabel('velocity')
391 axes[1].set_xlabel('time')
392 axes[0].legend(['onsets', 'offsets', 'in movement'])
393 plt.show()
395 return onsets, offsets, peak_amps, peak_vel_times 1aGEFjIklHmnopqrsbtfghideuvwxyzAcBCD
398def cm_to_deg(positions, wheel_diameter=WHEEL_DIAMETER):
399 """
400 Convert wheel position to degrees turned. This may be useful for e.g. calculating velocity
401 in revolutions per second
402 :param positions: array of wheel positions in cm
403 :param wheel_diameter: the diameter of the wheel in cm
404 :return: array of wheel positions in degrees turned
406 # Example: Convert linear cm to degrees
407 >>> cm_to_deg(3.142 * WHEEL_DIAMETER)
408 360.04667846020925
410 # Example: Get positions in deg from cm for 5cm diameter wheel
411 >>> import numpy as np
412 >>> cm_to_deg(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5)
413 array([0.61999992, 0.93000011, 1.24000007, 1.55000003])
414 """
415 return positions / (wheel_diameter * pi) * 360
418def cm_to_rad(positions, wheel_diameter=WHEEL_DIAMETER):
419 """
420 Convert wheel position to radians. This may be useful for e.g. calculating angular velocity.
421 :param positions: array of wheel positions in cm
422 :param wheel_diameter: the diameter of the wheel in cm
423 :return: array of wheel angle in radians
425 # Example: Convert linear cm to radians
426 >>> cm_to_rad(1)
427 0.3225806451612903
429 # Example: Get positions in rad from cm for 5cm diameter wheel
430 >>> import numpy as np
431 >>> cm_to_rad(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5)
432 array([0.01082104, 0.01623156, 0.02164208, 0.0270526 ])
433 """
434 return positions * (2 / wheel_diameter) 1ajIklRmnopqrsbtfdeuvwxyzASTUcBCD
437def samples_to_cm(positions, wheel_diameter=WHEEL_DIAMETER, resolution=ENC_RES):
438 """
439 Convert wheel position samples to cm linear displacement. This may be useful for
440 inter-converting threshold units
441 :param positions: array of wheel positions in sample counts
442 :param wheel_diameter: the diameter of the wheel in cm
443 :param resolution: resolution of the rotary encoder
444 :return: array of wheel angle in radians
446 # Example: Get resolution in linear cm
447 >>> samples_to_cm(1)
448 0.004755340442445488
450 # Example: Get positions in linear cm for 4X, 360 ppr encoder
451 >>> import numpy as np
452 >>> samples_to_cm(np.array([2, 3, 4, 5, 6, 7, 6, 5, 4]), resolution=360*4)
453 array([0.0270526 , 0.04057891, 0.05410521, 0.06763151, 0.08115781,
454 0.09468411, 0.08115781, 0.06763151, 0.05410521])
455 """
456 return positions / resolution * pi * wheel_diameter 1aEFjIklmnopqrsbtfdeuvwxyzAcBCD
459def direction_changes(t, vel, intervals):
460 """
461 Find the direction changes for the given movement intervals.
463 Parameters
464 ----------
465 t : array_like
466 An array of evenly sampled wheel timestamps in absolute seconds
467 vel : array_like
468 An array of evenly sampled wheel positions
469 intervals : array_like
470 An n-by-2 array of wheel movement intervals
472 Returns
473 ----------
474 times : iterable
475 A list of numpy arrays of direction change timestamps, one array per interval
476 indices : iterable
477 A list of numpy arrays containing indices of direction changes; the size of times
478 """
479 indices = [] 1J
480 times = [] 1J
481 chg = np.insert(np.diff(np.sign(vel)) != 0, 0, 0) 1J
483 for on, off in intervals.reshape(-1, 2): 1J
484 mask = np.logical_and(t > on, t < off) 1J
485 ind, = np.where(np.logical_and(mask, chg)) 1J
486 times.append(t[ind]) 1J
487 indices.append(ind) 1J
489 return times, indices 1J
492def traces_by_trial(t, *args, start=None, end=None, separate=True):
493 """
494 Returns list of tuples of positions and velocity for samples between stimulus onset and
495 feedback.
496 :param t: numpy array of timestamps
497 :param args: optional numpy arrays of the same length as timestamps, such as positions,
498 velocities or accelerations
499 :param start: start timestamp or array thereof
500 :param end: end timestamp or array thereof
501 :param separate: when True, the output is returned as tuples list of the form [(t, args[0],
502 args[1]), ...], when False, the output is a list of n-by-m ndarrays where n = number of
503 positional args and m = len(t)
504 :return: list of sliced arrays where length == len(start)
505 """
506 if start is None: 1WXYRbdeSTUc
507 start = t[0]
508 if end is None: 1WXYRbdeSTUc
509 end = t[-1]
510 traces = np.stack((t, *args)) 1WXYRbdeSTUc
511 assert len(start) == len(end), 'number of start timestamps must equal end timestamps' 1WXYRbdeSTUc
513 def to_mask(a, b): 1WXYRbdeSTUc
514 return np.logical_and(t > a, t < b) 1WXYRbdeSTUc
516 cuts = [traces[:, to_mask(s, e)] for s, e in zip(start, end)] 1WXYRbdeSTUc
517 return [(cuts[n][0, :], cuts[n][1, :]) for n in range(len(cuts))] if separate else cuts 1WXYRbdeSTUc
520if __name__ == '__main__':
521 import doctest
522 doctest.testmod()