Coverage for ibllib/io/extractors/training_wheel.py: 87%

216 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1"""Extractors for the wheel position, velocity, and detected movement.""" 

2import logging 

3from collections.abc import Sized 

4 

5import numpy as np 

6from scipy import interpolate 

7 

8from ibldsp.utils import sync_timestamps 

9from ibllib.io.extractors.base import BaseBpodTrialsExtractor 

10import ibllib.io.raw_data_loaders as raw 

11from ibllib.misc import structarr 

12import ibllib.exceptions as err 

13import brainbox.behavior.wheel as wh 

14 

15_logger = logging.getLogger(__name__) 

16WHEEL_RADIUS_CM = 1 # we want the output in radians 

17THRESHOLD_RAD_PER_SEC = 10 

18THRESHOLD_CONSECUTIVE_SAMPLES = 0 

19EPS = 7. / 3 - 4. / 3 - 1 

20 

21 

22def get_trial_start_times(session_path, data=None, task_collection='raw_behavior_data'): 

23 if not data: 1acmhefoijgknbdlp

24 data = raw.load_data(session_path, task_collection=task_collection) 1acmhefoijgknbdlp

25 trial_start_times = [] 1acmhefoijgknbdlp

26 for tr in data: 1acmhefoijgknbdlp

27 trial_start_times.extend( 1acmhefoijgknbdlp

28 [x[0] for x in tr['behavior_data']['States timestamps']['trial_start']]) 

29 return np.array(trial_start_times) 1acmhefoijgknbdlp

30 

31 

32def sync_rotary_encoder(session_path, bpod_data=None, re_events=None, task_collection='raw_behavior_data'): 

33 if not bpod_data: 1acmhefoijgknbdlp

34 bpod_data = raw.load_data(session_path, task_collection=task_collection) 1acmhefoijgknbdlp

35 evt = re_events or raw.load_encoder_events(session_path, task_collection=task_collection) 1acmhefoijgknbdlp

36 # we work with stim_on (2) and closed_loop (3) states for the synchronization with bpod 

37 tre = evt.re_ts.values / 1e6 # convert to seconds 1acmhefoijgknbdlp

38 # the first trial on the rotary encoder is a dud 

39 rote = {'stim_on': tre[evt.sm_ev == 2][:-1], 1acmhefoijgknbdlp

40 'closed_loop': tre[evt.sm_ev == 3][:-1]} 

41 bpod = { 1acmhefoijgknbdlp

42 'stim_on': np.array([tr['behavior_data']['States timestamps'] 

43 ['stim_on'][0][0] for tr in bpod_data]), 

44 'closed_loop': np.array([tr['behavior_data']['States timestamps'] 

45 ['closed_loop'][0][0] for tr in bpod_data]), 

46 } 

47 if rote['closed_loop'].size <= 1: 1acmhefoijgknbdlp

48 raise err.SyncBpodWheelException("Not enough Rotary Encoder events to perform wheel" 

49 " synchronization. Wheel data not extracted") 

50 # bpod bug that spits out events in ms instead of us 

51 if np.diff(bpod['closed_loop'][[-1, 0]])[0] / np.diff(rote['closed_loop'][[-1, 0]])[0] > 900: 1acmhefoijgknbdlp

52 _logger.error("Rotary encoder stores values in ms instead of us. Wheel timing inaccurate") 1c

53 rote['stim_on'] *= 1e3 1c

54 rote['closed_loop'] *= 1e3 1c

55 # just use the closed loop for synchronization 

56 # handle different sizes in synchronization: 

57 sz = min(rote['closed_loop'].size, bpod['closed_loop'].size) 1acmhefoijgknbdlp

58 # if all the sample are contiguous and first samples match 

59 diff_first_match = np.diff(rote['closed_loop'][:sz]) - np.diff(bpod['closed_loop'][:sz]) 1acmhefoijgknbdlp

60 # if all the sample are contiguous and last samples match 

