Coverage for ibllib/io/extractors/mesoscope.py: 91%

326 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Mesoscope (timeline) data extraction.""" 

2import logging 

3 

4import numpy as np 

5from scipy.signal import find_peaks 

6import one.alf.io as alfio 

7from one.util import ensure_list 

8from one.alf.files import session_path_parts 

9import matplotlib.pyplot as plt 

10from packaging import version 

11 

12from ibllib.plots.misc import squares, vertical_lines 

13from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel, 

14 correct_counter_discontinuities, load_timeline_sync_and_chmap) 

15import ibllib.io.extractors.base as extractors_base 

16from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial 

17from ibllib.io.extractors.training_wheel import extract_wheel_moves 

18from ibllib.io.extractors.camera import attribute_times 

19from brainbox.behavior.wheel import velocity_filtered 

20 

21_logger = logging.getLogger(__name__) 

22 

23 

24def patch_imaging_meta(meta: dict) -> dict: 

25 """ 

26 Patch imaging metadata for compatibility across versions. 

27 

28 A copy of the dict is NOT returned. 

29 

30 Parameters 

31 ---------- 

32 meta : dict 

33 A folder path that contains a rawImagingData.meta file. 

34 

35 Returns 

36 ------- 

37 dict 

38 The loaded metadata file, updated to the most recent version. 

39 """ 

40 # 2023-05-17 (unversioned) adds nFrames, channelSaved keys, MM and Deg keys 

41 ver = version.parse(meta.get('version') or '0.0.0') 1iklcb

42 if ver <= version.parse('0.0.0'): 1iklcb

43 if 'channelSaved' not in meta: 1ikl

44 meta['channelSaved'] = next((x['channelIdx'] for x in meta['FOV'] if 'channelIdx' in x), []) 1ikl

45 fields = ('topLeft', 'topRight', 'bottomLeft', 'bottomRight') 1ikl

46 for fov in meta.get('FOV', []): 1ikl

47 for unit in ('Deg', 'MM'): 1ikl

48 if unit not in fov: # topLeftDeg, etc. -> Deg[topLeft] 1ikl

49 fov[unit] = {f: fov.pop(f + unit, None) for f in fields} 1ikl

50 elif ver == version.parse('0.1.0'): 1icb

51 for fov in meta.get('FOV', []): 1i

52 if 'roiUuid' in fov: 1i

53 fov['roiUUID'] = fov.pop('roiUuid') 1i

54 assert 'nFrames' in meta, '"nFrames" key missing from meta data; rawImagingData.meta.json likely an old version' 1iklcb

55 return meta 1iklcb

56 

57 

58def plot_timeline(timeline, channels=None, raw=True): 

59 """ 

60 Plot the timeline data. 

61 

62 Parameters 

63 ---------- 

64 timeline : one.alf.io.AlfBunch 

65 The timeline data object. 

66 channels : list of str 

67 An iterable of channel names to plot. 

68 raw : bool 

69 If true, plot the raw DAQ samples; if false, apply TTL thresholds and plot changes. 

70 

71 Returns 

72 ------- 

73 matplotlib.pyplot.Figure 

74 The figure containing timeline subplots. 

75 list of matplotlib.pyplot.Axes 

76 The axes for each timeline channel plotted. 

77 """ 

78 meta = {x.copy().pop('name'): x for x in timeline['meta']['inputs']} 1h

79 channels = channels or meta.keys() 1h

80 fig, axes = plt.subplots(len(channels), 1, sharex=True) 1h

81 axes = ensure_list(axes) 1h

82 if not raw: 1h

83 chmap = {ch: meta[ch]['arrayColumn'] for ch in channels} 1h

84 sync = extract_sync_timeline(timeline, chmap=chmap) 1h

85 for i, (ax, ch) in enumerate(zip(axes, channels)): 1h

86 if raw: 1h

87 # axesScale controls vertical scaling of each trace (multiplicative) 

88 values = timeline['raw'][:, meta[ch]['arrayColumn'] - 1] * meta[ch]['axesScale'] 1h

89 ax.plot(timeline['timestamps'], values) 1h

90 elif np.any(idx := sync['channels'] == chmap[ch]): 1h

91 squares(sync['times'][idx], sync['polarities'][idx], ax=ax) 1h

92 ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) 1h

93 ax.spines['bottom'].set_visible(False), ax.spines['left'].set_visible(True) 1h

94 ax.set_ylabel(ch, rotation=45, fontsize=8) 1h

95 # Add back x-axis ticks to the last plot 

96 axes[-1].tick_params(axis='x', which='both', bottom=True, labelbottom=True) 1h

97 axes[-1].spines['bottom'].set_visible(True) 1h

98 plt.get_current_fig_manager().window.showMaximized() # full screen 1h

99 fig.tight_layout(h_pad=0) 1h

100 return fig, axes 1h

101 

102 

103class TimelineTrials(FpgaTrials): 

104 """Similar extraction to the FPGA, however counter and position channels are treated differently.""" 

105 

106 timeline = None 

107 """one.alf.io.AlfBunch: The timeline data object.""" 

108 

109 sync_field = 'itiIn_times' 

110 """str: The trial event to synchronize (must be present in extracted trials).""" 

111 

112 def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): 

113 """An extractor for all ephys trial data, in Timeline time""" 

114 super().__init__(*args, **kwargs) 1ega

115 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') 1ega

116 

117 def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_): 

118 """Load the DAQ sync and channel map data. 

