Coverage for ibllib/io/extractors/mesoscope.py: 85%

263 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1"""Mesoscope (timeline) data extraction.""" 

2import logging 

3 

4import numpy as np 

5import one.alf.io as alfio 

6from one.util import ensure_list 

7from one.alf.files import session_path_parts 

8import matplotlib.pyplot as plt 

9from neurodsp.utils import falls 

10from pkg_resources import parse_version 

11 

12from ibllib.plots.misc import squares, vertical_lines 

13from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel, 

14 correct_counter_discontinuities, load_timeline_sync_and_chmap) 

15import ibllib.io.extractors.base as extractors_base 

16from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, get_sync_fronts, get_protocol_period 

17from ibllib.io.extractors.training_wheel import extract_wheel_moves 

18from ibllib.io.extractors.camera import attribute_times 

19from ibllib.io.extractors.ephys_fpga import _assign_events_bpod 

20 

21_logger = logging.getLogger(__name__) 

22 

23 

24def patch_imaging_meta(meta: dict) -> dict: 

25 """ 

26 Patch imaging meta data for compatibility across versions. 

27 

28 A copy of the dict is NOT returned. 

29 

30 Parameters 

31 ---------- 

32 dict : dict 

33 A folder path that contains a rawImagingData.meta file. 

34 

35 Returns 

36 ------- 

37 dict 

38 The loaded meta data file, updated to the most recent version. 

39 """ 

40 # 2023-05-17 (unversioned) adds nFrames and channelSaved keys 

41 if parse_version(meta.get('version') or '0.0.0') <= parse_version('0.0.0'): 1icb

42 if 'channelSaved' not in meta: 1icb

43 meta['channelSaved'] = next((x['channelIdx'] for x in meta['FOV'] if 'channelIdx' in x), []) 1icb

44 return meta 1icb

45 

46 

47def plot_timeline(timeline, channels=None, raw=True): 

48 """ 

49 Plot the timeline data. 

50 

51 Parameters 

52 ---------- 

53 timeline : one.alf.io.AlfBunch 

54 The timeline data object. 

55 channels : list of str 

56 An iterable of channel names to plot. 

57 raw : bool 

58 If true, plot the raw DAQ samples; if false, apply TTL thresholds and plot changes. 

59 

60 Returns 

61 ------- 

62 matplotlib.pyplot.Figure 

63 The figure containing timeline subplots. 

64 list of matplotlib.pyplot.Axes 

65 The axes for each timeline channel plotted. 

66 """ 

67 meta = {x.copy().pop('name'): x for x in timeline['meta']['inputs']} 1f

68 channels = channels or meta.keys() 1f

69 fig, axes = plt.subplots(len(channels), 1, sharex=True) 1f

70 axes = ensure_list(axes) 1f

71 if not raw: 1f

72 chmap = {ch: meta[ch]['arrayColumn'] for ch in channels} 1f

73 sync = extract_sync_timeline(timeline, chmap=chmap) 1f

74 for i, (ax, ch) in enumerate(zip(axes, channels)): 1f

75 if raw: 1f

76 # axesScale controls vertical scaling of each trace (multiplicative) 

77 values = timeline['raw'][:, meta[ch]['arrayColumn'] - 1] * meta[ch]['axesScale'] 1f

78 ax.plot(timeline['timestamps'], values) 1f

79 elif np.any(idx := sync['channels'] == chmap[ch]): 1f

80 squares(sync['times'][idx], sync['polarities'][idx], ax=ax) 1f

81 ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) 1f

82 ax.spines['bottom'].set_visible(False), ax.spines['left'].set_visible(True) 1f

83 ax.set_ylabel(ch, rotation=45, fontsize=8) 1f

84 # Add back x-axis ticks to the last plot 

85 axes[-1].tick_params(axis='x', which='both', bottom=True, labelbottom=True) 1f

86 axes[-1].spines['bottom'].set_visible(True) 1f

87 plt.get_current_fig_manager().window.showMaximized() # full screen 1f

88 fig.tight_layout(h_pad=0) 1f

89 return fig, axes 1f

90 

91 

92class TimelineTrials(FpgaTrials): 

93 """Similar extraction to the FPGA, however counter and position channels are treated differently.""" 