61 diff_last_match = np.diff(rote['closed_loop'][-sz:]) - np.diff(bpod['closed_loop'][-sz:]) 1acmhefoijgknbdlp

62 # 99% of the pulses match for a first sample lock 

63 DIFF_THRESHOLD = 0.005 1acmhefoijgknbdlp

64 if np.mean(np.abs(diff_first_match) < DIFF_THRESHOLD) > 0.99: 1acmhefoijgknbdlp

65 re = rote['closed_loop'][:sz] 1acmhefoijgknbdlp

66 bp = bpod['closed_loop'][:sz] 1acmhefoijgknbdlp

67 indko = np.where(np.abs(diff_first_match) >= DIFF_THRESHOLD)[0] 1acmhefoijgknbdlp

68 # 99% of the pulses match for a last sample lock 

69 elif np.mean(np.abs(diff_last_match) < DIFF_THRESHOLD) > 0.99: 1bp

70 re = rote['closed_loop'][-sz:] 1p

71 bp = bpod['closed_loop'][-sz:] 1p

72 indko = np.where(np.abs(diff_last_match) >= DIFF_THRESHOLD)[0] 1p

73 # last resort is to use ad-hoc sync function 

74 else: 

75 bp, re = raw.sync_trials_robust(bpod['closed_loop'], rote['closed_loop'], 1bp

76 diff_threshold=DIFF_THRESHOLD, max_shift=5) 

77 # we dont' want to change the extractor, but in rare cases the following method may save the day 

78 if len(bp) == 0: 1bp

79 _, _, ib, ir = sync_timestamps(bpod['closed_loop'], rote['closed_loop'], return_indices=True) 

80 bp = bpod['closed_loop'][ib] 

81 re = rote['closed_loop'][ir] 

82 

83 indko = np.array([]) 1bp

84 # raise ValueError("Can't sync bpod and rotary encoder: non-contiguous sync pulses") 

85 # remove faulty indices due to missing or bad syncs 

86 indko = np.int32(np.unique(np.r_[indko + 1, indko])) 1acmhefoijgknbdlp

87 re = np.delete(re, indko) 1acmhefoijgknbdlp

88 bp = np.delete(bp, indko) 1acmhefoijgknbdlp

89 # check the linear drift 

90 assert bp.size > 1 1acmhefoijgknbdlp

91 poly = np.polyfit(bp, re, 1) 1acmhefoijgknbdlp

92 assert np.all(np.abs(np.polyval(poly, bp) - re) < 0.002) 1acmhefoijgknbdlp

93 return interpolate.interp1d(re, bp, fill_value="extrapolate") 1acmhefoijgknbdlp

94 

95 

96def get_wheel_position(session_path, bp_data=None, display=False, task_collection='raw_behavior_data'): 

97 """ 

98 Gets wheel timestamps and position from Bpod data. Position is in radian (constant above for 

99 radius is 1) mathematical convention. 

100 :param session_path: 

101 :param bp_data (optional): bpod trials read from jsonable file 

102 :param display (optional): (bool) 

103 :return: timestamps (np.array) 

104 :return: positions (np.array) 

105 """ 

106 status = 0 1acmhefowijgknbdlp

107 if not bp_data: 1acmhefowijgknbdlp

108 bp_data = raw.load_data(session_path, task_collection=task_collection) 1wp

109 df = raw.load_encoder_positions(session_path, task_collection=task_collection) 1acmhefowijgknbdlp

110 if df is None: 1acmhefowijgknbdlp

111 _logger.error('No wheel data for ' + str(session_path)) 1w

112 return None, None 1w

113 data = structarr(['re_ts', 're_pos', 'bns_ts'], 1acmhefoijgknbdlp

114 shape=(df.shape[0],), formats=['f8', 'f8', object]) 

115 data['re_ts'] = df.re_ts.values 1acmhefoijgknbdlp

116 data['re_pos'] = df.re_pos.values * -1 # anti-clockwise is positive in our output 1acmhefoijgknbdlp