119 

120 Parameters 

121 ---------- 

122 sync_collection : str 

123 The session subdirectory where the sync data are located. 

124 chmap : dict 

125 A map of channel names and their corresponding indices. If None, the channel map is 

126 loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method. 

127 

128 Returns 

129 ------- 

130 one.alf.io.AlfBunch 

131 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses 

132 and the corresponding channel numbers. 

133 dict 

134 A map of channel names and their corresponding indices. 

135 """ 

136 if not self.timeline: 1a

137 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') 

138 sync, chmap = load_timeline_sync_and_chmap( 1a

139 self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) 

140 return sync, chmap 1a

141 

142 def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: 

143 trials = super()._extract(sync, chmap, sync_collection='raw_sync_data', **kwargs) 1a

144 if kwargs.get('display', False): 1a

145 plot_timeline(self.timeline, channels=chmap.keys(), raw=True) 

146 return trials 1a

147 

148 def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): 

149 """ 

150 Extract Bpod times from sync. 

151 

152 Unlike the superclass method. This one doesn't reassign the first trial pulse. 

153 

154 Parameters 

155 ---------- 

156 sync : dict 

157 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses 

158 and the corresponding channel numbers. Must contain a 'bpod' key. 

159 chmap : dict 

160 A map of channel names and their corresponding indices. 

161 bpod_event_ttls : dict of tuple 

162 A map of event names to (min, max) TTL length. 

163 

164 Returns 

165 ------- 

166 dict 

167 A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. 

168 dict 

169 A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. 

170 """ 

171 # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these 

172 # lengths are defined by the state machine of the task protocol and therefore vary. 

173 if bpod_event_ttls is None: 1a

174 # The trial start TTLs are often too short for the low sampling rate of the DAQ and are 

175 # therefore not used in extraction 

176 bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} 1a

177 bpod, bpod_event_intervals = super().get_bpod_event_times( 1a

178 sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs) 

179 

180 # TODO Here we can make use of the 'bpod_rising_edge' channel, if available 

181 return bpod, bpod_event_intervals 1a

182 

183 def build_trials(self, sync=None, chmap=None, **kwargs): 

184 """ 

185 Extract task related event times from the sync. 

186 

187 The two major differences are that the sampling rate is lower for imaging so the short Bpod 

188 trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL. 

189 

190 Second, the valve used at the mesoscope has a way to record the raw voltage across the 

191 solenoid, giving a more accurate readout of the valve's activity. If the reward_valve 

192 channel is present on the DAQ, this is used to extract the valve open times. 

193 

194 Parameters 

195 ---------- 

196 sync : dict 

197 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' 

198 chmap : dict 

199 Map of channel names and their corresponding index. Default to constant. 

200 

201 Returns 

202 ------- 

203 dict 

204 A map of trial event timestamps. 

205 """ 

206 # Get the events from the sync. 

207 # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC 

208 self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) 1a

209 self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) 1a

210 if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: 1a

211 raise ValueError( 

212 'Expected at least "ready_tone" and "error_tone" audio events.' 

213 '`audio_event_ttls` kwarg may be incorrect.') 

214 

215 self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) 1a

216 if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}: 1a

217 raise ValueError( 

218 'Expected at least "trial_end" and "valve_open" audio events. ' 

219 '`bpod_event_ttls` kwarg may be incorrect.') 

220 

221 t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T 1a

222 fpga_events = alfio.AlfBunch({ 1a

223 'itiIn_times': t_iti_in, 

224 'intervals_1': t_trial_end, 

225 'goCue_times': audio_event_intervals['ready_tone'][:, 0], 

226 'errorTone_times': audio_event_intervals['error_tone'][:, 0] 

227 }) 

228 

229 # Sync the Bpod clock to the DAQ 

230 self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) 1a

231 

232 out = dict() 1a

233 out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) 1a

234 out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) 1a

235 

236 start_times = out['intervals'][:, 0] 1a

237 last_trial_end = out['intervals'][-1, 1] 1a

238 

239 def assign_to_trial(events, take='last', starts=start_times, **kwargs): 1a

240 """Assign DAQ events to trials. 

