Coverage for ibllib/io/extractors/training_wheel.py: 87%

218 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Extractors for the wheel position, velocity, and detected movement.""" 

2import logging 

3from collections.abc import Sized 

4 

5import numpy as np 

6from scipy import interpolate 

7 

8from ibldsp.utils import sync_timestamps 

9from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

10import ibllib.io.raw_data_loaders as raw 

11from ibllib.misc import structarr 

12import ibllib.exceptions as err 

13import brainbox.behavior.wheel as wh 

14 

15_logger = logging.getLogger(__name__) 

16WHEEL_RADIUS_CM = 1 # we want the output in radians 

17THRESHOLD_RAD_PER_SEC = 10 

18THRESHOLD_CONSECUTIVE_SAMPLES = 0 

19EPS = 7. / 3 - 4. / 3 - 1 

20 

21 

22def get_trial_start_times(session_path, data=None, task_collection='raw_behavior_data'): 

23 if not data: 1acdqgrlhiwmnojkstuvbefpx

24 data = raw.load_data(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx

25 trial_start_times = [] 1acdqgrlhiwmnojkstuvbefpx

26 for tr in data: 1acdqgrlhiwmnojkstuvbefpx

27 trial_start_times.extend( 1acdqgrlhiwmnojkstuvbefpx

28 [x[0] for x in tr['behavior_data']['States timestamps']['trial_start']]) 

29 return np.array(trial_start_times) 1acdqgrlhiwmnojkstuvbefpx

30 

31 

32def sync_rotary_encoder(session_path, bpod_data=None, re_events=None, task_collection='raw_behavior_data'): 

33 if not bpod_data: 1acdqgrlhiwmnojkstuvbefpx

34 bpod_data = raw.load_data(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx

35 evt = re_events or raw.load_encoder_events(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx

36 # we work with stim_on (2) and closed_loop (3) states for the synchronization with bpod 

37 tre = evt.re_ts.values / 1e6 # convert to seconds 1acdqgrlhiwmnojkstuvbefpx

38 # the first trial on the rotary encoder is a dud 

39 rote = {'stim_on': tre[evt.sm_ev == 2][:-1], 1acdqgrlhiwmnojkstuvbefpx

40 'closed_loop': tre[evt.sm_ev == 3][:-1]} 

41 bpod = { 1acdqgrlhiwmnojkstuvbefpx

42 'stim_on': np.array([tr['behavior_data']['States timestamps'] 

43 ['stim_on'][0][0] for tr in bpod_data]), 

44 'closed_loop': np.array([tr['behavior_data']['States timestamps'] 

45 ['closed_loop'][0][0] for tr in bpod_data]), 

46 } 

47 if rote['closed_loop'].size <= 1: 1acdqgrlhiwmnojkstuvbefpx

48 raise err.SyncBpodWheelException("Not enough Rotary Encoder events to perform wheel" 

49 " synchronization. Wheel data not extracted") 

50 # bpod bug that spits out events in ms instead of us 

51 if np.diff(bpod['closed_loop'][[-1, 0]])[0] / np.diff(rote['closed_loop'][[-1, 0]])[0] > 900: 1acdqgrlhiwmnojkstuvbefpx

52 _logger.error("Rotary encoder stores values in ms instead of us. Wheel timing inaccurate") 1cd

53 rote['stim_on'] *= 1e3 1cd

54 rote['closed_loop'] *= 1e3 1cd

55 # just use the closed loop for synchronization 

56 # handle different sizes in synchronization: 

57 sz = min(rote['closed_loop'].size, bpod['closed_loop'].size) 1acdqgrlhiwmnojkstuvbefpx

58 # if all the sample are contiguous and first samples match 

59 diff_first_match = np.diff(rote['closed_loop'][:sz]) - np.diff(bpod['closed_loop'][:sz]) 1acdqgrlhiwmnojkstuvbefpx

60 # if all the sample are contiguous and last samples match 

61 diff_last_match = np.diff(rote['closed_loop'][-sz:]) - np.diff(bpod['closed_loop'][-sz:]) 1acdqgrlhiwmnojkstuvbefpx

62 # 99% of the pulses match for a first sample lock 

63 DIFF_THRESHOLD = 0.005 1acdqgrlhiwmnojkstuvbefpx

64 if np.mean(np.abs(diff_first_match) < DIFF_THRESHOLD) > 0.99: 1acdqgrlhiwmnojkstuvbefpx

65 re = rote['closed_loop'][:sz] 1acdqgrlhiwmnojkstuvbefpx

66 bp = bpod['closed_loop'][:sz] 1acdqgrlhiwmnojkstuvbefpx

67 indko = np.where(np.abs(diff_first_match) >= DIFF_THRESHOLD)[0] 1acdqgrlhiwmnojkstuvbefpx

68 # 99% of the pulses match for a last sample lock 

69 elif np.mean(np.abs(diff_last_match) < DIFF_THRESHOLD) > 0.99: 1bx

70 re = rote['closed_loop'][-sz:] 1x

71 bp = bpod['closed_loop'][-sz:] 1x

72 indko = np.where(np.abs(diff_last_match) >= DIFF_THRESHOLD)[0] 1x

73 # last resort is to use ad-hoc sync function 

74 else: 

75 bp, re = raw.sync_trials_robust(bpod['closed_loop'], rote['closed_loop'], 1bx

76 diff_threshold=DIFF_THRESHOLD, max_shift=5) 

77 # we dont' want to change the extractor, but in rare cases the following method may save the day 

78 if len(bp) == 0: 1bx

79 _, _, ib, ir = sync_timestamps(bpod['closed_loop'], rote['closed_loop'], return_indices=True) 

80 bp = bpod['closed_loop'][ib] 

81 re = rote['closed_loop'][ir] 

82 

83 indko = np.array([]) 1bx

84 # raise ValueError("Can't sync bpod and rotary encoder: non-contiguous sync pulses") 

85 # remove faulty indices due to missing or bad syncs 

86 indko = np.int32(np.unique(np.r_[indko + 1, indko])) 1acdqgrlhiwmnojkstuvbefpx

87 re = np.delete(re, indko) 1acdqgrlhiwmnojkstuvbefpx

88 bp = np.delete(bp, indko) 1acdqgrlhiwmnojkstuvbefpx

89 # check the linear drift 

90 assert bp.size > 1 1acdqgrlhiwmnojkstuvbefpx

91 poly = np.polyfit(bp, re, 1) 1acdqgrlhiwmnojkstuvbefpx

92 assert np.all(np.abs(np.polyval(poly, bp) - re) < 0.002) 1acdqgrlhiwmnojkstuvbefpx

93 return interpolate.interp1d(re, bp, fill_value="extrapolate") 1acdqgrlhiwmnojkstuvbefpx

94 

95 

96def get_wheel_position(session_path, bp_data=None, display=False, task_collection='raw_behavior_data'): 

97 """ 

