Coverage for ibllib/io/extractors/mesoscope.py: 85%

337 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1"""Mesoscope (timeline) data extraction.""" 

2import logging 

3 

4import numpy as np 

5from scipy.signal import find_peaks 

6import one.alf.io as alfio 

7from one.alf.path import session_path_parts 

8from iblutil.util import ensure_list 

9import matplotlib.pyplot as plt 

10from packaging import version 

11 

12from ibllib.plots.misc import squares, vertical_lines 

13from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel, 

14 correct_counter_discontinuities, load_timeline_sync_and_chmap) 

15import ibllib.io.extractors.base as extractors_base 

16from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial 

17from ibllib.io.extractors.training_wheel import extract_wheel_moves 

18from ibllib.io.extractors.camera import attribute_times 

19from brainbox.behavior.wheel import velocity_filtered 

20 

21_logger = logging.getLogger(__name__) 

22 

23 

24def patch_imaging_meta(meta: dict) -> dict: 

25 """ 

26 Patch imaging metadata for compatibility across versions. 

27 

28 A copy of the dict is NOT returned. 

29 

30 Parameters 

31 ---------- 

32 meta : dict 

33 A folder path that contains a rawImagingData.meta file. 

34 

35 Returns 

36 ------- 

37 dict 

38 The loaded metadata file, updated to the most recent version. 

39 """ 

40 # 2023-05-17 (unversioned) adds nFrames, channelSaved keys, MM and Deg keys 

41 ver = version.parse(meta.get('version') or '0.0.0') 1hlmjcb

42 if ver <= version.parse('0.0.0'): 1hlmjcb

43 if 'channelSaved' not in meta: 1hlm

44 meta['channelSaved'] = next((x['channelIdx'] for x in meta['FOV'] if 'channelIdx' in x), []) 1hlm

45 fields = ('topLeft', 'topRight', 'bottomLeft', 'bottomRight') 1hlm

46 for fov in meta.get('FOV', []): 1hlm

47 for unit in ('Deg', 'MM'): 1hlm

48 if unit not in fov: # topLeftDeg, etc. -> Deg[topLeft] 1hlm

49 fov[unit] = {f: fov.pop(f + unit, None) for f in fields} 1hlm

50 elif ver == version.parse('0.1.0'): 1hjcb

51 for fov in meta.get('FOV', []): 1h

52 if 'roiUuid' in fov: 1h

53 fov['roiUUID'] = fov.pop('roiUuid') 1h

54 # 2024-09-17 Modified the 2 unit vectors for the positive ML axis and the positive AP axis, 

55 # which then transform [X,Y] coordinates (in degrees) to [ML,AP] coordinates (in MM). 

56 if ver < version.Version('0.1.5') and 'imageOrientation' in meta: 1hlmjcb

57 pos_ml, pos_ap = meta['imageOrientation']['positiveML'], meta['imageOrientation']['positiveAP'] 1hjcb

58 center_ml, center_ap = meta['centerMM']['ML'], meta['centerMM']['AP'] 1hjcb

59 res = meta['scanImageParams']['objectiveResolution'] 1hjcb

60 # previously [[0, res/1000], [-res/1000, 0], [0, 0]] 

61 TF = np.linalg.pinv(np.c_[np.vstack([pos_ml, pos_ap, [0, 0]]), [1, 1, 1]]) @ \ 1hjcb

62 (np.array([[res / 1000, 0], [0, res / 1000], [0, 0]]) + np.array([center_ml, center_ap])) 

63 TF = np.round(TF, 3) # handle floating-point error by rounding 1hjcb

64 if not np.allclose(TF, meta['coordsTF']): 1hjcb

65 meta['coordsTF'] = TF.tolist() 1hjcb

66 centerDegXY = np.array([meta['centerDeg']['x'], meta['centerDeg']['y']]) 1hjcb

67 for fov in meta.get('FOV', []): 1hjcb

68 fov['MM'] = {k: (np.r_[np.array(v) - centerDegXY, 1] @ TF).tolist() for k, v in fov['Deg'].items()} 1hjcb

69 

70 assert 'nFrames' in meta, '"nFrames" key missing from meta data; rawImagingData.meta.json likely an old version' 1hlmjcb

71 return meta 1hlmjcb

72 

73 

74def plot_timeline(timeline, channels=None, raw=True): 

75 """ 

76 Plot the timeline data. 

77 

78 Parameters 

79 ---------- 

80 timeline : one.alf.io.AlfBunch 

81 The timeline data object. 

82 channels : list of str 

83 An iterable of channel names to plot. 

84 raw : bool 

85 If true, plot the raw DAQ samples; if false, apply TTL thresholds and plot changes. 

86 

87 Returns 

88 ------- 

89 matplotlib.pyplot.Figure 

90 The figure containing timeline subplots. 

91 list of matplotlib.pyplot.Axes 

92 The axes for each timeline channel plotted. 

93 """ 