94 

95 """one.alf.io.AlfBunch: The timeline data object""" 

96 timeline = None 

97 

98 def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): 

99 """An extractor for all ephys trial data, in Timeline time""" 

100 super().__init__(*args, **kwargs) 1hga

101 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') 1hga

102 

103 def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs): 

104 if not (sync or chmap): 1a

105 sync, chmap = load_timeline_sync_and_chmap( 1a

106 self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) 

107 

108 if kwargs.get('display', False): 1a

109 plot_timeline(self.timeline, channels=chmap.keys(), raw=True) 

110 trials = super()._extract(sync, chmap, sync_collection, extractor_type='ephys', **kwargs) 1a

111 

112 # If no protocol number is defined, trim timestamps based on Bpod trials intervals 

113 trials_table = trials[self.var_names.index('table')] 1a

114 bpod = get_sync_fronts(sync, chmap['bpod']) 1a

115 if kwargs.get('protocol_number') is None: 1a

116 tmin = trials_table.intervals_0.iloc[0] - 1 1a

117 tmax = trials_table.intervals_1.iloc[-1] 1a

118 # Ensure wheel is cut off based on trials 

119 wheel_ts_idx = self.var_names.index('wheel_timestamps') 1a

120 mask = np.logical_and(tmin <= trials[wheel_ts_idx], trials[wheel_ts_idx] <= tmax) 1a

121 trials[wheel_ts_idx] = trials[wheel_ts_idx][mask] 1a

122 wheel_pos_idx = self.var_names.index('wheel_position') 1a

123 trials[wheel_pos_idx] = trials[wheel_pos_idx][mask] 1a

124 move_idx = self.var_names.index('wheelMoves_intervals') 1a

125 mask = np.logical_and(trials[move_idx][:, 0] >= tmin, trials[move_idx][:, 0] <= tmax) 1a

126 trials[move_idx] = trials[move_idx][mask, :] 1a

127 else: 

128 tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod) 

129 bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax) 1a

130 

131 self.frame2ttl = get_sync_fronts(sync, chmap['frame2ttl'], tmin, tmax) # save for later access by QC 1a

132 

133 # Replace valve open times with those extracted from the DAQ 

134 # TODO Let's look at the expected open length based on calibration and reward volume 

135 assert len(bpod['times']) > 0, 'No Bpod TTLs detected on DAQ' 1a

136 _, driver_out, _, = _assign_events_bpod(bpod['times'], bpod['polarities'], False) 1a

137 # Use the driver TTLs to find the valve open times that correspond to the valve opening 

138 valve_open_times = self.get_valve_open_times(driver_ttls=driver_out) 1a

139 assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion 1a

140 correct = trials_table.feedbackType == 1 1a

141 trials[self.var_names.index('valveOpen_times')][correct] = valve_open_times 1a

142 trials_table.feedback_times[correct] = valve_open_times 1a

143 

144 # Replace audio events 

145 self.audio = get_sync_fronts(sync, chmap['audio'], tmin, tmax) 1a

146 # Attempt to assign the go cue and error tone onsets based on TTL length 

147 go_cue, error_cue = self._assign_events_audio(self.audio['times'], self.audio['polarities']) 1a

148 

149 assert error_cue.size == np.sum(~correct), 'N detected error tones does not match number of incorrect trials' 1a

150 assert go_cue.size <= len(trials_table), 'More go cue tones detected than trials!' 1a

151 

152 if go_cue.size < len(trials_table): 1a

153 _logger.warning('%i go cue tones missed', len(trials_table) - go_cue.size) 

154 """ 

155 If the error cues are all assigned and some go cues are missed it may be that some 

156 responses were so fast that the go cue and error tone merged. 

157 """ 

158 err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times']) 

159 go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times']) 

160 assert not np.any(np.isnan(go_trig)) 

161 assert err_trig.size == go_trig.size 

162 

163 def first_true(arr): 

164 """Return the index of the first True value in an array.""" 

165 indices = np.where(arr)[0] 

166 return None if len(indices) == 0 else indices[0] 

167 

168 # Find which trials are missing a go cue 

169 _go_cue = np.full(len(trials_table), np.nan) 