241 

242 Because we may not have trial start TTLs on the DAQ (because of the low sampling rate), 

243 there may be an extra last trial that's not in the Bpod intervals as the extractor 

244 ignores the last trial. This function trims the input array before assigning so that 

245 the last trial's events are correctly assigned. 

246 """ 

247 return _assign_events_to_trial(starts, events[events <= last_trial_end], take, **kwargs) 1a

248 out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga]) 1a

249 

250 # Extract valve open times from the DAQ 

251 valve_driver_ttls = bpod_event_intervals['valve_open'] 1a

252 correct = self.bpod_trials['feedbackType'] == 1 1a

253 # If there is a reward_valve channel, the valve has 

254 if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): 1a

255 # TODO Let's look at the expected open length based on calibration and reward volume 

256 # import scipy.interpolate 

257 # # FIXME support v7 settings? 

258 # fcn_vol2time = scipy.interpolate.pchip( 

259 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'], 

260 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES'] 

261 # ) 

262 # reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3 

263 

264 # Use the driver TTLs to find the valve open times that correspond to the valve opening 

265 valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) 1a

266 if valve_open_times.size != np.sum(correct): 1a

267 _logger.warning( 

268 'Number of valve open times does not equal number of correct trials (%i != %i)', 

269 valve_open_times.size, np.sum(correct)) 

270 

271 out['valveOpen_times'] = assign_to_trial(valve_open_times) 1a

272 else: 

273 # Use the valve controller TTLs recorded on the Bpod channel as the reward time 

274 out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0]) 

275 

276 # Stimulus times extracted based on trigger times 

277 # When assigning events all start times must not be NaN so here we substitute freeze 

278 # trigger times on nogo trials for stim on trigger times, then replace with NaN again 

279 go_trials = np.where(out['choice'] != 0)[0] 1a

280 lims = np.copy(out['stimOnTrigger_times']) 1a

281 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a

282 out['stimFreeze_times'] = assign_to_trial( 1a

283 self.frame2ttl['times'], 'last', 

284 starts=lims, t_trial_end=out['stimOffTrigger_times']) 

285 out['stimFreeze_times'][out['choice'] == 0] = np.nan 1a

286 

287 # Here we do the same but use stim off trigger times 

288 lims = np.copy(out['stimOffTrigger_times']) 1a

289 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a

290 out['stimOn_times'] = assign_to_trial( 1a

291 self.frame2ttl['times'], 'first', 

292 starts=out['stimOnTrigger_times'], t_trial_end=lims) 

293 out['stimOff_times'] = assign_to_trial( 1a

294 self.frame2ttl['times'], 'first', 

295 starts=out['stimOffTrigger_times'], t_trial_end=out['intervals'][:, 1] 

296 ) 

297 

298 # Audio times 

299 error_cue = fpga_events['errorTone_times'] 1a

300 if error_cue.size != np.sum(~correct): 1a

301 _logger.warning( 

302 'N detected error tones does not match number of incorrect trials (%i != %i)', 

303 error_cue.size, np.sum(~correct)) 

304 go_cue = fpga_events['goCue_times'] 1a

305 out['goCue_times'] = assign_to_trial(go_cue, take='first') 1a

306 out['errorCue_times'] = assign_to_trial(error_cue) 1a

307 

308 if go_cue.size > start_times.size: 1a

309 _logger.warning( 

310 'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size) 

311 elif go_cue.size < start_times.size: 1a

312 """ 1a

313 If the error cues are all assigned and some go cues are missed it may be that some 

314 responses were so fast that the go cue and error tone merged, or the go cue TTL was too 

315 long. 