94 meta = {x.copy().pop('name'): x for x in timeline['meta']['inputs']} 1i

95 channels = channels or meta.keys() 1i

96 fig, axes = plt.subplots(len(channels), 1, sharex=True) 1i

97 axes = ensure_list(axes) 1i

98 if not raw: 1i

99 chmap = {ch: meta[ch]['arrayColumn'] for ch in channels} 1i

100 sync = extract_sync_timeline(timeline, chmap=chmap) 1i

101 for i, (ax, ch) in enumerate(zip(axes, channels)): 1i

102 if raw: 1i

103 # axesScale controls vertical scaling of each trace (multiplicative) 

104 values = timeline['raw'][:, meta[ch]['arrayColumn'] - 1] * meta[ch]['axesScale'] 1i

105 ax.plot(timeline['timestamps'], values) 1i

106 elif np.any(idx := sync['channels'] == chmap[ch]): 1i

107 squares(sync['times'][idx], sync['polarities'][idx], ax=ax) 1i

108 ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) 1i

109 ax.spines['bottom'].set_visible(False), ax.spines['left'].set_visible(True) 1i

110 ax.set_ylabel(ch, rotation=45, fontsize=8) 1i

111 # Add back x-axis ticks to the last plot 

112 axes[-1].tick_params(axis='x', which='both', bottom=True, labelbottom=True) 1i

113 axes[-1].spines['bottom'].set_visible(True) 1i

114 plt.get_current_fig_manager().window.showMaximized() # full screen 1i

115 fig.tight_layout(h_pad=0) 1i

116 return fig, axes 1i

117 

118 

119class TimelineTrials(FpgaTrials): 

120 """Similar extraction to the FPGA, however counter and position channels are treated differently.""" 

121 

122 timeline = None 

123 """one.alf.io.AlfBunch: The timeline data object.""" 

124 

125 sync_field = 'itiIn_times' 

126 """str: The trial event to synchronize (must be present in extracted trials).""" 

127 

128 def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): 

129 """An extractor for all ephys trial data, in Timeline time""" 

130 super().__init__(*args, **kwargs) 1ega

131 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') 1ega

132 

133 def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_): 

134 """Load the DAQ sync and channel map data. 

135 

136 Parameters 

137 ---------- 

138 sync_collection : str 

139 The session subdirectory where the sync data are located. 

140 chmap : dict 

141 A map of channel names and their corresponding indices. If None, the channel map is 

142 loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method. 

143 

144 Returns 

145 ------- 

146 one.alf.io.AlfBunch 

147 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses 

148 and the corresponding channel numbers. 

149 dict 

150 A map of channel names and their corresponding indices. 

151 """ 

152 if not self.timeline: 1a

153 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') 

154 sync, chmap = load_timeline_sync_and_chmap( 1a

155 self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) 

156 return sync, chmap 1a

157 

158 def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: 

159 trials = super()._extract(sync, chmap, sync_collection='raw_sync_data', **kwargs) 1a

160 if kwargs.get('display', False): 1a

161 plot_timeline(self.timeline, channels=chmap.keys(), raw=True) 

162 return trials 1a

163 

164 def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): 

165 """ 

166 Extract Bpod times from sync. 

167 

168 Unlike the superclass method. This one doesn't reassign the first trial pulse. 

169 

170 Parameters 

171 ---------- 

172 sync : dict 

173 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses 

174 and the corresponding channel numbers. Must contain a 'bpod' key. 

175 chmap : dict 

176 A map of channel names and their corresponding indices. 

177 bpod_event_ttls : dict of tuple 

178 A map of event names to (min, max) TTL length. 

179 

180 Returns 

181 ------- 

182 dict 

183 A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. 

184 dict 

185 A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. 

186 """ 

187 # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these 

188 # lengths are defined by the state machine of the task protocol and therefore vary. 

189 if bpod_event_ttls is None: 1a

190 # The trial start TTLs are often too short for the low sampling rate of the DAQ and are 

191 # therefore not used in extraction 

192 bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} 1a

193 bpod, bpod_event_intervals = super().get_bpod_event_times( 1a

194 sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs) 

195 

196 # TODO Here we can make use of the 'bpod_rising_edge' channel, if available 

197 return bpod, bpod_event_intervals 1a

198 

199 def build_trials(self, sync=None, chmap=None, **kwargs): 

200 """ 

201 Extract task related event times from the sync. 

202 

203 The two major differences are that the sampling rate is lower for imaging so the short Bpod 

204 trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL. 

205 

206 Second, the valve used at the mesoscope has a way to record the raw voltage across the 

207 solenoid, giving a more accurate readout of the valve's activity. If the reward_valve 

208 channel is present on the DAQ, this is used to extract the valve open times. 

209 

210 Parameters 

211 ---------- 

212 sync : dict 

213 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' 

214 chmap : dict 

215 Map of channel names and their corresponding index. Default to constant. 

216 

217 Returns 

218 ------- 

219 dict 

220 A map of trial event timestamps. 

221 """ 

