Coverage for ibllib/io/extractors/mesoscope.py: 85%
366 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 14:26 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 14:26 +0100
1"""Mesoscope (timeline) data extraction."""
2import logging
3from itertools import chain
5import numpy as np
6from scipy.signal import find_peaks
7import one.alf.io as alfio
8from one.alf.path import session_path_parts
9from iblutil.util import ensure_list
10import matplotlib.pyplot as plt
11from packaging import version
13from ibllib.plots.misc import squares, vertical_lines
14from ibllib.io.raw_daq_loaders import (
15 extract_sync_timeline, timeline_get_channel, correct_counter_discontinuities, load_timeline_sync_and_chmap)
16import ibllib.io.extractors.base as extractors_base
17from ibllib.io.extractors.ephys_fpga import (
18 FpgaTrials, FpgaTrialsHabituation, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial)
19from ibllib.io.extractors.training_wheel import extract_wheel_moves
20from ibllib.io.extractors.camera import attribute_times
21from brainbox.behavior.wheel import velocity_filtered
23_logger = logging.getLogger(__name__)
26def patch_imaging_meta(meta: dict) -> dict:
27 """
28 Patch imaging metadata for compatibility across versions.
30 A copy of the dict is NOT returned.
32 Parameters
33 ----------
34 meta : dict
35 A folder path that contains a rawImagingData.meta file.
37 Returns
38 -------
39 dict
40 The loaded metadata file, updated to the most recent version.
41 """
42 # 2023-05-17 (unversioned) adds nFrames, channelSaved keys, MM and Deg keys
43 ver = version.parse(meta.get('version') or '0.0.0') 1jnoldbc
44 if ver <= version.parse('0.0.0'): 1jnoldbc
45 if 'channelSaved' not in meta: 1jno
46 meta['channelSaved'] = next((x['channelIdx'] for x in meta['FOV'] if 'channelIdx' in x), []) 1jno
47 fields = ('topLeft', 'topRight', 'bottomLeft', 'bottomRight') 1jno
48 for fov in meta.get('FOV', []): 1jno
49 for unit in ('Deg', 'MM'): 1jno
50 if unit not in fov: # topLeftDeg, etc. -> Deg[topLeft] 1jno
51 fov[unit] = {f: fov.pop(f + unit, None) for f in fields} 1jno
52 elif ver == version.parse('0.1.0'): 1jldbc
53 for fov in meta.get('FOV', []): 1j
54 if 'roiUuid' in fov: 1j
55 fov['roiUUID'] = fov.pop('roiUuid') 1j
56 # 2024-09-17 Modified the 2 unit vectors for the positive ML axis and the positive AP axis,
57 # which then transform [X,Y] coordinates (in degrees) to [ML,AP] coordinates (in MM).
58 if ver < version.Version('0.1.5') and 'imageOrientation' in meta: 1jnoldbc
59 pos_ml, pos_ap = meta['imageOrientation']['positiveML'], meta['imageOrientation']['positiveAP'] 1jlbc
60 center_ml, center_ap = meta['centerMM']['ML'], meta['centerMM']['AP'] 1jlbc
61 res = meta['scanImageParams']['objectiveResolution'] 1jlbc
62 # previously [[0, res/1000], [-res/1000, 0], [0, 0]]
63 TF = np.linalg.pinv(np.c_[np.vstack([pos_ml, pos_ap, [0, 0]]), [1, 1, 1]]) @ \ 1jlbc
64 (np.array([[res / 1000, 0], [0, res / 1000], [0, 0]]) + np.array([center_ml, center_ap]))
65 TF = np.round(TF, 3) # handle floating-point error by rounding 1jlbc
66 if not np.allclose(TF, meta['coordsTF']): 1jlbc
67 meta['coordsTF'] = TF.tolist() 1jlbc
68 centerDegXY = np.array([meta['centerDeg']['x'], meta['centerDeg']['y']]) 1jlbc
69 for fov in meta.get('FOV', []): 1jlbc
70 fov['MM'] = {k: (np.r_[np.array(v) - centerDegXY, 1] @ TF).tolist() for k, v in fov['Deg'].items()} 1jlbc
72 assert 'nFrames' in meta, '"nFrames" key missing from meta data; rawImagingData.meta.json likely an old version' 1jnoldbc
73 return meta 1jnoldbc
76def plot_timeline(timeline, channels=None, raw=True):
77 """
78 Plot the timeline data.
80 Parameters
81 ----------
82 timeline : one.alf.io.AlfBunch
83 The timeline data object.
84 channels : list of str
85 An iterable of channel names to plot.
86 raw : bool
87 If true, plot the raw DAQ samples; if false, apply TTL thresholds and plot changes.
89 Returns
90 -------
91 matplotlib.pyplot.Figure
92 The figure containing timeline subplots.
93 list of matplotlib.pyplot.Axes
94 The axes for each timeline channel plotted.
95 """
96 meta = {x.copy().pop('name'): x for x in timeline['meta']['inputs']} 1k
97 channels = channels or meta.keys() 1k
98 fig, axes = plt.subplots(len(channels), 1, sharex=True) 1k
99 axes = ensure_list(axes) 1k
100 if not raw: 1k
101 chmap = {ch: meta[ch]['arrayColumn'] for ch in channels} 1k
102 sync = extract_sync_timeline(timeline, chmap=chmap) 1k
103 for i, (ax, ch) in enumerate(zip(axes, channels)): 1k
104 if raw: 1k
105 # axesScale controls vertical scaling of each trace (multiplicative)
106 values = timeline['raw'][:, meta[ch]['arrayColumn'] - 1] * meta[ch]['axesScale'] 1k
107 ax.plot(timeline['timestamps'], values) 1k
108 elif np.any(idx := sync['channels'] == chmap[ch]): 1k
109 squares(sync['times'][idx], sync['polarities'][idx], ax=ax) 1k
110 ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) 1k
111 ax.spines['bottom'].set_visible(False), ax.spines['left'].set_visible(True) 1k
112 ax.set_ylabel(ch, rotation=45, fontsize=8) 1k
113 # Add back x-axis ticks to the last plot
114 axes[-1].tick_params(axis='x', which='both', bottom=True, labelbottom=True) 1k
115 axes[-1].spines['bottom'].set_visible(True) 1k
116 plt.get_current_fig_manager().window.showMaximized() # full screen 1k
117 fig.tight_layout(h_pad=0) 1k
118 return fig, axes 1k
121class TimelineTrials(FpgaTrials):
122 """Similar extraction to the FPGA, however counter and position channels are treated differently."""
124 timeline = None
125 """one.alf.io.AlfBunch: The timeline data object."""
127 sync_field = 'itiIn_times'
128 """str: The trial event to synchronize (must be present in extracted trials)."""
130 def __init__(self, *args, sync_collection='raw_sync_data', **kwargs):
131 """An extractor for all ephys trial data, in Timeline time"""
132 super().__init__(*args, **kwargs) 1giae
133 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') 1giae
135 def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_):
136 """Load the DAQ sync and channel map data.
138 Parameters
139 ----------
140 sync_collection : str
141 The session subdirectory where the sync data are located.
142 chmap : dict
143 A map of channel names and their corresponding indices. If None, the channel map is
144 loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method.
146 Returns
147 -------
148 one.alf.io.AlfBunch
149 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
150 and the corresponding channel numbers.
151 dict
152 A map of channel names and their corresponding indices.
153 """
154 if not self.timeline: 1ae
155 self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline')
156 sync, chmap = load_timeline_sync_and_chmap( 1ae
157 self.session_path / sync_collection, timeline=self.timeline, chmap=chmap)
158 return sync, chmap 1ae
160 def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict:
161 trials = super()._extract(sync, chmap, sync_collection=sync_collection, **kwargs) 1ae
162 if kwargs.get('display', False): 1ae
163 plot_timeline(self.timeline, channels=chmap.keys(), raw=True)
164 return trials 1ae
166 def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
167 """
168 Extract Bpod times from sync.
170 Unlike the superclass method. This one doesn't reassign the first trial pulse.
172 Parameters
173 ----------
174 sync : dict
175 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
176 and the corresponding channel numbers. Must contain a 'bpod' key.
177 chmap : dict
178 A map of channel names and their corresponding indices.
179 bpod_event_ttls : dict of tuple
180 A map of event names to (min, max) TTL length.
182 Returns
183 -------
184 dict
185 A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
186 dict
187 A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
188 """
189 # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
190 # lengths are defined by the state machine of the task protocol and therefore vary.
191 if bpod_event_ttls is None: 1a
192 # The trial start TTLs are often too short for the low sampling rate of the DAQ and are
193 # therefore not used in extraction
194 bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} 1a
195 bpod, bpod_event_intervals = super().get_bpod_event_times( 1a
196 sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs)
198 # TODO Here we can make use of the 'bpod_rising_edge' channel, if available
199 return bpod, bpod_event_intervals 1a
201 def build_trials(self, sync=None, chmap=None, **kwargs):
202 """
203 Extract task related event times from the sync.
205 The two major differences are that the sampling rate is lower for imaging so the short Bpod
206 trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL.
208 Second, the valve used at the mesoscope has a way to record the raw voltage across the
209 solenoid, giving a more accurate readout of the valve's activity. If the reward_valve
210 channel is present on the DAQ, this is used to extract the valve open times.
212 Parameters
213 ----------
214 sync : dict
215 'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
216 chmap : dict
217 Map of channel names and their corresponding index. Default to constant.
219 Returns
220 -------
221 dict
222 A map of trial event timestamps.
223 """
224 # Get the events from the sync.
225 # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
226 self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) 1a
227 self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) 1a
228 if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: 1a
229 raise ValueError(
230 'Expected at least "ready_tone" and "error_tone" audio events.'
231 '`audio_event_ttls` kwarg may be incorrect.')
233 self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) 1a
234 if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}: 1a
235 raise ValueError(
236 'Expected at least "trial_end" and "valve_open" audio events. '
237 '`bpod_event_ttls` kwarg may be incorrect.')
239 t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T 1a
240 fpga_events = alfio.AlfBunch({ 1a
241 'itiIn_times': t_iti_in,
242 'intervals_1': t_trial_end,
243 'goCue_times': audio_event_intervals['ready_tone'][:, 0],
244 'errorTone_times': audio_event_intervals['error_tone'][:, 0]
245 })
247 # Sync the Bpod clock to the DAQ
248 self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) 1a
250 out = dict() 1a
251 out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) 1a
252 out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) 1a
254 start_times = out['intervals'][:, 0] 1a
255 last_trial_end = out['intervals'][-1, 1] 1a
257 def assign_to_trial(events, take='last', starts=start_times, **kwargs): 1a
258 """Assign DAQ events to trials.
260 Because we may not have trial start TTLs on the DAQ (because of the low sampling rate),
261 there may be an extra last trial that's not in the Bpod intervals as the extractor
262 ignores the last trial. This function trims the input array before assigning so that
263 the last trial's events are correctly assigned.
264 """
265 return _assign_events_to_trial(starts, events[events <= last_trial_end], take, **kwargs) 1a
266 out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga]) 1a
268 # Extract valve open times from the DAQ
269 valve_driver_ttls = bpod_event_intervals['valve_open'] 1a
270 correct = self.bpod_trials['feedbackType'] == 1 1a
271 # If there is a reward_valve channel, the voltage across the valve has been recorded and
272 # should give a more accurate readout of the valve's activity.
273 if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): 1a
274 # TODO Let's look at the expected open length based on calibration and reward volume
275 # import scipy.interpolate
276 # # FIXME support v7 settings?
277 # fcn_vol2time = scipy.interpolate.pchip(
278 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'],
279 # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES']
280 # )
281 # reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3
283 # Use the driver TTLs to find the valve open times that correspond to the valve opening
284 valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) 1a
285 if valve_open_times.size != np.sum(correct): 1a
286 _logger.warning(
287 'Number of valve open times does not equal number of correct trials (%i != %i)',
288 valve_open_times.size, np.sum(correct))
290 out['valveOpen_times'] = assign_to_trial(valve_open_times) 1a
291 else:
292 # Use the valve controller TTLs recorded on the Bpod channel as the reward time
293 out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0])
295 # Stimulus times extracted based on trigger times
296 # When assigning events all start times must not be NaN so here we substitute freeze
297 # trigger times on nogo trials for stim on trigger times, then replace with NaN again
298 go_trials = np.where(out['choice'] != 0)[0] 1a
299 lims = np.copy(out['stimOnTrigger_times']) 1a
300 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a
301 out['stimFreeze_times'] = assign_to_trial( 1a
302 self.frame2ttl['times'], 'last',
303 starts=lims, t_trial_end=out['stimOffTrigger_times'])
304 out['stimFreeze_times'][out['choice'] == 0] = np.nan 1a
306 # Here we do the same but use stim off trigger times
307 lims = np.copy(out['stimOffTrigger_times']) 1a
308 lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] 1a
309 out['stimOn_times'] = assign_to_trial( 1a
310 self.frame2ttl['times'], 'first',
311 starts=out['stimOnTrigger_times'], t_trial_end=lims)
312 out['stimOff_times'] = assign_to_trial( 1a
313 self.frame2ttl['times'], 'first',
314 starts=out['stimOffTrigger_times'], t_trial_end=out['intervals'][:, 1]
315 )
317 # Audio times
318 error_cue = fpga_events['errorTone_times'] 1a
319 if error_cue.size != np.sum(~correct): 1a
320 _logger.warning(
321 'N detected error tones does not match number of incorrect trials (%i != %i)',
322 error_cue.size, np.sum(~correct))
323 go_cue = fpga_events['goCue_times'] 1a
324 out['goCue_times'] = assign_to_trial(go_cue, take='first') 1a
325 out['errorCue_times'] = assign_to_trial(error_cue) 1a
327 if go_cue.size > start_times.size: 1a
328 _logger.warning(
329 'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size)
330 elif go_cue.size < start_times.size: 1a
331 """
332 If the error cues are all assigned and some go cues are missed it may be that some
333 responses were so fast that the go cue and error tone merged, or the go cue TTL was too
334 long.
335 """
336 _logger.warning('%i go cue tones missed', start_times.size - go_cue.size)
337 err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times'])
338 go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times'])
339 assert not np.any(np.isnan(go_trig))
340 assert err_trig.size == go_trig.size # should be length of n trials with NaNs
342 # Find which trials are missing a go cue
343 _go_cue = assign_to_trial(go_cue, take='first')
344 error_cue = assign_to_trial(error_cue)
345 missing = np.isnan(_go_cue)
347 # Get all the DAQ timestamps where audio channel was HIGH
348 raw = timeline_get_channel(self.timeline, 'audio')
349 raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize
350 ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH
352 # Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after).
353 # Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN)
354 idx = attribute_times(ups, go_trig, tol=0.2, take='after')
355 # Trial indices that didn't have detected goCue and now has been assigned an `ups` index
356 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned
357 _go_cue[assigned] = ups[idx[assigned]]
359 # Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue)
360 error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig))
361 i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True)
362 error_cue[i_to_remove] = np.nan
364 # For those trials where go cue was merged with the error cue and therefore mis-assigned,
365 # we must re-assign the error cue times as the first HIGH after the error trigger.
366 idx = attribute_times(ups, err_trig, tol=0.2, take='after')
367 assigned = np.where(idx != -1 & missing)[0] # ignore unassigned
368 error_cue[assigned] = ups[idx[assigned]]
369 out['goCue_times'] = _go_cue
370 out['errorCue_times'] = error_cue
372 # Because we're not
373 assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \ 1a
374 'audio tones not assigned correctly; tones likely missed'
376 # Feedback times
377 out['feedback_times'] = np.copy(out['valveOpen_times']) 1a
378 ind_err = np.isnan(out['valveOpen_times']) 1a
379 out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] 1a
381 return out 1a
383 def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None):
384 """
385 Gets the wheel position from Timeline counter channel.
387 Parameters
388 ----------
389 ticks : int
390 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder).
391 radius : float
392 Radius of the wheel. Defaults to 1 for an output in radians.
393 coding : str {'x1', 'x2', 'x4'}
394 Rotary encoder encoding (IBL default is x4).
395 tmin : float
396 The minimum time from which to extract the sync pulses.
397 tmax : float
398 The maximum time up to which we extract the sync pulses.
400 Returns
401 -------
402 np.array
403 Wheel timestamps in seconds.
404 np.array
405 Wheel positions in radians.
407 See Also
408 --------
409 ibllib.io.extractors.ephys_fpga.extract_wheel_sync
410 """
411 if coding not in ('x1', 'x2', 'x4'): 1ia
412 raise ValueError('Unsupported coding; must be one of x1, x2 or x4') 1i
413 raw = correct_counter_discontinuities(timeline_get_channel(self.timeline, 'rotary_encoder')) 1ia
415 # Timeline evenly samples counter so we extract only change points
416 d = np.diff(raw) 1ia
417 ind, = np.where(~np.isclose(d, 0)) 1ia
418 pos = raw[ind + 1] 1ia
419 pos -= pos[0] # Start from zero 1ia
420 pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians 1ia
422 # Get timestamps of changes and trim based on protocol spacers
423 ts = self.timeline['timestamps'][ind + 1] 1ia
424 tmin = ts.min() if tmin is None else tmin 1ia
425 tmax = ts.max() if tmax is None else tmax 1ia
426 mask = np.logical_and(ts >= tmin, ts <= tmax) 1ia
427 return ts[mask], pos[mask] 1ia
429 def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4',
430 tmin=None, tmax=None, display=False, **kwargs):
431 """
432 Gets the wheel position and detected movements from Timeline counter channel.
434 Called by the super class extractor (FPGATrials._extract).
436 Parameters
437 ----------
438 ticks : int
439 Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder).
440 radius : float
441 Radius of the wheel. Defaults to 1 for an output in radians.
442 coding : str {'x1', 'x2', 'x4'}
443 Rotary encoder encoding (IBL default is x4).
444 tmin : float
445 The minimum time from which to extract the sync pulses.
446 tmax : float
447 The maximum time up to which we extract the sync pulses.
448 display : bool
449 If true, plot the wheel positions from bpod and the DAQ.
451 Returns
452 -------
453 dict
454 wheel object with keys ('timestamps', 'position').
455 dict
456 wheelMoves object with keys ('intervals' 'peakAmplitude').
457 """
458 wheel = self.extract_wheel_sync(ticks=ticks, radius=radius, coding=coding, tmin=tmin, tmax=tmax) 1ia
459 wheel = dict(zip(('timestamps', 'position'), wheel)) 1ia
460 moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) 1ia
462 if display: 1ia
463 assert self.bpod_trials, 'no bpod trials to compare' 1i
464 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1i
465 bpod_ts = self.bpod_trials['wheel_timestamps'] 1i
466 bpod_pos = self.bpod_trials['wheel_position'] 1i
467 ax0.plot(self.bpod2fpga(bpod_ts), bpod_pos) 1i
468 ax0.set_ylabel('Bpod wheel position / rad') 1i
469 ax1.plot(wheel['timestamps'], wheel['position']) 1i
470 ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') 1i
471 return wheel, moves 1ia
473 def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None):
474 """
475 Get the valve open times from the raw timeline voltage trace.
477 Parameters
478 ----------
479 display : bool
480 Plot detected times on the raw voltage trace.
481 threshold : float
482 The threshold of voltage change to apply. The default was set by eye; units should be
483 Volts per sample but doesn't appear to be.
484 driver_ttls : numpy.array
485 An optional array of driver TTLs to use for assigning with the valve times.
487 Returns
488 -------
489 numpy.array
490 The detected valve open intervals.
491 numpy.array
492 If driver_ttls is not None, returns an array of open times that occurred directly after
493 the driver TTLs.
494 """
495 WARN_THRESH = 10e-3 # open time threshold below which to log warning 1gae
496 tl = self.timeline 1gae
497 info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') 1gae
498 values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 1gae
500 # The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz
501 # making simple thresholding an issue. For this reason we convolve the signal with a
502 # window and detect the peaks and troughs.
503 if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz 1gae
504 _logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs) 1gae
505 dt = 1e-3 # change in voltage takes ~1ms when changing valve open state 1gae
506 N = dt / (1 / Fs) # this means voltage change occurs over N samples 1gae
507 vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time 1gae
508 ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V) 1gae
509 downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V) 1gae
511 # Convert these times into intervals
512 ixs = np.argsort(np.r_[downs, ups]) # sort indices 1gae
513 times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times 1gae
514 polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted 1gae
515 missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN 1gae
516 times = np.insert(times, missing + int(polarities[0] == -1), np.nan) 1gae
517 if polarities[-1] == -1: # ensure ends with a valve close 1gae
518 times = np.r_[times, np.nan] 1ga
519 if polarities[0] == 1: # ensure starts with a valve open 1gae
520 # It seems it can start out at -5V (open), then when the reward happens it closes and
521 # immediately opens. In this case we insert discard the first open time.
522 times = np.r_[np.nan, times]
523 intervals = times.reshape(-1, 2) 1gae
525 # Log warning of improbably short intervals
526 short = np.sum(np.diff(intervals) < WARN_THRESH) 1gae
527 if short > 0: 1gae
528 _logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH)
530 # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL
531 if driver_ttls is not None: 1gae
532 # Returns an array of open_times indices, one for each driver TTL
533 ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after') 1ae
534 open_times = intervals[ind[ind >= 0], 0] 1ae
535 # TODO Log any > 40ms? Difficult to report missing valve times because of calibration
537 if display: 1gae
538 fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) 1g
539 ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-') 1g
540 if driver_ttls is not None: 1g
541 x = np.empty_like(driver_ttls.flatten())
542 x[0::2] = driver_ttls[:, 0]
543 x[1::2] = driver_ttls[:, 1]
544 y = np.ones_like(x)
545 y[1::2] -= 2
546 squares(x, y, ax=ax0, yrange=[0, 5])
547 # vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b')
548 ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*')
549 ax1.plot(tl['timestamps'], values, 'k-o') 1g
550 ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') 1g
552 ax2 = ax1.twinx() 1g
553 ax2.set_ylabel('dV', color='grey') 1g
554 ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey') 1g
555 ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close') 1g
556 ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open') 1g
557 return intervals if driver_ttls is None else (intervals, open_times) 1gae
559 def _assign_events_audio(self, audio_times, audio_polarities, display=False):
560 """
561 This is identical to ephys_fpga._assign_events_audio, except for the ready tone threshold.
563 Parameters
564 ----------
565 audio_times : numpy.array
566 An array of audio TTL front times.
567 audio_polarities : numpy.array
568 An array of audio TTL front polarities (1 for rises, -1 for falls).
569 display : bool
570 If true, display audio pulses and the assigned onsets.
572 Returns
573 -------
574 numpy.array
575 The times of the go cue onsets.
576 numpy.array
577 The times of the error tone onsets.
578 """
579 # make sure that there are no 2 consecutive fall or consecutive rise events
580 assert np.all(np.abs(np.diff(audio_polarities)) == 2)
581 # take only even time differences: i.e. from rising to falling fronts
582 dt = np.diff(audio_times)
583 onsets = audio_polarities[:-1] == 1
585 # error tones are events lasting from 400ms to 1200ms
586 i_error_tone_in = np.where(np.logical_and(0.4 < dt, dt < 1.2) & onsets)[0]
587 t_error_tone_in = audio_times[i_error_tone_in]
589 # detect ready tone by length below 300 ms
590 i_ready_tone_in = np.where(np.logical_and(dt <= 0.3, onsets))[0]
591 t_ready_tone_in = audio_times[i_ready_tone_in]
592 if display: # pragma: no cover
593 fig, ax = plt.subplots(nrows=2, sharex=True)
594 ax[0].plot(self.timeline.timestamps, timeline_get_channel(self.timeline, 'audio'), 'k-o')
595 ax[0].set_ylabel('Voltage / V')
596 squares(audio_times, audio_polarities, yrange=[-1, 1], ax=ax[1])
597 vertical_lines(t_ready_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='go cue')
598 vertical_lines(t_error_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='error tone')
599 ax[1].set_xlabel('Time / s')
600 ax[1].legend()
602 return t_ready_tone_in, t_error_tone_in
605class TimelineTrialsHabituation(FpgaTrialsHabituation, TimelineTrials):
606 """Habituation trials extraction from Timeline DAQ data."""
608 sync_field = 'intervals_0'
610 def build_trials(self, sync=None, chmap=None, **kwargs):
611 """
612 Extract task related event times from the sync.
614 The valve used at the mesoscope has a way to record the raw voltage across the solenoid,
615 giving a more accurate readout of the valve's activity. If the reward_valve channel is
616 present on the DAQ, this is used to extract the valve open times.
618 Parameters
619 ----------
620 sync : dict
621 'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
622 chmap : dict
623 Map of channel names and their corresponding index. Default to constant.
625 Returns
626 -------
627 dict
628 A map of trial event timestamps.
629 """
630 out = super().build_trials(sync, chmap, **kwargs) 1e
632 start_times = out['intervals'][:, 0] 1e
633 last_trial_end = out['intervals'][-1, 1] 1e
635 # Extract valve open times from the DAQ
636 _, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) 1e
637 bpod_feedback_times = self.bpod2fpga(self.bpod_trials['feedback_times']) 1e
638 valve_driver_ttls = bpod_event_intervals['valve_open'] 1e
639 # If there is a reward_valve channel, the voltage across the valve has been recorded and
640 # should give a more accurate readout of the valve's activity.
641 if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): 1e
642 # Use the driver TTLs to find the valve open times that correspond to the valve opening
643 valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) 1e
644 if valve_open_times.size != start_times.size: 1e
645 _logger.warning( 1e
646 'Number of valve open times does not equal number of correct trials (%i != %i)',
647 valve_open_times.size, start_times.size)
648 else:
649 # Use the valve controller TTLs recorded on the Bpod channel as the reward time
650 valve_open_times = valve_driver_ttls[:, 0]
651 # there may be an extra last trial that's not in the Bpod intervals as the extractor ignores the last trial
652 valve_open_times = valve_open_times[valve_open_times <= last_trial_end] 1e
653 out['valveOpen_times'] = _assign_events_to_trial( 1e
654 bpod_feedback_times, valve_open_times, take='first', t_trial_end=out['intervals'][:, 1])
656 # Feedback times
657 out['feedback_times'] = np.copy(out['valveOpen_times']) 1e
659 return out 1e
662class MesoscopeSyncTimeline(extractors_base.BaseExtractor):
663 """Extraction of mesoscope imaging times."""
665 var_names = ('mpci_times', 'mpciStack_timeshift')
666 save_names = ('mpci.times.npy', 'mpciStack.timeshift.npy')
668 """one.alf.io.AlfBunch: The raw imaging meta data and frame times"""
669 rawImagingData = None
671 def __init__(self, session_path, n_FOVs):
672 """
673 Extract the mesoscope frame times from DAQ data acquired through Timeline.
675 Parameters
676 ----------
677 session_path : str, pathlib.Path
678 The session path to extract times from.
679 n_FOVs : int
680 The number of fields of view acquired.
681 """
682 super().__init__(session_path) 1hdbc
683 self.n_FOVs = n_FOVs 1hdbc
684 fov = list(map(lambda n: f'FOV_{n:02}', range(self.n_FOVs))) 1hdbc
685 self.var_names = [f'{x}_{y.lower()}' for x in self.var_names for y in fov] 1hdbc
686 self.save_names = [f'{y}/{x}' for x in self.save_names for y in fov] 1hdbc
688 def _extract(self, sync=None, chmap=None, device_collection='raw_imaging_data', events=None, use_volume_counter=False):
689 """
690 Extract the frame timestamps for each individual field of view (FOV) and the time offsets
691 for each line scan.
693 The detected frame times from the 'neural_frames' channel of the DAQ are split into bouts
694 corresponding to the number of raw_imaging_data folders. These timestamps should match the
695 number of frame timestamps extracted from the image file headers (found in the
696 rawImagingData.times file). The field of view (FOV) shifts are then applied to these
697 timestamps for each field of view and provided together with the line shifts.
699 Note that for single plane sessions, the 'neural_frames' and 'volume_counter' channels are
700 identical. For multi-depth sessions, 'neural_frames' contains the frame times for each
701 depth acquired.
703 Parameters
704 ----------
705 sync : one.alf.io.AlfBunch
706 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
707 and the corresponding channel numbers.
708 chmap : dict
709 A map of channel names and their corresponding indices. Only the 'neural_frames'
710 channel is required.
711 device_collection : str, iterable of str
712 The location of the raw imaging data.
713 events : pandas.DataFrame
714 A table of software events, with columns {'time_timeline' 'name_timeline',
715 'event_timeline'}.
716 use_volume_counter : bool
717 If True, use the 'volume_counter' channel to extract the frame times. On the scale of
718 calcium dynamics, it shouldn't matter whether we read specifically the timing of each
719 slice, or assume that they are equally spaced between the volume_counter pulses. But
720 in cases where each depth doesn't have the same nr of FOVs / scanlines, some depths
721 will be faster than others, so it would be better to read out the neural frames for
722 the purpose of computing the correct timeshifts per line. This can be set to True
723 for legacy extractions.
725 Returns
726 -------
727 list of numpy.array
728 A list of timestamps for each FOV and the time offsets for each line scan.
729 """
730 volume_times = sync['times'][sync['channels'] == chmap['volume_counter']] 1dbc
731 frame_times = sync['times'][sync['channels'] == chmap['neural_frames']] 1dbc
732 # imaging_start_time = datetime.datetime(*map(round, self.rawImagingData.meta['acquisitionStartTime']))
733 if isinstance(device_collection, str): 1dbc
734 device_collection = [device_collection] 1c
735 if events is not None: 1dbc
736 events = events[events.name == 'mpepUDP'] 1db
737 edges = self.get_bout_edges(frame_times, device_collection, events) 1dbc
738 fov_times = [] 1dbc
739 line_shifts = [] 1dbc
740 for (tmin, tmax), collection in zip(edges, sorted(device_collection)): 1dbc
741 imaging_data = alfio.load_object(self.session_path / collection, 'rawImagingData') 1dbc
742 imaging_data['meta'] = patch_imaging_meta(imaging_data['meta']) 1dbc
743 # Calculate line shifts
744 _, fov_time_shifts, line_time_shifts = self.get_timeshifts(imaging_data['meta']) 1dbc
745 assert len(fov_time_shifts) == self.n_FOVs, f'unexpected number of FOVs for {collection}' 1dbc
746 vts = volume_times[np.logical_and(volume_times >= tmin, volume_times <= tmax)] 1dbc
747 ts = frame_times[np.logical_and(frame_times >= tmin, frame_times <= tmax)] 1dbc
748 assert ts.size >= imaging_data['times_scanImage'].size, \ 1dbc
749 (f'fewer DAQ timestamps for {collection} than expected: '
750 f'DAQ/frames = {ts.size}/{imaging_data["times_scanImage"].size}')
751 if ts.size > imaging_data['times_scanImage'].size: 1dbc
752 _logger.warning( 1c
753 'More DAQ frame times detected for %s than were found in the raw image data.\n'
754 'N DAQ frame times:\t%i\nN raw image data times:\t%i.\n'
755 'This may occur if the bout detection fails (e.g. UDPs recorded late), '
756 'when image data is corrupt, or when frames are not written to file.',
757 collection, ts.size, imaging_data['times_scanImage'].size)
758 _logger.info('Dropping last %i frame times for %s', ts.size - imaging_data['times_scanImage'].size, collection) 1c
759 vts = vts[vts < ts[imaging_data['times_scanImage'].size]] 1c
760 ts = ts[:imaging_data['times_scanImage'].size] 1c
762 # A 'slice_id' is a ScanImage 'ROI', comprising a collection of 'scanfields' a.k.a. slices at different depths
763 # The total number of 'scanfields' == len(imaging_data['meta']['FOV'])
764 slice_ids = np.array([x['slice_id'] for x in imaging_data['meta']['FOV']]) 1dbc
765 unique_areas, slice_counts = np.unique(slice_ids, return_counts=True) 1dbc
766 n_unique_areas = len(unique_areas) 1dbc
768 if use_volume_counter: 1dbc
769 # This is the simple, less accurate way of extrating imaging times
770 fov_times.append([vts + offset for offset in fov_time_shifts]) 1d
771 else:
772 if len(np.unique(slice_counts)) != 1: 1dbc
773 # A different number of depths per FOV may no longer be an issue with this new method
774 # of extracting imaging times, but the below assertion is kept as it's not tested and
775 # not implemented for a different number of scanlines per FOV
776 _logger.warning(
777 'different number of slices per area (i.e. scanfields per ROI) (%s).',
778 ' vs '.join(map(str, slice_counts)))
779 # This gets the imaging times for each FOV, respecting the order of the scanfields in multidepth imaging
780 fov_times.append(list(chain.from_iterable( 1dbc
781 [ts[i::n_unique_areas][:vts.size] + offset for offset in fov_time_shifts[:n_depths]]
782 for i, n_depths in enumerate(slice_counts)
783 )))
785 if not line_shifts: 1dbc
786 line_shifts = line_time_shifts 1dbc
787 else: # The line shifts should be the same across all imaging bouts
788 [np.testing.assert_array_equal(x, y) for x, y in zip(line_time_shifts, line_shifts)] 1db
790 # Concatenate imaging timestamps across all bouts for each field of view
791 fov_times = list(map(np.concatenate, zip(*fov_times))) 1dbc
792 n_fov_times, = set(map(len, fov_times)) 1dbc
793 if n_fov_times != volume_times.size: 1dbc
794 # This may happen if an experimenter deletes a raw_imaging_data folder
795 _logger.debug('FOV timestamps length does not match neural frame count; imaging bout(s) likely missing') 1dc
796 return fov_times + line_shifts 1dbc
798 def get_bout_edges(self, frame_times, collections=None, events=None, min_gap=1., display=False):
799 """
800 Return an array of edge times for each imaging bout corresponding to a raw_imaging_data
801 collection.
803 Parameters
804 ----------
805 frame_times : numpy.array
806 An array of all neural frame count times.
807 collections : iterable of str
808 A set of raw_imaging_data collections, used to extract selected imaging periods.
809 events : pandas.DataFrame
810 A table of UDP event times, corresponding to times when recordings start and end.
811 min_gap : float
812 If start or end events not present, split bouts by finding gaps larger than this value.
813 display : bool
814 If true, plot the detected bout edges and raw frame times.
816 Returns
817 -------
818 numpy.array
819 An array of imaging bout intervals.
820 """
821 if events is None or events.empty: 1hdbc
822 # No UDP events to mark blocks so separate based on gaps in frame rate
823 idx = np.where(np.diff(frame_times) > min_gap)[0] 1hc
824 starts = np.r_[frame_times[0], frame_times[idx + 1]] 1hc
825 ends = np.r_[frame_times[idx], frame_times[-1]] 1hc
826 else:
827 # Split using Exp/BlockStart and Exp/BlockEnd times
828 _, subject, date, _ = session_path_parts(self.session_path) 1hdb
829 pattern = rf'(Exp|Block)%s\s{subject}\s{date.replace("-", "")}\s\d+' 1hdb
831 # Get start times
832 UDP_start = events[events['info'].str.match(pattern % 'Start')] 1hdb
833 if len(UDP_start) > 1 and UDP_start.loc[0, 'info'].startswith('Exp'): 1hdb
834 # Use ExpStart instead of first bout start
835 UDP_start = UDP_start.copy().drop(1) 1hdb
836 # Use ExpStart/End instead of first/last BlockStart/End
837 starts = frame_times[[np.where(frame_times >= t)[0][0] for t in UDP_start.time]] 1hdb
839 # Get end times
840 UDP_end = events[events['info'].str.match(pattern % 'End')] 1hdb
841 if len(UDP_end) > 1 and UDP_end['info'].values[-1].startswith('Exp'): 1hdb
842 # Use last BlockEnd instead of ExpEnd
843 UDP_end = UDP_end.copy().drop(UDP_end.index[-1]) 1hdb
844 if not UDP_end.empty: 1hdb
845 ends = frame_times[[np.where(frame_times <= t)[0][-1] for t in UDP_end.time]] 1hdb
846 else:
847 # Get index of last frame to occur within a second of the previous frame
848 consec = np.r_[np.diff(frame_times) > min_gap, True] 1h
849 idx = [np.where(np.logical_and(frame_times > t, consec))[0][0] for t in starts] 1h
850 ends = frame_times[idx] 1h
852 # Remove any missing imaging bout collections
853 edges = np.c_[starts, ends] 1hdbc
854 if collections: 1hdbc
855 if edges.shape[0] > len(collections): 1hdbc
856 # Remove any bouts that correspond to a skipped collection
857 # e.g. if {raw_imaging_data_00, raw_imaging_data_02}, remove middle bout
858 include = sorted(int(c.rsplit('_', 1)[-1]) for c in collections)
859 edges = edges[include, :]
860 elif edges.shape[0] < len(collections): 1hdbc
861 raise ValueError('More raw imaging folders than detected bouts') 1h
863 if display: 1hdbc
864 _, ax = plt.subplots(1) 1h
865 ax.step(frame_times, np.arange(frame_times.size), label='frame times', color='k', ) 1h
866 vertical_lines(edges[:, 0], ax=ax, ymin=0, ymax=frame_times.size, label='bout start', color='b') 1h
867 vertical_lines(edges[:, 1], ax=ax, ymin=0, ymax=frame_times.size, label='bout end', color='orange') 1h
868 if edges.shape[0] != len(starts): 1h
869 vertical_lines(np.setdiff1d(starts, edges[:, 0]), ax=ax, ymin=0, ymax=frame_times.size,
870 label='missing bout start', linestyle=':', color='b')
871 vertical_lines(np.setdiff1d(ends, edges[:, 1]), ax=ax, ymin=0, ymax=frame_times.size,
872 label='missing bout end', linestyle=':', color='orange')
873 ax.set_xlabel('Time / s'), ax.set_ylabel('Frame #'), ax.legend(loc='lower right') 1h
874 return edges 1hdbc
876 @staticmethod
877 def get_timeshifts(raw_imaging_meta):
878 """
879 Calculate the time shifts for each field of view (FOV) and the relative offsets for each
880 scan line.
882 For a 2 area (i.e. 'ROI'), 2 depth recording (so 4 FOVs):
884 Frame 1, lines 1-512 correspond to FOV_00
885 Frame 1, lines 551-1062 correspond to FOV_01
886 Frame 2, lines 1-512 correspond to FOV_02
887 Frame 2, lines 551-1062 correspond to FOV_03
888 Frame 3, lines 1-512 correspond to FOV_00
889 ...
891 All areas are acquired for each depth such that...
893 FOV_00 = area 1, depth 1
894 FOV_01 = area 2, depth 1
895 FOV_02 = area 1, depth 2
896 FOV_03 = area 2, depth 2
898 Parameters
899 ----------
900 raw_imaging_meta : dict
901 Extracted ScanImage meta data (_ibl_rawImagingData.meta.json).
903 Returns
904 -------
905 list of numpy.array
906 A list of arrays, one per FOV, containing indices of each image scan line.
907 numpy.array
908 An array of FOV time offsets (one value per FOV) relative to each frame acquisition
909 time.
910 list of numpy.array
911 A list of arrays, one per FOV, containing the time offsets for each scan line, relative
912 to each FOV offset.
913 """
914 FOVs = raw_imaging_meta['FOV'] 1mdbc
916 # Double-check meta extracted properly
917 # assert meta.FOV.Zs is ascending but use slice_id field. This may not be necessary but is expected.
918 slice_ids = np.array([fov['slice_id'] for fov in FOVs]) 1mdbc
919 assert np.all(np.diff([x['Zs'] for x in FOVs]) >= 0), 'FOV depths not in ascending order' 1mdbc
920 assert np.all(np.diff(slice_ids) >= 0), 'slice IDs not ordered' 1mdbc
921 # Number of scan lines per FOV, i.e. number of Y pixels / image height
922 n_lines = np.array([x['nXnYnZ'][1] for x in FOVs]) 1mdbc
924 # We get indices from MATLAB extracted metadata so below two lines are no longer needed
925 # n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks
926 # n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) # N lines during flyback
927 line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod'] 1mdbc
928 frame_time_shifts = slice_ids / raw_imaging_meta['scanImageParams']['hRoiManager']['scanFrameRate'] 1mdbc
930 # Line indices are now extracted by the MATLAB function mesoscopeMetadataExtraction.m
931 # They are indexed from 1 so we subtract 1 to convert to zero-indexed
932 line_indices = [np.array(fov['lineIdx']) - 1 for fov in FOVs] # Convert to zero-indexed from MATLAB 1-indexed 1mdbc
933 assert all(lns.size == n for lns, n in zip(line_indices, n_lines)), 'unexpected number of scan lines' 1mdbc
934 # The start indices of each FOV in the raw images
935 fov_start_idx = np.array([lns[0] for lns in line_indices]) 1mdbc
936 roi_time_shifts = fov_start_idx * line_period # The time offset for each FOV 1mdbc
937 fov_time_shifts = roi_time_shifts + frame_time_shifts 1mdbc
938 line_time_shifts = [(lns - ln0) * line_period for lns, ln0 in zip(line_indices, fov_start_idx)] 1mdbc
940 return line_indices, fov_time_shifts, line_time_shifts 1mdbc