98 Gets wheel timestamps and position from Bpod data. Position is in radian (constant above for 

99 radius is 1) mathematical convention. 

100 :param session_path: 

101 :param bp_data (optional): bpod trials read from jsonable file 

102 :param display (optional): (bool) 

103 :return: timestamps (np.array) 

104 :return: positions (np.array) 

105 """ 

106 status = 0 1acdqgrlhiwFmnojkstuvbefpx

107 if not bp_data: 1acdqgrlhiwFmnojkstuvbefpx

108 bp_data = raw.load_data(session_path, task_collection=task_collection) 1Fx

109 df = raw.load_encoder_positions(session_path, task_collection=task_collection) 1acdqgrlhiwFmnojkstuvbefpx

110 if df is None: 1acdqgrlhiwFmnojkstuvbefpx

111 _logger.error('No wheel data for ' + str(session_path)) 1Fb

112 return None, None 1Fb

113 data = structarr(['re_ts', 're_pos', 'bns_ts'], 1acdqgrlhiwmnojkstuvbefpx

114 shape=(df.shape[0],), formats=['f8', 'f8', object]) 

115 data['re_ts'] = df.re_ts.values 1acdqgrlhiwmnojkstuvbefpx

116 data['re_pos'] = df.re_pos.values * -1 # anti-clockwise is positive in our output 1acdqgrlhiwmnojkstuvbefpx

117 data['re_pos'] = data['re_pos'] / 1024 * 2 * np.pi # convert positions to radians 1acdqgrlhiwmnojkstuvbefpx

118 trial_starts = get_trial_start_times(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx

119 # need a flag if the data resolution is 1ms due to the old version of rotary encoder firmware 

120 if np.all(np.mod(data['re_ts'], 1e3) == 0): 1acdqgrlhiwmnojkstuvbefpx

121 status = 1 1wbefx

122 data['re_ts'] = data['re_ts'] / 1e6 # convert ts to seconds 1acdqgrlhiwmnojkstuvbefpx

123 # # get the converter function to translate re_ts into behavior times 

124 re2bpod = sync_rotary_encoder(session_path, task_collection=task_collection) 1acdqgrlhiwmnojkstuvbefpx

125 data['re_ts'] = re2bpod(data['re_ts']) 1acdqgrlhiwmnojkstuvbefpx

126 

127 def get_reset_trace_compensation_with_state_machine_times(): 1acdqgrlhiwmnojkstuvbefpx

128 # this is the preferred way of getting resets using the state machine time information 

129 # it will not always work depending on firmware versions, new bugs 

130 iwarn = [] 1acdqgrlhimnojkstuvbefpx

131 ns = len(data['re_pos']) 1acdqgrlhimnojkstuvbefpx

132 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1acdqgrlhimnojkstuvbefpx

133 for bp_dat in bp_data: 1acdqgrlhimnojkstuvbefpx

134 restarts = np.sort(np.array( 1acdqgrlhimnojkstuvbefpx

135 bp_dat['behavior_data']['States timestamps']['reset_rotary_encoder'] + 

136 bp_dat['behavior_data']['States timestamps']['reset2_rotary_encoder'])[:, 0]) 

137 ind = np.unique(np.searchsorted(data['re_ts'], restarts, side='left') - 1) 1acdqgrlhimnojkstuvbefpx

138 # the rotary encoder doesn't always reset right away, and the reset sample given the 

139 # timestamp can be ambiguous: look for zeros 

140 for i in np.where(data['re_pos'][ind] != 0)[0]: 1acdqgrlhimnojkstuvbefpx

141 # handle boundary effects 

142 if ind[i] > ns - 2: 1acdqgrlhimnojkstuvbefpx

143 continue 1x

144 # it happens quite often that we have to lock in to next sample to find the reset 

145 if data['re_pos'][ind[i] + 1] == 0: 1acdqgrlhimnojkstuvbefpx

146 ind[i] = ind[i] + 1 1acdqgrlhimnojkstuvbefpx

147 continue 1acdqgrlhimnojkstuvbefpx

148 # also case where the rotary doesn't reset to 0, but erratically to -1/+1 

149 if data['re_pos'][ind[i]] <= (1 / 1024 * 2 * np.pi): 1cdghibx

150 ind[i] = ind[i] + 1 1cdhibx

151 continue 1cdhibx

152 # compounded with the fact that the reset may have happened at next sample. 

153 if np.abs(data['re_pos'][ind[i] + 1]) <= (1 / 1024 * 2 * np.pi): 1gb

154 ind[i] = ind[i] + 1 

155 continue 

156 # sometimes it is also the last trial that has this behaviour 

157 if (bp_data[-1] is bp_dat) or (bp_data[0] is bp_dat): 1gb

158 continue 1g

159 iwarn.append(ind[i]) 1b

160 # at which point we are running out of possible bugs and calling it 

161 tr_dc[ind] = data['re_pos'][ind - 1] 1acdqgrlhimnojkstuvbefpx

162 if iwarn: # if a warning flag was caught in the loop throw a single warning 1acdqgrlhimnojkstuvbefpx

163 _logger.warning('Rotary encoder reset events discrepancy. Doing my best to merge.') 1b

164 _logger.debug('Offending inds: ' + str(iwarn) + ' times: ' + str(data['re_ts'][iwarn])) 1b

165 # exit status 0 is fine, 1 something went wrong 

166 return tr_dc, len(iwarn) != 0 1acdqgrlhimnojkstuvbefpx

167 

168 # attempt to get the resets properly unless the unit is ms which means precision is 

169 # not good enough to match SM times to wheel samples time 

170 if not status: 1acdqgrlhiwmnojkstuvbefpx

171 tr_dc, status = get_reset_trace_compensation_with_state_machine_times() 1acdqgrlhimnojkstuvbefpx

172 

173 # if something was wrong or went wrong agnostic way of getting resets: just get zeros values 

174 if status: 1acdqgrlhiwmnojkstuvbefpx

175 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1wbefx

176 i0 = np.where(data['re_pos'] == 0)[0] 1wbefx

177 tr_dc[i0] = data['re_pos'][i0 - 1] 1wbefx

178 # even if things went ok, rotary encoder may not log the whole session. Need to fix outside 

179 else: 

180 i0 = np.where(np.bitwise_and(np.bitwise_or(data['re_ts'] >= trial_starts[-1], 1acdqgrlhimnojkstuvbefpx

181 data['re_ts'] <= trial_starts[0]), 

182 data['re_pos'] == 0))[0] 

183 # make sure the bounds are not included in the current list 

184 i0 = np.delete(i0, np.where(np.bitwise_or(i0 >= len(data['re_pos']) - 1, i0 == 0))) 1acdqgrlhiwmnojkstuvbefpx

185 # a 0 sample is not a reset if 2 conditions are met: 

186 # 1/2 no inflexion (continuous derivative) 

187 c1 = np.abs(np.sign(data['re_pos'][i0 + 1] - data['re_pos'][i0]) - 1acdqgrlhiwmnojkstuvbefpx

188 np.sign(data['re_pos'][i0] - data['re_pos'][i0 - 1])) == 2 

189 # 2/2 needs to be below threshold 

190 c2 = np.abs((data['re_pos'][i0] - data['re_pos'][i0 - 1]) / 1acdqgrlhiwmnojkstuvbefpx

191 (EPS + (data['re_ts'][i0] - data['re_ts'][i0 - 1]))) < THRESHOLD_RAD_PER_SEC 

192 # apply reset to points identified as resets 

193 i0 = i0[np.where(np.bitwise_not(np.bitwise_and(c1, c2)))] 1acdqgrlhiwmnojkstuvbefpx

194 tr_dc[i0] = data['re_pos'][i0 - 1] 1acdqgrlhiwmnojkstuvbefpx

195 

196 # unwrap the rotation (in radians) and then add the DC component from restarts 

197 data['re_pos'] = np.unwrap(data['re_pos']) + np.cumsum(tr_dc) 1acdqgrlhiwmnojkstuvbefpx

198 

199 # Also forgot to mention that time stamps may be repeated or very close to one another. 

200 # Find them as they will induce large jitters on the velocity function or errors in 

201 # attempts of interpolation 

202 rep_idx = np.where(np.diff(data['re_ts']) <= THRESHOLD_CONSECUTIVE_SAMPLES)[0] 1acdqgrlhiwmnojkstuvbefpx

203 # Change the value of the repeated position 

204 data['re_pos'][rep_idx] = (data['re_pos'][rep_idx] + 1acdqgrlhiwmnojkstuvbefpx

205 data['re_pos'][rep_idx + 1]) / 2 

206 data['re_ts'][rep_idx] = (data['re_ts'][rep_idx] + 1acdqgrlhiwmnojkstuvbefpx

207 data['re_ts'][rep_idx + 1]) / 2 

208 # Now remove the repeat times that are rep_idx + 1 

209 data = np.delete(data, rep_idx + 1) 1acdqgrlhiwmnojkstuvbefpx

210 

211 # convert to cm 

212 data['re_pos'] = data['re_pos'] * WHEEL_RADIUS_CM 1acdqgrlhiwmnojkstuvbefpx

213 

214 # DEBUG PLOTS START HERE ######################## 

215 if display: 1acdqgrlhiwmnojkstuvbefpx

216 import matplotlib.pyplot as plt 

217 plt.figure() 

218 ax = plt.axes() 

219 tstart = get_trial_start_times(session_path) 

220 tts = np.c_[tstart, tstart, tstart + np.nan].flatten() 

221 vts = np.c_[tstart * 0 + 100, tstart * 0 - 100, tstart + np.nan].flatten() 

222 ax.plot(tts, vts, label='Trial starts') 

223 ax.plot(re2bpod(df.re_ts.values / 1e6), df.re_pos.values / 1024 * 2 * np.pi, 

224 '.-', label='Raw data') 

225 i0 = np.where(df.re_pos.values == 0) 

226 ax.plot(re2bpod(df.re_ts.values[i0] / 1e6), df.re_pos.values[i0] / 1024 * 2 * np.pi, 

227 'r*', label='Raw data zero samples') 

228 ax.plot(re2bpod(df.re_ts.values / 1e6), tr_dc, label='reset compensation') 

229 ax.set_xlabel('Bpod Time') 

230 ax.set_ylabel('radians') 

231 # restarts = np.array(bp_data[10]['behavior_data']['States timestamps'] 

232 # ['reset_rotary_encoder']).flatten() 

233 # x__ = np.c_[restarts, restarts, restarts + np.nan].flatten() 

234 # y__ = np.c_[restarts * 0 + 1, restarts * 0 - 1, restarts+ np.nan].flatten() 

235 # ax.plot(x__, y__, 'k', label='Restarts') 

236 ax.plot(data['re_ts'], data['re_pos'] / WHEEL_RADIUS_CM, '.-', label='Output Trace') 

237 ax.legend() 

238 # plt.hist(np.diff(data['re_ts']), 400, range=[0, 0.01]) 

239 return data['re_ts'], data['re_pos'] 1acdqgrlhiwmnojkstuvbefpx

240 

241 

242def infer_wheel_units(pos): 

243 """ 