222 # Get the events from the sync. 

223 # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC 

224 self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) 1a

225 self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) 1a

226 if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: 1a

227 raise ValueError( 

228 'Expected at least "ready_tone" and "error_tone" audio events.' 

229 '`audio_event_ttls` kwarg may be incorrect.') 

230 

231 self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) 1a

232 if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}: 1a

233 raise ValueError( 

234 'Expected at least "trial_end" and "valve_open" audio events. ' 

235 '`bpod_event_ttls` kwarg may be incorrect.') 

236 

237 t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T 1a

238 fpga_events = alfio.AlfBunch({ 1a

239 'itiIn_times': t_iti_in, 

240 'intervals_1': t_trial_end, 

241 'goCue_times': audio_event_intervals['ready_tone'][:, 0], 

242 'errorTone_times': audio_event_intervals['error_tone'][:, 0] 

243 }) 

244 

245 # Sync the Bpod clock to the DAQ 

246 self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) 1a

247 

248 out = dict() 1a

249 out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) 1a

250 out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) 1a

251 

252 start_times = out['intervals'][:, 0] 1a

253 last_trial_end = out['intervals'][-1, 1] 1a

254 

255 def assign_to_trial(events, take='last', starts=start_times, **kwargs): 1a

256 """Assign DAQ events to trials. 

257 

258 Because we may not have trial start TTLs on the DAQ (because of the low sampling rate), 

259 there may be an extra last trial that's not in the Bpod intervals as the extractor 

260 ignores the last trial. This function trims the input array before assigning so that 

261 the last trial's events are correctly assigned. 

262 """ 

263 return _assign_events_to_trial(starts, events[events <= last_trial_end], take, **kwargs) 1a

264 out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga]) 1a

265 

266 # Extract valve open times from the DAQ 

267 valve_driver_ttls = bpod_event_intervals['valve_open'] 1a

268 correct = self.bpod_trials['feedbackType'] == 1 1a

269 # If there is a reward_valve channel, the valve has 

270 if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): 1a

271 # TODO Let's look at the expected open length based on calibration and reward volume 

272 # import scipy.interpolate 

273 # # FIXME support v7 settings? 

274 # fcn_vol2time = scipy.interpolate.pchip( 

275 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'], 

276 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES'] 

277 # ) 

278 # reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3 

279 

280 # Use the driver TTLs to find the valve open times that correspond to the valve opening 

281 valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) 1a

282 if valve_open_times.size != np.sum(correct): 1a

283 _logger.warning( 

284 'Number of valve open times does not equal number of correct trials (%i != %i)', 

285 valve_open_times.size, np.sum(correct)) 

286 

287 out['valveOpen_times'] = assign_to_trial(valve_open_times) 1a

288 else: 

289 # Use the valve controller TTLs recorded on the Bpod channel as the reward time 

290 out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0]) 

291 

292 # Stimulus times extracted based on trigger times 

293 # When assigning events all start times must not be NaN so here we substitute freeze 

294 # trigger times on nogo trials for stim on trigger times, then replace with NaN again 

295 go_trials = np.where(out['choice'] != 0)[0] 1a

296 lims = np.copy(out['stimOnTrigger_times']) 1a

297 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a

298 out['stimFreeze_times'] = assign_to_trial( 1a

299 self.frame2ttl['times'], 'last', 

300 starts=lims, t_trial_end=out['stimOffTrigger_times']) 

301 out['stimFreeze_times'][out['choice'] == 0] = np.nan 1a

302 

303 # Here we do the same but use stim off trigger times 

304 lims = np.copy(out['stimOffTrigger_times']) 1a

305 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a

306 out['stimOn_times'] = assign_to_trial( 1a

307 self.frame2ttl['times'], 'first', 

308 starts=out['stimOnTrigger_times'], t_trial_end=lims) 

309 out['stimOff_times'] = assign_to_trial( 1a

310 self.frame2ttl['times'], 'first', 

311 starts=out['stimOffTrigger_times'], t_trial_end=out['intervals'][:, 1] 

312 ) 

313 

314 # Audio times 

315 error_cue = fpga_events['errorTone_times'] 1a

316 if error_cue.size != np.sum(~correct): 1a

317 _logger.warning( 

318 'N detected error tones does not match number of incorrect trials (%i != %i)', 

319 error_cue.size, np.sum(~correct)) 

320 go_cue = fpga_events['goCue_times'] 1a

321 out['goCue_times'] = assign_to_trial(go_cue, take='first') 1a

322 out['errorCue_times'] = assign_to_trial(error_cue) 1a

323 

324 if go_cue.size > start_times.size: 1a