170 for i, intervals in enumerate(trials_table[['intervals_0', 'intervals_1']].values): 

171 idx = first_true(np.logical_and(go_cue > intervals[0], go_cue < intervals[1])) 

172 if idx is not None: 

173 _go_cue[i] = go_cue[idx] 

174 

175 # Get all the DAQ timestamps where audio channel was HIGH 

176 raw = timeline_get_channel(self.timeline, 'audio') 

177 raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize 

178 ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH 

179 for i in np.where(np.isnan(_go_cue))[0]: 

180 # Get the timestamp of the first HIGH after the trigger times 

181 _go_cue[i] = ups[first_true(ups > go_trig[i])] 

182 idx = first_true(np.logical_and( 

183 error_cue > trials_table['intervals_0'][i], 

184 error_cue < trials_table['intervals_1'][i])) 

185 if np.isnan(err_trig[i]): 

186 if idx is not None: 

187 error_cue = np.delete(error_cue, idx) # Remove mis-assigned error tone time 

188 else: 

189 error_cue[idx] = ups[first_true(ups > err_trig[i])] 

190 go_cue = _go_cue 

191 

192 trials_table.feedback_times[~correct] = error_cue 1a

193 trials_table.goCue_times = go_cue 1a

194 return trials 1a

195 

196 def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): 

197 """ 

198 Gets the wheel position from Timeline counter channel. 

199 

200 Parameters 

201 ---------- 

202 ticks : int 

203 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder). 

204 radius : float 

205 Radius of the wheel. Defaults to 1 for an output in radians. 

206 coding : str {'x1', 'x2', 'x4'} 

207 Rotary encoder encoding (IBL default is x4). 

208 tmin : float 

209 The minimum time from which to extract the sync pulses. 

210 tmax : float 

211 The maximum time up to which we extract the sync pulses. 

212 

213 Returns 

214 ------- 

215 np.array 

216 Wheel timestamps in seconds. 

217 np.array 

218 Wheel positions in radians. 

219 

220 See Also 

221 -------- 

222 ibllib.io.extractors.ephys_fpga.extract_wheel_sync 

223 """ 

224 if coding not in ('x1', 'x2', 'x4'): 1ga

225 raise ValueError('Unsupported coding; must be one of x1, x2 or x4') 1g

226 raw = correct_counter_discontinuities(timeline_get_channel(self.timeline, 'rotary_encoder')) 1ga

227 

228 # Timeline evenly samples counter so we extract only change points 

229 d = np.diff(raw) 1ga

230 ind, = np.where(d.astype(int)) 1ga

231 pos = raw[ind + 1] 1ga

232 pos -= pos[0] # Start from zero 1ga

233 pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians 1ga

234 

235 # Get timestamps of changes and trim based on protocol spacers 

236 ts = self.timeline['timestamps'][ind + 1] 1ga

237 tmin = ts.min() if tmin is None else tmin 1ga

238 tmax = ts.max() if tmax is None else tmax 1ga

239 mask = np.logical_and(ts >= tmin, ts <= tmax) 1ga

240 return ts[mask], pos[mask] 1ga

241 

242 def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', 

243 tmin=None, tmax=None, display=False, **kwargs): 

244 """ 

245 Gets the wheel position and detected movements from Timeline counter channel. 

246 

247 Called by the super class extractor (FPGATrials._extract). 

248 

249 Parameters 

250 ---------- 

251 ticks : int 

252 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder). 

253 radius : float 

254 Radius of the wheel. Defaults to 1 for an output in radians. 

255 coding : str {'x1', 'x2', 'x4'} 

256 Rotary encoder encoding (IBL default is x4). 

257 tmin : float 

258 The minimum time from which to extract the sync pulses. 

259 tmax : float 

260 The maximum time up to which we extract the sync pulses. 

261 display : bool 

262 If true, plot the wheel positions from bpod and the DAQ. 

263 

264 Returns 

265 ------- 

266 dict 

267 wheel object with keys ('timestamps', 'position'). 

268 dict 

269 wheelMoves object with keys ('intervals' 'peakAmplitude'). 

270 """ 

271 wheel = self.extract_wheel_sync(ticks=ticks, radius=radius, coding=coding, tmin=tmin, tmax=tmax) 1ga

