Coverage for ibllib/io/extractors/training_wheel.py: 87%
218 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1import logging
2from collections.abc import Sized
4import numpy as np
5from scipy import interpolate
7from neurodsp.utils import sync_timestamps
8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
9import ibllib.io.raw_data_loaders as raw
10from ibllib.misc import structarr
11import ibllib.exceptions as err
12import brainbox.behavior.wheel as wh
14_logger = logging.getLogger(__name__)
15WHEEL_RADIUS_CM = 1 # we want the output in radians
16THRESHOLD_RAD_PER_SEC = 10
17THRESHOLD_CONSECUTIVE_SAMPLES = 0
18EPS = 7. / 3 - 4. / 3 - 1
21def get_trial_start_times(session_path, data=None, task_collection='raw_behavior_data'):
22 if not data: 1acdqgrlhiwmnopjstuvbefkx
23 data = raw.load_data(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx
24 trial_start_times = [] 1acdqgrlhiwmnopjstuvbefkx
25 for tr in data: 1acdqgrlhiwmnopjstuvbefkx
26 trial_start_times.extend( 1acdqgrlhiwmnopjstuvbefkx
27 [x[0] for x in tr['behavior_data']['States timestamps']['trial_start']])
28 return np.array(trial_start_times) 1acdqgrlhiwmnopjstuvbefkx
31def sync_rotary_encoder(session_path, bpod_data=None, re_events=None, task_collection='raw_behavior_data'):
32 if not bpod_data: 1acdqgrlhiwmnopjstuvbefkx
33 bpod_data = raw.load_data(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx
34 evt = re_events or raw.load_encoder_events(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx
35 # we work with stim_on (2) and closed_loop (3) states for the synchronization with bpod
36 tre = evt.re_ts.values / 1e6 # convert to seconds 1acdqgrlhiwmnopjstuvbefkx
37 # the first trial on the rotary encoder is a dud
38 rote = {'stim_on': tre[evt.sm_ev == 2][:-1], 1acdqgrlhiwmnopjstuvbefkx
39 'closed_loop': tre[evt.sm_ev == 3][:-1]}
40 bpod = { 1acdqgrlhiwmnopjstuvbefkx
41 'stim_on': np.array([tr['behavior_data']['States timestamps']
42 ['stim_on'][0][0] for tr in bpod_data]),
43 'closed_loop': np.array([tr['behavior_data']['States timestamps']
44 ['closed_loop'][0][0] for tr in bpod_data]),
45 }
46 if rote['closed_loop'].size <= 1: 1acdqgrlhiwmnopjstuvbefkx
47 raise err.SyncBpodWheelException("Not enough Rotary Encoder events to perform wheel"
48 " synchronization. Wheel data not extracted")
49 # bpod bug that spits out events in ms instead of us
50 if np.diff(bpod['closed_loop'][[-1, 0]])[0] / np.diff(rote['closed_loop'][[-1, 0]])[0] > 900: 1acdqgrlhiwmnopjstuvbefkx
51 _logger.error("Rotary encoder stores values in ms instead of us. Wheel timing inaccurate") 1cd
52 rote['stim_on'] *= 1e3 1cd
53 rote['closed_loop'] *= 1e3 1cd
54 # just use the closed loop for synchronization
55 # handle different sizes in synchronization:
56 sz = min(rote['closed_loop'].size, bpod['closed_loop'].size) 1acdqgrlhiwmnopjstuvbefkx
57 # if all the sample are contiguous and first samples match
58 diff_first_match = np.diff(rote['closed_loop'][:sz]) - np.diff(bpod['closed_loop'][:sz]) 1acdqgrlhiwmnopjstuvbefkx
59 # if all the sample are contiguous and last samples match
60 diff_last_match = np.diff(rote['closed_loop'][-sz:]) - np.diff(bpod['closed_loop'][-sz:]) 1acdqgrlhiwmnopjstuvbefkx
61 # 99% of the pulses match for a first sample lock
62 DIFF_THRESHOLD = 0.005 1acdqgrlhiwmnopjstuvbefkx
63 if np.mean(np.abs(diff_first_match) < DIFF_THRESHOLD) > 0.99: 1acdqgrlhiwmnopjstuvbefkx
64 re = rote['closed_loop'][:sz] 1acdqgrlhiwmnopjstuvbefkx
65 bp = bpod['closed_loop'][:sz] 1acdqgrlhiwmnopjstuvbefkx
66 indko = np.where(np.abs(diff_first_match) >= DIFF_THRESHOLD)[0] 1acdqgrlhiwmnopjstuvbefkx
67 # 99% of the pulses match for a last sample lock
68 elif np.mean(np.abs(diff_last_match) < DIFF_THRESHOLD) > 0.99: 1bx
69 re = rote['closed_loop'][-sz:] 1x
70 bp = bpod['closed_loop'][-sz:] 1x
71 indko = np.where(np.abs(diff_last_match) >= DIFF_THRESHOLD)[0] 1x
72 # last resort is to use ad-hoc sync function
73 else:
74 bp, re = raw.sync_trials_robust(bpod['closed_loop'], rote['closed_loop'], 1bx
75 diff_threshold=DIFF_THRESHOLD, max_shift=5)
76 # we dont' want to change the extractor, but in rare cases the following method may save the day
77 if len(bp) == 0: 1bx
78 _, _, ib, ir = sync_timestamps(bpod['closed_loop'], rote['closed_loop'], return_indices=True)
79 bp = bpod['closed_loop'][ib]
80 re = rote['closed_loop'][ir]
82 indko = np.array([]) 1bx
83 # raise ValueError("Can't sync bpod and rotary encoder: non-contiguous sync pulses")
84 # remove faulty indices due to missing or bad syncs
85 indko = np.int32(np.unique(np.r_[indko + 1, indko])) 1acdqgrlhiwmnopjstuvbefkx
86 re = np.delete(re, indko) 1acdqgrlhiwmnopjstuvbefkx
87 bp = np.delete(bp, indko) 1acdqgrlhiwmnopjstuvbefkx
88 # check the linear drift
89 assert bp.size > 1 1acdqgrlhiwmnopjstuvbefkx
90 poly = np.polyfit(bp, re, 1) 1acdqgrlhiwmnopjstuvbefkx
91 assert np.all(np.abs(np.polyval(poly, bp) - re) < 0.002) 1acdqgrlhiwmnopjstuvbefkx
92 return interpolate.interp1d(re, bp, fill_value="extrapolate") 1acdqgrlhiwmnopjstuvbefkx
95def get_wheel_position(session_path, bp_data=None, display=False, task_collection='raw_behavior_data'):
96 """
97 Gets wheel timestamps and position from Bpod data. Position is in radian (constant above for
98 radius is 1) mathematical convention.
99 :param session_path:
100 :param bp_data (optional): bpod trials read from jsonable file
101 :param display (optional): (bool)
102 :return: timestamps (np.array)
103 :return: positions (np.array)
104 """
105 status = 0 1acdqgrlhiwFmnopjstuvbefkx
106 if not bp_data: 1acdqgrlhiwFmnopjstuvbefkx
107 bp_data = raw.load_data(session_path, task_collection=task_collection) 1Fbfkx
108 df = raw.load_encoder_positions(session_path, task_collection=task_collection) 1acdqgrlhiwFmnopjstuvbefkx
109 if df is None: 1acdqgrlhiwFmnopjstuvbefkx
110 _logger.error('No wheel data for ' + str(session_path)) 1Fb
111 return None, None 1Fb
112 data = structarr(['re_ts', 're_pos', 'bns_ts'], 1acdqgrlhiwmnopjstuvbefkx
113 shape=(df.shape[0],), formats=['f8', 'f8', object])
114 data['re_ts'] = df.re_ts.values 1acdqgrlhiwmnopjstuvbefkx
115 data['re_pos'] = df.re_pos.values * -1 # anti-clockwise is positive in our output 1acdqgrlhiwmnopjstuvbefkx
116 data['re_pos'] = data['re_pos'] / 1024 * 2 * np.pi # convert positions to radians 1acdqgrlhiwmnopjstuvbefkx
117 trial_starts = get_trial_start_times(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx
118 # need a flag if the data resolution is 1ms due to the old version of rotary encoder firmware
119 if np.all(np.mod(data['re_ts'], 1e3) == 0): 1acdqgrlhiwmnopjstuvbefkx
120 status = 1 1wbefx
121 data['re_ts'] = data['re_ts'] / 1e6 # convert ts to seconds 1acdqgrlhiwmnopjstuvbefkx
122 # # get the converter function to translate re_ts into behavior times
123 re2bpod = sync_rotary_encoder(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx
124 data['re_ts'] = re2bpod(data['re_ts']) 1acdqgrlhiwmnopjstuvbefkx
126 def get_reset_trace_compensation_with_state_machine_times(): 1acdqgrlhiwmnopjstuvbefkx
127 # this is the preferred way of getting resets using the state machine time information
128 # it will not always work depending on firmware versions, new bugs
129 iwarn = [] 1acdqgrlhimnopjstuvbefkx
130 ns = len(data['re_pos']) 1acdqgrlhimnopjstuvbefkx
131 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1acdqgrlhimnopjstuvbefkx
132 for bp_dat in bp_data: 1acdqgrlhimnopjstuvbefkx
133 restarts = np.sort(np.array( 1acdqgrlhimnopjstuvbefkx
134 bp_dat['behavior_data']['States timestamps']['reset_rotary_encoder'] +
135 bp_dat['behavior_data']['States timestamps']['reset2_rotary_encoder'])[:, 0])
136 ind = np.unique(np.searchsorted(data['re_ts'], restarts, side='left') - 1) 1acdqgrlhimnopjstuvbefkx
137 # the rotary encoder doesn't always reset right away, and the reset sample given the
138 # timestamp can be ambiguous: look for zeros
139 for i in np.where(data['re_pos'][ind] != 0)[0]: 1acdqgrlhimnopjstuvbefkx
140 # handle boundary effects
141 if ind[i] > ns - 2: 1acdqgrlhimnopjstuvbefkx
142 continue 1x
143 # it happens quite often that we have to lock in to next sample to find the reset
144 if data['re_pos'][ind[i] + 1] == 0: 1acdqgrlhimnopjstuvbefkx
145 ind[i] = ind[i] + 1 1acdqgrlhimnopjstuvbefkx
146 continue 1acdqgrlhimnopjstuvbefkx
147 # also case where the rotary doesn't reset to 0, but erratically to -1/+1
148 if data['re_pos'][ind[i]] <= (1 / 1024 * 2 * np.pi): 1cdghibx
149 ind[i] = ind[i] + 1 1cdhibx
150 continue 1cdhibx
151 # compounded with the fact that the reset may have happened at next sample.
152 if np.abs(data['re_pos'][ind[i] + 1]) <= (1 / 1024 * 2 * np.pi): 1gb
153 ind[i] = ind[i] + 1
154 continue
155 # sometimes it is also the last trial that has this behaviour
156 if (bp_data[-1] is bp_dat) or (bp_data[0] is bp_dat): 1gb
157 continue 1g
158 iwarn.append(ind[i]) 1b
159 # at which point we are running out of possible bugs and calling it
160 tr_dc[ind] = data['re_pos'][ind - 1] 1acdqgrlhimnopjstuvbefkx
161 if iwarn: # if a warning flag was caught in the loop throw a single warning 1acdqgrlhimnopjstuvbefkx
162 _logger.warning('Rotary encoder reset events discrepancy. Doing my best to merge.') 1b
163 _logger.debug('Offending inds: ' + str(iwarn) + ' times: ' + str(data['re_ts'][iwarn])) 1b
164 # exit status 0 is fine, 1 something went wrong
165 return tr_dc, len(iwarn) != 0 1acdqgrlhimnopjstuvbefkx
167 # attempt to get the resets properly unless the unit is ms which means precision is
168 # not good enough to match SM times to wheel samples time
169 if not status: 1acdqgrlhiwmnopjstuvbefkx
170 tr_dc, status = get_reset_trace_compensation_with_state_machine_times() 1acdqgrlhimnopjstuvbefkx
172 # if something was wrong or went wrong agnostic way of getting resets: just get zeros values
173 if status: 1acdqgrlhiwmnopjstuvbefkx
174 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1wbefx
175 i0 = np.where(data['re_pos'] == 0)[0] 1wbefx
176 tr_dc[i0] = data['re_pos'][i0 - 1] 1wbefx
177 # even if things went ok, rotary encoder may not log the whole session. Need to fix outside
178 else:
179 i0 = np.where(np.bitwise_and(np.bitwise_or(data['re_ts'] >= trial_starts[-1], 1acdqgrlhimnopjstuvbefkx
180 data['re_ts'] <= trial_starts[0]),
181 data['re_pos'] == 0))[0]
182 # make sure the bounds are not included in the current list
183 i0 = np.delete(i0, np.where(np.bitwise_or(i0 >= len(data['re_pos']) - 1, i0 == 0))) 1acdqgrlhiwmnopjstuvbefkx
184 # a 0 sample is not a reset if 2 conditions are met:
185 # 1/2 no inflexion (continuous derivative)
186 c1 = np.abs(np.sign(data['re_pos'][i0 + 1] - data['re_pos'][i0]) - 1acdqgrlhiwmnopjstuvbefkx
187 np.sign(data['re_pos'][i0] - data['re_pos'][i0 - 1])) == 2
188 # 2/2 needs to be below threshold
189 c2 = np.abs((data['re_pos'][i0] - data['re_pos'][i0 - 1]) / 1acdqgrlhiwmnopjstuvbefkx
190 (EPS + (data['re_ts'][i0] - data['re_ts'][i0 - 1]))) < THRESHOLD_RAD_PER_SEC
191 # apply reset to points identified as resets
192 i0 = i0[np.where(np.bitwise_not(np.bitwise_and(c1, c2)))] 1acdqgrlhiwmnopjstuvbefkx
193 tr_dc[i0] = data['re_pos'][i0 - 1] 1acdqgrlhiwmnopjstuvbefkx
195 # unwrap the rotation (in radians) and then add the DC component from restarts
196 data['re_pos'] = np.unwrap(data['re_pos']) + np.cumsum(tr_dc) 1acdqgrlhiwmnopjstuvbefkx
198 # Also forgot to mention that time stamps may be repeated or very close to one another.
199 # Find them as they will induce large jitters on the velocity function or errors in
200 # attempts of interpolation
201 rep_idx = np.where(np.diff(data['re_ts']) <= THRESHOLD_CONSECUTIVE_SAMPLES)[0] 1acdqgrlhiwmnopjstuvbefkx
202 # Change the value of the repeated position
203 data['re_pos'][rep_idx] = (data['re_pos'][rep_idx] + 1acdqgrlhiwmnopjstuvbefkx
204 data['re_pos'][rep_idx + 1]) / 2
205 data['re_ts'][rep_idx] = (data['re_ts'][rep_idx] + 1acdqgrlhiwmnopjstuvbefkx
206 data['re_ts'][rep_idx + 1]) / 2
207 # Now remove the repeat times that are rep_idx + 1
208 data = np.delete(data, rep_idx + 1) 1acdqgrlhiwmnopjstuvbefkx
210 # convert to cm
211 data['re_pos'] = data['re_pos'] * WHEEL_RADIUS_CM 1acdqgrlhiwmnopjstuvbefkx
213 # DEBUG PLOTS START HERE ########################
214 if display: 1acdqgrlhiwmnopjstuvbefkx
215 import matplotlib.pyplot as plt
216 plt.figure()
217 ax = plt.axes()
218 tstart = get_trial_start_times(session_path)
219 tts = np.c_[tstart, tstart, tstart + np.nan].flatten()
220 vts = np.c_[tstart * 0 + 100, tstart * 0 - 100, tstart + np.nan].flatten()
221 ax.plot(tts, vts, label='Trial starts')
222 ax.plot(re2bpod(df.re_ts.values / 1e6), df.re_pos.values / 1024 * 2 * np.pi,
223 '.-', label='Raw data')
224 i0 = np.where(df.re_pos.values == 0)
225 ax.plot(re2bpod(df.re_ts.values[i0] / 1e6), df.re_pos.values[i0] / 1024 * 2 * np.pi,
226 'r*', label='Raw data zero samples')
227 ax.plot(re2bpod(df.re_ts.values / 1e6), tr_dc, label='reset compensation')
228 ax.set_xlabel('Bpod Time')
229 ax.set_ylabel('radians')
230 # restarts = np.array(bp_data[10]['behavior_data']['States timestamps']
231 # ['reset_rotary_encoder']).flatten()
232 # x__ = np.c_[restarts, restarts, restarts + np.nan].flatten()
233 # y__ = np.c_[restarts * 0 + 1, restarts * 0 - 1, restarts+ np.nan].flatten()
234 # ax.plot(x__, y__, 'k', label='Restarts')
235 ax.plot(data['re_ts'], data['re_pos'] / WHEEL_RADIUS_CM, '.-', label='Output Trace')
236 ax.legend()
237 # plt.hist(np.diff(data['re_ts']), 400, range=[0, 0.01])
238 return data['re_ts'], data['re_pos'] 1acdqgrlhiwmnopjstuvbefkx
241def infer_wheel_units(pos):
242 """
243 Given an array of wheel positions, infer the rotary encoder resolution, encoding type and units
245 The encoding type varies across hardware (Bpod uses X1 while FPGA usually extracted as X4), and
246 older data were extracted in linear cm rather than radians.
248 :param pos: a 1D array of extracted wheel positions
249 :return units: the position units, assumed to be either 'rad' or 'cm'
250 :return resolution: the number of decoded fronts per 360 degree rotation
251 :return encoding: one of {'X1', 'X2', 'X4'}
252 """
253 if len(pos.shape) > 1: # Ensure 1D array of positions 1aEzABcdqCgrlhiwDymnopjstuvbefk
254 pos = pos.flatten()
256 # Check the values and units of wheel position
257 res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) 1aEzABcdqCgrlhiwDymnopjstuvbefk
258 # min change in rad and cm for each decoding type
259 # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1]
260 min_change = np.concatenate([2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res]) 1aEzABcdqCgrlhiwDymnopjstuvbefk
261 pos_diff = np.median(np.abs(np.ediff1d(pos))) 1aEzABcdqCgrlhiwDymnopjstuvbefk
263 # find min change closest to min pos_diff
264 idx = np.argmin(np.abs(min_change - pos_diff)) 1aEzABcdqCgrlhiwDymnopjstuvbefk
265 if idx < len(res): 1aEzABcdqCgrlhiwDymnopjstuvbefk
266 # Assume values are in radians
267 units = 'rad' 1aEABcdqCgrlhiwDymnopjstuvbefk
268 encoding = idx 1aEABcdqCgrlhiwDymnopjstuvbefk
269 else:
270 units = 'cm' 1EzAB
271 encoding = idx - len(res) 1EzAB
272 enc_names = {0: 'X4', 1: 'X2', 2: 'X1'} 1aEzABcdqCgrlhiwDymnopjstuvbefk
273 return units, int(res[encoding]), enc_names[int(encoding)] 1aEzABcdqCgrlhiwDymnopjstuvbefk
276def extract_wheel_moves(re_ts, re_pos, display=False):
277 """
278 Extract wheel positions and times from sync fronts dictionary
279 :param re_ts: numpy array of rotary encoder timestamps
280 :param re_pos: numpy array of rotary encoder positions
281 :param display: bool: show the wheel position and velocity for full session with detected
282 movements highlighted
283 :return: wheel_moves dictionary
284 """
285 if len(re_ts.shape) == 1: 1azABcdqCgrlhiwDymnopjstuvbefk
286 assert re_ts.size == re_pos.size, 'wheel data dimension mismatch' 1azABcdqCgrlhiwDymnopjstuvbefk
287 else:
288 _logger.debug('2D wheel timestamps')
289 if len(re_pos.shape) > 1: # Ensure 1D array of positions
290 re_pos = re_pos.flatten()
291 # Linearly interpolate the times
292 x = np.arange(re_pos.size)
293 re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1])
295 units, res, enc = infer_wheel_units(re_pos) 1azABcdqCgrlhiwDymnopjstuvbefk
296 _logger.info('Wheel in %s units using %s encoding', units, enc) 1azABcdqCgrlhiwDymnopjstuvbefk
298 # The below assertion is violated by Bpod wheel data
299 # assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips'
301 # Convert the pos threshold defaults from samples to correct unit
302 thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res) 1azABcdqCgrlhiwDymnopjstuvbefk
303 if units == 'rad': 1azABcdqCgrlhiwDymnopjstuvbefk
304 thresholds = wh.cm_to_rad(thresholds) 1aABcdqCgrlhiwDymnopjstuvbefk
305 kwargs = {'pos_thresh': thresholds[0], 1azABcdqCgrlhiwDymnopjstuvbefk
306 'pos_thresh_onset': thresholds[1],
307 'make_plots': display}
309 # Interpolate and get onsets
310 pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000) 1azABcdqCgrlhiwDymnopjstuvbefk
311 on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) 1azABcdqCgrlhiwDymnopjstuvbefk
312 assert on.size == off.size, 'onset/offset number mismatch' 1azABcdqCgrlhiwDymnopjstuvbefk
313 assert np.all(np.diff(on) > 0) and np.all( 1azABcdqCgrlhiwDymnopjstuvbefk
314 np.diff(off) > 0), 'onsets/offsets not strictly increasing'
315 assert np.all((off - on) > 0), 'not all offsets occur after onset' 1azABcdqCgrlhiwDymnopjstuvbefk
317 # Put into dict
318 wheel_moves = { 1azABcdqCgrlhiwDymnopjstuvbefk
319 'intervals': np.c_[on, off], 'peakAmplitude': amp, 'peakVelocity_times': peak_vel}
320 return wheel_moves 1azABcdqCgrlhiwDymnopjstuvbefk
323def extract_first_movement_times(wheel_moves, trials, min_qt=None):
324 """
325 Extracts the time of the first sufficiently large wheel movement for each trial.
326 To be counted, the movement must occur between go cue / stim on and before feedback /
327 response time. The movement onset is sometimes just before the cue (occurring in the
328 gap between quiescence end and cue start, or during the quiescence period but sub-
329 threshold). The movement is sufficiently large if it is greater than or equal to THRESH
330 :param wheel_moves: dictionary of detected wheel movement onsets and peak amplitudes for
331 use in extracting each trial's time of first movement.
332 :param trials: dictionary of trial data
333 :param min_qt: the minimum quiescence period, if None a default is used
334 :return: numpy array of first movement times, bool array indicating whether movement
335 crossed response threshold, and array of indices for wheel_moves arrays
336 """
337 THRESH = .1 # peak amp should be at least .1 rad; ~1/3rd of the distance to threshold 1azcdqgrlhiwymnopjstuvbefk
338 MIN_QT = .2 # default minimum enforced quiescence period 1azcdqgrlhiwymnopjstuvbefk
340 # Determine minimum quiescent period
341 if min_qt is None: 1azcdqgrlhiwymnopjstuvbefk
342 min_qt = MIN_QT 1azcmnouvefk
343 _logger.info('minimum quiescent period assumed to be %.0fms', MIN_QT * 1e3) 1azcmnouvefk
344 elif isinstance(min_qt, Sized) and len(min_qt) > len(trials['goCue_times']): 1cdqgrlhiwypjstbe
345 min_qt = np.array(min_qt[0:trials['goCue_times'].size]) 1lp
347 # Initialize as nans
348 first_move_onsets = np.full(trials['goCue_times'].shape, np.nan) 1azcdqgrlhiwymnopjstuvbefk
349 ids = np.full(trials['goCue_times'].shape, int(-1)) 1azcdqgrlhiwymnopjstuvbefk
350 is_final_movement = np.zeros(trials['goCue_times'].shape, bool) 1azcdqgrlhiwymnopjstuvbefk
351 flinch = abs(wheel_moves['peakAmplitude']) < THRESH 1azcdqgrlhiwymnopjstuvbefk
352 all_move_onsets = wheel_moves['intervals'][:, 0] 1azcdqgrlhiwymnopjstuvbefk
353 # Iterate over trials, extracting onsets approx. within closed-loop period
354 cwarn = 0 1azcdqgrlhiwymnopjstuvbefk
355 for i, (t1, t2) in enumerate(zip(trials['goCue_times'] - min_qt, trials['feedback_times'])): 1azcdqgrlhiwymnopjstuvbefk
356 if ~np.isnan(t2 - t1): # If both timestamps defined 1azcdqgrlhiwymnopjstuvbefk
357 mask = (all_move_onsets > t1) & (all_move_onsets < t2) 1zcdqgrlhiwymnopjstbefk
358 if np.any(mask): # If any onsets for this trial 1zcdqgrlhiwymnopjstbefk
359 trial_onset_ids, = np.where(mask) 1zcdqgrlhiwymnopjstbefk
360 if np.any(~flinch[mask]): # If any trial moves were sufficiently large 1zcdqgrlhiwymnopjstbefk
361 ids[i] = trial_onset_ids[~flinch[mask]][0] # Find first large move id 1zcdqgrlhiwymnopjstbefk
362 first_move_onsets[i] = all_move_onsets[ids[i]] # Save first large onset 1zcdqgrlhiwymnopjstbefk
363 is_final_movement[i] = ids[i] == trial_onset_ids[-1] # Final move of trial 1zcdqgrlhiwymnopjstbefk
364 else: # Log missing timestamps
365 cwarn += 1 1acdgyjuvbef
366 if cwarn: 1azcdqgrlhiwymnopjstuvbefk
367 _logger.warning(f'no reliable goCue/Feedback times (both needed) for {cwarn} trials') 1acdgyjuvbef
369 return first_move_onsets, is_final_movement, ids[ids != -1] 1azcdqgrlhiwymnopjstuvbefk
372class Wheel(BaseBpodTrialsExtractor):
373 """
374 Get wheel data from raw files and converts positions into radians mathematical convention
375 (anti-clockwise = +) and timestamps into seconds relative to Bpod clock.
376 **Optional:** saves _ibl_wheel.times.npy and _ibl_wheel.position.npy
378 Times:
379 Gets Rotary Encoder timestamps (us) for each position and converts to times.
380 Synchronize with Bpod and outputs
382 Positions:
383 Radians mathematical convention
384 """
385 save_names = ('_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
386 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None,
387 '_ibl_trials.firstMovement_times.npy', None)
388 var_names = ('wheel_timestamps', 'wheel_position', 'wheel_moves_intervals',
389 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'firstMovement_times',
390 'is_final_movement')
392 def _extract(self):
393 ts, pos = get_wheel_position(self.session_path, self.bpod_trials, task_collection=self.task_collection) 1acdqgrlhiwmnopjstuvbefk
394 moves = extract_wheel_moves(ts, pos) 1acdqgrlhiwmnopjstuvbefk
396 # need some trial based info to output the first movement times
397 from ibllib.io.extractors import training_trials # Avoids circular imports 1acdqgrlhiwmnopjstuvbefk
398 goCue_times, _ = training_trials.GoCueTimes(self.session_path).extract( 1acdqgrlhiwmnopjstuvbefk
399 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)
400 feedback_times, _ = training_trials.FeedbackTimes(self.session_path).extract( 1acdqgrlhiwmnopjstuvbefk
401 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)
402 trials = {'goCue_times': goCue_times, 'feedback_times': feedback_times} 1acdqgrlhiwmnopjstuvbefk
403 min_qt = self.settings.get('QUIESCENT_PERIOD', None) 1acdqgrlhiwmnopjstuvbefk
405 first_moves, is_final, _ = extract_first_movement_times(moves, trials, min_qt=min_qt) 1acdqgrlhiwmnopjstuvbefk
406 output = (ts, pos, moves['intervals'], moves['peakAmplitude'], 1acdqgrlhiwmnopjstuvbefk
407 moves['peakVelocity_times'], first_moves, is_final)
408 return output 1acdqgrlhiwmnopjstuvbefk
411def extract_all(session_path, bpod_trials=None, settings=None, save=False, task_collection='raw_behavior_data', save_path=None):
412 """Extract the wheel data.
414 NB: Wheel extraction is now called through ibllib.io.training_trials.extract_all
416 Parameters
417 ----------
418 session_path : str, pathlib.Path
419 The path to the session
420 save : bool
421 If true save the data files to ALF
422 bpod_trials : list of dicts
423 The Bpod trial dicts loaded from the _iblrig_taskData.raw dataset
424 settings : dict
425 The Bpod settings loaded from the _iblrig_taskSettings.raw dataset
427 Returns
428 -------
429 A list of extracted data and a list of file paths if save is True (otherwise None)
430 """
431 return run_extractor_classes(Wheel, save=save, session_path=session_path,
432 bpod_trials=bpod_trials, settings=settings, task_collection=task_collection, path_out=save_path)