325 _logger.warning( 

326 'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size) 

327 elif go_cue.size < start_times.size: 1a

328 """ 

329 If the error cues are all assigned and some go cues are missed it may be that some 

330 responses were so fast that the go cue and error tone merged, or the go cue TTL was too 

331 long. 

332 """ 

333 _logger.warning('%i go cue tones missed', start_times.size - go_cue.size) 

334 err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times']) 

335 go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times']) 

336 assert not np.any(np.isnan(go_trig)) 

337 assert err_trig.size == go_trig.size # should be length of n trials with NaNs 

338 

339 # Find which trials are missing a go cue 

340 _go_cue = assign_to_trial(go_cue, take='first') 

341 error_cue = assign_to_trial(error_cue) 

342 missing = np.isnan(_go_cue) 

343 

344 # Get all the DAQ timestamps where audio channel was HIGH 

345 raw = timeline_get_channel(self.timeline, 'audio') 

346 raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize 

347 ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH 

348 

349 # Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after). 

350 # Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN) 

351 idx = attribute_times(ups, go_trig, tol=0.2, take='after') 

352 # Trial indices that didn't have detected goCue and now has been assigned an `ups` index 

353 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned 

354 _go_cue[assigned] = ups[idx[assigned]] 

355 

356 # Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue) 

357 error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig)) 

358 i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True) 

359 error_cue[i_to_remove] = np.nan 

360 

361 # For those trials where go cue was merged with the error cue and therefore mis-assigned, 

362 # we must re-assign the error cue times as the first HIGH after the error trigger. 

363 idx = attribute_times(ups, err_trig, tol=0.2, take='after') 

364 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned 

365 error_cue[assigned] = ups[idx[assigned]] 

366 out['goCue_times'] = _go_cue 

367 out['errorCue_times'] = error_cue 

368 

369 # Because we're not 

370 assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \ 1a

371 'audio tones not assigned correctly; tones likely missed' 

372 

373 # Feedback times 

374 out['feedback_times'] = np.copy(out['valveOpen_times']) 1a

375 ind_err = np.isnan(out['valveOpen_times']) 1a

376 out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] 1a

377 

378 return out 1a

379 

380 def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): 

381 """ 

382 Gets the wheel position from Timeline counter channel. 

383 

384 Parameters 

385 ---------- 

386 ticks : int 

387 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder). 

388 radius : float 

389 Radius of the wheel. Defaults to 1 for an output in radians. 

390 coding : str {'x1', 'x2', 'x4'} 

391 Rotary encoder encoding (IBL default is x4). 

392 tmin : float 

393 The minimum time from which to extract the sync pulses. 

394 tmax : float 

395 The maximum time up to which we extract the sync pulses. 

396 

397 Returns 

398 ------- 

399 np.array 

400 Wheel timestamps in seconds. 

401 np.array 

402 Wheel positions in radians. 

403 

404 See Also 

405 -------- 

406 ibllib.io.extractors.ephys_fpga.extract_wheel_sync 

407 """ 

408 if coding not in ('x1', 'x2', 'x4'): 1ga

409 raise ValueError('Unsupported coding; must be one of x1, x2 or x4') 1g

410 raw = correct_counter_discontinuities(timeline_get_channel(self.timeline, 'rotary_encoder')) 1ga

411 

412 # Timeline evenly samples counter so we extract only change points 

413 d = np.diff(raw) 1ga

414 ind, = np.where(~np.isclose(d, 0)) 1ga

415 pos = raw[ind + 1] 1ga

416 pos -= pos[0] # Start from zero 1ga

417 pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians 1ga

418 

419 # Get timestamps of changes and trim based on protocol spacers 

420 ts = self.timeline['timestamps'][ind + 1] 1ga

421 tmin = ts.min() if tmin is None else tmin 1ga

422 tmax = ts.max() if tmax is None else tmax 1ga

423 mask = np.logical_and(ts >= tmin, ts <= tmax) 1ga

424 return ts[mask], pos[mask] 1ga

425 

426 def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', 

427 tmin=None, tmax=None, display=False, **kwargs): 

428 """ 

429 Gets the wheel position and detected movements from Timeline counter channel. 

430 

431 Called by the super class extractor (FPGATrials._extract). 

432 

433 Parameters 

434 ---------- 

435 ticks : int 

436 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder). 

437 radius : float 

438 Radius of the wheel. Defaults to 1 for an output in radians. 

439 coding : str {'x1', 'x2', 'x4'} 

440 Rotary encoder encoding (IBL default is x4). 

441 tmin : float 

442 The minimum time from which to extract the sync pulses. 

443 tmax : float 

444 The maximum time up to which we extract the sync pulses. 

445 display : bool 

446 If true, plot the wheel positions from bpod and the DAQ. 

447 

448 Returns 

449 ------- 

450 dict 

451 wheel object with keys ('timestamps', 'position'). 

452 dict 

453 wheelMoves object with keys ('intervals' 'peakAmplitude'). 

454 """ 

