Coverage for ibllib/io/extractors/mesoscope.py: 85%
337 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""Mesoscope (timeline) data extraction."""
2import logging
4import numpy as np
5from scipy.signal import find_peaks
6import one.alf.io as alfio
7from one.alf.path import session_path_parts
8from iblutil.util import ensure_list
9import matplotlib.pyplot as plt
10from packaging import version
12from ibllib.plots.misc import squares, vertical_lines
13from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel,
14 correct_counter_discontinuities, load_timeline_sync_and_chmap)
15import ibllib.io.extractors.base as extractors_base
16from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial
17from ibllib.io.extractors.training_wheel import extract_wheel_moves
18from ibllib.io.extractors.camera import attribute_times
19from brainbox.behavior.wheel import velocity_filtered
21_logger = logging.getLogger(__name__)
24def patch_imaging_meta(meta: dict) -> dict:
25 """
26 Patch imaging metadata for compatibility across versions.
28 A copy of the dict is NOT returned.
30 Parameters
31 ----------
32 meta : dict
33 A folder path that contains a rawImagingData.meta file.
35 Returns
36 -------
37 dict
38 The loaded metadata file, updated to the most recent version.
39 """
40 # 2023-05-17 (unversioned) adds nFrames, channelSaved keys, MM and Deg keys
41 ver = version.parse(meta.get('version') or '0.0.0') 1hlmjcb
42 if ver <= version.parse('0.0.0'): 1hlmjcb
43 if 'channelSaved' not in meta: 1hlm
44 meta['channelSaved'] = next((x['channelIdx'] for x in meta['FOV'] if 'channelIdx' in x), []) 1hlm
45 fields = ('topLeft', 'topRight', 'bottomLeft', 'bottomRight') 1hlm
46 for fov in meta.get('FOV', []): 1hlm
47 for unit in ('Deg', 'MM'): 1hlm
48 if unit not in fov: # topLeftDeg, etc. -> Deg[topLeft] 1hlm
49 fov[unit] = {f: fov.pop(f + unit, None) for f in fields} 1hlm
50 elif ver == version.parse('0.1.0'): 1hjcb
51 for fov in meta.get('FOV', []): 1h
52 if 'roiUuid' in fov: 1h
53 fov['roiUUID'] = fov.pop('roiUuid') 1h
54 # 2024-09-17 Modified the 2 unit vectors for the positive ML axis and the positive AP axis,
55 # which then transform [X,Y] coordinates (in degrees) to [ML,AP] coordinates (in MM).
56 if ver < version.Version('0.1.5') and 'imageOrientation' in meta: 1hlmjcb
57 pos_ml, pos_ap = meta['imageOrientation']['positiveML'], meta['imageOrientation']['positiveAP'] 1hjcb
58 center_ml, center_ap = meta['centerMM']['ML'], meta['centerMM']['AP'] 1hjcb
59 res = meta['scanImageParams']['objectiveResolution'] 1hjcb
60 # previously [[0, res/1000], [-res/1000, 0], [0, 0]]
61 TF = np.linalg.pinv(np.c_[np.vstack([pos_ml, pos_ap, [0, 0]]), [1, 1, 1]]) @ \ 1hjcb
62 (np.array([[res / 1000, 0], [0, res / 1000], [0, 0]]) + np.array([center_ml, center_ap]))
63 TF = np.round(TF, 3) # handle floating-point error by rounding 1hjcb
64 if not np.allclose(TF, meta['coordsTF']): 1hjcb
65 meta['coordsTF'] = TF.tolist() 1hjcb
66 centerDegXY = np.array([meta['centerDeg']['x'], meta['centerDeg']['y']]) 1hjcb
67 for fov in meta.get('FOV', []): 1hjcb
68 fov['MM'] = {k: (np.r_[np.array(v) - centerDegXY, 1] @ TF).tolist() for k, v in fov['Deg'].items()} 1hjcb
70 assert 'nFrames' in meta, '"nFrames" key missing from meta data; rawImagingData.meta.json likely an old version' 1hlmjcb
71 return meta 1hlmjcb
74def plot_timeline(timeline, channels=None, raw=True):
75 """
76 Plot the timeline data.
78 Parameters
79 ----------
80 timeline : one.alf.io.AlfBunch
81 The timeline data object.
82 channels : list of str
83 An iterable of channel names to plot.
84 raw : bool
85 If true, plot the raw DAQ samples; if false, apply TTL thresholds and plot changes.
87 Returns
88 -------
89 matplotlib.pyplot.Figure
90 The figure containing timeline subplots.
91 list of matplotlib.pyplot.Axes
92 The axes for each timeline channel plotted.
93 """
94 meta = {x.copy().pop('name'): x for x in timeline['meta']['inputs']} 1i
95 channels = channels or meta.keys() 1i
96 fig, axes = plt.subplots(len(channels), 1, sharex=True) 1i
97 axes = ensure_list(axes) 1i
98 if not raw: 1i
99 chmap = {ch: meta[ch]['arrayColumn'] for ch in channels} 1i
100 sync = extract_sync_timeline(timeline, chmap=chmap) 1i
101 for i, (ax, ch) in enumerate(zip(axes, channels)): 1i
102 if raw: 1i
103 # axesScale controls vertical scaling of each trace (multiplicative)
104 values = timeline['raw'][:, meta[ch]['arrayColumn'] - 1] * meta[ch]['axesScale'] 1i
105 ax.plot(timeline['timestamps'], values) 1i
106 elif np.any(idx := sync['channels'] == chmap[ch]): 1i
107 squares(sync['times'][idx], sync['polarities'][idx], ax=ax) 1i
108 ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) 1i
109 ax.spines['bottom'].set_visible(False), ax.spines['left'].set_visible(True) 1i
110 ax.set_ylabel(ch, rotation=45, fontsize=8) 1i
111 # Add back x-axis ticks to the last plot
112 axes[-1].tick_params(axis='x', which='both', bottom=True, labelbottom=True) 1i
113 axes[-1].spines['bottom'].set_visible(True) 1i
114 plt.get_current_fig_manager().window.showMaximized() # full screen 1i
115 fig.tight_layout(h_pad=0) 1i
116 return fig, axes 1i
119class TimelineTrials(FpgaTrials):
120 """Similar extraction to the FPGA, however counter and position channels are treated differently."""
122 timeline = None
123 """one.alf.io.AlfBunch: The timeline data object."""
125 sync_field = 'itiIn_times'
126 """str: The trial event to synchronize (must be present in extracted trials)."""
128 def __init__(self, *args, sync_collection='raw_sync_data', **kwargs):
129 """An extractor for all ephys trial data, in Timeline time"""
130 super().__init__(*args, **kwargs) 1ega
131 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') 1ega
133 def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_):
134 """Load the DAQ sync and channel map data.
136 Parameters
137 ----------
138 sync_collection : str
139 The session subdirectory where the sync data are located.
140 chmap : dict
141 A map of channel names and their corresponding indices. If None, the channel map is
142 loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method.
144 Returns
145 -------
146 one.alf.io.AlfBunch
147 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
148 and the corresponding channel numbers.
149 dict
150 A map of channel names and their corresponding indices.
151 """
152 if not self.timeline: 1a
153 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline')
154 sync, chmap = load_timeline_sync_and_chmap( 1a
155 self.session_path / sync_collection, timeline=self.timeline, chmap=chmap)
156 return sync, chmap 1a
158 def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict:
159 trials = super()._extract(sync, chmap, sync_collection='raw_sync_data', **kwargs) 1a
160 if kwargs.get('display', False): 1a
161 plot_timeline(self.timeline, channels=chmap.keys(), raw=True)
162 return trials 1a
164 def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
165 """
166 Extract Bpod times from sync.
168 Unlike the superclass method. This one doesn't reassign the first trial pulse.
170 Parameters
171 ----------
172 sync : dict
173 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
174 and the corresponding channel numbers. Must contain a 'bpod' key.
175 chmap : dict
176 A map of channel names and their corresponding indices.
177 bpod_event_ttls : dict of tuple
178 A map of event names to (min, max) TTL length.
180 Returns
181 -------
182 dict
183 A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
184 dict
185 A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
186 """
187 # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
188 # lengths are defined by the state machine of the task protocol and therefore vary.
189 if bpod_event_ttls is None: 1a
190 # The trial start TTLs are often too short for the low sampling rate of the DAQ and are
191 # therefore not used in extraction
192 bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} 1a
193 bpod, bpod_event_intervals = super().get_bpod_event_times( 1a
194 sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs)
196 # TODO Here we can make use of the 'bpod_rising_edge' channel, if available
197 return bpod, bpod_event_intervals 1a
199 def build_trials(self, sync=None, chmap=None, **kwargs):
200 """
201 Extract task related event times from the sync.
203 The two major differences are that the sampling rate is lower for imaging so the short Bpod
204 trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL.
206 Second, the valve used at the mesoscope has a way to record the raw voltage across the
207 solenoid, giving a more accurate readout of the valve's activity. If the reward_valve
208 channel is present on the DAQ, this is used to extract the valve open times.
210 Parameters
211 ----------
212 sync : dict
213 'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
214 chmap : dict
215 Map of channel names and their corresponding index. Default to constant.
217 Returns
218 -------
219 dict
220 A map of trial event timestamps.
221 """
222 # Get the events from the sync.
223 # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
224 self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) 1a
225 self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) 1a
226 if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: 1a
227 raise ValueError(
228 'Expected at least "ready_tone" and "error_tone" audio events.'
229 '`audio_event_ttls` kwarg may be incorrect.')
231 self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) 1a
232 if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}: 1a
233 raise ValueError(
234 'Expected at least "trial_end" and "valve_open" audio events. '
235 '`bpod_event_ttls` kwarg may be incorrect.')
237 t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T 1a
238 fpga_events = alfio.AlfBunch({ 1a
239 'itiIn_times': t_iti_in,
240 'intervals_1': t_trial_end,
241 'goCue_times': audio_event_intervals['ready_tone'][:, 0],
242 'errorTone_times': audio_event_intervals['error_tone'][:, 0]
243 })
245 # Sync the Bpod clock to the DAQ
246 self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) 1a
248 out = dict() 1a
249 out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) 1a
250 out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) 1a
252 start_times = out['intervals'][:, 0] 1a
253 last_trial_end = out['intervals'][-1, 1] 1a
255 def assign_to_trial(events, take='last', starts=start_times, **kwargs): 1a
256 """Assign DAQ events to trials.
258 Because we may not have trial start TTLs on the DAQ (because of the low sampling rate),
259 there may be an extra last trial that's not in the Bpod intervals as the extractor
260 ignores the last trial. This function trims the input array before assigning so that
261 the last trial's events are correctly assigned.
262 """
263 return _assign_events_to_trial(starts, events[events <= last_trial_end], take, **kwargs) 1a
264 out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga]) 1a
266 # Extract valve open times from the DAQ
267 valve_driver_ttls = bpod_event_intervals['valve_open'] 1a
268 correct = self.bpod_trials['feedbackType'] == 1 1a
269 # If there is a reward_valve channel, the valve has
270 if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): 1a
271 # TODO Let's look at the expected open length based on calibration and reward volume
272 # import scipy.interpolate
273 # # FIXME support v7 settings?
274 # fcn_vol2time = scipy.interpolate.pchip(
275 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'],
276 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES']
277 # )
278 # reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3
280 # Use the driver TTLs to find the valve open times that correspond to the valve opening
281 valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) 1a
282 if valve_open_times.size != np.sum(correct): 1a
283 _logger.warning(
284 'Number of valve open times does not equal number of correct trials (%i != %i)',
285 valve_open_times.size, np.sum(correct))
287 out['valveOpen_times'] = assign_to_trial(valve_open_times) 1a
288 else:
289 # Use the valve controller TTLs recorded on the Bpod channel as the reward time
290 out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0])
292 # Stimulus times extracted based on trigger times
293 # When assigning events all start times must not be NaN so here we substitute freeze
294 # trigger times on nogo trials for stim on trigger times, then replace with NaN again
295 go_trials = np.where(out['choice'] != 0)[0] 1a
296 lims = np.copy(out['stimOnTrigger_times']) 1a
297 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a
298 out['stimFreeze_times'] = assign_to_trial( 1a
299 self.frame2ttl['times'], 'last',
300 starts=lims, t_trial_end=out['stimOffTrigger_times'])
301 out['stimFreeze_times'][out['choice'] == 0] = np.nan 1a
303 # Here we do the same but use stim off trigger times
304 lims = np.copy(out['stimOffTrigger_times']) 1a
305 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a
306 out['stimOn_times'] = assign_to_trial( 1a
307 self.frame2ttl['times'], 'first',
308 starts=out['stimOnTrigger_times'], t_trial_end=lims)
309 out['stimOff_times'] = assign_to_trial( 1a
310 self.frame2ttl['times'], 'first',
311 starts=out['stimOffTrigger_times'], t_trial_end=out['intervals'][:, 1]
312 )
314 # Audio times
315 error_cue = fpga_events['errorTone_times'] 1a
316 if error_cue.size != np.sum(~correct): 1a
317 _logger.warning(
318 'N detected error tones does not match number of incorrect trials (%i != %i)',
319 error_cue.size, np.sum(~correct))
320 go_cue = fpga_events['goCue_times'] 1a
321 out['goCue_times'] = assign_to_trial(go_cue, take='first') 1a
322 out['errorCue_times'] = assign_to_trial(error_cue) 1a
324 if go_cue.size > start_times.size: 1a
325 _logger.warning(
326 'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size)
327 elif go_cue.size < start_times.size: 1a
328 """
329 If the error cues are all assigned and some go cues are missed it may be that some
330 responses were so fast that the go cue and error tone merged, or the go cue TTL was too
331 long.
332 """
333 _logger.warning('%i go cue tones missed', start_times.size - go_cue.size)
334 err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times'])
335 go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times'])
336 assert not np.any(np.isnan(go_trig))
337 assert err_trig.size == go_trig.size # should be length of n trials with NaNs
339 # Find which trials are missing a go cue
340 _go_cue = assign_to_trial(go_cue, take='first')
341 error_cue = assign_to_trial(error_cue)
342 missing = np.isnan(_go_cue)
344 # Get all the DAQ timestamps where audio channel was HIGH
345 raw = timeline_get_channel(self.timeline, 'audio')
346 raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize
347 ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH
349 # Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after).
350 # Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN)
351 idx = attribute_times(ups, go_trig, tol=0.2, take='after')
352 # Trial indices that didn't have detected goCue and now has been assigned an `ups` index
353 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned
354 _go_cue[assigned] = ups[idx[assigned]]
356 # Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue)
357 error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig))
358 i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True)
359 error_cue[i_to_remove] = np.nan
361 # For those trials where go cue was merged with the error cue and therefore mis-assigned,
362 # we must re-assign the error cue times as the first HIGH after the error trigger.
363 idx = attribute_times(ups, err_trig, tol=0.2, take='after')
364 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned
365 error_cue[assigned] = ups[idx[assigned]]
366 out['goCue_times'] = _go_cue
367 out['errorCue_times'] = error_cue
369 # Because we're not
370 assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \ 1a
371 'audio tones not assigned correctly; tones likely missed'
373 # Feedback times
374 out['feedback_times'] = np.copy(out['valveOpen_times']) 1a
375 ind_err = np.isnan(out['valveOpen_times']) 1a
376 out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] 1a
378 return out 1a
380 def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None):
381 """
382 Gets the wheel position from Timeline counter channel.
384 Parameters
385 ----------
386 ticks : int
387 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder).
388 radius : float
389 Radius of the wheel. Defaults to 1 for an output in radians.
390 coding : str {'x1', 'x2', 'x4'}
391 Rotary encoder encoding (IBL default is x4).
392 tmin : float
393 The minimum time from which to extract the sync pulses.
394 tmax : float
395 The maximum time up to which we extract the sync pulses.
397 Returns
398 -------
399 np.array
400 Wheel timestamps in seconds.
401 np.array
402 Wheel positions in radians.
404 See Also
405 --------
406 ibllib.io.extractors.ephys_fpga.extract_wheel_sync
407 """
408 if coding not in ('x1', 'x2', 'x4'): 1ga
409 raise ValueError('Unsupported coding; must be one of x1, x2 or x4') 1g
410 raw = correct_counter_discontinuities(timeline_get_channel(self.timeline, 'rotary_encoder')) 1ga
412 # Timeline evenly samples counter so we extract only change points
413 d = np.diff(raw) 1ga
414 ind, = np.where(~np.isclose(d, 0)) 1ga
415 pos = raw[ind + 1] 1ga
416 pos -= pos[0] # Start from zero 1ga
417 pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians 1ga
419 # Get timestamps of changes and trim based on protocol spacers
420 ts = self.timeline['timestamps'][ind + 1] 1ga
421 tmin = ts.min() if tmin is None else tmin 1ga
422 tmax = ts.max() if tmax is None else tmax 1ga
423 mask = np.logical_and(ts >= tmin, ts <= tmax) 1ga
424 return ts[mask], pos[mask] 1ga
426 def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4',
427 tmin=None, tmax=None, display=False, **kwargs):
428 """
429 Gets the wheel position and detected movements from Timeline counter channel.
431 Called by the super class extractor (FPGATrials._extract).
433 Parameters
434 ----------
435 ticks : int
436 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder).
437 radius : float
438 Radius of the wheel. Defaults to 1 for an output in radians.
439 coding : str {'x1', 'x2', 'x4'}
440 Rotary encoder encoding (IBL default is x4).
441 tmin : float
442 The minimum time from which to extract the sync pulses.
443 tmax : float
444 The maximum time up to which we extract the sync pulses.
445 display : bool
446 If true, plot the wheel positions from bpod and the DAQ.
448 Returns
449 -------
450 dict
451 wheel object with keys ('timestamps', 'position').
452 dict
453 wheelMoves object with keys ('intervals' 'peakAmplitude').
454 """
455 wheel = self.extract_wheel_sync(ticks=ticks, radius=radius, coding=coding, tmin=tmin, tmax=tmax) 1ga
456 wheel = dict(zip(('timestamps', 'position'), wheel)) 1ga
457 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 1ga
459 if display: 1ga
460 assert self.bpod_trials, 'no bpod trials to compare' 1g
461 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1g
462 bpod_ts = self.bpod_trials['wheel_timestamps'] 1g
463 bpod_pos = self.bpod_trials['wheel_position'] 1g
464 ax0.plot(self.bpod2fpga(bpod_ts), bpod_pos) 1g
465 ax0.set_ylabel('Bpod wheel position / rad') 1g
466 ax1.plot(wheel['timestamps'], wheel['position']) 1g
467 ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') 1g
468 return wheel, moves 1ga
470 def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None):
471 """
472 Get the valve open times from the raw timeline voltage trace.
474 Parameters
475 ----------
476 display : bool
477 Plot detected times on the raw voltage trace.
478 threshold : float
479 The threshold of voltage change to apply. The default was set by eye; units should be
480 Volts per sample but doesn't appear to be.
481 driver_ttls : numpy.array
482 An optional array of driver TTLs to use for assigning with the valve times.
484 Returns
485 -------
486 numpy.array
487 The detected valve open intervals.
488 numpy.array
489 If driver_ttls is not None, returns an array of open times that occurred directly after
490 the driver TTLs.
491 """
492 WARN_THRESH = 10e-3 # open time threshold below which to log warning 1ea
493 tl = self.timeline 1ea
494 info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') 1ea
495 values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 1ea
497 # The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz
498 # making simple thresholding an issue. For this reason we convolve the signal with a
499 # window and detect the peaks and troughs.
500 if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz 1ea
501 _logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs) 1ea
502 dt = 1e-3 # change in voltage takes ~1ms when changing valve open state 1ea
503 N = dt / (1 / Fs) # this means voltage change occurs over N samples 1ea
504 vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time 1ea
505 ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V) 1ea
506 downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V) 1ea
508 # Convert these times into intervals
509 ixs = np.argsort(np.r_[downs, ups]) # sort indices 1ea
510 times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times 1ea
511 polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted 1ea
512 missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN 1ea
513 times = np.insert(times, missing + int(polarities[0] == -1), np.nan) 1ea
514 if polarities[-1] == -1: # ensure ends with a valve close 1ea
515 times = np.r_[times, np.nan] 1ea
516 if polarities[0] == 1: # ensure starts with a valve open 1ea
517 # It seems it can start out at -5V (open), then when the reward happens it closes and
518 # immediately opens. In this case we insert discard the first open time.
519 times = np.r_[np.nan, times]
520 intervals = times.reshape(-1, 2) 1ea
522 # Log warning of improbably short intervals
523 short = np.sum(np.diff(intervals) < WARN_THRESH) 1ea
524 if short > 0: 1ea
525 _logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH)
527 # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL
528 if driver_ttls is not None: 1ea
529 # Returns an array of open_times indices, one for each driver TTL
530 ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after') 1a
531 open_times = intervals[ind[ind >= 0], 0] 1a
532 # TODO Log any > 40ms? Difficult to report missing valve times because of calibration
534 if display: 1ea
535 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1e
536 ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-') 1e
537 if driver_ttls is not None: 1e
538 x = np.empty_like(driver_ttls.flatten())
539 x[0::2] = driver_ttls[:, 0]
540 x[1::2] = driver_ttls[:, 1]
541 y = np.ones_like(x)
542 y[1::2] -= 2
543 squares(x, y, ax=ax0, yrange=[0, 5])
544 # vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b')
545 ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*')
546 ax1.plot(tl['timestamps'], values, 'k-o') 1e
547 ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') 1e
549 ax2 = ax1.twinx() 1e
550 ax2.set_ylabel('dV', color='grey') 1e
551 ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey') 1e
552 ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close') 1e
553 ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open') 1e
554 return intervals if driver_ttls is None else (intervals, open_times) 1ea
556 def _assign_events_audio(self, audio_times, audio_polarities, display=False):
557 """
558 This is identical to ephys_fpga._assign_events_audio, except for the ready tone threshold.
560 Parameters
561 ----------
562 audio_times : numpy.array
563 An array of audio TTL front times.
564 audio_polarities : numpy.array
565 An array of audio TTL front polarities (1 for rises, -1 for falls).
566 display : bool
567 If true, display audio pulses and the assigned onsets.
569 Returns
570 -------
571 numpy.array
572 The times of the go cue onsets.
573 numpy.array
574 The times of the error tone onsets.
575 """
576 # make sure that there are no 2 consecutive fall or consecutive rise events
577 assert np.all(np.abs(np.diff(audio_polarities)) == 2)
578 # take only even time differences: i.e. from rising to falling fronts
579 dt = np.diff(audio_times)
580 onsets = audio_polarities[:-1] == 1
582 # error tones are events lasting from 400ms to 1200ms
583 i_error_tone_in = np.where(np.logical_and(0.4 < dt, dt < 1.2) & onsets)[0]
584 t_error_tone_in = audio_times[i_error_tone_in]
586 # detect ready tone by length below 300 ms
587 i_ready_tone_in = np.where(np.logical_and(dt <= 0.3, onsets))[0]
588 t_ready_tone_in = audio_times[i_ready_tone_in]
589 if display: # pragma: no cover
590 fig, ax = plt.subplots(nrows=2, sharex=True)
591 ax[0].plot(self.timeline.timestamps, timeline_get_channel(self.timeline, 'audio'), 'k-o')
592 ax[0].set_ylabel('Voltage / V')
593 squares(audio_times, audio_polarities, yrange=[-1, 1], ax=ax[1])
594 vertical_lines(t_ready_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='go cue')
595 vertical_lines(t_error_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='error tone')
596 ax[1].set_xlabel('Time / s')
597 ax[1].legend()
599 return t_ready_tone_in, t_error_tone_in
602class MesoscopeSyncTimeline(extractors_base.BaseExtractor):
603 """Extraction of mesoscope imaging times."""
605 var_names = ('mpci_times', 'mpciStack_timeshift')
606 save_names = ('mpci.times.npy', 'mpciStack.timeshift.npy')
608 """one.alf.io.AlfBunch: The raw imaging meta data and frame times"""
609 rawImagingData = None
611 def __init__(self, session_path, n_FOVs):
612 """
613 Extract the mesoscope frame times from DAQ data acquired through Timeline.
615 Parameters
616 ----------
617 session_path : str, pathlib.Path
618 The session path to extract times from.
619 n_FOVs : int
620 The number of fields of view acquired.
621 """
622 super().__init__(session_path) 1fcb
623 self.n_FOVs = n_FOVs 1fcb
624 fov = list(map(lambda n: f'FOV_{n:02}', range(self.n_FOVs))) 1fcb
625 self.var_names = [f'{x}_{y.lower()}' for x in self.var_names for y in fov] 1fcb
626 self.save_names = [f'{y}/{x}' for x in self.save_names for y in fov] 1fcb
628 def _extract(self, sync=None, chmap=None, device_collection='raw_imaging_data', events=None):
629 """
630 Extract the frame timestamps for each individual field of view (FOV) and the time offsets
631 for each line scan.
633 The detected frame times from the 'neural_frames' channel of the DAQ are split into bouts
634 corresponding to the number of raw_imaging_data folders. These timestamps should match the
635 number of frame timestamps extracted from the image file headers (found in the
636 rawImagingData.times file). The field of view (FOV) shifts are then applied to these
637 timestamps for each field of view and provided together with the line shifts.
639 Parameters
640 ----------
641 sync : one.alf.io.AlfBunch
642 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
643 and the corresponding channel numbers.
644 chmap : dict
645 A map of channel names and their corresponding indices. Only the 'neural_frames'
646 channel is required.
647 device_collection : str, iterable of str
648 The location of the raw imaging data.
649 events : pandas.DataFrame
650 A table of software events, with columns {'time_timeline' 'name_timeline',
651 'event_timeline'}.
653 Returns
654 -------
655 list of numpy.array
656 A list of timestamps for each FOV and the time offsets for each line scan.
657 """
658 frame_times = sync['times'][sync['channels'] == chmap['neural_frames']] 1cb
660 # imaging_start_time = datetime.datetime(*map(round, self.rawImagingData.meta['acquisitionStartTime']))
661 if isinstance(device_collection, str): 1cb
662 device_collection = [device_collection] 1c
663 if events is not None: 1cb
664 events = events[events.name == 'mpepUDP'] 1b
665 edges = self.get_bout_edges(frame_times, device_collection, events) 1cb
666 fov_times = [] 1cb
667 line_shifts = [] 1cb
668 for (tmin, tmax), collection in zip(edges, sorted(device_collection)): 1cb
669 imaging_data = alfio.load_object(self.session_path / collection, 'rawImagingData') 1cb
670 imaging_data['meta'] = patch_imaging_meta(imaging_data['meta']) 1cb
671 # Calculate line shifts
672 _, fov_time_shifts, line_time_shifts = self.get_timeshifts(imaging_data['meta']) 1cb
673 assert len(fov_time_shifts) == self.n_FOVs, f'unexpected number of FOVs for {collection}' 1cb
674 ts = frame_times[np.logical_and(frame_times >= tmin, frame_times <= tmax)] 1cb
675 assert ts.size >= imaging_data['times_scanImage'].size, f'fewer DAQ timestamps for {collection} than expected' 1cb
676 if ts.size > imaging_data['times_scanImage'].size: 1cb
677 _logger.warning( 1c
678 'More DAQ frame times detected for %s than were found in the raw image data.\n'
679 'N DAQ frame times:\t%i\nN raw image data times:\t%i.\n'
680 'This may occur if the bout detection fails (e.g. UDPs recorded late), '
681 'when image data is corrupt, or when frames are not written to file.',
682 collection, ts.size, imaging_data['times_scanImage'].size)
683 _logger.info('Dropping last %i frame times for %s', ts.size - imaging_data['times_scanImage'].size, collection) 1c
684 ts = ts[:imaging_data['times_scanImage'].size] 1c
685 fov_times.append([ts + offset for offset in fov_time_shifts]) 1cb
686 if not line_shifts: 1cb
687 line_shifts = line_time_shifts 1cb
688 else: # The line shifts should be the same across all imaging bouts
689 [np.testing.assert_array_equal(x, y) for x, y in zip(line_time_shifts, line_shifts)] 1b
691 # Concatenate imaging timestamps across all bouts for each field of view
692 fov_times = list(map(np.concatenate, zip(*fov_times))) 1cb
693 n_fov_times, = set(map(len, fov_times)) 1cb
694 if n_fov_times != frame_times.size: 1cb
695 # This may happen if an experimenter deletes a raw_imaging_data folder
696 _logger.debug('FOV timestamps length does not match neural frame count; imaging bout(s) likely missing') 1c
697 return fov_times + line_shifts 1cb
699 def get_bout_edges(self, frame_times, collections=None, events=None, min_gap=1., display=False):
700 """
701 Return an array of edge times for each imaging bout corresponding to a raw_imaging_data
702 collection.
704 Parameters
705 ----------
706 frame_times : numpy.array
707 An array of all neural frame count times.
708 collections : iterable of str
709 A set of raw_imaging_data collections, used to extract selected imaging periods.
710 events : pandas.DataFrame
711 A table of UDP event times, corresponding to times when recordings start and end.
712 min_gap : float
713 If start or end events not present, split bouts by finding gaps larger than this value.
714 display : bool
715 If true, plot the detected bout edges and raw frame times.
717 Returns
718 -------
719 numpy.array
720 An array of imaging bout intervals.
721 """
722 if events is None or events.empty: 1fcb
723 # No UDP events to mark blocks so separate based on gaps in frame rate
724 idx = np.where(np.diff(frame_times) > min_gap)[0] 1fc
725 starts = np.r_[frame_times[0], frame_times[idx + 1]] 1fc
726 ends = np.r_[frame_times[idx], frame_times[-1]] 1fc
727 else:
728 # Split using Exp/BlockStart and Exp/BlockEnd times
729 _, subject, date, _ = session_path_parts(self.session_path) 1fb
730 pattern = rf'(Exp|Block)%s\s{subject}\s{date.replace("-", "")}\s\d+' 1fb
732 # Get start times
733 UDP_start = events[events['info'].str.match(pattern % 'Start')] 1fb
734 if len(UDP_start) > 1 and UDP_start.loc[0, 'info'].startswith('Exp'): 1fb
735 # Use ExpStart instead of first bout start
736 UDP_start = UDP_start.copy().drop(1) 1fb
737 # Use ExpStart/End instead of first/last BlockStart/End
738 starts = frame_times[[np.where(frame_times >= t)[0][0] for t in UDP_start.time]] 1fb
740 # Get end times
741 UDP_end = events[events['info'].str.match(pattern % 'End')] 1fb
742 if len(UDP_end) > 1 and UDP_end['info'].values[-1].startswith('Exp'): 1fb
743 # Use last BlockEnd instead of ExpEnd
744 UDP_end = UDP_end.copy().drop(UDP_end.index[-1]) 1fb
745 if not UDP_end.empty: 1fb
746 ends = frame_times[[np.where(frame_times <= t)[0][-1] for t in UDP_end.time]] 1fb
747 else:
748 # Get index of last frame to occur within a second of the previous frame
749 consec = np.r_[np.diff(frame_times) > min_gap, True] 1f
750 idx = [np.where(np.logical_and(frame_times > t, consec))[0][0] for t in starts] 1f
751 ends = frame_times[idx] 1f
753 # Remove any missing imaging bout collections
754 edges = np.c_[starts, ends] 1fcb
755 if collections: 1fcb
756 if edges.shape[0] > len(collections): 1fcb
757 # Remove any bouts that correspond to a skipped collection
758 # e.g. if {raw_imaging_data_00, raw_imaging_data_02}, remove middle bout
759 include = sorted(int(c.rsplit('_', 1)[-1]) for c in collections)
760 edges = edges[include, :]
761 elif edges.shape[0] < len(collections): 1fcb
762 raise ValueError('More raw imaging folders than detected bouts') 1f
764 if display: 1fcb
765 _, ax = plt.subplots(1) 1f
766 ax.step(frame_times, np.arange(frame_times.size), label='frame times', color='k', ) 1f
767 vertical_lines(edges[:, 0], ax=ax, ymin=0, ymax=frame_times.size, label='bout start', color='b') 1f
768 vertical_lines(edges[:, 1], ax=ax, ymin=0, ymax=frame_times.size, label='bout end', color='orange') 1f
769 if edges.shape[0] != len(starts): 1f
770 vertical_lines(np.setdiff1d(starts, edges[:, 0]), ax=ax, ymin=0, ymax=frame_times.size,
771 label='missing bout start', linestyle=':', color='b')
772 vertical_lines(np.setdiff1d(ends, edges[:, 1]), ax=ax, ymin=0, ymax=frame_times.size,
773 label='missing bout end', linestyle=':', color='orange')
774 ax.set_xlabel('Time / s'), ax.set_ylabel('Frame #'), ax.legend(loc='lower right') 1f
775 return edges 1fcb
777 @staticmethod
778 def get_timeshifts(raw_imaging_meta):
779 """
780 Calculate the time shifts for each field of view (FOV) and the relative offsets for each
781 scan line.
783 For a 2 scan field, 2 depth recording (so 4 FOVs):
785 Frame 1, lines 1-512 correspond to FOV_00
786 Frame 1, lines 551-1062 correspond to FOV_01
787 Frame 2, lines 1-512 correspond to FOV_02
788 Frame 2, lines 551-1062 correspond to FOV_03
789 Frame 3, lines 1-512 correspond to FOV_00
790 ...
792 Parameters
793 ----------
794 raw_imaging_meta : dict
795 Extracted ScanImage meta data (_ibl_rawImagingData.meta.json).
797 Returns
798 -------
799 list of numpy.array
800 A list of arrays, one per FOV, containing indices of each image scan line.
801 numpy.array
802 An array of FOV time offsets (one value per FOV) relative to each frame acquisition
803 time.
804 list of numpy.array
805 A list of arrays, one per FOV, containing the time offsets for each scan line, relative
806 to each FOV offset.
807 """
808 FOVs = raw_imaging_meta['FOV'] 1kcb
810 # Double-check meta extracted properly
811 # assert meta.FOV.Zs is ascending but use slice_id field. This may not be necessary but is expected.
812 slice_ids = np.array([fov['slice_id'] for fov in FOVs]) 1kcb
813 assert np.all(np.diff([x['Zs'] for x in FOVs]) >= 0), 'FOV depths not in ascending order' 1kcb
814 assert np.all(np.diff(slice_ids) >= 0), 'slice IDs not ordered' 1kcb
815 # Number of scan lines per FOV, i.e. number of Y pixels / image height
816 n_lines = np.array([x['nXnYnZ'][1] for x in FOVs]) 1kcb
818 # We get indices from MATLAB extracted metadata so below two lines are no longer needed
819 # n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks
820 # n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) # N lines during flyback
821 line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod'] 1kcb
822 frame_time_shifts = slice_ids / raw_imaging_meta['scanImageParams']['hRoiManager']['scanFrameRate'] 1kcb
824 # Line indices are now extracted by the MATLAB function mesoscopeMetadataExtraction.m
825 # They are indexed from 1 so we subtract 1 to convert to zero-indexed
826 line_indices = [np.array(fov['lineIdx']) - 1 for fov in FOVs] # Convert to zero-indexed from MATLAB 1-indexed 1kcb
827 assert all(lns.size == n for lns, n in zip(line_indices, n_lines)), 'unexpected number of scan lines' 1kcb
828 # The start indices of each FOV in the raw images
829 fov_start_idx = np.array([lns[0] for lns in line_indices]) 1kcb
830 roi_time_shifts = fov_start_idx * line_period # The time offset for each FOV 1kcb
831 fov_time_shifts = roi_time_shifts + frame_time_shifts 1kcb
832 line_time_shifts = [(lns - ln0) * line_period for lns, ln0 in zip(line_indices, fov_start_idx)] 1kcb
834 return line_indices, fov_time_shifts, line_time_shifts 1kcb