Coverage for ibllib/io/extractors/training_wheel.py: 87%

218 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1import logging 

2from collections.abc import Sized 

3 

4import numpy as np 

5from scipy import interpolate 

6 

7from neurodsp.utils import sync_timestamps 

8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

9import ibllib.io.raw_data_loaders as raw 

10from ibllib.misc import structarr 

11import ibllib.exceptions as err 

12import brainbox.behavior.wheel as wh 

13 

14_logger = logging.getLogger(__name__) 

15WHEEL_RADIUS_CM = 1 # we want the output in radians 

16THRESHOLD_RAD_PER_SEC = 10 

17THRESHOLD_CONSECUTIVE_SAMPLES = 0 

18EPS = 7. / 3 - 4. / 3 - 1 

19 

20 

21def get_trial_start_times(session_path, data=None, task_collection='raw_behavior_data'): 

22 if not data: 1acdqgrlhiwmnopjstuvbefkx

23 data = raw.load_data(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx

24 trial_start_times = [] 1acdqgrlhiwmnopjstuvbefkx

25 for tr in data: 1acdqgrlhiwmnopjstuvbefkx

26 trial_start_times.extend( 1acdqgrlhiwmnopjstuvbefkx

27 [x[0] for x in tr['behavior_data']['States timestamps']['trial_start']]) 

28 return np.array(trial_start_times) 1acdqgrlhiwmnopjstuvbefkx

29 

30 

31def sync_rotary_encoder(session_path, bpod_data=None, re_events=None, task_collection='raw_behavior_data'): 

32 if not bpod_data: 1acdqgrlhiwmnopjstuvbefkx

33 bpod_data = raw.load_data(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx

34 evt = re_events or raw.load_encoder_events(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx

35 # we work with stim_on (2) and closed_loop (3) states for the synchronization with bpod 

36 tre = evt.re_ts.values / 1e6 # convert to seconds 1acdqgrlhiwmnopjstuvbefkx

37 # the first trial on the rotary encoder is a dud 

38 rote = {'stim_on': tre[evt.sm_ev == 2][:-1], 1acdqgrlhiwmnopjstuvbefkx

39 'closed_loop': tre[evt.sm_ev == 3][:-1]} 

40 bpod = { 1acdqgrlhiwmnopjstuvbefkx

41 'stim_on': np.array([tr['behavior_data']['States timestamps'] 

42 ['stim_on'][0][0] for tr in bpod_data]), 

43 'closed_loop': np.array([tr['behavior_data']['States timestamps'] 

44 ['closed_loop'][0][0] for tr in bpod_data]), 

45 } 

46 if rote['closed_loop'].size <= 1: 1acdqgrlhiwmnopjstuvbefkx

47 raise err.SyncBpodWheelException("Not enough Rotary Encoder events to perform wheel" 

48 " synchronization. Wheel data not extracted") 

49 # bpod bug that spits out events in ms instead of us 

50 if np.diff(bpod['closed_loop'][[-1, 0]])[0] / np.diff(rote['closed_loop'][[-1, 0]])[0] > 900: 1acdqgrlhiwmnopjstuvbefkx

51 _logger.error("Rotary encoder stores values in ms instead of us. Wheel timing inaccurate") 1cd

52 rote['stim_on'] *= 1e3 1cd

53 rote['closed_loop'] *= 1e3 1cd

54 # just use the closed loop for synchronization 

55 # handle different sizes in synchronization: 

56 sz = min(rote['closed_loop'].size, bpod['closed_loop'].size) 1acdqgrlhiwmnopjstuvbefkx

57 # if all the sample are contiguous and first samples match 

58 diff_first_match = np.diff(rote['closed_loop'][:sz]) - np.diff(bpod['closed_loop'][:sz]) 1acdqgrlhiwmnopjstuvbefkx

59 # if all the sample are contiguous and last samples match 

60 diff_last_match = np.diff(rote['closed_loop'][-sz:]) - np.diff(bpod['closed_loop'][-sz:]) 1acdqgrlhiwmnopjstuvbefkx

61 # 99% of the pulses match for a first sample lock 

62 DIFF_THRESHOLD = 0.005 1acdqgrlhiwmnopjstuvbefkx

63 if np.mean(np.abs(diff_first_match) < DIFF_THRESHOLD) > 0.99: 1acdqgrlhiwmnopjstuvbefkx

64 re = rote['closed_loop'][:sz] 1acdqgrlhiwmnopjstuvbefkx

65 bp = bpod['closed_loop'][:sz] 1acdqgrlhiwmnopjstuvbefkx

66 indko = np.where(np.abs(diff_first_match) >= DIFF_THRESHOLD)[0] 1acdqgrlhiwmnopjstuvbefkx

67 # 99% of the pulses match for a last sample lock 

68 elif np.mean(np.abs(diff_last_match) < DIFF_THRESHOLD) > 0.99: 1bx

69 re = rote['closed_loop'][-sz:] 1x

70 bp = bpod['closed_loop'][-sz:] 1x

71 indko = np.where(np.abs(diff_last_match) >= DIFF_THRESHOLD)[0] 1x

72 # last resort is to use ad-hoc sync function 

73 else: 

74 bp, re = raw.sync_trials_robust(bpod['closed_loop'], rote['closed_loop'], 1bx

75 diff_threshold=DIFF_THRESHOLD, max_shift=5) 

76 # we dont' want to change the extractor, but in rare cases the following method may save the day 

77 if len(bp) == 0: 1bx

78 _, _, ib, ir = sync_timestamps(bpod['closed_loop'], rote['closed_loop'], return_indices=True) 

79 bp = bpod['closed_loop'][ib] 

80 re = rote['closed_loop'][ir] 

81 

82 indko = np.array([]) 1bx

83 # raise ValueError("Can't sync bpod and rotary encoder: non-contiguous sync pulses") 

84 # remove faulty indices due to missing or bad syncs 

85 indko = np.int32(np.unique(np.r_[indko + 1, indko])) 1acdqgrlhiwmnopjstuvbefkx

86 re = np.delete(re, indko) 1acdqgrlhiwmnopjstuvbefkx

87 bp = np.delete(bp, indko) 1acdqgrlhiwmnopjstuvbefkx

88 # check the linear drift 

89 assert bp.size > 1 1acdqgrlhiwmnopjstuvbefkx

90 poly = np.polyfit(bp, re, 1) 1acdqgrlhiwmnopjstuvbefkx

91 assert np.all(np.abs(np.polyval(poly, bp) - re) < 0.002) 1acdqgrlhiwmnopjstuvbefkx

92 return interpolate.interp1d(re, bp, fill_value="extrapolate") 1acdqgrlhiwmnopjstuvbefkx

93 

94 

95def get_wheel_position(session_path, bp_data=None, display=False, task_collection='raw_behavior_data'): 

96 """ 

97 Gets wheel timestamps and position from Bpod data. Position is in radian (constant above for 

98 radius is 1) mathematical convention. 

99 :param session_path: 

100 :param bp_data (optional): bpod trials read from jsonable file 

101 :param display (optional): (bool) 

102 :return: timestamps (np.array) 

103 :return: positions (np.array) 

104 """ 

105 status = 0 1acdqgrlhiwFmnopjstuvbefkx

106 if not bp_data: 1acdqgrlhiwFmnopjstuvbefkx

107 bp_data = raw.load_data(session_path, task_collection=task_collection) 1Fbfkx

108 df = raw.load_encoder_positions(session_path, task_collection=task_collection) 1acdqgrlhiwFmnopjstuvbefkx

109 if df is None: 1acdqgrlhiwFmnopjstuvbefkx

110 _logger.error('No wheel data for ' + str(session_path)) 1Fb

111 return None, None 1Fb

112 data = structarr(['re_ts', 're_pos', 'bns_ts'], 1acdqgrlhiwmnopjstuvbefkx

113 shape=(df.shape[0],), formats=['f8', 'f8', object]) 

114 data['re_ts'] = df.re_ts.values 1acdqgrlhiwmnopjstuvbefkx

115 data['re_pos'] = df.re_pos.values * -1 # anti-clockwise is positive in our output 1acdqgrlhiwmnopjstuvbefkx

116 data['re_pos'] = data['re_pos'] / 1024 * 2 * np.pi # convert positions to radians 1acdqgrlhiwmnopjstuvbefkx

117 trial_starts = get_trial_start_times(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx

118 # need a flag if the data resolution is 1ms due to the old version of rotary encoder firmware 

119 if np.all(np.mod(data['re_ts'], 1e3) == 0): 1acdqgrlhiwmnopjstuvbefkx

120 status = 1 1wbefx

121 data['re_ts'] = data['re_ts'] / 1e6 # convert ts to seconds 1acdqgrlhiwmnopjstuvbefkx

122 # # get the converter function to translate re_ts into behavior times 

123 re2bpod = sync_rotary_encoder(session_path, task_collection=task_collection) 1acdqgrlhiwmnopjstuvbefkx

124 data['re_ts'] = re2bpod(data['re_ts']) 1acdqgrlhiwmnopjstuvbefkx

125 

126 def get_reset_trace_compensation_with_state_machine_times(): 1acdqgrlhiwmnopjstuvbefkx

127 # this is the preferred way of getting resets using the state machine time information 

128 # it will not always work depending on firmware versions, new bugs 

129 iwarn = [] 1acdqgrlhimnopjstuvbefkx

130 ns = len(data['re_pos']) 1acdqgrlhimnopjstuvbefkx

131 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1acdqgrlhimnopjstuvbefkx

132 for bp_dat in bp_data: 1acdqgrlhimnopjstuvbefkx

133 restarts = np.sort(np.array( 1acdqgrlhimnopjstuvbefkx

134 bp_dat['behavior_data']['States timestamps']['reset_rotary_encoder'] + 

135 bp_dat['behavior_data']['States timestamps']['reset2_rotary_encoder'])[:, 0]) 

136 ind = np.unique(np.searchsorted(data['re_ts'], restarts, side='left') - 1) 1acdqgrlhimnopjstuvbefkx

137 # the rotary encoder doesn't always reset right away, and the reset sample given the 

138 # timestamp can be ambiguous: look for zeros 

139 for i in np.where(data['re_pos'][ind] != 0)[0]: 1acdqgrlhimnopjstuvbefkx

140 # handle boundary effects 

141 if ind[i] > ns - 2: 1acdqgrlhimnopjstuvbefkx

142 continue 1x

143 # it happens quite often that we have to lock in to next sample to find the reset 

144 if data['re_pos'][ind[i] + 1] == 0: 1acdqgrlhimnopjstuvbefkx

145 ind[i] = ind[i] + 1 1acdqgrlhimnopjstuvbefkx

146 continue 1acdqgrlhimnopjstuvbefkx

147 # also case where the rotary doesn't reset to 0, but erratically to -1/+1 

148 if data['re_pos'][ind[i]] <= (1 / 1024 * 2 * np.pi): 1cdghibx

149 ind[i] = ind[i] + 1 1cdhibx

150 continue 1cdhibx

151 # compounded with the fact that the reset may have happened at next sample. 

152 if np.abs(data['re_pos'][ind[i] + 1]) <= (1 / 1024 * 2 * np.pi): 1gb

153 ind[i] = ind[i] + 1 

154 continue 

155 # sometimes it is also the last trial that has this behaviour 

156 if (bp_data[-1] is bp_dat) or (bp_data[0] is bp_dat): 1gb

157 continue 1g

158 iwarn.append(ind[i]) 1b

159 # at which point we are running out of possible bugs and calling it 

160 tr_dc[ind] = data['re_pos'][ind - 1] 1acdqgrlhimnopjstuvbefkx

161 if iwarn: # if a warning flag was caught in the loop throw a single warning 1acdqgrlhimnopjstuvbefkx

162 _logger.warning('Rotary encoder reset events discrepancy. Doing my best to merge.') 1b

163 _logger.debug('Offending inds: ' + str(iwarn) + ' times: ' + str(data['re_ts'][iwarn])) 1b

164 # exit status 0 is fine, 1 something went wrong 

165 return tr_dc, len(iwarn) != 0 1acdqgrlhimnopjstuvbefkx

166 

167 # attempt to get the resets properly unless the unit is ms which means precision is 

168 # not good enough to match SM times to wheel samples time 

169 if not status: 1acdqgrlhiwmnopjstuvbefkx

170 tr_dc, status = get_reset_trace_compensation_with_state_machine_times() 1acdqgrlhimnopjstuvbefkx

171 

172 # if something was wrong or went wrong agnostic way of getting resets: just get zeros values 

173 if status: 1acdqgrlhiwmnopjstuvbefkx

174 tr_dc = np.zeros_like(data['re_pos']) # trial dc component 1wbefx

175 i0 = np.where(data['re_pos'] == 0)[0] 1wbefx

176 tr_dc[i0] = data['re_pos'][i0 - 1] 1wbefx

177 # even if things went ok, rotary encoder may not log the whole session. Need to fix outside 

178 else: 

179 i0 = np.where(np.bitwise_and(np.bitwise_or(data['re_ts'] >= trial_starts[-1], 1acdqgrlhimnopjstuvbefkx

180 data['re_ts'] <= trial_starts[0]), 

181 data['re_pos'] == 0))[0] 

182 # make sure the bounds are not included in the current list 

183 i0 = np.delete(i0, np.where(np.bitwise_or(i0 >= len(data['re_pos']) - 1, i0 == 0))) 1acdqgrlhiwmnopjstuvbefkx

184 # a 0 sample is not a reset if 2 conditions are met: 

185 # 1/2 no inflexion (continuous derivative) 

186 c1 = np.abs(np.sign(data['re_pos'][i0 + 1] - data['re_pos'][i0]) - 1acdqgrlhiwmnopjstuvbefkx

187 np.sign(data['re_pos'][i0] - data['re_pos'][i0 - 1])) == 2 

188 # 2/2 needs to be below threshold 

189 c2 = np.abs((data['re_pos'][i0] - data['re_pos'][i0 - 1]) / 1acdqgrlhiwmnopjstuvbefkx

190 (EPS + (data['re_ts'][i0] - data['re_ts'][i0 - 1]))) < THRESHOLD_RAD_PER_SEC 

191 # apply reset to points identified as resets 

192 i0 = i0[np.where(np.bitwise_not(np.bitwise_and(c1, c2)))] 1acdqgrlhiwmnopjstuvbefkx

193 tr_dc[i0] = data['re_pos'][i0 - 1] 1acdqgrlhiwmnopjstuvbefkx

194 

195 # unwrap the rotation (in radians) and then add the DC component from restarts 

196 data['re_pos'] = np.unwrap(data['re_pos']) + np.cumsum(tr_dc) 1acdqgrlhiwmnopjstuvbefkx

197 

198 # Also forgot to mention that time stamps may be repeated or very close to one another. 

199 # Find them as they will induce large jitters on the velocity function or errors in 

200 # attempts of interpolation 

201 rep_idx = np.where(np.diff(data['re_ts']) <= THRESHOLD_CONSECUTIVE_SAMPLES)[0] 1acdqgrlhiwmnopjstuvbefkx

202 # Change the value of the repeated position 

203 data['re_pos'][rep_idx] = (data['re_pos'][rep_idx] + 1acdqgrlhiwmnopjstuvbefkx

204 data['re_pos'][rep_idx + 1]) / 2 

205 data['re_ts'][rep_idx] = (data['re_ts'][rep_idx] + 1acdqgrlhiwmnopjstuvbefkx

206 data['re_ts'][rep_idx + 1]) / 2 

207 # Now remove the repeat times that are rep_idx + 1 

208 data = np.delete(data, rep_idx + 1) 1acdqgrlhiwmnopjstuvbefkx

209 

210 # convert to cm 

211 data['re_pos'] = data['re_pos'] * WHEEL_RADIUS_CM 1acdqgrlhiwmnopjstuvbefkx

212 

213 # DEBUG PLOTS START HERE ######################## 

214 if display: 1acdqgrlhiwmnopjstuvbefkx

215 import matplotlib.pyplot as plt 

216 plt.figure() 

217 ax = plt.axes() 

218 tstart = get_trial_start_times(session_path) 

219 tts = np.c_[tstart, tstart, tstart + np.nan].flatten() 

220 vts = np.c_[tstart * 0 + 100, tstart * 0 - 100, tstart + np.nan].flatten() 

221 ax.plot(tts, vts, label='Trial starts') 

222 ax.plot(re2bpod(df.re_ts.values / 1e6), df.re_pos.values / 1024 * 2 * np.pi, 

223 '.-', label='Raw data') 

224 i0 = np.where(df.re_pos.values == 0) 

225 ax.plot(re2bpod(df.re_ts.values[i0] / 1e6), df.re_pos.values[i0] / 1024 * 2 * np.pi, 

226 'r*', label='Raw data zero samples') 

227 ax.plot(re2bpod(df.re_ts.values / 1e6), tr_dc, label='reset compensation') 

228 ax.set_xlabel('Bpod Time') 

229 ax.set_ylabel('radians') 

230 # restarts = np.array(bp_data[10]['behavior_data']['States timestamps'] 

231 # ['reset_rotary_encoder']).flatten() 

232 # x__ = np.c_[restarts, restarts, restarts + np.nan].flatten() 

233 # y__ = np.c_[restarts * 0 + 1, restarts * 0 - 1, restarts+ np.nan].flatten() 

234 # ax.plot(x__, y__, 'k', label='Restarts') 

235 ax.plot(data['re_ts'], data['re_pos'] / WHEEL_RADIUS_CM, '.-', label='Output Trace') 

236 ax.legend() 

237 # plt.hist(np.diff(data['re_ts']), 400, range=[0, 0.01]) 

238 return data['re_ts'], data['re_pos'] 1acdqgrlhiwmnopjstuvbefkx

239 

240 

241def infer_wheel_units(pos): 

242 """ 

243 Given an array of wheel positions, infer the rotary encoder resolution, encoding type and units 

244 

245 The encoding type varies across hardware (Bpod uses X1 while FPGA usually extracted as X4), and 

246 older data were extracted in linear cm rather than radians. 

247 

248 :param pos: a 1D array of extracted wheel positions 

249 :return units: the position units, assumed to be either 'rad' or 'cm' 

250 :return resolution: the number of decoded fronts per 360 degree rotation 

251 :return encoding: one of {'X1', 'X2', 'X4'} 

252 """ 

253 if len(pos.shape) > 1: # Ensure 1D array of positions 1aEzABcdqCgrlhiwDymnopjstuvbefk

254 pos = pos.flatten() 

255 

256 # Check the values and units of wheel position 

257 res = np.array([wh.ENC_RES, wh.ENC_RES / 2, wh.ENC_RES / 4]) 1aEzABcdqCgrlhiwDymnopjstuvbefk

258 # min change in rad and cm for each decoding type 

259 # [rad_X4, rad_X2, rad_X1, cm_X4, cm_X2, cm_X1] 

260 min_change = np.concatenate([2 * np.pi / res, wh.WHEEL_DIAMETER * np.pi / res]) 1aEzABcdqCgrlhiwDymnopjstuvbefk

261 pos_diff = np.median(np.abs(np.ediff1d(pos))) 1aEzABcdqCgrlhiwDymnopjstuvbefk

262 

263 # find min change closest to min pos_diff 

264 idx = np.argmin(np.abs(min_change - pos_diff)) 1aEzABcdqCgrlhiwDymnopjstuvbefk

265 if idx < len(res): 1aEzABcdqCgrlhiwDymnopjstuvbefk

266 # Assume values are in radians 

267 units = 'rad' 1aEABcdqCgrlhiwDymnopjstuvbefk

268 encoding = idx 1aEABcdqCgrlhiwDymnopjstuvbefk

269 else: 

270 units = 'cm' 1EzAB

271 encoding = idx - len(res) 1EzAB

272 enc_names = {0: 'X4', 1: 'X2', 2: 'X1'} 1aEzABcdqCgrlhiwDymnopjstuvbefk

273 return units, int(res[encoding]), enc_names[int(encoding)] 1aEzABcdqCgrlhiwDymnopjstuvbefk

274 

275 

276def extract_wheel_moves(re_ts, re_pos, display=False): 

277 """ 

278 Extract wheel positions and times from sync fronts dictionary 

279 :param re_ts: numpy array of rotary encoder timestamps 

280 :param re_pos: numpy array of rotary encoder positions 

281 :param display: bool: show the wheel position and velocity for full session with detected 

282 movements highlighted 

283 :return: wheel_moves dictionary 

284 """ 

285 if len(re_ts.shape) == 1: 1azABcdqCgrlhiwDymnopjstuvbefk

286 assert re_ts.size == re_pos.size, 'wheel data dimension mismatch' 1azABcdqCgrlhiwDymnopjstuvbefk

287 else: 

288 _logger.debug('2D wheel timestamps') 

289 if len(re_pos.shape) > 1: # Ensure 1D array of positions 

290 re_pos = re_pos.flatten() 

291 # Linearly interpolate the times 

292 x = np.arange(re_pos.size) 

293 re_ts = np.interp(x, re_ts[:, 0], re_ts[:, 1]) 

294 

295 units, res, enc = infer_wheel_units(re_pos) 1azABcdqCgrlhiwDymnopjstuvbefk

296 _logger.info('Wheel in %s units using %s encoding', units, enc) 1azABcdqCgrlhiwDymnopjstuvbefk

297 

298 # The below assertion is violated by Bpod wheel data 

299 # assert np.allclose(pos_diff, min_change, rtol=1e-05), 'wheel position skips' 

300 

301 # Convert the pos threshold defaults from samples to correct unit 

302 thresholds = wh.samples_to_cm(np.array([8, 1.5]), resolution=res) 1azABcdqCgrlhiwDymnopjstuvbefk

303 if units == 'rad': 1azABcdqCgrlhiwDymnopjstuvbefk

304 thresholds = wh.cm_to_rad(thresholds) 1aABcdqCgrlhiwDymnopjstuvbefk

305 kwargs = {'pos_thresh': thresholds[0], 1azABcdqCgrlhiwDymnopjstuvbefk

306 'pos_thresh_onset': thresholds[1], 

307 'make_plots': display} 

308 

309 # Interpolate and get onsets 

310 pos, t = wh.interpolate_position(re_ts, re_pos, freq=1000) 1azABcdqCgrlhiwDymnopjstuvbefk

311 on, off, amp, peak_vel = wh.movements(t, pos, freq=1000, **kwargs) 1azABcdqCgrlhiwDymnopjstuvbefk

312 assert on.size == off.size, 'onset/offset number mismatch' 1azABcdqCgrlhiwDymnopjstuvbefk

313 assert np.all(np.diff(on) > 0) and np.all( 1azABcdqCgrlhiwDymnopjstuvbefk

314 np.diff(off) > 0), 'onsets/offsets not strictly increasing' 

315 assert np.all((off - on) > 0), 'not all offsets occur after onset' 1azABcdqCgrlhiwDymnopjstuvbefk

316 

317 # Put into dict 

318 wheel_moves = { 1azABcdqCgrlhiwDymnopjstuvbefk

319 'intervals': np.c_[on, off], 'peakAmplitude': amp, 'peakVelocity_times': peak_vel} 

320 return wheel_moves 1azABcdqCgrlhiwDymnopjstuvbefk

321 

322 

323def extract_first_movement_times(wheel_moves, trials, min_qt=None): 

324 """ 

325 Extracts the time of the first sufficiently large wheel movement for each trial. 

326 To be counted, the movement must occur between go cue / stim on and before feedback / 

327 response time. The movement onset is sometimes just before the cue (occurring in the 

328 gap between quiescence end and cue start, or during the quiescence period but sub- 

329 threshold). The movement is sufficiently large if it is greater than or equal to THRESH 

330 :param wheel_moves: dictionary of detected wheel movement onsets and peak amplitudes for 

331 use in extracting each trial's time of first movement. 

332 :param trials: dictionary of trial data 

333 :param min_qt: the minimum quiescence period, if None a default is used 

334 :return: numpy array of first movement times, bool array indicating whether movement 

335 crossed response threshold, and array of indices for wheel_moves arrays 

336 """ 

337 THRESH = .1 # peak amp should be at least .1 rad; ~1/3rd of the distance to threshold 1azcdqgrlhiwymnopjstuvbefk

338 MIN_QT = .2 # default minimum enforced quiescence period 1azcdqgrlhiwymnopjstuvbefk

339 

340 # Determine minimum quiescent period 

341 if min_qt is None: 1azcdqgrlhiwymnopjstuvbefk

342 min_qt = MIN_QT 1azcmnouvefk

343 _logger.info('minimum quiescent period assumed to be %.0fms', MIN_QT * 1e3) 1azcmnouvefk

344 elif isinstance(min_qt, Sized) and len(min_qt) > len(trials['goCue_times']): 1cdqgrlhiwypjstbe

345 min_qt = np.array(min_qt[0:trials['goCue_times'].size]) 1lp

346 

347 # Initialize as nans 

348 first_move_onsets = np.full(trials['goCue_times'].shape, np.nan) 1azcdqgrlhiwymnopjstuvbefk

349 ids = np.full(trials['goCue_times'].shape, int(-1)) 1azcdqgrlhiwymnopjstuvbefk

350 is_final_movement = np.zeros(trials['goCue_times'].shape, bool) 1azcdqgrlhiwymnopjstuvbefk

351 flinch = abs(wheel_moves['peakAmplitude']) < THRESH 1azcdqgrlhiwymnopjstuvbefk

352 all_move_onsets = wheel_moves['intervals'][:, 0] 1azcdqgrlhiwymnopjstuvbefk

353 # Iterate over trials, extracting onsets approx. within closed-loop period 

354 cwarn = 0 1azcdqgrlhiwymnopjstuvbefk

355 for i, (t1, t2) in enumerate(zip(trials['goCue_times'] - min_qt, trials['feedback_times'])): 1azcdqgrlhiwymnopjstuvbefk

356 if ~np.isnan(t2 - t1): # If both timestamps defined 1azcdqgrlhiwymnopjstuvbefk

357 mask = (all_move_onsets > t1) & (all_move_onsets < t2) 1zcdqgrlhiwymnopjstbefk

358 if np.any(mask): # If any onsets for this trial 1zcdqgrlhiwymnopjstbefk

359 trial_onset_ids, = np.where(mask) 1zcdqgrlhiwymnopjstbefk

360 if np.any(~flinch[mask]): # If any trial moves were sufficiently large 1zcdqgrlhiwymnopjstbefk

361 ids[i] = trial_onset_ids[~flinch[mask]][0] # Find first large move id 1zcdqgrlhiwymnopjstbefk

362 first_move_onsets[i] = all_move_onsets[ids[i]] # Save first large onset 1zcdqgrlhiwymnopjstbefk

363 is_final_movement[i] = ids[i] == trial_onset_ids[-1] # Final move of trial 1zcdqgrlhiwymnopjstbefk

364 else: # Log missing timestamps 

365 cwarn += 1 1acdgyjuvbef

366 if cwarn: 1azcdqgrlhiwymnopjstuvbefk

367 _logger.warning(f'no reliable goCue/Feedback times (both needed) for {cwarn} trials') 1acdgyjuvbef

368 

369 return first_move_onsets, is_final_movement, ids[ids != -1] 1azcdqgrlhiwymnopjstuvbefk

370 

371 

372class Wheel(BaseBpodTrialsExtractor): 

373 """ 

374 Get wheel data from raw files and converts positions into radians mathematical convention 

375 (anti-clockwise = +) and timestamps into seconds relative to Bpod clock. 

376 **Optional:** saves _ibl_wheel.times.npy and _ibl_wheel.position.npy 

377 

378 Times: 

379 Gets Rotary Encoder timestamps (us) for each position and converts to times. 

380 Synchronize with Bpod and outputs 

381 

382 Positions: 

383 Radians mathematical convention 

384 """ 

385 save_names = ('_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

386 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, 

387 '_ibl_trials.firstMovement_times.npy', None) 

388 var_names = ('wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', 

389 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'firstMovement_times', 

390 'is_final_movement') 

391 

392 def _extract(self): 

393 ts, pos = get_wheel_position(self.session_path, self.bpod_trials, task_collection=self.task_collection) 1acdqgrlhiwmnopjstuvbefk

394 moves = extract_wheel_moves(ts, pos) 1acdqgrlhiwmnopjstuvbefk

395 

396 # need some trial based info to output the first movement times 

397 from ibllib.io.extractors import training_trials # Avoids circular imports 1acdqgrlhiwmnopjstuvbefk

398 goCue_times, _ = training_trials.GoCueTimes(self.session_path).extract( 1acdqgrlhiwmnopjstuvbefk

399 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

400 feedback_times, _ = training_trials.FeedbackTimes(self.session_path).extract( 1acdqgrlhiwmnopjstuvbefk

401 save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

402 trials = {'goCue_times': goCue_times, 'feedback_times': feedback_times} 1acdqgrlhiwmnopjstuvbefk

403 min_qt = self.settings.get('QUIESCENT_PERIOD', None) 1acdqgrlhiwmnopjstuvbefk

404 

405 first_moves, is_final, _ = extract_first_movement_times(moves, trials, min_qt=min_qt) 1acdqgrlhiwmnopjstuvbefk

406 output = (ts, pos, moves['intervals'], moves['peakAmplitude'], 1acdqgrlhiwmnopjstuvbefk

407 moves['peakVelocity_times'], first_moves, is_final) 

408 return output 1acdqgrlhiwmnopjstuvbefk

409 

410 

411def extract_all(session_path, bpod_trials=None, settings=None, save=False, task_collection='raw_behavior_data', save_path=None): 

412 """Extract the wheel data. 

413 

414 NB: Wheel extraction is now called through ibllib.io.training_trials.extract_all 

415 

416 Parameters 

417 ---------- 

418 session_path : str, pathlib.Path 

419 The path to the session 

420 save : bool 

421 If true save the data files to ALF 

422 bpod_trials : list of dicts 

423 The Bpod trial dicts loaded from the _iblrig_taskData.raw dataset 

424 settings : dict 

425 The Bpod settings loaded from the _iblrig_taskSettings.raw dataset 

426 

427 Returns 

428 ------- 

429 A list of extracted data and a list of file paths if save is True (otherwise None) 

430 """ 

431 return run_extractor_classes(Wheel, save=save, session_path=session_path, 

432 bpod_trials=bpod_trials, settings=settings, task_collection=task_collection, path_out=save_path)