272 wheel = dict(zip(('timestamps', 'position'), wheel)) 1ga

273 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 1ga

274 

275 if display: 1ga

276 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 

277 bpod_ts = self.bpod_trials['wheel_timestamps'] 

278 bpod_pos = self.bpod_trials['wheel_position'] 

279 ax0.plot(self.bpod2fpga(bpod_ts), bpod_pos) 

280 ax0.set_ylabel('Bpod wheel position / rad') 

281 ax1.plot(wheel['timestamps'], wheel['position']) 

282 ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') 

283 return wheel, moves 1ga

284 

285 def get_valve_open_times(self, display=False, threshold=-2.5, floor_percentile=10, driver_ttls=None): 

286 """ 

287 Get the valve open times from the raw timeline voltage trace. 

288 

289 Parameters 

290 ---------- 

291 display : bool 

292 Plot detected times on the raw voltage trace. 

293 threshold : float 

294 The threshold for applying to analogue channels. 

295 floor_percentile : float 

296 10% removes the percentile value of the analog trace before thresholding. This is to 

297 avoid DC offset drift. 

298 driver_ttls : numpy.array 

299 An optional array of driver TTLs to use for assigning with the valve times. 

300 

301 Returns 

302 ------- 

303 numpy.array 

304 The detected valve open times. 

305 

306 TODO extract close times too 

307 """ 

308 tl = self.timeline 1ha

309 info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') 1ha

310 values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 1ha

311 offset = np.percentile(values, floor_percentile, axis=0) 1ha

312 idx = falls(values - offset, step=threshold) # Voltage falls when valve opens 1ha

313 open_times = tl['timestamps'][idx] 1ha

314 # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL 

315 if driver_ttls is not None: 1ha

316 # Returns an array of open_times indices, one for each driver TTL 

317 ind = attribute_times(open_times, driver_ttls, tol=.1, take='after') 1a

318 open_times = open_times[ind[ind >= 0]] 1a

319 # TODO Log any > 40ms? Difficult to report missing valve times because of calibration 

320 

321 if display: 1ha

322 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1h

323 ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), 'k-o') 1h

324 if driver_ttls is not None: 1h

325 vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') 

326 ax1.plot(tl['timestamps'], values - offset, 'k-o') 1h

327 ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') 1h

328 ax1.plot(tl['timestamps'][idx], np.zeros_like(idx), 'r*') 1h

329 if driver_ttls is not None: 1h

330 ax1.plot(open_times, np.zeros_like(open_times), 'g*') 

331 return open_times 1ha

332 

333 def _assign_events_audio(self, audio_times, audio_polarities, display=False): 

334 """ 

335 This is identical to ephys_fpga._assign_events_audio, except for the ready tone threshold. 

336 

337 Parameters 

338 ---------- 

339 audio_times : numpy.array 

340 An array of audio TTL front times. 

341 audio_polarities : numpy.array 

342 An array of audio TTL front polarities (1 for rises, -1 for falls). 

343 display : bool 

344 If true, display audio pulses and the assigned onsets. 

345 

346 Returns 

347 ------- 

348 numpy.array 

349 The times of the go cue onsets. 

350 numpy.array 

351 The times of the error tone onsets. 

352 """ 

353 # make sure that there are no 2 consecutive fall or consecutive rise events 

354 assert np.all(np.abs(np.diff(audio_polarities)) == 2) 1a

355 # take only even time differences: ie. from rising to falling fronts 

356 dt = np.diff(audio_times) 1a

357 onsets = audio_polarities[:-1] == 1 1a

358 

359 # error tones are events lasting from 400ms to 1200ms 

360 i_error_tone_in = np.where(np.logical_and(0.4 < dt, dt < 1.2) & onsets)[0] 1a

361 t_error_tone_in = audio_times[i_error_tone_in] 1a

362 

363 # detect ready tone by length below 300 ms 

364 i_ready_tone_in = np.where(np.logical_and(dt <= 0.3, onsets))[0] 1a

365 t_ready_tone_in = audio_times[i_ready_tone_in] 1a

366 if display: # pragma: no cover 1a

367 fig, ax = plt.subplots(nrows=2, sharex=True) 

