Coverage for ibllib/io/extractors/training_wheel.py: 87%
218 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""Extractors for the wheel position, velocity, and detected movement."""
2import logging
3from collections.abc import Sized
5import numpy as np
6from scipy import interpolate
8from ibldsp.utils import sync_timestamps
9from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
10import ibllib.io.raw_data_loaders as raw
11from ibllib.misc import structarr
12import ibllib.exceptions as err
13import brainbox.behavior.wheel as wh
15_logger = logging.getLogger(__name__)
16WHEEL_RADIUS_CM = 1 # we want the output in radians
17THRESHOLD_RAD_PER_SEC = 10
18THRESHOLD_CONSECUTIVE_SAMPLES = 0
19EPS = 7. / 3 - 4. / 3 - 1
22def get_trial_start_times(session_path, data=None, task_collection='raw_behavior_data'):
23 if not data: 1acdqgrlhiwmnojkstuvbefpx
24 data = raw.load_data(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx
25 trial_start_times = [] 1acdqgrlhiwmnojkstuvbefpx
26 for tr in data: 1acdqgrlhiwmnojkstuvbefpx
27 trial_start_times.extend( 1acdqgrlhiwmnojkstuvbefpx
28 [x[0] for x in tr['behavior_data']['States timestamps']['trial_start']])
29 return np.array(trial_start_times) 1acdqgrlhiwmnojkstuvbefpx
32def sync_rotary_encoder(session_path, bpod_data=None, re_events=None, task_collection='raw_behavior_data'):
33 if not bpod_data: 1acdqgrlhiwmnojkstuvbefpx
34 bpod_data = raw.load_data(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx
35 evt = re_events or raw.load_encoder_events(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx
36 # we work with stim_on (2) and closed_loop (3) states for the synchronization with bpod
37 tre = evt.re_ts.values / 1e6 # convert to seconds 1acdqgrlhiwmnojkstuvbefpx
38 # the first trial on the rotary encoder is a dud
39 rote = {'stim_on': tre[evt.sm_ev == 2][:-1], 1acdqgrlhiwmnojkstuvbefpx
40 'closed_loop': tre[evt.sm_ev == 3][:-1]}
41 bpod = { 1acdqgrlhiwmnojkstuvbefpx
42 'stim_on': np.array([tr['behavior_data']['States timestamps']
43 ['stim_on'][0][0] for tr in bpod_data]),
44 'closed_loop': np.array([tr['behavior_data']['States timestamps']
45 ['closed_loop'][0][0] for tr in bpod_data]),
46 }
47 if rote['closed_loop'].size <= 1: 1acdqgrlhiwmnojkstuvbefpx
48 raise err.SyncBpodWheelException("Not enough Rotary Encoder events to perform wheel"
49 " synchronization. Wheel data not extracted")
50 # bpod bug that spits out events in ms instead of us
51 if np.diff(bpod['closed_loop'][[-1, 0]])[0] / np.diff(rote['closed_loop'][[-1, 0]])[0] > 900: 1acdqgrlhiwmnojkstuvbefpx
52 _logger.error("Rotary encoder stores values in ms instead of us. Wheel timing inaccurate") 1cd
53 rote['stim_on'] *= 1e3 1cd
54 rote['closed_loop'] *= 1e3 1cd
55 # just use the closed loop for synchronization
56 # handle different sizes in synchronization:
57 sz = min(rote['closed_loop'].size, bpod['closed_loop'].size) 1acdqgrlhiwmnojkstuvbefpx
58 # if all the sample are contiguous and first samples match
59 diff_first_match = np.diff(rote['closed_loop'][:sz]) - np.diff(bpod['closed_loop'][:sz]) 1acdqgrlhiwmnojkstuvbefpx
60 # if all the sample are contiguous and last samples match
61 diff_last_match = np.diff(rote['closed_loop'][-sz:]) - np.diff(bpod['closed_loop'][-sz:]) 1acdqgrlhiwmnojkstuvbefpx
62 # 99% of the pulses match for a first sample lock
63 DIFF_THRESHOLD = 0.005 1acdqgrlhiwmnojkstuvbefpx
64 if np.mean(np.abs(diff_first_match) < DIFF_THRESHOLD) > 0.99: 1acdqgrlhiwmnojkstuvbefpx
65 re = rote['closed_loop'][:sz] 1acdqgrlhiwmnojkstuvbefpx
66 bp = bpod['closed_loop'][:sz] 1acdqgrlhiwmnojkstuvbefpx
67 indko = np.where(np.abs(diff_first_match) >= DIFF_THRESHOLD)[0] 1acdqgrlhiwmnojkstuvbefpx
68 # 99% of the pulses match for a last sample lock
69 elif np.mean(np.abs(diff_last_match) < DIFF_THRESHOLD) > 0.99: 1bx
70 re = rote['closed_loop'][-sz:] 1x
71 bp = bpod['closed_loop'][-sz:] 1x
72 indko = np.where(np.abs(diff_last_match) >= DIFF_THRESHOLD)[0] 1x
73 # last resort is to use ad-hoc sync function
74 else:
75 bp, re = raw.sync_trials_robust(bpod['closed_loop'], rote['closed_loop'], 1bx
76 diff_threshold=DIFF_THRESHOLD, max_shift=5)
77 # we dont' want to change the extractor, but in rare cases the following method may save the day
78 if len(bp) == 0: 1bx
79 _, _, ib, ir = sync_timestamps(bpod['closed_loop'], rote['closed_loop'], return_indices=True)
80 bp = bpod['closed_loop'][ib]
81 re = rote['closed_loop'][ir]
83 indko = np.array([]) 1bx
84 # raise ValueError("Can't sync bpod and rotary encoder: non-contiguous sync pulses")
85 # remove faulty indices due to missing or bad syncs
86 indko = np.int32(np.unique(np.r_[indko + 1, indko])) 1acdqgrlhiwmnojkstuvbefpx
87 re = np.delete(re, indko) 1acdqgrlhiwmnojkstuvbefpx
88 bp = np.delete(bp, indko) 1acdqgrlhiwmnojkstuvbefpx
89 # check the linear drift
90 assert bp.size > 1 1acdqgrlhiwmnojkstuvbefpx
91 poly = np.polyfit(bp, re, 1) 1acdqgrlhiwmnojkstuvbefpx
92 assert np.all(np.abs(np.polyval(poly, bp) - re) < 0.002) 1acdqgrlhiwmnojkstuvbefpx
93 return interpolate.interp1d(re, bp, fill_value="extrapolate") 1acdqgrlhiwmnojkstuvbefpx
96def get_wheel_position(session_path, bp_data=None, display=False, task_collection='raw_behavior_data'):
97 """
98 Gets wheel timestamps and position from Bpod data. Position is in radian (constant above for
99 radius is 1) mathematical convention.
100 :param session_path:
101 :param bp_data (optional): bpod trials read from jsonable file
102 :param display (optional): (bool)
103 :return: timestamps (np.array)
104 :return: positions (np.array)
105 """
106 status = 0 1acdqgrlhiwFmnojkstuvbefpx
107 if not bp_data: 1acdqgrlhiwFmnojkstuvbefpx
108 bp_data = raw.load_data(session_path, task_collection=task_collection) 1Fx
109 df = raw.load_encoder_positions(session_path, task_collection=task_collection) 1acdqgrlhiwFmnojkstuvbefpx
110 if df is None: 1acdqgrlhiwFmnojkstuvbefpx
111 _logger.error('No wheel data for ' + str(session_path)) 1Fb
112 return None, None 1Fb
113 data = structarr(['re_ts', 're_pos', 'bns_ts'], 1acdqgrlhiwmnojkstuvbefpx
114 shape=(df.shape[0],), formats=['f8', 'f8', object])
115 data['re_ts'] = df.re_ts.values 1acdqgrlhiwmnojkstuvbefpx
116 data['re_pos'] = df.re_pos.values * -1 # anti-clockwise is positive in our output 1acdqgrlhiwmnojkstuvbefpx
117 data['re_pos'] = data['re_pos'] / 1024 * 2 * np.pi # convert positions to radians 1acdqgrlhiwmnojkstuvbefpx
118 trial_starts = get_trial_start_times(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx
119 # need a flag if the data resolution is 1ms due to the old version of rotary encoder firmware
120 if np.all(np.mod(data['re_ts'], 1e3) == 0): 1acdqgrlhiwmnojkstuvbefpx
121 status = 1 1wbefx
122 data['re_ts'] = data['re_ts'] / 1e6 # convert ts to seconds 1acdqgrlhiwmnojkstuvbefpx
123 # # get the converter function to translate re_ts into behavior times
124 re2bpod = sync_rotary_encoder(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx
125 data['re_ts'] = re2bpod(data['re_ts']) 1acdqgrlhiwmnojkstuvbefpx
127 def get_reset_trace_compensation_with_state_machine_times(): 1acdqgrlhiwmnojkstuvbefpx
128 # this is the preferred way of getting resets using the state machine time information
129 # it will not always work depending on firmware versions, new bugs
130 iwarn = [] 1acdqgrlhimnojkstuvbefpx
131 ns = len(data['re_pos']) 1acdqgrlhimnojkstuvbefpx
132 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1acdqgrlhimnojkstuvbefpx
133 for bp_dat in bp_data: 1acdqgrlhimnojkstuvbefpx
134 restarts = np.sort(np.array( 1acdqgrlhimnojkstuvbefpx
135 bp_dat['behavior_data']['States timestamps']['reset_rotary_encoder'] +
136 bp_dat['behavior_data']['States timestamps']['reset2_rotary_encoder'])[:, 0])
137 ind = np.unique(np.searchsorted(data['re_ts'], restarts, side='left') - 1) 1acdqgrlhimnojkstuvbefpx
138 # the rotary encoder doesn't always reset right away, and the reset sample given the
139 # timestamp can be ambiguous: look for zeros
140 for i in np.where(data['re_pos'][ind] != 0)[0]: 1acdqgrlhimnojkstuvbefpx
141 # handle boundary effects
142 if ind[i] > ns - 2: 1acdqgrlhimnojkstuvbefpx
143 continue 1x
144 # it happens quite often that we have to lock in to next sample to find the reset
145 if data['re_pos'][ind[i] + 1] == 0: 1acdqgrlhimnojkstuvbefpx
146 ind[i] = ind[i] + 1 1acdqgrlhimnojkstuvbefpx
147 continue 1acdqgrlhimnojkstuvbefpx
148 # also case where the rotary doesn't reset to 0, but erratically to -1/+1
149 if data['re_pos'][ind[i]] <= (1 / 1024 * 2 * np.pi): 1cdghibx
150 ind[i] = ind[i] + 1 1cdhibx
151 continue 1cdhibx
152 # compounded with the fact that the reset may have happened at next sample.
153 if np.abs(data['re_pos'][ind[i] + 1]) <= (1 / 1024 * 2 * np.pi): 1gb
154 ind[i] = ind[i] + 1
155 continue
156 # sometimes it is also the last trial that has this behaviour
157 if (bp_data[-1] is bp_dat) or (bp_data[0] is bp_dat): 1gb
158 continue 1g
159 iwarn.append(ind[i]) 1b
160 # at which point we are running out of possible bugs and calling it
161 tr_dc[ind] = data['re_pos'][ind - 1] 1acdqgrlhimnojkstuvbefpx
162 if iwarn: # if a warning flag was caught in the loop throw a single warning 1acdqgrlhimnojkstuvbefpx
163 _logger.warning('Rotary encoder reset events discrepancy. Doing my best to merge.') 1b
164 _logger.debug('Offending inds: ' + str(iwarn) + ' times: ' + str(data['re_ts'][iwarn])) 1b
165 # exit status 0 is fine, 1 something went wrong
166 return tr_dc, len(iwarn) != 0 1acdqgrlhimnojkstuvbefpx
168 # attempt to get the resets properly unless the unit is ms which means precision is
169 # not good enough to match SM times to wheel samples time
170 if not status: 1acdqgrlhiwmnojkstuvbefpx
171 tr_dc, status = get_reset_trace_compensation_with_state_machine_times() 1acdqgrlhimnojkstuvbefpx
173 # if something was wrong or went wrong agnostic way of getting resets: just get zeros values
174 if status: 1acdqgrlhiwmnojkstuvbefpx
175 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1wbefx
176 i0 = np.where(data['re_pos'] == 0)[0] 1wbefx
177 tr_dc[i0] = data['re_pos'][i0 - 1] 1wbefx
178 # even if things went ok, rotary encoder may not log the whole session. Need to fix outside
179 else:
180 i0 = np.where(np.bitwise_and(np.bitwise_or(data['re_ts'] >= trial_starts[-1], 1acdqgrlhimnojkstuvbefpx
181 data['re_ts'] <= trial_starts[0]),
182 data['re_pos'] == 0))[0]
183 # make sure the bounds are not included in the current list
184 i0 = np.delete(i0, np.where(np.bitwise_or(i0 >= len(data['re_pos']) - 1, i0 == 0))) 1acdqgrlhiwmnojkstuvbefpx
185 # a 0 sample is not a reset if 2 conditions are met:
186 # 1/2 no inflexion (continuous derivative)
187 c1 = np.abs(np.sign(data['re_pos'][i0 + 1] - data['re_pos'][i0]) - 1acdqgrlhiwmnojkstuvbefpx
188 np.sign(data['re_pos'][i0] - data['re_pos'][i0 - 1])) == 2
189 # 2/2 needs to be below threshold
190 c2 = np.abs((data['re_pos'][i0] - data['re_pos'][i0 - 1]) / 1acdqgrlhiwmnojkstuvbefpx
191 (EPS + (data['re_ts'][i0] - data['re_ts'][i0 - 1]))) < THRESHOLD_RAD_PER_SEC
192 # apply reset to points identified as resets
193 i0 = i0[np.where(np.bitwise_not(np.bitwise_and(c1, c2)))] 1acdqgrlhiwmnojkstuvbefpx
194 tr_dc[i0] = data['re_pos'][i0 - 1] 1acdqgrlhiwmnojkstuvbefpx
196 # unwrap the rotation (in radians) and then add the DC component from restarts
197 data['re_pos'] = np.unwrap(data['re_pos']) + np.cumsum(tr_dc) 1acdqgrlhiwmnojkstuvbefpx
199 # Also forgot to mention that time stamps may be repeated or very close to one another.
200 # Find them as they will induce large jitters on the velocity function or errors in
201 # attempts of interpolation
202 rep_idx = np.where(np.diff(data['re_ts']) <= THRESHOLD_CONSECUTIVE_SAMPLES)[0] 1acdqgrlhiwmnojkstuvbefpx
203 # Change the value of the repeated position
204 data['re_pos'][rep_idx] = (data['re_pos'][rep_idx] + 1acdqgrlhiwmnojkstuvbefpx
205 data['re_pos'][rep_idx + 1]) / 2
206 data['re_ts'][rep_idx] = (data['re_ts'][rep_idx] + 1acdqgrlhiwmnojkstuvbefpx
207 data['re_ts'][rep_idx + 1]) / 2
208 # Now remove the repeat times that are rep_idx + 1
209 data = np.delete(data, rep_idx + 1) 1acdqgrlhiwmnojkstuvbefpx
211 # convert to cm
212 data['re_pos'] = data['re_pos'] * WHEEL_RADIUS_CM 1acdqgrlhiwmnojkstuvbefpx
214 # DEBUG PLOTS START HERE ########################
215 if display: 1acdqgrlhiwmnojkstuvbefpx
216 import matplotlib.pyplot as plt
217 plt.figure()
218 ax = plt.axes()
219 tstart = get_trial_start_times(session_path)
220 tts = np.c_[tstart, tstart, tstart + np.nan].flatten()
221 vts = np.c_[tstart * 0 + 100, tstart * 0 - 100, tstart + np.nan].flatten()
222 ax.plot(tts, vts, label='Trial starts')
223 ax.plot(re2bpod(df.re_ts.values / 1e6), df.re_pos.values / 1024 * 2 * np.pi,
224 '.-', label='Raw data')
225 i0 = np.where(df.re_pos.values == 0)
226 ax.plot(re2bpod(df.re_ts.values[i0] / 1e6), df.re_pos.values[i0] / 1024 * 2 * np.pi,
227 'r*', label='Raw data zero samples')
228 ax.plot(re2bpod(df.re_ts.values / 1e6), tr_dc, label='reset compensation')
229 ax.set_xlabel('Bpod Time')
230 ax.set_ylabel('radians')
231 # restarts = np.array(bp_data[10]['behavior_data']['States timestamps']
232 # ['reset_rotary_encoder']).flatten()
233 # x__ = np.c_[restarts, restarts, restarts + np.nan].flatten()
234 # y__ = np.c_[restarts * 0 + 1, restarts * 0 - 1, restarts+ np.nan].flatten()
235 # ax.plot(x__, y__, 'k', label='Restarts')
236 ax.plot(data['re_ts'], data['re_pos'] / WHEEL_RADIUS_CM, '.-', label='Output Trace')
237 ax.legend()
238 # plt.hist(np.diff(data['re_ts']), 400, range=[0, 0.01])
239 return data['re_ts'], data['re_pos'] 1acdqgrlhiwmnojkstuvbefpx
242def infer_wheel_units(pos):
243 """
244 Given an array of wheel positions, infer the rotary encoder resolution, encoding type and units
246 The encoding type varies across hardware (Bpod uses X1 while FPGA usually extracted as X4), and
247 older data were extracted in linear cm rather than radians.
249 :param pos: a 1D array of extracted wheel positions
250 :return units: the position units, assumed to be either 'rad' or 'cm'
251 :return resolution: the number of decoded fronts per 360 degree rotation
252 :return encoding: one of {'X1', 'X2', 'X4'}
253 """
254 if len(pos.shape) > 1: # Ensure 1D array of positions 1aEyABcdqCgrlhiwDzmnojkstuvbefp
255 pos = pos.flatten()
257 # Check the values and units of wheel position
258 res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) 1aEyABcdqCgrlhiwDzmnojkstuvbefp
259 # min change in rad and cm for each decoding type
260 # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1]
261 min_change = np.concatenate([2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res]) 1aEyABcdqCgrlhiwDzmnojkstuvbefp
262 pos_diff = np.median(np.abs(np.ediff1d(pos))) 1aEyABcdqCgrlhiwDzmnojkstuvbefp
264 # find min change closest to min pos_diff
265 idx = np.argmin(np.abs(min_change - pos_diff)) 1aEyABcdqCgrlhiwDzmnojkstuvbefp
266 if idx < len(res): 1aEyABcdqCgrlhiwDzmnojkstuvbefp
267 # Assume values are in radians
268 units = 'rad' 1aEABcdqCgrlhiwDzmnojkstuvbefp
269 encoding = idx 1aEABcdqCgrlhiwDzmnojkstuvbefp
270 else:
271 units = 'cm' 1EyAB
272 encoding = idx - len(res) 1EyAB
273 enc_names = {0: 'X4', 1: 'X2', 2: 'X1'} 1aEyABcdqCgrlhiwDzmnojkstuvbefp
274 return units, int(res[encoding]), enc_names[int(encoding)] 1aEyABcdqCgrlhiwDzmnojkstuvbefp
277def extract_wheel_moves(re_ts, re_pos, display=False):
278 """
279 Extract wheel positions and times from sync fronts dictionary
280 :param re_ts: numpy array of rotary encoder timestamps
281 :param re_pos: numpy array of rotary encoder positions
282 :param display: bool: show the wheel position and velocity for full session with detected
283 movements highlighted
284 :return: wheel_moves dictionary
285 """
286 if len(re_ts.shape) == 1: 1ayABcdqCgrlhiwDzmnojkstuvbefp
287 assert re_ts.size == re_pos.size, 'wheel data dimension mismatch' 1ayABcdqCgrlhiwDzmnojkstuvbefp
288 else:
289 _logger.debug('2D wheel timestamps')
290 if len(re_pos.shape) > 1: # Ensure 1D array of positions
291 re_pos = re_pos.flatten()
292 # Linearly interpolate the times
293 x = np.arange(re_pos.size)
294 re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1])
296 units, res, enc = infer_wheel_units(re_pos) 1ayABcdqCgrlhiwDzmnojkstuvbefp
297 _logger.info('Wheel in %s units using %s encoding', units, enc) 1ayABcdqCgrlhiwDzmnojkstuvbefp
299 # The below assertion is violated by Bpod wheel data
300 # assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips'
302 # Convert the pos threshold defaults from samples to correct unit
303 thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res) 1ayABcdqCgrlhiwDzmnojkstuvbefp
304 if units == 'rad': 1ayABcdqCgrlhiwDzmnojkstuvbefp
305 thresholds = wh.cm_to_rad(thresholds) 1aABcdqCgrlhiwDzmnojkstuvbefp
306 kwargs = {'pos_thresh': thresholds[0], 1ayABcdqCgrlhiwDzmnojkstuvbefp
307 'pos_thresh_onset': thresholds[1],
308 'make_plots': display}
310 # Interpolate and get onsets
311 pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000) 1ayABcdqCgrlhiwDzmnojkstuvbefp
312 on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) 1ayABcdqCgrlhiwDzmnojkstuvbefp
313 assert on.size == off.size, 'onset/offset number mismatch' 1ayABcdqCgrlhiwDzmnojkstuvbefp
314 assert np.all(np.diff(on) > 0) and np.all( 1ayABcdqCgrlhiwDzmnojkstuvbefp
315 np.diff(off) > 0), 'onsets/offsets not strictly increasing'
316 assert np.all((off - on) > 0), 'not all offsets occur after onset' 1ayABcdqCgrlhiwDzmnojkstuvbefp
318 # Put into dict
319 wheel_moves = { 1ayABcdqCgrlhiwDzmnojkstuvbefp
320 'intervals': np.c_[on, off], 'peakAmplitude': amp, 'peakVelocity_times': peak_vel}
321 return wheel_moves 1ayABcdqCgrlhiwDzmnojkstuvbefp
324def extract_first_movement_times(wheel_moves, trials, min_qt=None):
325 """
326 Extracts the time of the first sufficiently large wheel movement for each trial.
328 To be counted, the movement must occur between go cue / stim on and before feedback /
329 response time. The movement onset is sometimes just before the cue (occurring in the
330 gap between quiescence end and cue start, or during the quiescence period but sub-
331 threshold). The movement is sufficiently large if it is greater than or equal to THRESH.
333 :param wheel_moves:
334 :param trials: dictionary of trial data
335 :param min_qt:
336 :return: numpy array of
338 Parameters
339 ----------
340 wheel_moves : dict
341 Dictionary of detected wheel movement onsets and peak amplitudes for use in extracting each
342 trial's time of first movement.
343 trials : dict
344 Dictionary of trial data.
345 min_qt : float
346 The minimum quiescence period in seconds, if None a default is used.
348 Returns
349 -------
350 numpy.array
351 First movement times.
352 numpy.array
353 Bool array indicating whether movement crossed response threshold.
354 numpy.array
355 Indices for wheel_moves arrays.
356 """
357 THRESH = .1 # peak amp should be at least .1 rad; ~1/3rd of the distance to threshold 1aycdqgrlhiwzmnojkstuvbefp
358 MIN_QT = .2 # default minimum enforced quiescence period 1aycdqgrlhiwzmnojkstuvbefp
360 # Determine minimum quiescent period
361 if min_qt is None: 1aycdqgrlhiwzmnojkstuvbefp
362 min_qt = MIN_QT 1aycmouvefp
363 _logger.info('minimum quiescent period assumed to be %.0fms', MIN_QT * 1e3) 1aycmouvefp
364 elif isinstance(min_qt, Sized) and len(min_qt) > len(trials['goCue_times']): 1cdqgrlhiwznjkstbe
365 min_qt = np.array(min_qt[0:trials['goCue_times'].size]) 1ln
367 # Initialize as nans
368 first_move_onsets = np.full(trials['goCue_times'].shape, np.nan) 1aycdqgrlhiwzmnojkstuvbefp
369 ids = np.full(trials['goCue_times'].shape, int(-1)) 1aycdqgrlhiwzmnojkstuvbefp
370 is_final_movement = np.zeros(trials['goCue_times'].shape, bool) 1aycdqgrlhiwzmnojkstuvbefp
371 flinch = abs(wheel_moves['peakAmplitude']) < THRESH 1aycdqgrlhiwzmnojkstuvbefp
372 all_move_onsets = wheel_moves['intervals'][:, 0] 1aycdqgrlhiwzmnojkstuvbefp
373 # Iterate over trials, extracting onsets approx. within closed-loop period
374 cwarn = 0 1aycdqgrlhiwzmnojkstuvbefp
375 for i, (t1, t2) in enumerate(zip(trials['goCue_times'] - min_qt, trials['feedback_times'])): 1aycdqgrlhiwzmnojkstuvbefp
376 if ~np.isnan(t2 - t1): # If both timestamps defined 1aycdqgrlhiwzmnojkstuvbefp
377 mask = (all_move_onsets > t1) & (all_move_onsets < t2) 1ycdqgrlhiwzmnojkstbefp
378 if np.any(mask): # If any onsets for this trial 1ycdqgrlhiwzmnojkstbefp
379 trial_onset_ids, = np.where(mask) 1ycdqgrlhiwzmnojkstbefp
380 if np.any(~flinch[mask]): # If any trial moves were sufficiently large 1ycdqgrlhiwzmnojkstbefp
381 ids[i] = trial_onset_ids[~flinch[mask]][0] # Find first large move id 1ycdqgrlhiwzmnojkstbefp
382 first_move_onsets[i] = all_move_onsets[ids[i]] # Save first large onset 1ycdqgrlhiwzmnojkstbefp
383 is_final_movement[i] = ids[i] == trial_onset_ids[-1] # Final move of trial 1ycdqgrlhiwzmnojkstbefp
384 else: # Log missing timestamps
385 cwarn += 1 1acdgjkuvbef
386 if cwarn: 1aycdqgrlhiwzmnojkstuvbefp
387 _logger.warning(f'no reliable goCue/Feedback times (both needed) for {cwarn} trials') 1acdgjkuvbef
389 return first_move_onsets, is_final_movement, ids[ids != -1] 1aycdqgrlhiwzmnojkstuvbefp
392class Wheel(BaseBpodTrialsExtractor):
393 """
394 Wheel extractor.
396 Get wheel data from raw files and converts positions into radians mathematical convention
397 (anti-clockwise = +) and timestamps into seconds relative to Bpod clock.
398 **Optional:** saves _ibl_wheel.times.npy and _ibl_wheel.position.npy
400 Times:
401 Gets Rotary Encoder timestamps (us) for each position and converts to times.
402 Synchronize with Bpod and outputs
404 Positions:
405 Radians mathematical convention
406 """
407 save_names = ('_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
408 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None,
409 '_ibl_trials.firstMovement_times.npy', None)
410 var_names = ('wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
411 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'firstMovement_times',
412 'is_final_movement')
414 def _extract(self):
415 ts, pos = get_wheel_position(self.session_path, self.bpod_trials, task_collection=self.task_collection) 1acdqgrlhiwmnojkstuvbefp
416 moves = extract_wheel_moves(ts, pos) 1acdqgrlhiwmnojkstuvbefp
418 # need some trial based info to output the first movement times
419 from ibllib.io.extractors import training_trials # Avoids circular imports 1acdqgrlhiwmnojkstuvbefp
420 goCue_times, _ = training_trials.GoCueTimes(self.session_path).extract( 1acdqgrlhiwmnojkstuvbefp
421 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)
422 feedback_times, _ = training_trials.FeedbackTimes(self.session_path).extract( 1acdqgrlhiwmnojkstuvbefp
423 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)
424 trials = {'goCue_times': goCue_times, 'feedback_times': feedback_times} 1acdqgrlhiwmnojkstuvbefp
425 min_qt = self.settings.get('QUIESCENT_PERIOD', None) 1acdqgrlhiwmnojkstuvbefp
427 first_moves, is_final, _ = extract_first_movement_times(moves, trials, min_qt=min_qt) 1acdqgrlhiwmnojkstuvbefp
428 output = (ts, pos, moves['intervals'], moves['peakAmplitude'], 1acdqgrlhiwmnojkstuvbefp
429 moves['peakVelocity_times'], first_moves, is_final)
430 return output 1acdqgrlhiwmnojkstuvbefp
433def extract_all(session_path, bpod_trials=None, settings=None, save=False, task_collection='raw_behavior_data', save_path=None):
434 """Extract the wheel data.
436 NB: Wheel extraction is now called through ibllib.io.training_trials.extract_all
438 Parameters
439 ----------
440 session_path : str, pathlib.Path
441 The path to the session
442 save : bool
443 If true save the data files to ALF
444 bpod_trials : list of dicts
445 The Bpod trial dicts loaded from the _iblrig_taskData.raw dataset
446 settings : dict
447 The Bpod settings loaded from the _iblrig_taskSettings.raw dataset
449 Returns
450 -------
451 A list of extracted data and a list of file paths if save is True (otherwise None)
452 """
453 return run_extractor_classes(Wheel, save=save, session_path=session_path,
454 bpod_trials=bpod_trials, settings=settings, task_collection=task_collection, path_out=save_path)