455 wheel = self.extract_wheel_sync(ticks=ticks, radius=radius, coding=coding, tmin=tmin, tmax=tmax) 1ga

456 wheel = dict(zip(('timestamps', 'position'), wheel)) 1ga

457 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 1ga

458 

459 if display: 1ga

460 assert self.bpod_trials, 'no bpod trials to compare' 1g

461 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1g

462 bpod_ts = self.bpod_trials['wheel_timestamps'] 1g

463 bpod_pos = self.bpod_trials['wheel_position'] 1g

464 ax0.plot(self.bpod2fpga(bpod_ts), bpod_pos) 1g

465 ax0.set_ylabel('Bpod wheel position / rad') 1g

466 ax1.plot(wheel['timestamps'], wheel['position']) 1g

467 ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') 1g

468 return wheel, moves 1ga

469 

470 def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None): 

471 """ 

472 Get the valve open times from the raw timeline voltage trace. 

473 

474 Parameters 

475 ---------- 

476 display : bool 

477 Plot detected times on the raw voltage trace. 

478 threshold : float 

479 The threshold of voltage change to apply. The default was set by eye; units should be 

480 Volts per sample but doesn't appear to be. 

481 driver_ttls : numpy.array 

482 An optional array of driver TTLs to use for assigning with the valve times. 

483 

484 Returns 

485 ------- 

486 numpy.array 

487 The detected valve open intervals. 

488 numpy.array 

489 If driver_ttls is not None, returns an array of open times that occurred directly after 

490 the driver TTLs. 

491 """ 

492 WARN_THRESH = 10e-3 # open time threshold below which to log warning 1ea

493 tl = self.timeline 1ea

494 info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') 1ea

495 values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 1ea

496 

497 # The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz 

498 # making simple thresholding an issue. For this reason we convolve the signal with a 

499 # window and detect the peaks and troughs. 

500 if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz 1ea

501 _logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs) 1ea

502 dt = 1e-3 # change in voltage takes ~1ms when changing valve open state 1ea

503 N = dt / (1 / Fs) # this means voltage change occurs over N samples 1ea

504 vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time 1ea

505 ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V) 1ea

506 downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V) 1ea

507 

508 # Convert these times into intervals 

509 ixs = np.argsort(np.r_[downs, ups]) # sort indices 1ea

510 times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times 1ea

511 polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted 1ea

512 missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN 1ea

513 times = np.insert(times, missing + int(polarities[0] == -1), np.nan) 1ea

514 if polarities[-1] == -1: # ensure ends with a valve close 1ea

515 times = np.r_[times, np.nan] 1ea

516 if polarities[0] == 1: # ensure starts with a valve open 1ea

517 # It seems it can start out at -5V (open), then when the reward happens it closes and 

518 # immediately opens. In this case we insert discard the first open time. 

519 times = np.r_[np.nan, times] 

520 intervals = times.reshape(-1, 2) 1ea

521 

522 # Log warning of improbably short intervals 

523 short = np.sum(np.diff(intervals) < WARN_THRESH) 1ea

524 if short > 0: 1ea

525 _logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH) 

526 

527 # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL 

528 if driver_ttls is not None: 1ea

529 # Returns an array of open_times indices, one for each driver TTL 

530 ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after') 1a

531 open_times = intervals[ind[ind >= 0], 0] 1a

532 # TODO Log any > 40ms? Difficult to report missing valve times because of calibration 

533 

534 if display: 1ea

535 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1e

536 ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-') 1e

537 if driver_ttls is not None: 1e

538 x = np.empty_like(driver_ttls.flatten()) 

539 x[0::2] = driver_ttls[:, 0] 

540 x[1::2] = driver_ttls[:, 1] 

541 y = np.ones_like(x) 

542 y[1::2] -= 2 

543 squares(x, y, ax=ax0, yrange=[0, 5]) 

544 # vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') 

545 ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*') 

546 ax1.plot(tl['timestamps'], values, 'k-o') 1e

547 ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') 1e

548 

549 ax2 = ax1.twinx() 1e

550 ax2.set_ylabel('dV', color='grey') 1e

551 ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey') 1e

552 ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close') 1e

553 ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open') 1e

554 return intervals if driver_ttls is None else (intervals, open_times) 1ea

555 

556 def _assign_events_audio(self, audio_times, audio_polarities, display=False): 

557 """ 

558 This is identical to ephys_fpga._assign_events_audio, except for the ready tone threshold. 

559 

560 Parameters 

561 ---------- 

562 audio_times : numpy.array 

563 An array of audio TTL front times. 

564 audio_polarities : numpy.array 

565 An array of audio TTL front polarities (1 for rises, -1 for falls). 

566 display : bool 

567 If true, display audio pulses and the assigned onsets. 

568 

569 Returns 

570 ------- 

571 numpy.array 

572 The times of the go cue onsets. 

573 numpy.array 

574 The times of the error tone onsets. 

575 """ 