117 data['re_pos'] = data['re_pos'] / 1024 * 2 * np.pi # convert positions to radians 1acmhefoijgknbdlp

118 trial_starts = get_trial_start_times(session_path, task_collection=task_collection) 1acmhefoijgknbdlp

119 # need a flag if the data resolution is 1ms due to the old version of rotary encoder firmware 

120 if np.all(np.mod(data['re_ts'], 1e3) == 0): 1acmhefoijgknbdlp

121 status = 1 1obdp

122 data['re_ts'] = data['re_ts'] / 1e6 # convert ts to seconds 1acmhefoijgknbdlp

123 # # get the converter function to translate re_ts into behavior times 

124 re2bpod = sync_rotary_encoder(session_path, task_collection=task_collection) 1acmhefoijgknbdlp

125 data['re_ts'] = re2bpod(data['re_ts']) 1acmhefoijgknbdlp

126 

127 def get_reset_trace_compensation_with_state_machine_times(): 1acmhefoijgknbdlp

128 # this is the preferred way of getting resets using the state machine time information 

129 # it will not always work depending on firmware versions, new bugs 

130 iwarn = [] 1acmhefijgknbdlp

131 ns = len(data['re_pos']) 1acmhefijgknbdlp

132 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1acmhefijgknbdlp

133 for bp_dat in bp_data: 1acmhefijgknbdlp

134 restarts = np.sort(np.array( 1acmhefijgknbdlp

135 bp_dat['behavior_data']['States timestamps']['reset_rotary_encoder'] + 

136 bp_dat['behavior_data']['States timestamps']['reset2_rotary_encoder'])[:, 0]) 

137 ind = np.unique(np.searchsorted(data['re_ts'], restarts, side='left') - 1) 1acmhefijgknbdlp

138 # the rotary encoder doesn't always reset right away, and the reset sample given the 

139 # timestamp can be ambiguous: look for zeros 

140 for i in np.where(data['re_pos'][ind] != 0)[0]: 1acmhefijgknbdlp

141 # handle boundary effects 

142 if ind[i] > ns - 2: 1acmhefijgknbdlp

143 continue 1p

144 # it happens quite often that we have to lock in to next sample to find the reset 

145 if data['re_pos'][ind[i] + 1] == 0: 1acmhefijgknbdlp

146 ind[i] = ind[i] + 1 1acmhefijgknbdlp

147 continue 1acmhefijgknbdlp

148 # also case where the rotary doesn't reset to 0, but erratically to -1/+1 

149 if data['re_pos'][ind[i]] <= (1 / 1024 * 2 * np.pi): 1cefbp

150 ind[i] = ind[i] + 1 1cefbp

151 continue 1cefbp

152 # compounded with the fact that the reset may have happened at next sample. 

153 if np.abs(data['re_pos'][ind[i] + 1]) <= (1 / 1024 * 2 * np.pi): 1b

154 ind[i] = ind[i] + 1 

155 continue 

156 # sometimes it is also the last trial that has this behaviour 

157 if (bp_data[-1] is bp_dat) or (bp_data[0] is bp_dat): 1b

158 continue 

159 iwarn.append(ind[i]) 1b

160 # at which point we are running out of possible bugs and calling it 

161 tr_dc[ind] = data['re_pos'][ind - 1] 1acmhefijgknbdlp

162 if iwarn: # if a warning flag was caught in the loop throw a single warning 1acmhefijgknbdlp

163 _logger.warning('Rotary encoder reset events discrepancy. Doing my best to merge.') 1b

164 _logger.debug('Offending inds: ' + str(iwarn) + ' times: ' + str(data['re_ts'][iwarn])) 1b

165 # exit status 0 is fine, 1 something went wrong 

166 return tr_dc, len(iwarn) != 0 1acmhefijgknbdlp

167 

168 # attempt to get the resets properly unless the unit is ms which means precision is 

169 # not good enough to match SM times to wheel samples time 