368 ax[0].plot(self.timeline.timestamps, timeline_get_channel(self.timeline, 'audio'), 'k-o') 

369 ax[0].set_ylabel('Voltage / V') 

370 squares(audio_times, audio_polarities, yrange=[-1, 1], ax=ax[1]) 

371 vertical_lines(t_ready_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='go cue') 

372 vertical_lines(t_error_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='error tone') 

373 ax[1].set_xlabel('Time / s') 

374 ax[1].legend() 

375 

376 return t_ready_tone_in, t_error_tone_in 1a

377 

378 

379class MesoscopeSyncTimeline(extractors_base.BaseExtractor): 

380 """Extraction of mesoscope imaging times.""" 

381 

382 var_names = ('mpci_times', 'mpciStack_timeshift') 

383 save_names = ('mpci.times.npy', 'mpciStack.timeshift.npy') 

384 

385 """one.alf.io.AlfBunch: The raw imaging meta data and frame times""" 

386 rawImagingData = None 

387 

388 def __init__(self, session_path, n_FOVs): 

389 """ 

390 Extract the mesoscope frame times from DAQ data acquired through Timeline. 

391 

392 Parameters 

393 ---------- 

394 session_path : str, pathlib.Path 

395 The session path to extract times from. 

396 n_FOVs : int 

397 The number of fields of view acquired. 

398 """ 

399 super().__init__(session_path) 1ecb

400 self.n_FOVs = n_FOVs 1ecb

401 fov = list(map(lambda n: f'FOV_{n:02}', range(self.n_FOVs))) 1ecb

402 self.var_names = [f'{x}_{y.lower()}' for x in self.var_names for y in fov] 1ecb

403 self.save_names = [f'{y}/{x}' for x in self.save_names for y in fov] 1ecb

404 

405 def _extract(self, sync=None, chmap=None, device_collection='raw_imaging_data', events=None): 

406 """ 

407 Extract the frame timestamps for each individual field of view (FOV) and the time offsets 

408 for each line scan. 

409 

410 The detected frame times from the 'neural_frames' channel of the DAQ are split into bouts 

411 corresponding to the number of raw_imaging_data folders. These timestamps should match the 

412 number of frame timestamps extracted from the image file headers (found in the 

413 rawImagingData.times file). The field of view (FOV) shifts are then applied to these 

414 timestamps for each field of view and provided together with the line shifts. 

415 

416 Parameters 

417 ---------- 

418 sync : one.alf.io.AlfBunch 

419 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses 

420 and the corresponding channel numbers. 

421 chmap : dict 

422 A map of channel names and their corresponding indices. Only the 'neural_frames' 

423 channel is required. 

424 device_collection : str, iterable of str 

425 The location of the raw imaging data. 

426 events : pandas.DataFrame 

427 A table of software events, with columns {'time_timeline' 'name_timeline', 

428 'event_timeline'}. 

429 

430 Returns 

431 ------- 

432 list of numpy.array 

433 A list of timestamps for each FOV and the time offsets for each line scan. 

434 """ 

435 frame_times = sync['times'][sync['channels'] == chmap['neural_frames']] 1cb

436 

437 # imaging_start_time = datetime.datetime(*map(round, self.rawImagingData.meta['acquisitionStartTime'])) 

438 if isinstance(device_collection, str): 1cb

439 device_collection = [device_collection] 1c

440 if events is not None: 1cb

441 events = events[events.name == 'mpepUDP'] 1b

442 edges = self.get_bout_edges(frame_times, device_collection, events) 1cb

443 fov_times = [] 1cb

444 line_shifts = [] 1cb

445 for (tmin, tmax), collection in zip(edges, sorted(device_collection)): 1cb

446 imaging_data = alfio.load_object(self.session_path / collection, 'rawImagingData') 1cb

447 imaging_data['meta'] = patch_imaging_meta(imaging_data['meta']) 1cb

448 # Calculate line shifts 

449 _, fov_time_shifts, line_time_shifts = self.get_timeshifts(imaging_data['meta']) 1cb

450 assert len(fov_time_shifts) == self.n_FOVs, f'unexpected number of FOVs for {collection}' 1cb

451 ts = frame_times[np.logical_and(frame_times >= tmin, frame_times <= tmax)] 1cb