576 # make sure that there are no 2 consecutive fall or consecutive rise events 

577 assert np.all(np.abs(np.diff(audio_polarities)) == 2) 

578 # take only even time differences: i.e. from rising to falling fronts 

579 dt = np.diff(audio_times) 

580 onsets = audio_polarities[:-1] == 1 

581 

582 # error tones are events lasting from 400ms to 1200ms 

583 i_error_tone_in = np.where(np.logical_and(0.4 < dt, dt < 1.2) & onsets)[0] 

584 t_error_tone_in = audio_times[i_error_tone_in] 

585 

586 # detect ready tone by length below 300 ms 

587 i_ready_tone_in = np.where(np.logical_and(dt <= 0.3, onsets))[0] 

588 t_ready_tone_in = audio_times[i_ready_tone_in] 

589 if display: # pragma: no cover 

590 fig, ax = plt.subplots(nrows=2, sharex=True) 

591 ax[0].plot(self.timeline.timestamps, timeline_get_channel(self.timeline, 'audio'), 'k-o') 

592 ax[0].set_ylabel('Voltage / V') 

593 squares(audio_times, audio_polarities, yrange=[-1, 1], ax=ax[1]) 

594 vertical_lines(t_ready_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='go cue') 

595 vertical_lines(t_error_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='error tone') 

596 ax[1].set_xlabel('Time / s') 

597 ax[1].legend() 

598 

599 return t_ready_tone_in, t_error_tone_in 

600 

601 

602class MesoscopeSyncTimeline(extractors_base.BaseExtractor): 

603 """Extraction of mesoscope imaging times.""" 

604 

605 var_names = ('mpci_times', 'mpciStack_timeshift') 

606 save_names = ('mpci.times.npy', 'mpciStack.timeshift.npy') 

607 

608 """one.alf.io.AlfBunch: The raw imaging meta data and frame times""" 

609 rawImagingData = None 

610 

611 def __init__(self, session_path, n_FOVs): 

612 """ 

613 Extract the mesoscope frame times from DAQ data acquired through Timeline. 

614 

615 Parameters 

616 ---------- 

617 session_path : str, pathlib.Path 

618 The session path to extract times from. 

619 n_FOVs : int 

620 The number of fields of view acquired. 

621 """ 

622 super().__init__(session_path) 1fcb

623 self.n_FOVs = n_FOVs 1fcb

624 fov = list(map(lambda n: f'FOV_{n:02}', range(self.n_FOVs))) 1fcb

625 self.var_names = [f'{x}_{y.lower()}' for x in self.var_names for y in fov] 1fcb

626 self.save_names = [f'{y}/{x}' for x in self.save_names for y in fov] 1fcb

627 

628 def _extract(self, sync=None, chmap=None, device_collection='raw_imaging_data', events=None): 

629 """ 

630 Extract the frame timestamps for each individual field of view (FOV) and the time offsets 

631 for each line scan. 

632 

633 The detected frame times from the 'neural_frames' channel of the DAQ are split into bouts 

634 corresponding to the number of raw_imaging_data folders. These timestamps should match the 

635 number of frame timestamps extracted from the image file headers (found in the 

636 rawImagingData.times file). The field of view (FOV) shifts are then applied to these 

637 timestamps for each field of view and provided together with the line shifts. 

638 

639 Parameters 

640 ---------- 

641 sync : one.alf.io.AlfBunch 

642 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses 

643 and the corresponding channel numbers. 

644 chmap : dict 

645 A map of channel names and their corresponding indices. Only the 'neural_frames' 

646 channel is required. 

647 device_collection : str, iterable of str 

648 The location of the raw imaging data. 

649 events : pandas.DataFrame 

650 A table of software events, with columns {'time_timeline' 'name_timeline', 

651 'event_timeline'}. 

652 

653 Returns 

654 ------- 

655 list of numpy.array 

656 A list of timestamps for each FOV and the time offsets for each line scan. 

657 """ 

658 frame_times = sync['times'][sync['channels'] == chmap['neural_frames']] 1cb

659 

660 # imaging_start_time = datetime.datetime(*map(round, self.rawImagingData.meta['acquisitionStartTime'])) 

661 if isinstance(device_collection, str): 1cb

662 device_collection = [device_collection] 1c

663 if events is not None: 1cb

664 events = events[events.name == 'mpepUDP'] 1b

665 edges = self.get_bout_edges(frame_times, device_collection, events) 1cb

666 fov_times = [] 1cb

667 line_shifts = [] 1cb

668 for (tmin, tmax), collection in zip(edges, sorted(device_collection)): 1cb

669 imaging_data = alfio.load_object(self.session_path / collection, 'rawImagingData') 1cb

670 imaging_data['meta'] = patch_imaging_meta(imaging_data['meta']) 1cb

671 # Calculate line shifts 

672 _, fov_time_shifts, line_time_shifts = self.get_timeshifts(imaging_data['meta']) 1cb