170 if not status: 1acmhefoijgknbdlp

171 tr_dc, status = get_reset_trace_compensation_with_state_machine_times() 1acmhefijgknbdlp

172 

173 # if something was wrong or went wrong agnostic way of getting resets: just get zeros values 

174 if status: 1acmhefoijgknbdlp

175 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1obdp

176 i0 = np.where(data['re_pos'] == 0)[0] 1obdp

177 tr_dc[i0] = data['re_pos'][i0 - 1] 1obdp

178 # even if things went ok, rotary encoder may not log the whole session. Need to fix outside 

179 else: 

180 i0 = np.where(np.bitwise_and(np.bitwise_or(data['re_ts'] >= trial_starts[-1], 1acmhefijgknbdlp

181 data['re_ts'] <= trial_starts[0]), 

182 data['re_pos'] == 0))[0] 

183 # make sure the bounds are not included in the current list 

184 i0 = np.delete(i0, np.where(np.bitwise_or(i0 >= len(data['re_pos']) - 1, i0 == 0))) 1acmhefoijgknbdlp

185 # a 0 sample is not a reset if 2 conditions are met: 

186 # 1/2 no inflexion (continuous derivative) 

187 c1 = np.abs(np.sign(data['re_pos'][i0 + 1] - data['re_pos'][i0]) - 1acmhefoijgknbdlp

188 np.sign(data['re_pos'][i0] - data['re_pos'][i0 - 1])) == 2 

189 # 2/2 needs to be below threshold 

190 c2 = np.abs((data['re_pos'][i0] - data['re_pos'][i0 - 1]) / 1acmhefoijgknbdlp

191 (EPS + (data['re_ts'][i0] - data['re_ts'][i0 - 1]))) < THRESHOLD_RAD_PER_SEC 

192 # apply reset to points identified as resets 

193 i0 = i0[np.where(np.bitwise_not(np.bitwise_and(c1, c2)))] 1acmhefoijgknbdlp

194 tr_dc[i0] = data['re_pos'][i0 - 1] 1acmhefoijgknbdlp

195 

196 # unwrap the rotation (in radians) and then add the DC component from restarts 

197 data['re_pos'] = np.unwrap(data['re_pos']) + np.cumsum(tr_dc) 1acmhefoijgknbdlp

198 

199 # Also forgot to mention that time stamps may be repeated or very close to one another. 

200 # Find them as they will induce large jitters on the velocity function or errors in 

201 # attempts of interpolation 

202 rep_idx = np.where(np.diff(data['re_ts']) <= THRESHOLD_CONSECUTIVE_SAMPLES)[0] 1acmhefoijgknbdlp

203 # Change the value of the repeated position 

204 data['re_pos'][rep_idx] = (data['re_pos'][rep_idx] + 1acmhefoijgknbdlp

205 data['re_pos'][rep_idx + 1]) / 2 

206 data['re_ts'][rep_idx] = (data['re_ts'][rep_idx] + 1acmhefoijgknbdlp

207 data['re_ts'][rep_idx + 1]) / 2 

208 # Now remove the repeat times that are rep_idx + 1 

209 data = np.delete(data, rep_idx + 1) 1acmhefoijgknbdlp

210 

211 # convert to cm 

212 data['re_pos'] = data['re_pos'] * WHEEL_RADIUS_CM 1acmhefoijgknbdlp

213 

214 # DEBUG PLOTS START HERE ######################## 

215 if display: 1acmhefoijgknbdlp

216 import matplotlib.pyplot as plt 

217 plt.figure() 

218 ax = plt.axes() 

219 tstart = get_trial_start_times(session_path) 

220 tts = np.c_[tstart, tstart, tstart + np.nan].flatten() 

221 vts = np.c_[tstart * 0 + 100, tstart * 0 - 100, tstart + np.nan].flatten() 

222 ax.plot(tts, vts, label='Trial starts') 