316 """ 

317 _logger.warning('%i go cue tones missed', start_times.size - go_cue.size) 1a

318 err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times']) 1a

319 go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times']) 1a

320 assert not np.any(np.isnan(go_trig)) 1a

321 assert err_trig.size == go_trig.size # should be length of n trials with NaNs 1a

322 

323 # Find which trials are missing a go cue 

324 _go_cue = assign_to_trial(go_cue, take='first') 1a

325 error_cue = assign_to_trial(error_cue) 1a

326 missing = np.isnan(_go_cue) 1a

327 

328 # Get all the DAQ timestamps where audio channel was HIGH 

329 raw = timeline_get_channel(self.timeline, 'audio') 1a

330 raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize 1a

331 ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH 1a

332 

333 # Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after). 

334 # Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN) 

335 idx = attribute_times(ups, go_trig, tol=0.2, take='after') 1a

336 # Trial indices that didn't have detected goCue and now has been assigned an `ups` index 

337 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned 1a

338 _go_cue[assigned] = ups[idx[assigned]] 1a

339 

340 # Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue) 

341 error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig)) 1a

342 i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True) 1a

343 error_cue[i_to_remove] = np.nan 1a

344 

345 # For those trials where go cue was merged with the error cue and therefore mis-assigned, 

346 # we must re-assign the error cue times as the first HIGH after the error trigger. 

347 idx = attribute_times(ups, err_trig, tol=0.2, take='after') 1a

348 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned 1a

349 error_cue[assigned] = ups[idx[assigned]] 1a

350 out['goCue_times'] = _go_cue 1a

351 out['errorCue_times'] = error_cue 1a

352 

353 # Because we're not 

354 assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \ 1a

355 'audio tones not assigned correctly; tones likely missed' 

356 

357 # Feedback times 

358 out['feedback_times'] = np.copy(out['valveOpen_times']) 1a

359 ind_err = np.isnan(out['valveOpen_times']) 1a

360 out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] 1a

361 

362 return out 1a

363 

364 def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): 

365 """ 

366 Gets the wheel position from Timeline counter channel. 

367 

368 Parameters 

369 ---------- 

370 ticks : int 

371 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder). 

372 radius : float 

373 Radius of the wheel. Defaults to 1 for an output in radians. 

374 coding : str {'x1', 'x2', 'x4'} 

375 Rotary encoder encoding (IBL default is x4). 

376 tmin : float 

377 The minimum time from which to extract the sync pulses. 

378 tmax : float 

379 The maximum time up to which we extract the sync pulses. 

380 

381 Returns 

382 ------- 

383 np.array 

384 Wheel timestamps in seconds. 

385 np.array 

386 Wheel positions in radians. 

387 

388 See Also 

389 -------- 

390 ibllib.io.extractors.ephys_fpga.extract_wheel_sync 

391 """ 

392 if coding not in ('x1', 'x2', 'x4'): 1ga

393 raise ValueError('Unsupported coding; must be one of x1, x2 or x4') 1g

394 raw = correct_counter_discontinuities(timeline_get_channel(self.timeline, 'rotary_encoder')) 1ga

395 

396 # Timeline evenly samples counter so we extract only change points 

397 d = np.diff(raw) 1ga

398 ind, = np.where(~np.isclose(d, 0)) 1ga

399 pos = raw[ind + 1] 1ga

400 pos -= pos[0] # Start from zero 1ga

401 pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians 1ga

402 

403 # Get timestamps of changes and trim based on protocol spacers 

404 ts = self.timeline['timestamps'][ind + 1] 1ga

405 tmin = ts.min() if tmin is None else tmin 1ga

406 tmax = ts.max() if tmax is None else tmax 1ga

407 mask = np.logical_and(ts >= tmin, ts <= tmax) 1ga

408 return ts[mask], pos[mask] 1ga

409 

410 def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', 

411 tmin=None, tmax=None, display=False, **kwargs): 

412 """ 

413 Gets the wheel position and detected movements from Timeline counter channel. 

414 

415 Called by the super class extractor (FPGATrials._extract). 

416 

417 Parameters 

418 ---------- 

419 ticks : int 

420 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder). 

421 radius : float 

422 Radius of the wheel. Defaults to 1 for an output in radians. 

423 coding : str {'x1', 'x2', 'x4'} 

424 Rotary encoder encoding (IBL default is x4). 

425 tmin : float 

426 The minimum time from which to extract the sync pulses. 

427 tmax : float 

428 The maximum time up to which we extract the sync pulses. 

429 display : bool 

430 If true, plot the wheel positions from bpod and the DAQ. 

431 

432 Returns 

433 ------- 

434 dict 

435 wheel object with keys ('timestamps', 'position'). 

436 dict 

437 wheelMoves object with keys ('intervals' 'peakAmplitude'). 

438 """ 

439 wheel = self.extract_wheel_sync(ticks=ticks, radius=radius, coding=coding, tmin=tmin, tmax=tmax) 1ga

440 wheel = dict(zip(('timestamps', 'position'), wheel)) 1ga

441 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 1ga

442 

443 if display: 1ga

444 assert self.bpod_trials, 'no bpod trials to compare' 1g

445 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1g

446 bpod_ts = self.bpod_trials['wheel_timestamps'] 1g

447 bpod_pos = self.bpod_trials['wheel_position'] 1g

448 ax0.plot(self.bpod2fpga(bpod_ts), bpod_pos) 1g

449 ax0.set_ylabel('Bpod wheel position / rad') 1g

450 ax1.plot(wheel['timestamps'], wheel['position']) 1g

451 ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') 1g

452 return wheel, moves 1ga

453 

454 def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None): 

455 """ 