673 assert len(fov_time_shifts) == self.n_FOVs, f'unexpected number of FOVs for {collection}' 1cb

674 ts = frame_times[np.logical_and(frame_times >= tmin, frame_times <= tmax)] 1cb

675 assert ts.size >= imaging_data['times_scanImage'].size, f'fewer DAQ timestamps for {collection} than expected' 1cb

676 if ts.size > imaging_data['times_scanImage'].size: 1cb

677 _logger.warning( 1c

678 'More DAQ frame times detected for %s than were found in the raw image data.\n' 

679 'N DAQ frame times:\t%i\nN raw image data times:\t%i.\n' 

680 'This may occur if the bout detection fails (e.g. UDPs recorded late), ' 

681 'when image data is corrupt, or when frames are not written to file.', 

682 collection, ts.size, imaging_data['times_scanImage'].size) 

683 _logger.info('Dropping last %i frame times for %s', ts.size - imaging_data['times_scanImage'].size, collection) 1c

684 ts = ts[:imaging_data['times_scanImage'].size] 1c

685 fov_times.append([ts + offset for offset in fov_time_shifts]) 1cb

686 if not line_shifts: 1cb

687 line_shifts = line_time_shifts 1cb

688 else: # The line shifts should be the same across all imaging bouts 

689 [np.testing.assert_array_equal(x, y) for x, y in zip(line_time_shifts, line_shifts)] 1b

690 

691 # Concatenate imaging timestamps across all bouts for each field of view 

692 fov_times = list(map(np.concatenate, zip(*fov_times))) 1cb

693 n_fov_times, = set(map(len, fov_times)) 1cb

694 if n_fov_times != frame_times.size: 1cb

695 # This may happen if an experimenter deletes a raw_imaging_data folder 

696 _logger.debug('FOV timestamps length does not match neural frame count; imaging bout(s) likely missing') 1c

697 return fov_times + line_shifts 1cb

698 

699 def get_bout_edges(self, frame_times, collections=None, events=None, min_gap=1., display=False): 

700 """ 

701 Return an array of edge times for each imaging bout corresponding to a raw_imaging_data 

702 collection. 

703 

704 Parameters 

705 ---------- 

706 frame_times : numpy.array 

707 An array of all neural frame count times. 

708 collections : iterable of str 

709 A set of raw_imaging_data collections, used to extract selected imaging periods. 

710 events : pandas.DataFrame 

711 A table of UDP event times, corresponding to times when recordings start and end. 

712 min_gap : float 

713 If start or end events not present, split bouts by finding gaps larger than this value. 

714 display : bool 

715 If true, plot the detected bout edges and raw frame times. 

716 

717 Returns 

718 ------- 

719 numpy.array 

720 An array of imaging bout intervals. 

721 """ 

722 if events is None or events.empty: 1fcb

723 # No UDP events to mark blocks so separate based on gaps in frame rate 

724 idx = np.where(np.diff(frame_times) > min_gap)[0] 1fc

725 starts = np.r_[frame_times[0], frame_times[idx + 1]] 1fc

726 ends = np.r_[frame_times[idx], frame_times[-1]] 1fc

727 else: 

728 # Split using Exp/BlockStart and Exp/BlockEnd times 

729 _, subject, date, _ = session_path_parts(self.session_path) 1fb

730 pattern = rf'(Exp|Block)%s\s{subject}\s{date.replace("-", "")}\s\d+' 1fb

731 

732 # Get start times 

733 UDP_start = events[events['info'].str.match(pattern % 'Start')] 1fb

734 if len(UDP_start) > 1 and UDP_start.loc[0, 'info'].startswith('Exp'): 1fb

735 # Use ExpStart instead of first bout start 

736 UDP_start = UDP_start.copy().drop(1) 1fb

737 # Use ExpStart/End instead of first/last BlockStart/End 

738 starts = frame_times[[np.where(frame_times >= t)[0][0] for t in UDP_start.time]] 1fb

739 

740 # Get end times 

741 UDP_end = events[events['info'].str.match(pattern % 'End')] 1fb

742 if len(UDP_end) > 1 and UDP_end['info'].values[-1].startswith('Exp'): 1fb

743 # Use last BlockEnd instead of ExpEnd 

744 UDP_end = UDP_end.copy().drop(UDP_end.index[-1]) 1fb

745 if not UDP_end.empty: 1fb

746 ends = frame_times[[np.where(frame_times <= t)[0][-1] for t in UDP_end.time]] 1fb

747 else: 

748 # Get index of last frame to occur within a second of the previous frame 

749 consec = np.r_[np.diff(frame_times) > min_gap, True] 1f

750 idx = [np.where(np.logical_and(frame_times > t, consec))[0][0] for t in starts] 1f

751 ends = frame_times[idx] 1f

752 

753 # Remove any missing imaging bout collections 