452 assert ts.size >= imaging_data['times_scanImage'].size, f'fewer DAQ timestamps for {collection} than expected' 1cb

453 if ts.size > imaging_data['times_scanImage'].size: 1cb

454 _logger.warning( 1c

455 'More DAQ frame times detected for %s than were found in the raw image data.\n' 

456 'N DAQ frame times:\t%i\nN raw image data times:\t%i.\n' 

457 'This may occur if the bout detection fails (e.g. UDPs recorded late), ' 

458 'when image data is corrupt, or when frames are not written to file.', 

459 collection, ts.size, imaging_data['times_scanImage'].size) 

460 _logger.info('Dropping last %i frame times for %s', ts.size - imaging_data['times_scanImage'].size, collection) 1c

461 ts = ts[:imaging_data['times_scanImage'].size] 1c

462 fov_times.append([ts + offset for offset in fov_time_shifts]) 1cb

463 if not line_shifts: 1cb

464 line_shifts = line_time_shifts 1cb

465 else: # The line shifts should be the same across all imaging bouts 

466 [np.testing.assert_array_equal(x, y) for x, y in zip(line_time_shifts, line_shifts)] 1b

467 

468 # Concatenate imaging timestamps across all bouts for each field of view 

469 fov_times = list(map(np.concatenate, zip(*fov_times))) 1cb

470 n_fov_times, = set(map(len, fov_times)) 1cb

471 if n_fov_times != frame_times.size: 1cb

472 # This may happen if an experimenter deletes a raw_imaging_data folder 

473 _logger.debug('FOV timestamps length does not match neural frame count; imaging bout(s) likely missing') 1c

474 return fov_times + line_shifts 1cb

475 

476 def get_bout_edges(self, frame_times, collections=None, events=None, min_gap=1., display=False): 

477 """ 

478 Return an array of edge times for each imaging bout corresponding to a raw_imaging_data 

479 collection. 

480 

481 Parameters 

482 ---------- 

483 frame_times : numpy.array 

484 An array of all neural frame count times. 

485 collections : iterable of str 

486 A set of raw_imaging_data collections, used to extract selected imaging periods. 

487 events : pandas.DataFrame 

488 A table of UDP event times, corresponding to times when recordings start and end. 

489 min_gap : float 

490 If start or end events not present, split bouts by finding gaps larger than this value. 

491 display : bool 

492 If true, plot the detected bout edges and raw frame times. 

493 

494 Returns 

495 ------- 

496 numpy.array 

497 An array of imaging bout intervals. 

498 """ 

499 if events is None or events.empty: 1ecb

500 # No UDP events to mark blocks so separate based on gaps in frame rate 

501 idx = np.where(np.diff(frame_times) > min_gap)[0] 1ec

502 starts = np.r_[frame_times[0], frame_times[idx + 1]] 1ec

503 ends = np.r_[frame_times[idx], frame_times[-1]] 1ec

504 else: 

505 # Split using Exp/BlockStart and Exp/BlockEnd times 

506 _, subject, date, _ = session_path_parts(self.session_path) 1eb

507 pattern = rf'(Exp|Block)%s\s{subject}\s{date.replace("-", "")}\s\d+' 1eb

508 

509 # Get start times 

510 UDP_start = events[events['info'].str.match(pattern % 'Start')] 1eb

511 if len(UDP_start) > 1 and UDP_start.loc[0, 'info'].startswith('Exp'): 1eb

512 # Use ExpStart instead of first bout start 

513 UDP_start = UDP_start.copy().drop(1) 1eb

514 # Use ExpStart/End instead of first/last BlockStart/End 

515 starts = frame_times[[np.where(frame_times >= t)[0][0] for t in UDP_start.time]] 1eb

516 

517 # Get end times 

518 UDP_end = events[events['info'].str.match(pattern % 'End')] 1eb

519 if len(UDP_end) > 1 and UDP_end['info'].values[-1].startswith('Exp'): 1eb

520 # Use last BlockEnd instead of ExpEnd 

521 UDP_end = UDP_end.copy().drop(UDP_end.index[-1]) 1eb

522 if not UDP_end.empty: 1eb

523 ends = frame_times[[np.where(frame_times <= t)[0][-1] for t in UDP_end.time]] 1eb