223 ax.plot(re2bpod(df.re_ts.values / 1e6), df.re_pos.values / 1024 * 2 * np.pi, 

224 '.-', label='Raw data') 

225 i0 = np.where(df.re_pos.values == 0) 

226 ax.plot(re2bpod(df.re_ts.values[i0] / 1e6), df.re_pos.values[i0] / 1024 * 2 * np.pi, 

227 'r*', label='Raw data zero samples') 

228 ax.plot(re2bpod(df.re_ts.values / 1e6), tr_dc, label='reset compensation') 

229 ax.set_xlabel('Bpod Time') 

230 ax.set_ylabel('radians') 

231 # restarts = np.array(bp_data[10]['behavior_data']['States timestamps'] 

232 # ['reset_rotary_encoder']).flatten() 

233 # x__ = np.c_[restarts, restarts, restarts + np.nan].flatten() 

234 # y__ = np.c_[restarts * 0 + 1, restarts * 0 - 1, restarts+ np.nan].flatten() 

235 # ax.plot(x__, y__, 'k', label='Restarts') 

236 ax.plot(data['re_ts'], data['re_pos'] / WHEEL_RADIUS_CM, '.-', label='Output Trace') 

237 ax.legend() 

238 # plt.hist(np.diff(data['re_ts']), 400, range=[0, 0.01]) 

239 return data['re_ts'], data['re_pos'] 1acmhefoijgknbdlp

240 

241 

242def infer_wheel_units(pos): 

243 """ 

244 Given an array of wheel positions, infer the rotary encoder resolution, encoding type and units 

245 

246 The encoding type varies across hardware (Bpod uses X1 while FPGA usually extracted as X4), and 

247 older data were extracted in linear cm rather than radians. 

248 

249 :param pos: a 1D array of extracted wheel positions 

250 :return units: the position units, assumed to be either 'rad' or 'cm' 

251 :return resolution: the number of decoded fronts per 360 degree rotation 

252 :return encoding: one of {'X1', 'X2', 'X4'} 

253 """ 

254 if len(pos.shape) > 1: # Ensure 1D array of positions 1avqstcmhefourijgknbdl

255 pos = pos.flatten() 

256 

257 # Check the values and units of wheel position 

258 res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) 1avqstcmhefourijgknbdl

259 # min change in rad and cm for each decoding type 

260 # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1] 

261 min_change = np.concatenate([2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res]) 1avqstcmhefourijgknbdl

262 pos_diff = np.median(np.abs(np.ediff1d(pos))) 1avqstcmhefourijgknbdl

263 

264 # find min change closest to min pos_diff 

265 idx = np.argmin(np.abs(min_change - pos_diff)) 1avqstcmhefourijgknbdl

266 if idx < len(res): 1avqstcmhefourijgknbdl

267 # Assume values are in radians 

268 units = 'rad' 1avstcmhefourijgknbdl

269 encoding = idx 1avstcmhefourijgknbdl

270 else: 

271 units = 'cm' 1vqst

272 encoding = idx - len(res) 1vqst

273 enc_names = {0: 'X4', 1: 'X2', 2: 'X1'} 1avqstcmhefourijgknbdl

274 return units, int(res[encoding]), enc_names[int(encoding)] 1avqstcmhefourijgknbdl

275 

276 

277def extract_wheel_moves(re_ts, re_pos, display=False): 

278 """ 

279 Extract wheel positions and times from sync fronts dictionary 

280 :param re_ts: numpy array of rotary encoder timestamps 

281 :param re_pos: numpy array of rotary encoder positions 

282 :param display: bool: show the wheel position and velocity for full session with detected 

283 movements highlighted 

284 :return: wheel_moves dictionary 

285 """ 

286 if len(re_ts.shape) == 1: 1aqstcmhefourijgknbdl

287 assert re_ts.size == re_pos.size, 'wheel data dimension mismatch' 1aqstcmhefourijgknbdl

288 else: 

289 _logger.debug('2D wheel timestamps') 

290 if len(re_pos.shape) > 1: # Ensure 1D array of positions 