456 Get the valve open times from the raw timeline voltage trace. 

457 

458 Parameters 

459 ---------- 

460 display : bool 

461 Plot detected times on the raw voltage trace. 

462 threshold : float 

463 The threshold of voltage change to apply. The default was set by eye; units should be 

464 Volts per sample but doesn't appear to be. 

465 driver_ttls : numpy.array 

466 An optional array of driver TTLs to use for assigning with the valve times. 

467 

468 Returns 

469 ------- 

470 numpy.array 

471 The detected valve open intervals. 

472 numpy.array 

473 If driver_ttls is not None, returns an array of open times that occurred directly after 

474 the driver TTLs. 

475 """ 

476 WARN_THRESH = 10e-3 # open time threshold below which to log warning 1ea

477 tl = self.timeline 1ea

478 info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') 1ea

479 values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 1ea

480 

481 # The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz 

482 # making simple thresholding an issue. For this reason we convolve the signal with a 

483 # window and detect the peaks and troughs. 

484 if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz 1ea

485 _logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs) 1ea

486 dt = 1e-3 # change in voltage takes ~1ms when changing valve open state 1ea

487 N = dt / (1 / Fs) # this means voltage change occurs over N samples 1ea

488 vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time 1ea

489 ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V) 1ea

490 downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V) 1ea

491 

492 # Convert these times into intervals 

493 ixs = np.argsort(np.r_[downs, ups]) # sort indices 1ea

494 times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times 1ea

495 polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted 1ea

496 missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN 1ea

497 times = np.insert(times, missing + int(polarities[0] == -1), np.nan) 1ea

498 if polarities[-1] == -1: # ensure ends with a valve close 1ea

499 times = np.r_[times, np.nan] 1ea

500 if polarities[0] == 1: # ensure starts with a valve open 1ea

501 # It seems it can start out at -5V (open), then when the reward happens it closes and 

502 # immediately opens. In this case we insert discard the first open time. 

503 times = np.r_[np.nan, times] 

504 intervals = times.reshape(-1, 2) 1ea

505 

506 # Log warning of improbably short intervals 

507 short = np.sum(np.diff(intervals) < WARN_THRESH) 1ea

508 if short > 0: 1ea

509 _logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH) 

510 

511 # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL 

512 if driver_ttls is not None: 1ea

513 # Returns an array of open_times indices, one for each driver TTL 

514 ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after') 1a

515 open_times = intervals[ind[ind >= 0], 0] 1a

516 # TODO Log any > 40ms? Difficult to report missing valve times because of calibration 

517 

518 if display: 1ea

519 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1e

520 ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-') 1e

521 if driver_ttls is not None: 1e

522 x = np.empty_like(driver_ttls.flatten()) 

523 x[0::2] = driver_ttls[:, 0] 

524 x[1::2] = driver_ttls[:, 1] 

525 y = np.ones_like(x) 

526 y[1::2] -= 2 

527 squares(x, y, ax=ax0, yrange=[0, 5]) 

528 # vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') 

529 ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*') 

530 ax1.plot(tl['timestamps'], values, 'k-o') 1e

531 ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') 1e

532 

533 ax2 = ax1.twinx() 1e

534 ax2.set_ylabel('dV', color='grey') 1e

535 ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey') 1e

536 ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close') 1e

537 ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open') 1e

538 return intervals if driver_ttls is None else (intervals, open_times) 1ea

539 

540 def _assign_events_audio(self, audio_times, audio_polarities, display=False): 

541 """ 

542 This is identical to ephys_fpga._assign_events_audio, except for the ready tone threshold. 

543 

544 Parameters 

545 ---------- 

546 audio_times : numpy.array 

547 An array of audio TTL front times. 

548 audio_polarities : numpy.array 

549 An array of audio TTL front polarities (1 for rises, -1 for falls). 

550 display : bool 

551 If true, display audio pulses and the assigned onsets. 

552 

553 Returns 

554 ------- 

555 numpy.array 

556 The times of the go cue onsets. 

557 numpy.array 

558 The times of the error tone onsets. 