754 edges = np.c_[starts, ends] 1fcb

755 if collections: 1fcb

756 if edges.shape[0] > len(collections): 1fcb

757 # Remove any bouts that correspond to a skipped collection 

758 # e.g. if {raw_imaging_data_00, raw_imaging_data_02}, remove middle bout 

759 include = sorted(int(c.rsplit('_', 1)[-1]) for c in collections) 

760 edges = edges[include, :] 

761 elif edges.shape[0] < len(collections): 1fcb

762 raise ValueError('More raw imaging folders than detected bouts') 1f

763 

764 if display: 1fcb

765 _, ax = plt.subplots(1) 1f

766 ax.step(frame_times, np.arange(frame_times.size), label='frame times', color='k', ) 1f

767 vertical_lines(edges[:, 0], ax=ax, ymin=0, ymax=frame_times.size, label='bout start', color='b') 1f

768 vertical_lines(edges[:, 1], ax=ax, ymin=0, ymax=frame_times.size, label='bout end', color='orange') 1f

769 if edges.shape[0] != len(starts): 1f

770 vertical_lines(np.setdiff1d(starts, edges[:, 0]), ax=ax, ymin=0, ymax=frame_times.size, 

771 label='missing bout start', linestyle=':', color='b') 

772 vertical_lines(np.setdiff1d(ends, edges[:, 1]), ax=ax, ymin=0, ymax=frame_times.size, 

773 label='missing bout end', linestyle=':', color='orange') 

774 ax.set_xlabel('Time / s'), ax.set_ylabel('Frame #'), ax.legend(loc='lower right') 1f

775 return edges 1fcb

776 

777 @staticmethod 

778 def get_timeshifts(raw_imaging_meta): 

779 """ 

780 Calculate the time shifts for each field of view (FOV) and the relative offsets for each 

781 scan line. 

782 

783 For a 2 scan field, 2 depth recording (so 4 FOVs): 

784 

785 Frame 1, lines 1-512 correspond to FOV_00 

786 Frame 1, lines 551-1062 correspond to FOV_01 

787 Frame 2, lines 1-512 correspond to FOV_02 

788 Frame 2, lines 551-1062 correspond to FOV_03 

789 Frame 3, lines 1-512 correspond to FOV_00 

790 ... 

791 

792 Parameters 

793 ---------- 

794 raw_imaging_meta : dict 

795 Extracted ScanImage meta data (_ibl_rawImagingData.meta.json). 

796 

797 Returns 

798 ------- 

799 list of numpy.array 

800 A list of arrays, one per FOV, containing indices of each image scan line. 

801 numpy.array 

802 An array of FOV time offsets (one value per FOV) relative to each frame acquisition 

803 time. 

804 list of numpy.array 

805 A list of arrays, one per FOV, containing the time offsets for each scan line, relative 

806 to each FOV offset. 

807 """ 

808 FOVs = raw_imaging_meta['FOV'] 1kcb

809 

810 # Double-check meta extracted properly 

811 # assert meta.FOV.Zs is ascending but use slice_id field. This may not be necessary but is expected. 

812 slice_ids = np.array([fov['slice_id'] for fov in FOVs]) 1kcb

813 assert np.all(np.diff([x['Zs'] for x in FOVs]) >= 0), 'FOV depths not in ascending order' 1kcb

814 assert np.all(np.diff(slice_ids) >= 0), 'slice IDs not ordered' 1kcb

815 # Number of scan lines per FOV, i.e. number of Y pixels / image height 

816 n_lines = np.array([x['nXnYnZ'][1] for x in FOVs]) 1kcb

817 

818 # We get indices from MATLAB extracted metadata so below two lines are no longer needed 

819 # n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks 

820 # n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) # N lines during flyback 

821 line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod'] 1kcb

822 frame_time_shifts = slice_ids / raw_imaging_meta['scanImageParams']['hRoiManager']['scanFrameRate'] 1kcb

823 

824 # Line indices are now extracted by the MATLAB function mesoscopeMetadataExtraction.m 

825 # They are indexed from 1 so we subtract 1 to convert to zero-indexed 

826 line_indices = [np.array(fov['lineIdx']) - 1 for fov in FOVs] # Convert to zero-indexed from MATLAB 1-indexed 1kcb

827 assert all(lns.size == n for lns, n in zip(line_indices, n_lines)), 'unexpected number of scan lines' 1kcb

828 # The start indices of each FOV in the raw images 

829 fov_start_idx = np.array([lns[0] for lns in line_indices]) 1kcb

830 roi_time_shifts = fov_start_idx * line_period # The time offset for each FOV 1kcb

831 fov_time_shifts = roi_time_shifts + frame_time_shifts 1kcb

832 line_time_shifts = [(lns - ln0) * line_period for lns, ln0 in zip(line_indices, fov_start_idx)] 1kcb

833 

834 return line_indices, fov_time_shifts, line_time_shifts 1kcb