291 re_pos = re_pos.flatten() 

292 # Linearly interpolate the times 

293 x = np.arange(re_pos.size) 

294 re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1]) 

295 

296 units, res, enc = infer_wheel_units(re_pos) 1aqstcmhefourijgknbdl

297 _logger.info('Wheel in %s units using %s encoding', units, enc) 1aqstcmhefourijgknbdl

298 

299 # The below assertion is violated by Bpod wheel data 

300 # assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips' 

301 

302 # Convert the pos threshold defaults from samples to correct unit 

303 thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res) 1aqstcmhefourijgknbdl

304 if units == 'rad': 1aqstcmhefourijgknbdl

305 thresholds = wh.cm_to_rad(thresholds) 1astcmhefourijgknbdl

306 kwargs = {'pos_thresh': thresholds[0], 1aqstcmhefourijgknbdl

307 'pos_thresh_onset': thresholds[1], 

308 'make_plots': display} 

309 

310 # Interpolate and get onsets 

311 pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000) 1aqstcmhefourijgknbdl

312 on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) 1aqstcmhefourijgknbdl

313 assert on.size == off.size, 'onset/offset number mismatch' 1aqstcmhefourijgknbdl

314 assert np.all(np.diff(on) > 0) and np.all( 1aqstcmhefourijgknbdl

315 np.diff(off) > 0), 'onsets/offsets not strictly increasing' 

316 assert np.all((off - on) > 0), 'not all offsets occur after onset' 1aqstcmhefourijgknbdl

317 

318 # Put into dict 

319 wheel_moves = { 1aqstcmhefourijgknbdl

320 'intervals': np.c_[on, off], 'peakAmplitude': amp, 'peakVelocity_times': peak_vel} 

321 return wheel_moves 1aqstcmhefourijgknbdl

322 

323 

324def extract_first_movement_times(wheel_moves, trials, min_qt=None): 

325 """ 

326 Extracts the time of the first sufficiently large wheel movement for each trial. 

327 

328 To be counted, the movement must occur between go cue / stim on and before feedback / 

329 response time. The movement onset is sometimes just before the cue (occurring in the 

330 gap between quiescence end and cue start, or during the quiescence period but sub- 

331 threshold). The movement is sufficiently large if it is greater than or equal to THRESH. 

332 

333 Parameters 

334 ---------- 

335 wheel_moves : dict 

336 Dictionary of detected wheel movement onsets and peak amplitudes for use in extracting each 

337 trial's time of first movement. 

338 trials : dict 

339 Dictionary of trial data. 

340 min_qt : float 

341 The minimum quiescence period in seconds, if None a default is used. 

342 

343 Returns 

344 ------- 

345 numpy.array 

346 First movement times. 

347 numpy.array 

348 Bool array indicating whether movement crossed response threshold. 

349 numpy.array 

350 Indices for wheel_moves arrays. 

351 """ 

352 THRESH = .1 # peak amp should be at least .1 rad; ~1/3rd of the distance to threshold 1aqcmheforijgknbdl

353 MIN_QT = .2 # default minimum enforced quiescence period 1aqcmheforijgknbdl

354 

355 # Determine minimum quiescent period 

356 if min_qt is None: 1aqcmheforijgknbdl

357 min_qt = MIN_QT 1aqigkbdl

358 _logger.info('minimum quiescent period assumed to be %.0fms', MIN_QT * 1e3) 1aqigkbdl

359 elif isinstance(min_qt, Sized) and len(min_qt) > len(trials['goCue_times']): 1cmheforjnb

360 min_qt = np.array(min_qt[0:trials['goCue_times'].size]) 1hj

361 

362 # Initialize as nans 

363 first_move_onsets = np.full(trials['goCue_times'].shape, np.nan) 1aqcmheforijgknbdl

364 ids = np.full(trials['goCue_times'].shape, int(-1)) 1aqcmheforijgknbdl