244 Given an array of wheel positions, infer the rotary encoder resolution, encoding type and units 

245 

246 The encoding type varies across hardware (Bpod uses X1 while FPGA usually extracted as X4), and 

247 older data were extracted in linear cm rather than radians. 

248 

249 :param pos: a 1D array of extracted wheel positions 

250 :return units: the position units, assumed to be either 'rad' or 'cm' 

251 :return resolution: the number of decoded fronts per 360 degree rotation 

252 :return encoding: one of {'X1', 'X2', 'X4'} 

253 """ 

254 if len(pos.shape) > 1: # Ensure 1D array of positions 1aEyABcdqCgrlhiwDzmnojkstuvbefp

255 pos = pos.flatten() 

256 

257 # Check the values and units of wheel position 

258 res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) 1aEyABcdqCgrlhiwDzmnojkstuvbefp

259 # min change in rad and cm for each decoding type 

260 # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1] 

261 min_change = np.concatenate([2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res]) 1aEyABcdqCgrlhiwDzmnojkstuvbefp

262 pos_diff = np.median(np.abs(np.ediff1d(pos))) 1aEyABcdqCgrlhiwDzmnojkstuvbefp

263 

264 # find min change closest to min pos_diff 

265 idx = np.argmin(np.abs(min_change - pos_diff)) 1aEyABcdqCgrlhiwDzmnojkstuvbefp

266 if idx < len(res): 1aEyABcdqCgrlhiwDzmnojkstuvbefp

267 # Assume values are in radians 

268 units = 'rad' 1aEABcdqCgrlhiwDzmnojkstuvbefp

269 encoding = idx 1aEABcdqCgrlhiwDzmnojkstuvbefp

270 else: 

271 units = 'cm' 1EyAB

272 encoding = idx - len(res) 1EyAB

273 enc_names = {0: 'X4', 1: 'X2', 2: 'X1'} 1aEyABcdqCgrlhiwDzmnojkstuvbefp

274 return units, int(res[encoding]), enc_names[int(encoding)] 1aEyABcdqCgrlhiwDzmnojkstuvbefp

275 

276 

277def extract_wheel_moves(re_ts, re_pos, display=False): 

278 """ 

