Coverage for ibllib/io/extractors/mesoscope.py: 91%
326 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""Mesoscope (timeline) data extraction."""
2import logging
4import numpy as np
5from scipy.signal import find_peaks
6import one.alf.io as alfio
7from one.util import ensure_list
8from one.alf.files import session_path_parts
9import matplotlib.pyplot as plt
10from packaging import version
12from ibllib.plots.misc import squares, vertical_lines
13from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel,
14 correct_counter_discontinuities, load_timeline_sync_and_chmap)
15import ibllib.io.extractors.base as extractors_base
16from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial
17from ibllib.io.extractors.training_wheel import extract_wheel_moves
18from ibllib.io.extractors.camera import attribute_times
19from brainbox.behavior.wheel import velocity_filtered
21_logger = logging.getLogger(__name__)
24def patch_imaging_meta(meta: dict) -> dict:
25 """
26 Patch imaging metadata for compatibility across versions.
28 A copy of the dict is NOT returned.
30 Parameters
31 ----------
32 meta : dict
33 A folder path that contains a rawImagingData.meta file.
35 Returns
36 -------
37 dict
38 The loaded metadata file, updated to the most recent version.
39 """
40 # 2023-05-17 (unversioned) adds nFrames, channelSaved keys, MM and Deg keys
41 ver = version.parse(meta.get('version') or '0.0.0') 1iklcb
42 if ver <= version.parse('0.0.0'): 1iklcb
43 if 'channelSaved' not in meta: 1ikl
44 meta['channelSaved'] = next((x['channelIdx'] for x in meta['FOV'] if 'channelIdx' in x), []) 1ikl
45 fields = ('topLeft', 'topRight', 'bottomLeft', 'bottomRight') 1ikl
46 for fov in meta.get('FOV', []): 1ikl
47 for unit in ('Deg', 'MM'): 1ikl
48 if unit not in fov: # topLeftDeg, etc. -> Deg[topLeft] 1ikl
49 fov[unit] = {f: fov.pop(f + unit, None) for f in fields} 1ikl
50 elif ver == version.parse('0.1.0'): 1icb
51 for fov in meta.get('FOV', []): 1i
52 if 'roiUuid' in fov: 1i
53 fov['roiUUID'] = fov.pop('roiUuid') 1i
54 assert 'nFrames' in meta, '"nFrames" key missing from meta data; rawImagingData.meta.json likely an old version' 1iklcb
55 return meta 1iklcb
58def plot_timeline(timeline, channels=None, raw=True):
59 """
60 Plot the timeline data.
62 Parameters
63 ----------
64 timeline : one.alf.io.AlfBunch
65 The timeline data object.
66 channels : list of str
67 An iterable of channel names to plot.
68 raw : bool
69 If true, plot the raw DAQ samples; if false, apply TTL thresholds and plot changes.
71 Returns
72 -------
73 matplotlib.pyplot.Figure
74 The figure containing timeline subplots.
75 list of matplotlib.pyplot.Axes
76 The axes for each timeline channel plotted.
77 """
78 meta = {x.copy().pop('name'): x for x in timeline['meta']['inputs']} 1h
79 channels = channels or meta.keys() 1h
80 fig, axes = plt.subplots(len(channels), 1, sharex=True) 1h
81 axes = ensure_list(axes) 1h
82 if not raw: 1h
83 chmap = {ch: meta[ch]['arrayColumn'] for ch in channels} 1h
84 sync = extract_sync_timeline(timeline, chmap=chmap) 1h
85 for i, (ax, ch) in enumerate(zip(axes, channels)): 1h
86 if raw: 1h
87 # axesScale controls vertical scaling of each trace (multiplicative)
88 values = timeline['raw'][:, meta[ch]['arrayColumn'] - 1] * meta[ch]['axesScale'] 1h
89 ax.plot(timeline['timestamps'], values) 1h
90 elif np.any(idx := sync['channels'] == chmap[ch]): 1h
91 squares(sync['times'][idx], sync['polarities'][idx], ax=ax) 1h
92 ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) 1h
93 ax.spines['bottom'].set_visible(False), ax.spines['left'].set_visible(True) 1h
94 ax.set_ylabel(ch, rotation=45, fontsize=8) 1h
95 # Add back x-axis ticks to the last plot
96 axes[-1].tick_params(axis='x', which='both', bottom=True, labelbottom=True) 1h
97 axes[-1].spines['bottom'].set_visible(True) 1h
98 plt.get_current_fig_manager().window.showMaximized() # full screen 1h
99 fig.tight_layout(h_pad=0) 1h
100 return fig, axes 1h
103class TimelineTrials(FpgaTrials):
104 """Similar extraction to the FPGA, however counter and position channels are treated differently."""
106 timeline = None
107 """one.alf.io.AlfBunch: The timeline data object."""
109 sync_field = 'itiIn_times'
110 """str: The trial event to synchronize (must be present in extracted trials)."""
112 def __init__(self, *args, sync_collection='raw_sync_data', **kwargs):
113 """An extractor for all ephys trial data, in Timeline time"""
114 super().__init__(*args, **kwargs) 1ega
115 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') 1ega
117 def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_):
118 """Load the DAQ sync and channel map data.
120 Parameters
121 ----------
122 sync_collection : str
123 The session subdirectory where the sync data are located.
124 chmap : dict
125 A map of channel names and their corresponding indices. If None, the channel map is
126 loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method.
128 Returns
129 -------
130 one.alf.io.AlfBunch
131 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
132 and the corresponding channel numbers.
133 dict
134 A map of channel names and their corresponding indices.
135 """
136 if not self.timeline: 1a
137 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline')
138 sync, chmap = load_timeline_sync_and_chmap( 1a
139 self.session_path / sync_collection, timeline=self.timeline, chmap=chmap)
140 return sync, chmap 1a
142 def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict:
143 trials = super()._extract(sync, chmap, sync_collection='raw_sync_data', **kwargs) 1a
144 if kwargs.get('display', False): 1a
145 plot_timeline(self.timeline, channels=chmap.keys(), raw=True)
146 return trials 1a
148 def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
149 """
150 Extract Bpod times from sync.
152 Unlike the superclass method. This one doesn't reassign the first trial pulse.
154 Parameters
155 ----------
156 sync : dict
157 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
158 and the corresponding channel numbers. Must contain a 'bpod' key.
159 chmap : dict
160 A map of channel names and their corresponding indices.
161 bpod_event_ttls : dict of tuple
162 A map of event names to (min, max) TTL length.
164 Returns
165 -------
166 dict
167 A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
168 dict
169 A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
170 """
171 # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
172 # lengths are defined by the state machine of the task protocol and therefore vary.
173 if bpod_event_ttls is None: 1a
174 # The trial start TTLs are often too short for the low sampling rate of the DAQ and are
175 # therefore not used in extraction
176 bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} 1a
177 bpod, bpod_event_intervals = super().get_bpod_event_times( 1a
178 sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs)
180 # TODO Here we can make use of the 'bpod_rising_edge' channel, if available
181 return bpod, bpod_event_intervals 1a
183 def build_trials(self, sync=None, chmap=None, **kwargs):
184 """
185 Extract task related event times from the sync.
187 The two major differences are that the sampling rate is lower for imaging so the short Bpod
188 trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL.
190 Second, the valve used at the mesoscope has a way to record the raw voltage across the
191 solenoid, giving a more accurate readout of the valve's activity. If the reward_valve
192 channel is present on the DAQ, this is used to extract the valve open times.
194 Parameters
195 ----------
196 sync : dict
197 'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
198 chmap : dict
199 Map of channel names and their corresponding index. Default to constant.
201 Returns
202 -------
203 dict
204 A map of trial event timestamps.
205 """
206 # Get the events from the sync.
207 # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
208 self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) 1a
209 self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) 1a
210 if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: 1a
211 raise ValueError(
212 'Expected at least "ready_tone" and "error_tone" audio events.'
213 '`audio_event_ttls` kwarg may be incorrect.')
215 self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) 1a
216 if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}: 1a
217 raise ValueError(
218 'Expected at least "trial_end" and "valve_open" audio events. '
219 '`bpod_event_ttls` kwarg may be incorrect.')
221 t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T 1a
222 fpga_events = alfio.AlfBunch({ 1a
223 'itiIn_times': t_iti_in,
224 'intervals_1': t_trial_end,
225 'goCue_times': audio_event_intervals['ready_tone'][:, 0],
226 'errorTone_times': audio_event_intervals['error_tone'][:, 0]
227 })
229 # Sync the Bpod clock to the DAQ
230 self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) 1a
232 out = dict() 1a
233 out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) 1a
234 out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) 1a
236 start_times = out['intervals'][:, 0] 1a
237 last_trial_end = out['intervals'][-1, 1] 1a
239 def assign_to_trial(events, take='last', starts=start_times, **kwargs): 1a
240 """Assign DAQ events to trials.
242 Because we may not have trial start TTLs on the DAQ (because of the low sampling rate),
243 there may be an extra last trial that's not in the Bpod intervals as the extractor
244 ignores the last trial. This function trims the input array before assigning so that
245 the last trial's events are correctly assigned.
246 """
247 return _assign_events_to_trial(starts, events[events <= last_trial_end], take, **kwargs) 1a
248 out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga]) 1a
250 # Extract valve open times from the DAQ
251 valve_driver_ttls = bpod_event_intervals['valve_open'] 1a
252 correct = self.bpod_trials['feedbackType'] == 1 1a
253 # If there is a reward_valve channel, the valve has
254 if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): 1a
255 # TODO Let's look at the expected open length based on calibration and reward volume
256 # import scipy.interpolate
257 # # FIXME support v7 settings?
258 # fcn_vol2time = scipy.interpolate.pchip(
259 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'],
260 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES']
261 # )
262 # reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3
264 # Use the driver TTLs to find the valve open times that correspond to the valve opening
265 valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) 1a
266 if valve_open_times.size != np.sum(correct): 1a
267 _logger.warning(
268 'Number of valve open times does not equal number of correct trials (%i != %i)',
269 valve_open_times.size, np.sum(correct))
271 out['valveOpen_times'] = assign_to_trial(valve_open_times) 1a
272 else:
273 # Use the valve controller TTLs recorded on the Bpod channel as the reward time
274 out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0])
276 # Stimulus times extracted based on trigger times
277 # When assigning events all start times must not be NaN so here we substitute freeze
278 # trigger times on nogo trials for stim on trigger times, then replace with NaN again
279 go_trials = np.where(out['choice'] != 0)[0] 1a
280 lims = np.copy(out['stimOnTrigger_times']) 1a
281 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a
282 out['stimFreeze_times'] = assign_to_trial( 1a
283 self.frame2ttl['times'], 'last',
284 starts=lims, t_trial_end=out['stimOffTrigger_times'])
285 out['stimFreeze_times'][out['choice'] == 0] = np.nan 1a
287 # Here we do the same but use stim off trigger times
288 lims = np.copy(out['stimOffTrigger_times']) 1a
289 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a
290 out['stimOn_times'] = assign_to_trial( 1a
291 self.frame2ttl['times'], 'first',
292 starts=out['stimOnTrigger_times'], t_trial_end=lims)
293 out['stimOff_times'] = assign_to_trial( 1a
294 self.frame2ttl['times'], 'first',
295 starts=out['stimOffTrigger_times'], t_trial_end=out['intervals'][:, 1]
296 )
298 # Audio times
299 error_cue = fpga_events['errorTone_times'] 1a
300 if error_cue.size != np.sum(~correct): 1a
301 _logger.warning(
302 'N detected error tones does not match number of incorrect trials (%i != %i)',
303 error_cue.size, np.sum(~correct))
304 go_cue = fpga_events['goCue_times'] 1a
305 out['goCue_times'] = assign_to_trial(go_cue, take='first') 1a
306 out['errorCue_times'] = assign_to_trial(error_cue) 1a
308 if go_cue.size > start_times.size: 1a
309 _logger.warning(
310 'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size)
311 elif go_cue.size < start_times.size: 1a
312 """ 1a
313 If the error cues are all assigned and some go cues are missed it may be that some
314 responses were so fast that the go cue and error tone merged, or the go cue TTL was too
315 long.
316 """
317 _logger.warning('%i go cue tones missed', start_times.size - go_cue.size) 1a
318 err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times']) 1a
319 go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times']) 1a
320 assert not np.any(np.isnan(go_trig)) 1a
321 assert err_trig.size == go_trig.size # should be length of n trials with NaNs 1a
323 # Find which trials are missing a go cue
324 _go_cue = assign_to_trial(go_cue, take='first') 1a
325 error_cue = assign_to_trial(error_cue) 1a
326 missing = np.isnan(_go_cue) 1a
328 # Get all the DAQ timestamps where audio channel was HIGH
329 raw = timeline_get_channel(self.timeline, 'audio') 1a
330 raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize 1a
331 ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH 1a
333 # Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after).
334 # Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN)
335 idx = attribute_times(ups, go_trig, tol=0.2, take='after') 1a
336 # Trial indices that didn't have detected goCue and now has been assigned an `ups` index
337 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned 1a
338 _go_cue[assigned] = ups[idx[assigned]] 1a
340 # Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue)
341 error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig)) 1a
342 i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True) 1a
343 error_cue[i_to_remove] = np.nan 1a
345 # For those trials where go cue was merged with the error cue and therefore mis-assigned,
346 # we must re-assign the error cue times as the first HIGH after the error trigger.
347 idx = attribute_times(ups, err_trig, tol=0.2, take='after') 1a
348 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned 1a
349 error_cue[assigned] = ups[idx[assigned]] 1a
350 out['goCue_times'] = _go_cue 1a
351 out['errorCue_times'] = error_cue 1a
353 # Because we're not
354 assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \ 1a
355 'audio tones not assigned correctly; tones likely missed'
357 # Feedback times
358 out['feedback_times'] = np.copy(out['valveOpen_times']) 1a
359 ind_err = np.isnan(out['valveOpen_times']) 1a
360 out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] 1a
362 return out 1a
364 def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None):
365 """
366 Gets the wheel position from Timeline counter channel.
368 Parameters
369 ----------
370 ticks : int
371 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder).
372 radius : float
373 Radius of the wheel. Defaults to 1 for an output in radians.
374 coding : str {'x1', 'x2', 'x4'}
375 Rotary encoder encoding (IBL default is x4).
376 tmin : float
377 The minimum time from which to extract the sync pulses.
378 tmax : float
379 The maximum time up to which we extract the sync pulses.
381 Returns
382 -------
383 np.array
384 Wheel timestamps in seconds.
385 np.array
386 Wheel positions in radians.
388 See Also
389 --------
390 ibllib.io.extractors.ephys_fpga.extract_wheel_sync
391 """
392 if coding not in ('x1', 'x2', 'x4'): 1ga
393 raise ValueError('Unsupported coding; must be one of x1, x2 or x4') 1g
394 raw = correct_counter_discontinuities(timeline_get_channel(self.timeline, 'rotary_encoder')) 1ga
396 # Timeline evenly samples counter so we extract only change points
397 d = np.diff(raw) 1ga
398 ind, = np.where(~np.isclose(d, 0)) 1ga
399 pos = raw[ind + 1] 1ga
400 pos -= pos[0] # Start from zero 1ga
401 pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians 1ga
403 # Get timestamps of changes and trim based on protocol spacers
404 ts = self.timeline['timestamps'][ind + 1] 1ga
405 tmin = ts.min() if tmin is None else tmin 1ga
406 tmax = ts.max() if tmax is None else tmax 1ga
407 mask = np.logical_and(ts >= tmin, ts <= tmax) 1ga
408 return ts[mask], pos[mask] 1ga
410 def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4',
411 tmin=None, tmax=None, display=False, **kwargs):
412 """
413 Gets the wheel position and detected movements from Timeline counter channel.
415 Called by the super class extractor (FPGATrials._extract).
417 Parameters
418 ----------
419 ticks : int
420 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder).
421 radius : float
422 Radius of the wheel. Defaults to 1 for an output in radians.
423 coding : str {'x1', 'x2', 'x4'}
424 Rotary encoder encoding (IBL default is x4).
425 tmin : float
426 The minimum time from which to extract the sync pulses.
427 tmax : float
428 The maximum time up to which we extract the sync pulses.
429 display : bool
430 If true, plot the wheel positions from bpod and the DAQ.
432 Returns
433 -------
434 dict
435 wheel object with keys ('timestamps', 'position').
436 dict
437 wheelMoves object with keys ('intervals' 'peakAmplitude').
438 """
439 wheel = self.extract_wheel_sync(ticks=ticks, radius=radius, coding=coding, tmin=tmin, tmax=tmax) 1ga
440 wheel = dict(zip(('timestamps', 'position'), wheel)) 1ga
441 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 1ga
443 if display: 1ga
444 assert self.bpod_trials, 'no bpod trials to compare' 1g
445 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1g
446 bpod_ts = self.bpod_trials['wheel_timestamps'] 1g
447 bpod_pos = self.bpod_trials['wheel_position'] 1g
448 ax0.plot(self.bpod2fpga(bpod_ts), bpod_pos) 1g
449 ax0.set_ylabel('Bpod wheel position / rad') 1g
450 ax1.plot(wheel['timestamps'], wheel['position']) 1g
451 ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') 1g
452 return wheel, moves 1ga
454 def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None):
455 """
456 Get the valve open times from the raw timeline voltage trace.
458 Parameters
459 ----------
460 display : bool
461 Plot detected times on the raw voltage trace.
462 threshold : float
463 The threshold of voltage change to apply. The default was set by eye; units should be
464 Volts per sample but doesn't appear to be.
465 driver_ttls : numpy.array
466 An optional array of driver TTLs to use for assigning with the valve times.
468 Returns
469 -------
470 numpy.array
471 The detected valve open intervals.
472 numpy.array
473 If driver_ttls is not None, returns an array of open times that occurred directly after
474 the driver TTLs.
475 """
476 WARN_THRESH = 10e-3 # open time threshold below which to log warning 1ea
477 tl = self.timeline 1ea
478 info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') 1ea
479 values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 1ea
481 # The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz
482 # making simple thresholding an issue. For this reason we convolve the signal with a
483 # window and detect the peaks and troughs.
484 if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz 1ea
485 _logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs) 1ea
486 dt = 1e-3 # change in voltage takes ~1ms when changing valve open state 1ea
487 N = dt / (1 / Fs) # this means voltage change occurs over N samples 1ea
488 vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time 1ea
489 ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V) 1ea
490 downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V) 1ea
492 # Convert these times into intervals
493 ixs = np.argsort(np.r_[downs, ups]) # sort indices 1ea
494 times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times 1ea
495 polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted 1ea
496 missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN 1ea
497 times = np.insert(times, missing + int(polarities[0] == -1), np.nan) 1ea
498 if polarities[-1] == -1: # ensure ends with a valve close 1ea
499 times = np.r_[times, np.nan] 1ea
500 if polarities[0] == 1: # ensure starts with a valve open 1ea
501 # It seems it can start out at -5V (open), then when the reward happens it closes and
502 # immediately opens. In this case we insert discard the first open time.
503 times = np.r_[np.nan, times]
504 intervals = times.reshape(-1, 2) 1ea
506 # Log warning of improbably short intervals
507 short = np.sum(np.diff(intervals) < WARN_THRESH) 1ea
508 if short > 0: 1ea
509 _logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH)
511 # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL
512 if driver_ttls is not None: 1ea
513 # Returns an array of open_times indices, one for each driver TTL
514 ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after') 1a
515 open_times = intervals[ind[ind >= 0], 0] 1a
516 # TODO Log any > 40ms? Difficult to report missing valve times because of calibration
518 if display: 1ea
519 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1e
520 ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-') 1e
521 if driver_ttls is not None: 1e
522 x = np.empty_like(driver_ttls.flatten())
523 x[0::2] = driver_ttls[:, 0]
524 x[1::2] = driver_ttls[:, 1]
525 y = np.ones_like(x)
526 y[1::2] -= 2
527 squares(x, y, ax=ax0, yrange=[0, 5])
528 # vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b')
529 ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*')
530 ax1.plot(tl['timestamps'], values, 'k-o') 1e
531 ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') 1e
533 ax2 = ax1.twinx() 1e
534 ax2.set_ylabel('dV', color='grey') 1e
535 ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey') 1e
536 ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close') 1e
537 ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open') 1e
538 return intervals if driver_ttls is None else (intervals, open_times) 1ea
540 def _assign_events_audio(self, audio_times, audio_polarities, display=False):
541 """
542 This is identical to ephys_fpga._assign_events_audio, except for the ready tone threshold.
544 Parameters
545 ----------
546 audio_times : numpy.array
547 An array of audio TTL front times.
548 audio_polarities : numpy.array
549 An array of audio TTL front polarities (1 for rises, -1 for falls).
550 display : bool
551 If true, display audio pulses and the assigned onsets.
553 Returns
554 -------
555 numpy.array
556 The times of the go cue onsets.
557 numpy.array
558 The times of the error tone onsets.
559 """
560 # make sure that there are no 2 consecutive fall or consecutive rise events
561 assert np.all(np.abs(np.diff(audio_polarities)) == 2)
562 # take only even time differences: i.e. from rising to falling fronts
563 dt = np.diff(audio_times)
564 onsets = audio_polarities[:-1] == 1
566 # error tones are events lasting from 400ms to 1200ms
567 i_error_tone_in = np.where(np.logical_and(0.4 < dt, dt < 1.2) & onsets)[0]
568 t_error_tone_in = audio_times[i_error_tone_in]
570 # detect ready tone by length below 300 ms
571 i_ready_tone_in = np.where(np.logical_and(dt <= 0.3, onsets))[0]
572 t_ready_tone_in = audio_times[i_ready_tone_in]
573 if display: # pragma: no cover
574 fig, ax = plt.subplots(nrows=2, sharex=True)
575 ax[0].plot(self.timeline.timestamps, timeline_get_channel(self.timeline, 'audio'), 'k-o')
576 ax[0].set_ylabel('Voltage / V')
577 squares(audio_times, audio_polarities, yrange=[-1, 1], ax=ax[1])
578 vertical_lines(t_ready_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='go cue')
579 vertical_lines(t_error_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='error tone')
580 ax[1].set_xlabel('Time / s')
581 ax[1].legend()
583 return t_ready_tone_in, t_error_tone_in
586class MesoscopeSyncTimeline(extractors_base.BaseExtractor):
587 """Extraction of mesoscope imaging times."""
589 var_names = ('mpci_times', 'mpciStack_timeshift')
590 save_names = ('mpci.times.npy', 'mpciStack.timeshift.npy')
592 """one.alf.io.AlfBunch: The raw imaging meta data and frame times"""
593 rawImagingData = None
595 def __init__(self, session_path, n_FOVs):
596 """
597 Extract the mesoscope frame times from DAQ data acquired through Timeline.
599 Parameters
600 ----------
601 session_path : str, pathlib.Path
602 The session path to extract times from.
603 n_FOVs : int
604 The number of fields of view acquired.
605 """
606 super().__init__(session_path) 1fcb
607 self.n_FOVs = n_FOVs 1fcb
608 fov = list(map(lambda n: f'FOV_{n:02}', range(self.n_FOVs))) 1fcb
609 self.var_names = [f'{x}_{y.lower()}' for x in self.var_names for y in fov] 1fcb
610 self.save_names = [f'{y}/{x}' for x in self.save_names for y in fov] 1fcb
612 def _extract(self, sync=None, chmap=None, device_collection='raw_imaging_data', events=None):
613 """
614 Extract the frame timestamps for each individual field of view (FOV) and the time offsets
615 for each line scan.
617 The detected frame times from the 'neural_frames' channel of the DAQ are split into bouts
618 corresponding to the number of raw_imaging_data folders. These timestamps should match the
619 number of frame timestamps extracted from the image file headers (found in the
620 rawImagingData.times file). The field of view (FOV) shifts are then applied to these
621 timestamps for each field of view and provided together with the line shifts.
623 Parameters
624 ----------
625 sync : one.alf.io.AlfBunch
626 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
627 and the corresponding channel numbers.
628 chmap : dict
629 A map of channel names and their corresponding indices. Only the 'neural_frames'
630 channel is required.
631 device_collection : str, iterable of str
632 The location of the raw imaging data.
633 events : pandas.DataFrame
634 A table of software events, with columns {'time_timeline' 'name_timeline',
635 'event_timeline'}.
637 Returns
638 -------
639 list of numpy.array
640 A list of timestamps for each FOV and the time offsets for each line scan.
641 """
642 frame_times = sync['times'][sync['channels'] == chmap['neural_frames']] 1cb
644 # imaging_start_time = datetime.datetime(*map(round, self.rawImagingData.meta['acquisitionStartTime']))
645 if isinstance(device_collection, str): 1cb
646 device_collection = [device_collection] 1c
647 if events is not None: 1cb
648 events = events[events.name == 'mpepUDP'] 1b
649 edges = self.get_bout_edges(frame_times, device_collection, events) 1cb
650 fov_times = [] 1cb
651 line_shifts = [] 1cb
652 for (tmin, tmax), collection in zip(edges, sorted(device_collection)): 1cb
653 imaging_data = alfio.load_object(self.session_path / collection, 'rawImagingData') 1cb
654 imaging_data['meta'] = patch_imaging_meta(imaging_data['meta']) 1cb
655 # Calculate line shifts
656 _, fov_time_shifts, line_time_shifts = self.get_timeshifts(imaging_data['meta']) 1cb
657 assert len(fov_time_shifts) == self.n_FOVs, f'unexpected number of FOVs for {collection}' 1cb
658 ts = frame_times[np.logical_and(frame_times >= tmin, frame_times <= tmax)] 1cb
659 assert ts.size >= imaging_data['times_scanImage'].size, f'fewer DAQ timestamps for {collection} than expected' 1cb
660 if ts.size > imaging_data['times_scanImage'].size: 1cb
661 _logger.warning( 1c
662 'More DAQ frame times detected for %s than were found in the raw image data.\n'
663 'N DAQ frame times:\t%i\nN raw image data times:\t%i.\n'
664 'This may occur if the bout detection fails (e.g. UDPs recorded late), '
665 'when image data is corrupt, or when frames are not written to file.',
666 collection, ts.size, imaging_data['times_scanImage'].size)
667 _logger.info('Dropping last %i frame times for %s', ts.size - imaging_data['times_scanImage'].size, collection) 1c
668 ts = ts[:imaging_data['times_scanImage'].size] 1c
669 fov_times.append([ts + offset for offset in fov_time_shifts]) 1cb
670 if not line_shifts: 1cb
671 line_shifts = line_time_shifts 1cb
672 else: # The line shifts should be the same across all imaging bouts
673 [np.testing.assert_array_equal(x, y) for x, y in zip(line_time_shifts, line_shifts)] 1b
675 # Concatenate imaging timestamps across all bouts for each field of view
676 fov_times = list(map(np.concatenate, zip(*fov_times))) 1cb
677 n_fov_times, = set(map(len, fov_times)) 1cb
678 if n_fov_times != frame_times.size: 1cb
679 # This may happen if an experimenter deletes a raw_imaging_data folder
680 _logger.debug('FOV timestamps length does not match neural frame count; imaging bout(s) likely missing') 1c
681 return fov_times + line_shifts 1cb
683 def get_bout_edges(self, frame_times, collections=None, events=None, min_gap=1., display=False):
684 """
685 Return an array of edge times for each imaging bout corresponding to a raw_imaging_data
686 collection.
688 Parameters
689 ----------
690 frame_times : numpy.array
691 An array of all neural frame count times.
692 collections : iterable of str
693 A set of raw_imaging_data collections, used to extract selected imaging periods.
694 events : pandas.DataFrame
695 A table of UDP event times, corresponding to times when recordings start and end.
696 min_gap : float
697 If start or end events not present, split bouts by finding gaps larger than this value.
698 display : bool
699 If true, plot the detected bout edges and raw frame times.
701 Returns
702 -------
703 numpy.array
704 An array of imaging bout intervals.
705 """
706 if events is None or events.empty: 1fcb
707 # No UDP events to mark blocks so separate based on gaps in frame rate
708 idx = np.where(np.diff(frame_times) > min_gap)[0] 1fc
709 starts = np.r_[frame_times[0], frame_times[idx + 1]] 1fc
710 ends = np.r_[frame_times[idx], frame_times[-1]] 1fc
711 else:
712 # Split using Exp/BlockStart and Exp/BlockEnd times
713 _, subject, date, _ = session_path_parts(self.session_path) 1fb
714 pattern = rf'(Exp|Block)%s\s{subject}\s{date.replace("-", "")}\s\d+' 1fb
716 # Get start times
717 UDP_start = events[events['info'].str.match(pattern % 'Start')] 1fb
718 if len(UDP_start) > 1 and UDP_start.loc[0, 'info'].startswith('Exp'): 1fb
719 # Use ExpStart instead of first bout start
720 UDP_start = UDP_start.copy().drop(1) 1fb
721 # Use ExpStart/End instead of first/last BlockStart/End
722 starts = frame_times[[np.where(frame_times >= t)[0][0] for t in UDP_start.time]] 1fb
724 # Get end times
725 UDP_end = events[events['info'].str.match(pattern % 'End')] 1fb
726 if len(UDP_end) > 1 and UDP_end['info'].values[-1].startswith('Exp'): 1fb
727 # Use last BlockEnd instead of ExpEnd
728 UDP_end = UDP_end.copy().drop(UDP_end.index[-1]) 1fb
729 if not UDP_end.empty: 1fb
730 ends = frame_times[[np.where(frame_times <= t)[0][-1] for t in UDP_end.time]] 1fb
731 else:
732 # Get index of last frame to occur within a second of the previous frame
733 consec = np.r_[np.diff(frame_times) > min_gap, True] 1f
734 idx = [np.where(np.logical_and(frame_times > t, consec))[0][0] for t in starts] 1f
735 ends = frame_times[idx] 1f
737 # Remove any missing imaging bout collections
738 edges = np.c_[starts, ends] 1fcb
739 if collections: 1fcb
740 if edges.shape[0] > len(collections): 1fcb
741 # Remove any bouts that correspond to a skipped collection
742 # e.g. if {raw_imaging_data_00, raw_imaging_data_02}, remove middle bout
743 include = sorted(int(c.rsplit('_', 1)[-1]) for c in collections)
744 edges = edges[include, :]
745 elif edges.shape[0] < len(collections): 1fcb
746 raise ValueError('More raw imaging folders than detected bouts') 1f
748 if display: 1fcb
749 _, ax = plt.subplots(1) 1f
750 ax.step(frame_times, np.arange(frame_times.size), label='frame times', color='k', ) 1f
751 vertical_lines(edges[:, 0], ax=ax, ymin=0, ymax=frame_times.size, label='bout start', color='b') 1f
752 vertical_lines(edges[:, 1], ax=ax, ymin=0, ymax=frame_times.size, label='bout end', color='orange') 1f
753 if edges.shape[0] != len(starts): 1f
754 vertical_lines(np.setdiff1d(starts, edges[:, 0]), ax=ax, ymin=0, ymax=frame_times.size,
755 label='missing bout start', linestyle=':', color='b')
756 vertical_lines(np.setdiff1d(ends, edges[:, 1]), ax=ax, ymin=0, ymax=frame_times.size,
757 label='missing bout end', linestyle=':', color='orange')
758 ax.set_xlabel('Time / s'), ax.set_ylabel('Frame #'), ax.legend(loc='lower right') 1f
759 return edges 1fcb
761 @staticmethod
762 def get_timeshifts(raw_imaging_meta):
763 """
764 Calculate the time shifts for each field of view (FOV) and the relative offsets for each
765 scan line.
767 For a 2 scan field, 2 depth recording (so 4 FOVs):
769 Frame 1, lines 1-512 correspond to FOV_00
770 Frame 1, lines 551-1062 correspond to FOV_01
771 Frame 2, lines 1-512 correspond to FOV_02
772 Frame 2, lines 551-1062 correspond to FOV_03
773 Frame 3, lines 1-512 correspond to FOV_00
774 ...
776 Parameters
777 ----------
778 raw_imaging_meta : dict
779 Extracted ScanImage meta data (_ibl_rawImagingData.meta.json).
781 Returns
782 -------
783 list of numpy.array
784 A list of arrays, one per FOV, containing indices of each image scan line.
785 numpy.array
786 An array of FOV time offsets (one value per FOV) relative to each frame acquisition
787 time.
788 list of numpy.array
789 A list of arrays, one per FOV, containing the time offsets for each scan line, relative
790 to each FOV offset.
791 """
792 FOVs = raw_imaging_meta['FOV'] 1jcb
794 # Double-check meta extracted properly
795 # assert meta.FOV.Zs is ascending but use slice_id field. This may not be necessary but is expected.
796 slice_ids = np.array([fov['slice_id'] for fov in FOVs]) 1jcb
797 assert np.all(np.diff([x['Zs'] for x in FOVs]) >= 0), 'FOV depths not in ascending order' 1jcb
798 assert np.all(np.diff(slice_ids) >= 0), 'slice IDs not ordered' 1jcb
799 # Number of scan lines per FOV, i.e. number of Y pixels / image height
800 n_lines = np.array([x['nXnYnZ'][1] for x in FOVs]) 1jcb
802 # We get indices from MATLAB extracted metadata so below two lines are no longer needed
803 # n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks
804 # n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) # N lines during flyback
805 line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod'] 1jcb
806 frame_time_shifts = slice_ids / raw_imaging_meta['scanImageParams']['hRoiManager']['scanFrameRate'] 1jcb
808 # Line indices are now extracted by the MATLAB function mesoscopeMetadataExtraction.m
809 # They are indexed from 1 so we subtract 1 to convert to zero-indexed
810 line_indices = [np.array(fov['lineIdx']) - 1 for fov in FOVs] # Convert to zero-indexed from MATLAB 1-indexed 1jcb
811 assert all(lns.size == n for lns, n in zip(line_indices, n_lines)), 'unexpected number of scan lines' 1jcb
812 # The start indices of each FOV in the raw images
813 fov_start_idx = np.array([lns[0] for lns in line_indices]) 1jcb
814 roi_time_shifts = fov_start_idx * line_period # The time offset for each FOV 1jcb
815 fov_time_shifts = roi_time_shifts + frame_time_shifts 1jcb
816 line_time_shifts = [(lns - ln0) * line_period for lns, ln0 in zip(line_indices, fov_start_idx)] 1jcb
818 return line_indices, fov_time_shifts, line_time_shifts 1jcb