559 """ 

560 # make sure that there are no 2 consecutive fall or consecutive rise events 

561 assert np.all(np.abs(np.diff(audio_polarities)) == 2) 

562 # take only even time differences: i.e. from rising to falling fronts 

563 dt = np.diff(audio_times) 

564 onsets = audio_polarities[:-1] == 1 

565 

566 # error tones are events lasting from 400ms to 1200ms 

567 i_error_tone_in = np.where(np.logical_and(0.4 < dt, dt < 1.2) & onsets)[0] 

568 t_error_tone_in = audio_times[i_error_tone_in] 

569 

570 # detect ready tone by length below 300 ms 

571 i_ready_tone_in = np.where(np.logical_and(dt <= 0.3, onsets))[0] 

572 t_ready_tone_in = audio_times[i_ready_tone_in] 

573 if display: # pragma: no cover 

574 fig, ax = plt.subplots(nrows=2, sharex=True) 

575 ax[0].plot(self.timeline.timestamps, timeline_get_channel(self.timeline, 'audio'), 'k-o') 

576 ax[0].set_ylabel('Voltage / V') 

577 squares(audio_times, audio_polarities, yrange=[-1, 1], ax=ax[1]) 

578 vertical_lines(t_ready_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='go cue') 

579 vertical_lines(t_error_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='error tone') 

580 ax[1].set_xlabel('Time / s') 

581 ax[1].legend() 

582 

583 return t_ready_tone_in, t_error_tone_in 

584 

585 

586class MesoscopeSyncTimeline(extractors_base.BaseExtractor): 

587 """Extraction of mesoscope imaging times.""" 

588 

589 var_names = ('mpci_times', 'mpciStack_timeshift') 

590 save_names = ('mpci.times.npy', 'mpciStack.timeshift.npy') 

591 

592 """one.alf.io.AlfBunch: The raw imaging meta data and frame times""" 

593 rawImagingData = None 

594 

595 def __init__(self, session_path, n_FOVs): 

596 """ 

597 Extract the mesoscope frame times from DAQ data acquired through Timeline. 

598 

599 Parameters 

600 ---------- 

601 session_path : str, pathlib.Path 

602 The session path to extract times from. 

603 n_FOVs : int 

604 The number of fields of view acquired. 

605 """ 

606 super().__init__(session_path) 1fcb

607 self.n_FOVs = n_FOVs 1fcb

608 fov = list(map(lambda n: f'FOV_{n:02}', range(self.n_FOVs))) 1fcb

609 self.var_names = [f'{x}_{y.lower()}' for x in self.var_names for y in fov] 1fcb

610 self.save_names = [f'{y}/{x}' for x in self.save_names for y in fov] 1fcb

611 

612 def _extract(self, sync=None, chmap=None, device_collection='raw_imaging_data', events=None): 

613 """ 

614 Extract the frame timestamps for each individual field of view (FOV) and the time offsets 

615 for each line scan. 

616 

617 The detected frame times from the 'neural_frames' channel of the DAQ are split into bouts 

618 corresponding to the number of raw_imaging_data folders. These timestamps should match the 

619 number of frame timestamps extracted from the image file headers (found in the 

620 rawImagingData.times file). The field of view (FOV) shifts are then applied to these 

621 timestamps for each field of view and provided together with the line shifts. 

622 

623 Parameters 

624 ---------- 

625 sync : one.alf.io.AlfBunch 

626 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses 

627 and the corresponding channel numbers. 

628 chmap : dict 

629 A map of channel names and their corresponding indices. Only the 'neural_frames' 

630 channel is required. 

631 device_collection : str, iterable of str 

632 The location of the raw imaging data. 

633 events : pandas.DataFrame 

634 A table of software events, with columns {'time_timeline' 'name_timeline', 

635 'event_timeline'}. 

636 

637 Returns 

638 ------- 

639 list of numpy.array 

640 A list of timestamps for each FOV and the time offsets for each line scan. 