279 Extract wheel positions and times from sync fronts dictionary 

280 :param re_ts: numpy array of rotary encoder timestamps 

281 :param re_pos: numpy array of rotary encoder positions 

282 :param display: bool: show the wheel position and velocity for full session with detected 

283 movements highlighted 

284 :return: wheel_moves dictionary 

285 """ 

286 if len(re_ts.shape) == 1: 1ayABcdqCgrlhiwDzmnojkstuvbefp

287 assert re_ts.size == re_pos.size, 'wheel data dimension mismatch' 1ayABcdqCgrlhiwDzmnojkstuvbefp

288 else: 

289 _logger.debug('2D wheel timestamps') 

290 if len(re_pos.shape) > 1: # Ensure 1D array of positions 

291 re_pos = re_pos.flatten() 

292 # Linearly interpolate the times 

293 x = np.arange(re_pos.size) 

294 re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1]) 

295 

296 units, res, enc = infer_wheel_units(re_pos) 1ayABcdqCgrlhiwDzmnojkstuvbefp

297 _logger.info('Wheel in %s units using %s encoding', units, enc) 1ayABcdqCgrlhiwDzmnojkstuvbefp

298 

299 # The below assertion is violated by Bpod wheel data 

300 # assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips' 

301 

302 # Convert the pos threshold defaults from samples to correct unit 

303 thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res) 1ayABcdqCgrlhiwDzmnojkstuvbefp

304 if units == 'rad': 1ayABcdqCgrlhiwDzmnojkstuvbefp

305 thresholds = wh.cm_to_rad(thresholds) 1aABcdqCgrlhiwDzmnojkstuvbefp

306 kwargs = {'pos_thresh': thresholds[0], 1ayABcdqCgrlhiwDzmnojkstuvbefp

307 'pos_thresh_onset': thresholds[1], 

308 'make_plots': display} 

309 

310 # Interpolate and get onsets 

311 pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000) 1ayABcdqCgrlhiwDzmnojkstuvbefp

312 on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) 1ayABcdqCgrlhiwDzmnojkstuvbefp

313 assert on.size == off.size, 'onset/offset number mismatch' 1ayABcdqCgrlhiwDzmnojkstuvbefp

314 assert np.all(np.diff(on) > 0) and np.all( 1ayABcdqCgrlhiwDzmnojkstuvbefp

315 np.diff(off) > 0), 'onsets/offsets not strictly increasing' 

316 assert np.all((off - on) > 0), 'not all offsets occur after onset' 1ayABcdqCgrlhiwDzmnojkstuvbefp

317 

318 # Put into dict 

319 wheel_moves = { 1ayABcdqCgrlhiwDzmnojkstuvbefp

320 'intervals': np.c_[on, off], 'peakAmplitude': amp, 'peakVelocity_times': peak_vel} 

321 return wheel_moves 1ayABcdqCgrlhiwDzmnojkstuvbefp

322 

323 

324def extract_first_movement_times(wheel_moves, trials, min_qt=None): 

325 """ 