524 else: 

525 # Get index of last frame to occur within a second of the previous frame 

526 consec = np.r_[np.diff(frame_times) > min_gap, True] 1e

527 idx = [np.where(np.logical_and(frame_times > t, consec))[0][0] for t in starts] 1e

528 ends = frame_times[idx] 1e

529 

530 # Remove any missing imaging bout collections 

531 edges = np.c_[starts, ends] 1ecb

532 if collections: 1ecb

533 if edges.shape[0] > len(collections): 1ecb

534 # Remove any bouts that correspond to a skipped collection 

535 # e.g. if {raw_imaging_data_00, raw_imaging_data_02}, remove middle bout 

536 include = sorted(int(c.rsplit('_', 1)[-1]) for c in collections) 

537 edges = edges[include, :] 

538 elif edges.shape[0] < len(collections): 1ecb

539 raise ValueError('More raw imaging folders than detected bouts') 1e

540 

541 if display: 1ecb

542 _, ax = plt.subplots(1) 1e

543 ax.step(frame_times, np.arange(frame_times.size), label='frame times', color='k', ) 1e

544 vertical_lines(edges[:, 0], ax=ax, ymin=0, ymax=frame_times.size, label='bout start', color='b') 1e

545 vertical_lines(edges[:, 1], ax=ax, ymin=0, ymax=frame_times.size, label='bout end', color='orange') 1e

546 if edges.shape[0] != len(starts): 1e

547 vertical_lines(np.setdiff1d(starts, edges[:, 0]), ax=ax, ymin=0, ymax=frame_times.size, 

548 label='missing bout start', linestyle=':', color='b') 

549 vertical_lines(np.setdiff1d(ends, edges[:, 1]), ax=ax, ymin=0, ymax=frame_times.size, 

550 label='missing bout end', linestyle=':', color='orange') 

551 ax.set_xlabel('Time / s'), ax.set_ylabel('Frame #'), ax.legend(loc='lower right') 1e

552 return edges 1ecb

553 

554 @staticmethod 

555 def get_timeshifts(raw_imaging_meta): 

556 """ 

557 Calculate the time shifts for each field of view (FOV) and the relative offsets for each 

558 scan line. 

559 

560 Parameters 

561 ---------- 

562 raw_imaging_meta : dict 

563 Extracted ScanImage meta data (_ibl_rawImagingData.meta.json). 

564 

565 Returns 

566 ------- 

567 list of numpy.array 

568 A list of arrays, one per FOV, containing indices of each image scan line. 

569 numpy.array 

570 An array of FOV time offsets (one value per FOV) relative to each frame acquisition 

571 time. 

572 list of numpy.array 

573 A list of arrays, one per FOV, containing the time offsets for each scan line, relative 

574 to each FOV offset. 

575 """ 

576 FOVs = raw_imaging_meta['FOV'] 1cb

577 

578 # Double-check meta extracted properly 

579 raw_meta = raw_imaging_meta['rawScanImageMeta'] 1cb

580 artist = raw_meta['Artist'] 1cb

581 assert sum(x['enable'] for x in artist['RoiGroups']['imagingRoiGroup']['rois']) == len(FOVs) 1cb

582 

583 # Number of scan lines per FOV, i.e. number of Y pixels / image height 

584 n_lines = np.array([x['nXnYnZ'][1] for x in FOVs]) 1cb

585 n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks 1cb

586 # Number of lines during flyback 

587 n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) 1cb

588 # The start and end indices of each FOV in the raw images 

589 fov_start_idx = np.insert(np.cumsum(n_lines[:-1] + n_lines_per_gap), 0, 0) 1cb

590 fov_end_idx = fov_start_idx + n_lines 1cb

591 line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod'] 1cb

592 

593 line_indices = [] 1cb

594 fov_time_shifts = fov_start_idx * line_period 1cb

595 line_time_shifts = [] 1cb

596 

597 for ln, s, e in zip(n_lines, fov_start_idx, fov_end_idx): 1cb

598 line_indices.append(np.arange(s, e)) 1cb

599 line_time_shifts.append(np.arange(0, ln) * line_period) 1cb

600 

601 return line_indices, fov_time_shifts, line_time_shifts 1cb