641 """ 

642 frame_times = sync['times'][sync['channels'] == chmap['neural_frames']] 1cb

643 

644 # imaging_start_time = datetime.datetime(*map(round, self.rawImagingData.meta['acquisitionStartTime'])) 

645 if isinstance(device_collection, str): 1cb

646 device_collection = [device_collection] 1c

647 if events is not None: 1cb

648 events = events[events.name == 'mpepUDP'] 1b

649 edges = self.get_bout_edges(frame_times, device_collection, events) 1cb

650 fov_times = [] 1cb

651 line_shifts = [] 1cb

652 for (tmin, tmax), collection in zip(edges, sorted(device_collection)): 1cb

653 imaging_data = alfio.load_object(self.session_path / collection, 'rawImagingData') 1cb

654 imaging_data['meta'] = patch_imaging_meta(imaging_data['meta']) 1cb

655 # Calculate line shifts 

656 _, fov_time_shifts, line_time_shifts = self.get_timeshifts(imaging_data['meta']) 1cb

657 assert len(fov_time_shifts) == self.n_FOVs, f'unexpected number of FOVs for {collection}' 1cb

658 ts = frame_times[np.logical_and(frame_times >= tmin, frame_times <= tmax)] 1cb

659 assert ts.size >= imaging_data['times_scanImage'].size, f'fewer DAQ timestamps for {collection} than expected' 1cb

660 if ts.size > imaging_data['times_scanImage'].size: 1cb

661 _logger.warning( 1c

662 'More DAQ frame times detected for %s than were found in the raw image data.\n' 

663 'N DAQ frame times:\t%i\nN raw image data times:\t%i.\n' 

664 'This may occur if the bout detection fails (e.g. UDPs recorded late), ' 

665 'when image data is corrupt, or when frames are not written to file.', 

666 collection, ts.size, imaging_data['times_scanImage'].size) 

667 _logger.info('Dropping last %i frame times for %s', ts.size - imaging_data['times_scanImage'].size, collection) 1c

668 ts = ts[:imaging_data['times_scanImage'].size] 1c

669 fov_times.append([ts + offset for offset in fov_time_shifts]) 1cb

670 if not line_shifts: 1cb

671 line_shifts = line_time_shifts 1cb

672 else: # The line shifts should be the same across all imaging bouts 

673 [np.testing.assert_array_equal(x, y) for x, y in zip(line_time_shifts, line_shifts)] 1b

674 

675 # Concatenate imaging timestamps across all bouts for each field of view 

676 fov_times = list(map(np.concatenate, zip(*fov_times))) 1cb

677 n_fov_times, = set(map(len, fov_times)) 1cb

678 if n_fov_times != frame_times.size: 1cb

679 # This may happen if an experimenter deletes a raw_imaging_data folder 

680 _logger.debug('FOV timestamps length does not match neural frame count; imaging bout(s) likely missing') 1c

681 return fov_times + line_shifts 1cb

682 

683 def get_bout_edges(self, frame_times, collections=None, events=None, min_gap=1., display=False): 

684 """ 

685 Return an array of edge times for each imaging bout corresponding to a raw_imaging_data 

686 collection. 

687 

688 Parameters 

689 ---------- 

690 frame_times : numpy.array 

691 An array of all neural frame count times. 

692 collections : iterable of str 

693 A set of raw_imaging_data collections, used to extract selected imaging periods. 

694 events : pandas.DataFrame 

695 A table of UDP event times, corresponding to times when recordings start and end. 

696 min_gap : float 

697 If start or end events not present, split bouts by finding gaps larger than this value. 

698 display : bool 

699 If true, plot the detected bout edges and raw frame times. 

700 

701 Returns 

702 ------- 

703 numpy.array 

704 An array of imaging bout intervals. 

705 """ 

706 if events is None or events.empty: 1fcb

707 # No UDP events to mark blocks so separate based on gaps in frame rate 

708 idx = np.where(np.diff(frame_times) > min_gap)[0] 1fc

709 starts = np.r_[frame_times[0], frame_times[idx + 1]] 1fc

710 ends = np.r_[frame_times[idx], frame_times[-1]] 1fc

711 else: 

712 # Split using Exp/BlockStart and Exp/BlockEnd times 

713 _, subject, date, _ = session_path_parts(self.session_path) 1fb

714 pattern = rf'(Exp|Block)%s\s{subject}\s{date.replace("-", "")}\s\d+' 1fb

715 

716 # Get start times 

717 UDP_start = events[events['info'].str.match(pattern % 'Start')] 1fb

718 if len(UDP_start) > 1 and UDP_start.loc[0, 'info'].startswith('Exp'): 1fb

719 # Use ExpStart instead of first bout start 

720 UDP_start = UDP_start.copy().drop(1) 1fb

721 # Use ExpStart/End instead of first/last BlockStart/End 

722 starts = frame_times[[np.where(frame_times >= t)[0][0] for t in UDP_start.time]] 1fb

723 

724 # Get end times 

725 UDP_end = events[events['info'].str.match(pattern % 'End')] 1fb

726 if len(UDP_end) > 1 and UDP_end['info'].values[-1].startswith('Exp'): 1fb

727 # Use last BlockEnd instead of ExpEnd 

728 UDP_end = UDP_end.copy().drop(UDP_end.index[-1]) 1fb

729 if not UDP_end.empty: 1fb

730 ends = frame_times[[np.where(frame_times <= t)[0][-1] for t in UDP_end.time]] 1fb

731 else: 

732 # Get index of last frame to occur within a second of the previous frame 

733 consec = np.r_[np.diff(frame_times) > min_gap, True] 1f

734 idx = [np.where(np.logical_and(frame_times > t, consec))[0][0] for t in starts] 1f

735 ends = frame_times[idx] 1f

736 

737 # Remove any missing imaging bout collections 

738 edges = np.c_[starts, ends] 1fcb

739 if collections: 1fcb

740 if edges.shape[0] > len(collections): 1fcb

741 # Remove any bouts that correspond to a skipped collection 

742 # e.g. if {raw_imaging_data_00, raw_imaging_data_02}, remove middle bout 

743 include = sorted(int(c.rsplit('_', 1)[-1]) for c in collections) 

744 edges = edges[include, :] 

745 elif edges.shape[0] < len(collections): 1fcb

746 raise ValueError('More raw imaging folders than detected bouts') 1f

747 

748 if display: 1fcb

749 _, ax = plt.subplots(1) 1f

750 ax.step(frame_times, np.arange(frame_times.size), label='frame times', color='k', ) 1f

751 vertical_lines(edges[:, 0], ax=ax, ymin=0, ymax=frame_times.size, label='bout start', color='b') 1f

752 vertical_lines(edges[:, 1], ax=ax, ymin=0, ymax=frame_times.size, label='bout end', color='orange') 1f

753 if edges.shape[0] != len(starts): 1f

754 vertical_lines(np.setdiff1d(starts, edges[:, 0]), ax=ax, ymin=0, ymax=frame_times.size, 

755 label='missing bout start', linestyle=':', color='b') 

756 vertical_lines(np.setdiff1d(ends, edges[:, 1]), ax=ax, ymin=0, ymax=frame_times.size, 

757 label='missing bout end', linestyle=':', color='orange') 

758 ax.set_xlabel('Time / s'), ax.set_ylabel('Frame #'), ax.legend(loc='lower right') 1f

759 return edges 1fcb

760 

761 @staticmethod 

762 def get_timeshifts(raw_imaging_meta): 

763 """ 