365 is_final_movement = np.zeros(trials['goCue_times'].shape, bool) 1aqcmheforijgknbdl

366 flinch = abs(wheel_moves['peakAmplitude']) < THRESH 1aqcmheforijgknbdl

367 all_move_onsets = wheel_moves['intervals'][:, 0] 1aqcmheforijgknbdl

368 # Iterate over trials, extracting onsets approx. within closed-loop period 

369 cwarn = 0 1aqcmheforijgknbdl

370 for i, (t1, t2) in enumerate(zip(trials['goCue_times'] - min_qt, trials['feedback_times'])): 1aqcmheforijgknbdl

371 if ~np.isnan(t2 - t1): # If both timestamps defined 1aqcmheforijgknbdl

372 mask = (all_move_onsets > t1) & (all_move_onsets < t2) 1qcmheforijgknbdl

373 if np.any(mask): # If any onsets for this trial 1qcmheforijgknbdl

374 trial_onset_ids, = np.where(mask) 1qcmheforijgknbdl

375 if np.any(~flinch[mask]): # If any trial moves were sufficiently large 1qcmheforijgknbdl

376 ids[i] = trial_onset_ids[~flinch[mask]][0] # Find first large move id 1qcmheforijgknbdl

377 first_move_onsets[i] = all_move_onsets[ids[i]] # Save first large onset 1qcmheforijgknbdl

378 is_final_movement[i] = ids[i] == trial_onset_ids[-1] # Final move of trial 1qcmheforijgknbdl

379 else: # Log missing timestamps 

380 cwarn += 1 1acgbd

381 if cwarn: 1aqcmheforijgknbdl

382 _logger.warning(f'no reliable goCue/Feedback times (both needed) for {cwarn} trials') 1acgbd

383 

384 return first_move_onsets, is_final_movement, ids[ids != -1] 1aqcmheforijgknbdl

385 

386 

387class Wheel(BaseBpodTrialsExtractor): 

388 """ 

389 Wheel extractor. 

390 

391 Get wheel data from raw files and converts positions into radians mathematical convention 

392 (anti-clockwise = +) and timestamps into seconds relative to Bpod clock. 

393 **Optional:** saves _ibl_wheel.times.npy and _ibl_wheel.position.npy 

394 

395 Times: 

396 Gets Rotary Encoder timestamps (us) for each position and converts to times. 

397 Synchronize with Bpod and outputs 

398 

399 Positions: 

400 Radians mathematical convention 

401 """ 

402 save_names = ('_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

403 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, 

404 '_ibl_trials.firstMovement_times.npy', None) 

405 var_names = ('wheel_timestamps', 'wheel_position', 

406 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 

407 'firstMovement_times', 'is_final_movement') 

408 

409 def _extract(self): 

410 ts, pos = get_wheel_position(self.session_path, self.bpod_trials, task_collection=self.task_collection) 1acmhefoijgknbdl

411 moves = extract_wheel_moves(ts, pos) 1acmhefoijgknbdl

412 

413 # need some trial based info to output the first movement times 

414 from ibllib.io.extractors import training_trials # Avoids circular imports 1acmhefoijgknbdl

415 goCue_times, _ = training_trials.GoCueTimes(self.session_path).extract( 1acmhefoijgknbdl

416 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

417 feedback_times, _ = training_trials.FeedbackTimes(self.session_path).extract( 1acmhefoijgknbdl

418 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

419 trials = {'goCue_times': goCue_times, 'feedback_times': feedback_times} 1acmhefoijgknbdl

420 min_qt = self.settings.get('QUIESCENT_PERIOD', None) 1acmhefoijgknbdl

421 

422 first_moves, is_final, _ = extract_first_movement_times(moves, trials, min_qt=min_qt) 1acmhefoijgknbdl

423 output = (ts, pos, moves['intervals'], moves['peakAmplitude'], moves['peakVelocity_times'], first_moves, is_final) 1acmhefoijgknbdl

424 return output 1acmhefoijgknbdl