326 Extracts the time of the first sufficiently large wheel movement for each trial. 

327 

328 To be counted, the movement must occur between go cue / stim on and before feedback / 

329 response time. The movement onset is sometimes just before the cue (occurring in the 

330 gap between quiescence end and cue start, or during the quiescence period but sub- 

331 threshold). The movement is sufficiently large if it is greater than or equal to THRESH. 

332 

333 :param wheel_moves: 

334 :param trials: dictionary of trial data 

335 :param min_qt: 

336 :return: numpy array of 

337 

338 Parameters 

339 ---------- 

340 wheel_moves : dict 

341 Dictionary of detected wheel movement onsets and peak amplitudes for use in extracting each 

342 trial's time of first movement. 

343 trials : dict 

344 Dictionary of trial data. 

345 min_qt : float 

346 The minimum quiescence period in seconds, if None a default is used. 

347 

348 Returns 

349 ------- 

350 numpy.array 

351 First movement times. 

352 numpy.array 

353 Bool array indicating whether movement crossed response threshold. 

354 numpy.array 

355 Indices for wheel_moves arrays. 

356 """ 

357 THRESH = .1 # peak amp should be at least .1 rad; ~1/3rd of the distance to threshold 1aycdqgrlhiwzmnojkstuvbefp

358 MIN_QT = .2 # default minimum enforced quiescence period 1aycdqgrlhiwzmnojkstuvbefp

359 

360 # Determine minimum quiescent period 

361 if min_qt is None: 1aycdqgrlhiwzmnojkstuvbefp

362 min_qt = MIN_QT 1aycmouvefp

363 _logger.info('minimum quiescent period assumed to be %.0fms', MIN_QT * 1e3) 1aycmouvefp

364 elif isinstance(min_qt, Sized) and len(min_qt) > len(trials['goCue_times']): 1cdqgrlhiwznjkstbe

365 min_qt = np.array(min_qt[0:trials['goCue_times'].size]) 1ln

366 

367 # Initialize as nans 

368 first_move_onsets = np.full(trials['goCue_times'].shape, np.nan) 1aycdqgrlhiwzmnojkstuvbefp

369 ids = np.full(trials['goCue_times'].shape, int(-1)) 1aycdqgrlhiwzmnojkstuvbefp

370 is_final_movement = np.zeros(trials['goCue_times'].shape, bool) 1aycdqgrlhiwzmnojkstuvbefp

371 flinch = abs(wheel_moves['peakAmplitude']) < THRESH 1aycdqgrlhiwzmnojkstuvbefp

372 all_move_onsets = wheel_moves['intervals'][:, 0] 1aycdqgrlhiwzmnojkstuvbefp

373 # Iterate over trials, extracting onsets approx. within closed-loop period 

374 cwarn = 0 1aycdqgrlhiwzmnojkstuvbefp

375 for i, (t1, t2) in enumerate(zip(trials['goCue_times'] - min_qt, trials['feedback_times'])): 1aycdqgrlhiwzmnojkstuvbefp

376 if ~np.isnan(t2 - t1): # If both timestamps defined 1aycdqgrlhiwzmnojkstuvbefp

377 mask = (all_move_onsets > t1) & (all_move_onsets < t2) 1ycdqgrlhiwzmnojkstbefp

378 if np.any(mask): # If any onsets for this trial 1ycdqgrlhiwzmnojkstbefp

379 trial_onset_ids, = np.where(mask) 1ycdqgrlhiwzmnojkstbefp

380 if np.any(~flinch[mask]): # If any trial moves were sufficiently large 1ycdqgrlhiwzmnojkstbefp

381 ids[i] = trial_onset_ids[~flinch[mask]][0] # Find first large move id 1ycdqgrlhiwzmnojkstbefp

382 first_move_onsets[i] = all_move_onsets[ids[i]] # Save first large onset 1ycdqgrlhiwzmnojkstbefp

383 is_final_movement[i] = ids[i] == trial_onset_ids[-1] # Final move of trial 1ycdqgrlhiwzmnojkstbefp

384 else: # Log missing timestamps 

385 cwarn += 1 1acdgjkuvbef

386 if cwarn: 1aycdqgrlhiwzmnojkstuvbefp

387 _logger.warning(f'no reliable goCue/Feedback times (both needed) for {cwarn} trials') 1acdgjkuvbef

388 

389 return first_move_onsets, is_final_movement, ids[ids != -1] 1aycdqgrlhiwzmnojkstuvbefp

390 

391 

392class Wheel(BaseBpodTrialsExtractor): 

393 """ 