764 Calculate the time shifts for each field of view (FOV) and the relative offsets for each 

765 scan line. 

766 

767 For a 2 scan field, 2 depth recording (so 4 FOVs): 

768 

769 Frame 1, lines 1-512 correspond to FOV_00 

770 Frame 1, lines 551-1062 correspond to FOV_01 

771 Frame 2, lines 1-512 correspond to FOV_02 

772 Frame 2, lines 551-1062 correspond to FOV_03 

773 Frame 3, lines 1-512 correspond to FOV_00 

774 ... 

775 

776 Parameters 

777 ---------- 

778 raw_imaging_meta : dict 

779 Extracted ScanImage meta data (_ibl_rawImagingData.meta.json). 

780 

781 Returns 

782 ------- 

783 list of numpy.array 

784 A list of arrays, one per FOV, containing indices of each image scan line. 

785 numpy.array 

786 An array of FOV time offsets (one value per FOV) relative to each frame acquisition 

787 time. 

788 list of numpy.array 

789 A list of arrays, one per FOV, containing the time offsets for each scan line, relative 

790 to each FOV offset. 

791 """ 

792 FOVs = raw_imaging_meta['FOV'] 1jcb

793 

794 # Double-check meta extracted properly 

795 # assert meta.FOV.Zs is ascending but use slice_id field. This may not be necessary but is expected. 

796 slice_ids = np.array([fov['slice_id'] for fov in FOVs]) 1jcb

797 assert np.all(np.diff([x['Zs'] for x in FOVs]) >= 0), 'FOV depths not in ascending order' 1jcb

798 assert np.all(np.diff(slice_ids) >= 0), 'slice IDs not ordered' 1jcb

799 # Number of scan lines per FOV, i.e. number of Y pixels / image height 

800 n_lines = np.array([x['nXnYnZ'][1] for x in FOVs]) 1jcb

801 

802 # We get indices from MATLAB extracted metadata so below two lines are no longer needed 

803 # n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks 

804 # n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) # N lines during flyback 

805 line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod'] 1jcb

806 frame_time_shifts = slice_ids / raw_imaging_meta['scanImageParams']['hRoiManager']['scanFrameRate'] 1jcb

807 

808 # Line indices are now extracted by the MATLAB function mesoscopeMetadataExtraction.m 

809 # They are indexed from 1 so we subtract 1 to convert to zero-indexed 

810 line_indices = [np.array(fov['lineIdx']) - 1 for fov in FOVs] # Convert to zero-indexed from MATLAB 1-indexed 1jcb

811 assert all(lns.size == n for lns, n in zip(line_indices, n_lines)), 'unexpected number of scan lines' 1jcb

812 # The start indices of each FOV in the raw images 

813 fov_start_idx = np.array([lns[0] for lns in line_indices]) 1jcb

814 roi_time_shifts = fov_start_idx * line_period # The time offset for each FOV 1jcb

815 fov_time_shifts = roi_time_shifts + frame_time_shifts 1jcb

816 line_time_shifts = [(lns - ln0) * line_period for lns, ln0 in zip(line_indices, fov_start_idx)] 1jcb

817 

818 return line_indices, fov_time_shifts, line_time_shifts 1jcb