394 Wheel extractor. 

395 

396 Get wheel data from raw files and converts positions into radians mathematical convention 

397 (anti-clockwise = +) and timestamps into seconds relative to Bpod clock. 

398 **Optional:** saves _ibl_wheel.times.npy and _ibl_wheel.position.npy 

399 

400 Times: 

401 Gets Rotary Encoder timestamps (us) for each position and converts to times. 

402 Synchronize with Bpod and outputs 

403 

404 Positions: 

405 Radians mathematical convention 

406 """ 

407 save_names = ('_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

408 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, 

409 '_ibl_trials.firstMovement_times.npy', None) 

410 var_names = ('wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 

411 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'firstMovement_times', 

412 'is_final_movement') 

413 

414 def _extract(self): 

415 ts, pos = get_wheel_position(self.session_path, self.bpod_trials, task_collection=self.task_collection) 1acdqgrlhiwmnojkstuvbefp

416 moves = extract_wheel_moves(ts, pos) 1acdqgrlhiwmnojkstuvbefp

417 

418 # need some trial based info to output the first movement times 

419 from ibllib.io.extractors import training_trials # Avoids circular imports 1acdqgrlhiwmnojkstuvbefp

420 goCue_times, _ = training_trials.GoCueTimes(self.session_path).extract( 1acdqgrlhiwmnojkstuvbefp

421 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

422 feedback_times, _ = training_trials.FeedbackTimes(self.session_path).extract( 1acdqgrlhiwmnojkstuvbefp

423 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

424 trials = {'goCue_times': goCue_times, 'feedback_times': feedback_times} 1acdqgrlhiwmnojkstuvbefp

425 min_qt = self.settings.get('QUIESCENT_PERIOD', None) 1acdqgrlhiwmnojkstuvbefp

426 

427 first_moves, is_final, _ = extract_first_movement_times(moves, trials, min_qt=min_qt) 1acdqgrlhiwmnojkstuvbefp

428 output = (ts, pos, moves['intervals'], moves['peakAmplitude'], 1acdqgrlhiwmnojkstuvbefp

429 moves['peakVelocity_times'], first_moves, is_final) 

430 return output 1acdqgrlhiwmnojkstuvbefp

431 

432 

433def extract_all(session_path, bpod_trials=None, settings=None, save=False, task_collection='raw_behavior_data', save_path=None): 

434 """Extract the wheel data. 

435 

436 NB: Wheel extraction is now called through ibllib.io.training_trials.extract_all 

437 

438 Parameters 

439 ---------- 

440 session_path : str, pathlib.Path 

441 The path to the session 

442 save : bool 

443 If true save the data files to ALF 

444 bpod_trials : list of dicts 

445 The Bpod trial dicts loaded from the _iblrig_taskData.raw dataset 

446 settings : dict 

447 The Bpod settings loaded from the _iblrig_taskSettings.raw dataset 

448 

449 Returns 

450 ------- 

451 A list of extracted data and a list of file paths if save is True (otherwise None) 

452 """ 

453 return run_extractor_classes(Wheel, save=save, session_path=session_path, 

454 bpod_trials=bpod_trials, settings=settings, task_collection=task_collection, path_out=save_path)